Browse Source

Fix lm column clobbering in calculate_interaction_scores

Bryan Roessler 6 tháng trước cách đây
mục cha
commit
bee9aea866

+ 166 - 123
qhtcp-workflow/apps/r/calculate_interaction_zscores.R

@@ -300,37 +300,37 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
     ungroup() %>%  # Ungroup before group_modify
     ungroup() %>%  # Ungroup before group_modify
     group_by(across(all_of(group_vars))) %>%
     group_by(across(all_of(group_vars))) %>%
     group_modify(~ {
     group_modify(~ {
-      # Perform linear models only if there are enough unique conc_num_factor levels
+      # Check if there are enough unique conc_num_factor levels to perform lm
       if (length(unique(.x$conc_num_factor)) > 1) {
       if (length(unique(.x$conc_num_factor)) > 1) {
+        
+        # Perform linear modeling
+        lm_L <- lm(Delta_L ~ conc_num_factor, data = .x)
+        lm_K <- lm(Delta_K ~ conc_num_factor, data = .x)
+        lm_r <- lm(Delta_r ~ conc_num_factor, data = .x)
+        lm_AUC <- lm(Delta_AUC ~ conc_num_factor, data = .x)
 
 
-        # Filter and calculate each lm() separately with individual checks for NAs
-        lm_L <- if (!all(is.na(.x$Delta_L))) tryCatch(lm(Delta_L ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-        lm_K <- if (!all(is.na(.x$Delta_K))) tryCatch(lm(Delta_K ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-        lm_r <- if (!all(is.na(.x$Delta_r))) tryCatch(lm(Delta_r ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-        lm_AUC <- if (!all(is.na(.x$Delta_AUC))) tryCatch(lm(Delta_AUC ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-
-        # Mutate results for each lm if it was successfully calculated, suppress warnings for perfect fits
+        # If the model fails, set model-related values to NA
         .x %>%
         .x %>%
           mutate(
           mutate(
-            lm_intercept_L = if (!is.null(lm_L)) coef(lm_L)[1] else NA,
-            lm_slope_L = if (!is.null(lm_L)) coef(lm_L)[2] else NA,
-            R_Squared_L = if (!is.null(lm_L)) suppressWarnings(summary(lm_L)$r.squared) else NA,
-            lm_Score_L = if (!is.null(lm_L)) max_conc * coef(lm_L)[2] + coef(lm_L)[1] else NA,
-
-            lm_intercept_K = if (!is.null(lm_K)) coef(lm_K)[1] else NA,
-            lm_slope_K = if (!is.null(lm_K)) coef(lm_K)[2] else NA,
-            R_Squared_K = if (!is.null(lm_K)) suppressWarnings(summary(lm_K)$r.squared) else NA,
-            lm_Score_K = if (!is.null(lm_K)) max_conc * coef(lm_K)[2] + coef(lm_K)[1] else NA,
-
-            lm_intercept_r = if (!is.null(lm_r)) coef(lm_r)[1] else NA,
-            lm_slope_r = if (!is.null(lm_r)) coef(lm_r)[2] else NA,
-            R_Squared_r = if (!is.null(lm_r)) suppressWarnings(summary(lm_r)$r.squared) else NA,
-            lm_Score_r = if (!is.null(lm_r)) max_conc * coef(lm_r)[2] + coef(lm_r)[1] else NA,
-
-            lm_intercept_AUC = if (!is.null(lm_AUC)) coef(lm_AUC)[1] else NA,
-            lm_slope_AUC = if (!is.null(lm_AUC)) coef(lm_AUC)[2] else NA,
-            R_Squared_AUC = if (!is.null(lm_AUC)) suppressWarnings(summary(lm_AUC)$r.squared) else NA,
-            lm_Score_AUC = if (!is.null(lm_AUC)) max_conc * coef(lm_AUC)[2] + coef(lm_AUC)[1] else NA
+            lm_intercept_L = ifelse(!is.null(lm_L), coef(lm_L)[1], NA),
+            lm_slope_L = ifelse(!is.null(lm_L), coef(lm_L)[2], NA),
+            R_Squared_L = ifelse(!is.null(lm_L), summary(lm_L)$r.squared, NA),
+            lm_Score_L = ifelse(!is.null(lm_L), max_conc * coef(lm_L)[2] + coef(lm_L)[1], NA),
+            
+            lm_intercept_K = ifelse(!is.null(lm_K), coef(lm_K)[1], NA),
+            lm_slope_K = ifelse(!is.null(lm_K), coef(lm_K)[2], NA),
+            R_Squared_K = ifelse(!is.null(lm_K), summary(lm_K)$r.squared, NA),
+            lm_Score_K = ifelse(!is.null(lm_K), max_conc * coef(lm_K)[2] + coef(lm_K)[1], NA),
+            
+            lm_intercept_r = ifelse(!is.null(lm_r), coef(lm_r)[1], NA),
+            lm_slope_r = ifelse(!is.null(lm_r), coef(lm_r)[2], NA),
+            R_Squared_r = ifelse(!is.null(lm_r), summary(lm_r)$r.squared, NA),
+            lm_Score_r = ifelse(!is.null(lm_r), max_conc * coef(lm_r)[2] + coef(lm_r)[1], NA),
+            
+            lm_intercept_AUC = ifelse(!is.null(lm_AUC), coef(lm_AUC)[1], NA),
+            lm_slope_AUC = ifelse(!is.null(lm_AUC), coef(lm_AUC)[2], NA),
+            R_Squared_AUC = ifelse(!is.null(lm_AUC), summary(lm_AUC)$r.squared, NA),
+            lm_Score_AUC = ifelse(!is.null(lm_AUC), max_conc * coef(lm_AUC)[2] + coef(lm_AUC)[1], NA)
           )
           )
       } else {
       } else {
         # If not enough conc_num_factor levels, set lm-related values to NA
         # If not enough conc_num_factor levels, set lm-related values to NA
@@ -345,6 +345,7 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
     }) %>%
     }) %>%
     ungroup()
     ungroup()
 
 
+
   # For interaction plot error bars
   # For interaction plot error bars
   delta_means_sds <- calculations %>%
   delta_means_sds <- calculations %>%
     group_by(across(all_of(group_vars))) %>%
     group_by(across(all_of(group_vars))) %>%
@@ -452,14 +453,14 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
       ),
       ),
       
       
       # For correlation plots
       # For correlation plots
-      lm_R_squared_L = if (!all(is.na(Z_lm_L)) && !all(is.na(Avg_Zscore_L))) summary(lm(Z_lm_L ~ Avg_Zscore_L))$r.squared else NA,
-      lm_R_squared_K = if (!all(is.na(Z_lm_K)) && !all(is.na(Avg_Zscore_K))) summary(lm(Z_lm_K ~ Avg_Zscore_K))$r.squared else NA,
-      lm_R_squared_r = if (!all(is.na(Z_lm_r)) && !all(is.na(Avg_Zscore_r))) summary(lm(Z_lm_r ~ Avg_Zscore_r))$r.squared else NA,
-      lm_R_squared_AUC = if (!all(is.na(Z_lm_AUC)) && !all(is.na(Avg_Zscore_AUC))) summary(lm(Z_lm_AUC ~ Avg_Zscore_AUC))$r.squared else NA
+      lm_R_squared_L = summary(lm(Z_lm_L ~ Avg_Zscore_L))$r.squared,
+      lm_R_squared_K = summary(lm(Z_lm_K ~ Avg_Zscore_K))$r.squared,
+      lm_R_squared_r = summary(lm(Z_lm_r ~ Avg_Zscore_r))$r.squared,
+      lm_R_squared_AUC = summary(lm(Z_lm_AUC ~ Avg_Zscore_AUC))$r.squared
     )
     )
 
 
   # Creating the final calculations and interactions dataframes with only required columns for csv output
   # Creating the final calculations and interactions dataframes with only required columns for csv output
-  calculations_df <- calculations %>%
+  df_calculations <- calculations %>%
     select(
     select(
       all_of(group_vars),
       all_of(group_vars),
       conc_num, conc_num_factor, conc_num_factor_factor, N,
       conc_num, conc_num_factor, conc_num_factor_factor, N,
@@ -477,7 +478,7 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
       Zscore_L, Zscore_K, Zscore_r, Zscore_AUC
       Zscore_L, Zscore_K, Zscore_r, Zscore_AUC
     )
     )
 
 
-  interactions_df <- interactions %>%
+  df_interactions <- interactions %>%
     select(
     select(
       all_of(group_vars),
       all_of(group_vars),
       NG, DB, SM,
       NG, DB, SM,
@@ -486,7 +487,8 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
       Raw_Shift_L, Raw_Shift_K, Raw_Shift_r, Raw_Shift_AUC,
       Raw_Shift_L, Raw_Shift_K, Raw_Shift_r, Raw_Shift_AUC,
       Z_Shift_L, Z_Shift_K, Z_Shift_r, Z_Shift_AUC,
       Z_Shift_L, Z_Shift_K, Z_Shift_r, Z_Shift_AUC,
       lm_R_squared_L, lm_R_squared_K, lm_R_squared_r, lm_R_squared_AUC,
       lm_R_squared_L, lm_R_squared_K, lm_R_squared_r, lm_R_squared_AUC,
-      Overlap
+      lm_intercept_L, lm_intercept_K, lm_intercept_r, lm_intercept_AUC,
+      lm_slope_L, lm_slope_K, lm_slope_r, lm_slope_AUC, Overlap
     )
     )
 
 
   # Join calculations and interactions to avoid dimension mismatch
   # Join calculations and interactions to avoid dimension mismatch
@@ -494,15 +496,19 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
     select(-any_of(c("DB", "NG", "SM",
     select(-any_of(c("DB", "NG", "SM",
       "Raw_Shift_L", "Raw_Shift_K", "Raw_Shift_r", "Raw_Shift_AUC",
       "Raw_Shift_L", "Raw_Shift_K", "Raw_Shift_r", "Raw_Shift_AUC",
       "Z_Shift_L", "Z_Shift_K", "Z_Shift_r", "Z_Shift_AUC",
       "Z_Shift_L", "Z_Shift_K", "Z_Shift_r", "Z_Shift_AUC",
-      "Z_lm_L", "Z_lm_K", "Z_lm_r", "Z_lm_AUC")))
+      "Z_lm_L", "Z_lm_K", "Z_lm_r", "Z_lm_AUC",
+      "lm_R_squared_L", "lm_R_squared_K", "lm_R_squared_r", "lm_R_squared_AUC",
+      "lm_intercept_L", "lm_intercept_K", "lm_intercept_r", "lm_intercept_AUC",
+      "lm_slope_L", "lm_slope_K", "lm_slope_r", "lm_slope_AUC"
+    )))
 
 
   full_data <- calculations_no_overlap %>%
   full_data <- calculations_no_overlap %>%
-    left_join(interactions_df, by = group_vars)
+    left_join(df_interactions, by = group_vars)
 
 
   # Return final dataframes
   # Return final dataframes
   return(list(
   return(list(
-    calculations = calculations_df,
-    interactions = interactions_df,
+    calculations = df_calculations,
+    interactions = df_interactions,
     full_data = full_data
     full_data = full_data
   ))
   ))
 }
 }
@@ -535,7 +541,7 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
 
 
       # Filter points outside of y-limits if specified
       # Filter points outside of y-limits if specified
       if (!is.null(config$ylim_vals)) {
       if (!is.null(config$ylim_vals)) {
-        out_of_bounds_df <- df %>%
+        out_of_bounds <- df %>%
           filter(
           filter(
             is.na(.data[[config$y_var]]) |
             is.na(.data[[config$y_var]]) |
             .data[[config$y_var]] < config$ylim_vals[1] |
             .data[[config$y_var]] < config$ylim_vals[1] |
@@ -543,10 +549,10 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
           )
           )
         
         
         # Print rows being filtered out
         # Print rows being filtered out
-        if (nrow(out_of_bounds_df) > 0) {
-          message("Filtered: ", config$title, " using y-limits: [", config$ylim_vals[1], ", ", config$ylim_vals[2], "]")
-          message("# of filtered rows outside y-limits (for plotting): ", nrow(out_of_bounds_df))
-          print(out_of_bounds_df)
+        if (nrow(out_of_bounds) > 0) {
+          message("Filtered ", nrow(out_of_bounds), " row(s) from '", config$title, "' because ", config$y_var,
+            " is outside of y-limits: [", config$ylim_vals[1], ", ", config$ylim_vals[2], "]:")
+          print(out_of_bounds %>% select(OrfRep, Gene, num, Drug, scan, Plate, Row, Col, conc_num, all_of(config$y_var)), width = 1000)
         }
         }
 
 
         df <- df %>%
         df <- df %>%
@@ -558,9 +564,8 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
       }
       }
 
 
       # Filter NAs if specified
       # Filter NAs if specified
-      if (!is.null(config$na_rm) && config$na_rm) {
+      if (!is.null(config$filter_na) && config$filter_na) {
         df <- df %>%
         df <- df %>%
-          filter(!is.na(.data[[config$x_var]])) %>%
           filter(!is.na(.data[[config$y_var]]))
           filter(!is.na(.data[[config$y_var]]))
       }
       }
 
 
@@ -648,20 +653,18 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
 
 
           plot <- plot + geom_errorbar(
           plot <- plot + geom_errorbar(
             aes(
             aes(
-              x = .data[[config$x_var]],
               ymin = !!custom_ymin_expr,
               ymin = !!custom_ymin_expr,
               ymax = !!custom_ymax_expr
               ymax = !!custom_ymax_expr
             ),
             ),
             color = config$error_bar_params$color,
             color = config$error_bar_params$color,
             linewidth = ifelse(is.null(config$error_bar_params$linewidth), 0.1, config$error_bar_params$linewidth)
             linewidth = ifelse(is.null(config$error_bar_params$linewidth), 0.1, config$error_bar_params$linewidth)
           )
           )
-
         } else {
         } else {
           # If no custom error bar formula, use the default or dynamic ones
           # If no custom error bar formula, use the default or dynamic ones
-          if (!is.null(config$color_var) && is.null(config$error_bar_params$color)) {
+          if (!is.null(config$color_var) && config$color_var %in% colnames(config$df)) {
+            # Only use color_var if it's present in the dataframe
             plot <- plot + geom_errorbar(
             plot <- plot + geom_errorbar(
               aes(
               aes(
-                x = .data[[config$x_var]],
                 ymin = .data[[y_mean_col]] - .data[[y_sd_col]],
                 ymin = .data[[y_mean_col]] - .data[[y_sd_col]],
                 ymax = .data[[y_mean_col]] + .data[[y_sd_col]],
                 ymax = .data[[y_mean_col]] + .data[[y_sd_col]],
                 color = .data[[config$color_var]]
                 color = .data[[config$color_var]]
@@ -669,13 +672,13 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
               linewidth = 0.1
               linewidth = 0.1
             )
             )
           } else {
           } else {
+            # If color_var is missing, fall back to a default color or none
             plot <- plot + geom_errorbar(
             plot <- plot + geom_errorbar(
               aes(
               aes(
-                x = .data[[config$x_var]],
                 ymin = .data[[y_mean_col]] - .data[[y_sd_col]],
                 ymin = .data[[y_mean_col]] - .data[[y_sd_col]],
                 ymax = .data[[y_mean_col]] + .data[[y_sd_col]]
                 ymax = .data[[y_mean_col]] + .data[[y_sd_col]]
               ),
               ),
-              color = config$error_bar_params$color,
+              color = config$error_bar_params$color, # use the provided color or default
               linewidth = ifelse(is.null(config$error_bar_params$linewidth), 0.1, config$error_bar_params$linewidth)
               linewidth = ifelse(is.null(config$error_bar_params$linewidth), 0.1, config$error_bar_params$linewidth)
             )
             )
           }
           }
@@ -683,23 +686,18 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
 
 
         # Add the center point if the option is provided
         # Add the center point if the option is provided
         if (!is.null(config$error_bar_params$mean_point) && config$error_bar_params$mean_point) {
         if (!is.null(config$error_bar_params$mean_point) && config$error_bar_params$mean_point) {
-          if (!is.null(config$color_var) && is.null(config$error_bar_params$color)) {
+          if (!is.null(config$error_bar_params$color)) {
             plot <- plot + geom_point(
             plot <- plot + geom_point(
-              aes(
-                x = .data[[config$x_var]],
-                y = .data[[y_mean_col]],
-                color = .data[[config$color_var]]
-              ),
-              shape = 16
+              mapping = aes(x = .data[[config$x_var]], y = .data[[y_mean_col]]),  # Include both x and y mappings
+              color = config$error_bar_params$color,
+              shape = 16,
+              inherit.aes = FALSE  # Prevent overriding global aesthetics
             )
             )
           } else {
           } else {
             plot <- plot + geom_point(
             plot <- plot + geom_point(
-              aes(
-                x = .data[[config$x_var]],
-                y = .data[[y_mean_col]]
-              ),
-              color = config$error_bar_params$color,
-              shape = 16
+              mapping = aes(x = .data[[config$x_var]], y = .data[[y_mean_col]]),  # Include both x and y mappings
+              shape = 16,
+              inherit.aes = FALSE  # Prevent overriding global aesthetics
             )
             )
           }
           }
         }
         }
@@ -779,33 +777,72 @@ generate_scatter_plot <- function(plot, config) {
     position = position
     position = position
   )
   )
 
 
+  # Add a cyan point for the reference data for correlation plots
   if (!is.null(config$cyan_points) && config$cyan_points) {
   if (!is.null(config$cyan_points) && config$cyan_points) {
     plot <- plot + geom_point(
     plot <- plot + geom_point(
-      aes(x = .data[[config$x_var]], y = .data[[config$y_var]]),
+      data = config$df_reference,
+      mapping = aes(x = .data[[config$x_var]], y = .data[[config$y_var]]),
       color = "cyan",
       color = "cyan",
       shape = 3,
       shape = 3,
-      size = 0.5
+      size = 0.5,
+      inherit.aes = FALSE
     )
     )
   }
   }
-
-  if (!is.null(config$gray_points) && config$gray_points) {
-    plot <- plot + geom_point(shape = 3, color = "gray70", size = 1)
-  }
   
   
   # Add linear regression line if specified
   # Add linear regression line if specified
   if (!is.null(config$lm_line)) {
   if (!is.null(config$lm_line)) {
-    plot <- plot +
-      annotate(
-        "segment",
-        x = config$lm_line$x_min,
-        xend = config$lm_line$x_max,
-        y = config$lm_line$intercept + config$lm_line$slope * config$lm_line$x_min,  # Calculate y for x_min
-        yend = config$lm_line$intercept + config$lm_line$slope * config$lm_line$x_max, # Calculate y for x_max
-        color = ifelse(!is.null(config$lm_line$color), config$lm_line$color, "blue"),
-        linewidth = ifelse(!is.null(config$lm_line$linewidth), config$lm_line$linewidth, 1)
-      )
+    # Extract necessary values
+    x_min <- config$lm_line$x_min
+    x_max <- config$lm_line$x_max
+    intercept <- config$lm_line$intercept
+    slope <- config$lm_line$slope
+    color <- ifelse(!is.null(config$lm_line$color), config$lm_line$color, "blue")
+    linewidth <- ifelse(!is.null(config$lm_line$linewidth), config$lm_line$linewidth, 1)
+
+    # Ensure none of the values are NA and calculate y-values
+    if (!is.na(x_min) && !is.na(x_max) && !is.na(intercept) && !is.na(slope)) {
+      y_min <- intercept + slope * x_min
+      y_max <- intercept + slope * x_max
+      
+      # Ensure y-values are within y-limits (if any)
+      if (!is.null(config$ylim_vals)) {
+        y_min_within_limits <- y_min >= config$ylim_vals[1] && y_min <= config$ylim_vals[2]
+        y_max_within_limits <- y_max >= config$ylim_vals[1] && y_max <= config$ylim_vals[2]
+
+        # Adjust or skip based on whether the values fall within limits
+        if (y_min_within_limits && y_max_within_limits) {
+          # Ensure x-values are also valid
+          if (!is.na(x_min) && !is.na(x_max)) {
+            plot <- plot + annotate(
+              "segment",
+              x = x_min,
+              xend = x_max,
+              y = y_min,
+              yend = y_max,
+              color = color,
+              linewidth = linewidth
+            )
+          }
+        } else {
+          message("Skipping linear modeling line due to y-values outside of limits.")
+        }
+      } else {
+        # If no y-limits are provided, proceed with the annotation
+        plot <- plot + annotate(
+          "segment",
+          x = x_min,
+          xend = x_max,
+          y = y_min,
+          yend = y_max,
+          color = color,
+          linewidth = linewidth
+        )
+      }
+    } else {
+      message("Skipping linear modeling line due to missing or invalid values.")
+    }
   }
   }
-  
+
   # Add SD Bands if specified
   # Add SD Bands if specified
   if (!is.null(config$sd_band)) {
   if (!is.null(config$sd_band)) {
     plot <- plot +
     plot <- plot +
@@ -829,7 +866,7 @@ generate_scatter_plot <- function(plot, config) {
       )
       )
   }
   }
 
 
-  # Add Rectangles if specified
+  # Add rectangles if specified
   if (!is.null(config$rectangles)) {
   if (!is.null(config$rectangles)) {
     for (rect in config$rectangles) {
     for (rect in config$rectangles) {
       plot <- plot + annotate(
       plot <- plot + annotate(
@@ -909,11 +946,11 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
       df_plot <- if (stage == "before") df_before else df_after
       df_plot <- if (stage == "before") df_before else df_after
 
 
       # Check for non-finite values in the y-variable
       # Check for non-finite values in the y-variable
-      df_plot_filtered <- df_plot %>% filter(is.finite(!!sym(var)))
+      # df_plot_filtered <- df_plot %>% filter(is.finite(.data[[var]]))
 
 
       # Adjust settings based on plot_type
       # Adjust settings based on plot_type
       plot_config <- list(
       plot_config <- list(
-        df = df_plot_filtered,
+        df = df_plot,
         x_var = "scan",
         x_var = "scan",
         y_var = var,
         y_var = var,
         plot_type = plot_type,
         plot_type = plot_type,
@@ -921,7 +958,8 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
         color_var = "conc_num_factor_factor",
         color_var = "conc_num_factor_factor",
         size = 0.2,
         size = 0.2,
         error_bar = (plot_type == "scatter"),
         error_bar = (plot_type == "scatter"),
-        legend_position = "bottom"
+        legend_position = "bottom",
+        filter_na = TRUE
       )
       )
 
 
       # Add config to plots list
       # Add config to plots list
@@ -1086,7 +1124,7 @@ generate_interaction_plot_configs <- function(df_summary, df_interactions, type)
         x_breaks = unique(group_data$conc_num_factor_factor),
         x_breaks = unique(group_data$conc_num_factor_factor),
         x_labels = as.character(unique(group_data$conc_num)),
         x_labels = as.character(unique(group_data$conc_num)),
         ylim_vals = y_limits,
         ylim_vals = y_limits,
-        y_filter = FALSE,
+        filter_na = TRUE,
         lm_line = list(
         lm_line = list(
           intercept = lm_intercept_value,
           intercept = lm_intercept_value,
           slope = lm_slope_value,
           slope = lm_slope_value,
@@ -1111,7 +1149,7 @@ generate_interaction_plot_configs <- function(df_summary, df_interactions, type)
   ))
   ))
 }
 }
 
 
-generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm = FALSE, overlap_color = FALSE) {
+generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, filter_na = FALSE, overlap_color = FALSE) {
   sd_bands <- c(1, 2, 3)
   sd_bands <- c(1, 2, 3)
   plot_configs <- list()
   plot_configs <- list()
   
   
@@ -1128,7 +1166,7 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
   }
   }
 
 
   # Helper function to create a rank plot configuration
   # Helper function to create a rank plot configuration
-  create_plot_config <- function(variable, rank_var, zscore_var, y_label, sd_band, na_rm, with_annotations = TRUE) {
+  create_plot_config <- function(variable, rank_var, zscore_var, y_label, sd_band, filter_na, with_annotations = TRUE) {
     num_enhancers <- sum(df[[zscore_var]] >= sd_band, na.rm = TRUE)
     num_enhancers <- sum(df[[zscore_var]] >= sd_band, na.rm = TRUE)
     num_suppressors <- sum(df[[zscore_var]] <= -sd_band, na.rm = TRUE)
     num_suppressors <- sum(df[[zscore_var]] <= -sd_band, na.rm = TRUE)
 
 
@@ -1148,7 +1186,7 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
       alpha_negative = 0.3,
       alpha_negative = 0.3,
       shape = 3,
       shape = 3,
       size = 0.1,
       size = 0.1,
-      na_rm = na_rm,
+      filter_na = filter_na,
       legend_position = "none"
       legend_position = "none"
     )
     )
     
     
@@ -1181,11 +1219,11 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
     for (sd_band in sd_bands) {
     for (sd_band in sd_bands) {
       # Create plot with annotations
       # Create plot with annotations
       plot_configs[[length(plot_configs) + 1]] <-
       plot_configs[[length(plot_configs) + 1]] <-
-        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, na_rm, with_annotations = TRUE)
+        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, filter_na, with_annotations = TRUE)
       
       
       # Create plot without annotations
       # Create plot without annotations
       plot_configs[[length(plot_configs) + 1]] <-
       plot_configs[[length(plot_configs) + 1]] <-
-        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, na_rm, with_annotations = FALSE)
+        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, filter_na, with_annotations = FALSE)
     }
     }
   }
   }
 
 
@@ -1198,7 +1236,7 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
   ))
   ))
 }
 }
 
 
-generate_correlation_plot_configs <- function(df) {
+generate_correlation_plot_configs <- function(df, df_reference) {
   # Define relationships for different-variable correlations
   # Define relationships for different-variable correlations
   relationships <- list(
   relationships <- list(
     list(x = "L", y = "K"),
     list(x = "L", y = "K"),
@@ -1209,6 +1247,10 @@ generate_correlation_plot_configs <- function(df) {
     list(x = "r", y = "AUC")
     list(x = "r", y = "AUC")
   )
   )
 
 
+  # This filtering was in the original script
+  # df_reference <- df_reference %>%
+  #   filter(!is.na(Z_lm_L))
+
   plot_configs <- list()
   plot_configs <- list()
 
 
   # Iterate over the option to highlight cyan points (TRUE/FALSE)
   # Iterate over the option to highlight cyan points (TRUE/FALSE)
@@ -1221,10 +1263,10 @@ generate_correlation_plot_configs <- function(df) {
       y_var <- paste0("Z_lm_", rel$y)
       y_var <- paste0("Z_lm_", rel$y)
 
 
       # Extract the R-squared, intercept, and slope from the df
       # Extract the R-squared, intercept, and slope from the df
-      relationship_name <- paste0(rel$x, "_vs_", rel$y)  # Example: L_vs_K
-      intercept <- mean(df[[paste0("lm_intercept_", rel$x)]], na.rm = TRUE)
-      slope <- mean(df[[paste0("lm_slope_", rel$x)]], na.rm = TRUE)
-      r_squared <- mean(df[[paste0("lm_R_squared_", rel$x)]], na.rm = TRUE)
+      relationship_name <- paste0(rel$x, "_vs_", rel$y)
+      intercept <- df[[paste0("lm_intercept_", rel$x)]]
+      slope <- df[[paste0("lm_slope_", rel$x)]]
+      r_squared <- df[[paste0("lm_R_squared_", rel$x)]]
 
 
       # Generate the label for the plot
       # Generate the label for the plot
       plot_label <- paste("Interaction", rel$x, "vs.", rel$y)
       plot_label <- paste("Interaction", rel$x, "vs.", rel$y)
@@ -1232,6 +1274,7 @@ generate_correlation_plot_configs <- function(df) {
       # Construct plot config
       # Construct plot config
       plot_config <- list(
       plot_config <- list(
         df = df,
         df = df,
+        df_reference = df_reference,
         x_var = x_var,
         x_var = x_var,
         y_var = y_var,
         y_var = y_var,
         plot_type = "scatter",
         plot_type = "scatter",
@@ -1248,11 +1291,9 @@ generate_correlation_plot_configs <- function(df) {
           slope = slope,
           slope = slope,
           color = "tomato3"
           color = "tomato3"
         ),
         ),
-        shape = 3,
-        size = 0.5,
-        color_var = "Overlap",
-        cyan_points = highlight_cyan, # include cyan points or not based on the loop
-        gray_points = TRUE
+        color = "gray70",
+        filter_na = TRUE,
+        cyan_points = highlight_cyan # include cyan points or not based on the loop
       )
       )
 
 
       plot_configs <- append(plot_configs, list(plot_config))
       plot_configs <- append(plot_configs, list(plot_config))
@@ -1434,7 +1475,7 @@ main <- function() {
           x_var = "L",
           x_var = "L",
           y_var = "K",
           y_var = "K",
           plot_type = "scatter",
           plot_type = "scatter",
-          title = "Raw L vs K for strains falling outside 2SD of the K mean at each Conc",
+          title = "Raw L vs K for strains falling outside 2 SD of the K mean at each Conc",
           color_var = "conc_num_factor_factor",
           color_var = "conc_num_factor_factor",
           position = "jitter",
           position = "jitter",
           tooltip_vars = c("OrfRep", "Gene", "delta_bg"),
           tooltip_vars = c("OrfRep", "Gene", "delta_bg"),
@@ -1459,7 +1500,7 @@ main <- function() {
           x_label = "Delta Background",
           x_label = "Delta Background",
           y_var = "K",
           y_var = "K",
           plot_type = "scatter",
           plot_type = "scatter",
-          title = "Delta Background vs K for strains falling outside 2SD of the K mean at each Conc",
+          title = "Delta Background vs K for strains falling outside 2 SD of K",
           color_var = "conc_num_factor_factor",
           color_var = "conc_num_factor_factor",
           position = "jitter",
           position = "jitter",
           tooltip_vars = c("OrfRep", "Gene", "delta_bg"),
           tooltip_vars = c("OrfRep", "Gene", "delta_bg"),
@@ -1573,22 +1614,23 @@ main <- function() {
           .groups = "drop"
           .groups = "drop"
         )
         )
 
 
-      # message("Calculating reference strain interaction summary statistics") # formerly X_stats_interaction
-      # df_reference_interaction_stats <- calculate_summary_stats(
-      #   df = df_reference,
-      #   variables = c("L", "K", "r", "AUC"),
-      #   group_vars = c("OrfRep", "Gene", "num", "Drug", "conc_num", "conc_num_factor_factor")
-      #   )$df_with_stats
+      message("Calculating reference strain interaction summary statistics") # formerly X_stats_interaction
+      df_reference_interaction_stats <- calculate_summary_stats(
+        df = df_reference,
+        variables = c("L", "K", "r", "AUC"),
+        group_vars = c("OrfRep", "Gene", "num", "Drug", "conc_num", "conc_num_factor_factor")
+        )$df_with_stats
       
       
-      # # message("Calculating reference strain interaction scores")
-      # reference_results <- calculate_interaction_scores(df_reference_interaction_stats, df_bg_stats, "reference")
-      # df_reference_interactions_joined <- reference_results$full_data
-      # write.csv(reference_results$calculations, file = file.path(out_dir, "zscore_calculations_reference.csv"), row.names = FALSE)
-      # write.csv(reference_results$interactions, file = file.path(out_dir, "zscore_interactions_reference.csv"), row.names = FALSE)
+      message("Calculating reference strain interaction scores")
+      reference_results <- calculate_interaction_scores(df_reference_interaction_stats, df_bg_stats, "reference")
+      df_reference_interactions_joined <- reference_results$full_data
+      df_reference_interactions <- reference_results$interactions
+      write.csv(reference_results$calculations, file = file.path(out_dir, "zscore_calculations_reference.csv"), row.names = FALSE)
+      write.csv(df_reference_interactions, file = file.path(out_dir, "zscore_interactions_reference.csv"), row.names = FALSE)
 
 
-      # # message("Generating reference interaction plots")
-      # reference_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_reference_interactions_joined, "reference")
-      # generate_and_save_plots(out_dir, "interaction_plots_reference", reference_plot_configs, page_width = 16, page_height = 16)
+      message("Generating reference interaction plots")
+      reference_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_reference_interactions_joined, "reference")
+      generate_and_save_plots(out_dir, "interaction_plots_reference", reference_plot_configs, page_width = 16, page_height = 16)
 
 
       message("Setting missing deletion values to the highest theoretical value at each drug conc for L")
       message("Setting missing deletion values to the highest theoretical value at each drug conc for L")
       df_deletion <- df_na_stats %>% # formerly X2
       df_deletion <- df_na_stats %>% # formerly X2
@@ -1616,9 +1658,9 @@ main <- function() {
       write.csv(deletion_results$calculations, file = file.path(out_dir, "zscore_calculations.csv"), row.names = FALSE)
       write.csv(deletion_results$calculations, file = file.path(out_dir, "zscore_calculations.csv"), row.names = FALSE)
       write.csv(df_interactions, file = file.path(out_dir, "zscore_interactions.csv"), row.names = FALSE)
       write.csv(df_interactions, file = file.path(out_dir, "zscore_interactions.csv"), row.names = FALSE)
 
 
-      # message("Generating deletion interaction plots")
-      # deletion_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_interactions_joined, "deletion")
-      # generate_and_save_plots(out_dir, "interaction_plots", deletion_plot_configs, page_width = 16, page_height = 16)
+      message("Generating deletion interaction plots")
+      deletion_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_interactions_joined, "deletion")
+      generate_and_save_plots(out_dir, "interaction_plots", deletion_plot_configs, page_width = 16, page_height = 16)
 
 
       message("Writing enhancer/suppressor csv files")
       message("Writing enhancer/suppressor csv files")
       interaction_threshold <- 2  # TODO add to study config?
       interaction_threshold <- 2  # TODO add to study config?
@@ -1675,7 +1717,7 @@ main <- function() {
         df_interactions,
         df_interactions,
         is_lm = FALSE,
         is_lm = FALSE,
         adjust = FALSE,
         adjust = FALSE,
-        na_rm = TRUE,
+        filter_na = TRUE,
         overlap_color = TRUE
         overlap_color = TRUE
       )
       )
       generate_and_save_plots(out_dir, "rank_plots_na_rm", rank_plot_filtered_configs,
       generate_and_save_plots(out_dir, "rank_plots_na_rm", rank_plot_filtered_configs,
@@ -1686,7 +1728,7 @@ main <- function() {
         df_interactions,
         df_interactions,
         is_lm = TRUE,
         is_lm = TRUE,
         adjust = FALSE,
         adjust = FALSE,
-        na_rm = TRUE,
+        filter_na = TRUE,
         overlap_color = TRUE
         overlap_color = TRUE
       )
       )
       generate_and_save_plots(out_dir, "rank_plots_lm_na_rm", rank_plot_lm_filtered_configs,
       generate_and_save_plots(out_dir, "rank_plots_lm_na_rm", rank_plot_lm_filtered_configs,
@@ -1694,7 +1736,8 @@ main <- function() {
 
 
       message("Generating correlation curve parameter pair plots")
       message("Generating correlation curve parameter pair plots")
       correlation_plot_configs <- generate_correlation_plot_configs(
       correlation_plot_configs <- generate_correlation_plot_configs(
-        df_interactions
+        df_interactions,
+        df_reference_interactions
       )
       )
       generate_and_save_plots(out_dir, "correlation_cpps", correlation_plot_configs,
       generate_and_save_plots(out_dir, "correlation_cpps", correlation_plot_configs,
         page_width = 10, page_height = 7)
         page_width = 10, page_height = 7)