瀏覽代碼

Fix lm column clobbering in calculate_interaction_scores

Bryan Roessler 6 月之前
父節點
當前提交
bee9aea866
共有 1 個文件被更改,包括 166 次插入123 次删除
  1. 166 123
      qhtcp-workflow/apps/r/calculate_interaction_zscores.R

+ 166 - 123
qhtcp-workflow/apps/r/calculate_interaction_zscores.R

@@ -300,37 +300,37 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
     ungroup() %>%  # Ungroup before group_modify
     group_by(across(all_of(group_vars))) %>%
     group_modify(~ {
-      # Perform linear models only if there are enough unique conc_num_factor levels
+      # Check if there are enough unique conc_num_factor levels to perform lm
       if (length(unique(.x$conc_num_factor)) > 1) {
+        
+        # Perform linear modeling
+        lm_L <- lm(Delta_L ~ conc_num_factor, data = .x)
+        lm_K <- lm(Delta_K ~ conc_num_factor, data = .x)
+        lm_r <- lm(Delta_r ~ conc_num_factor, data = .x)
+        lm_AUC <- lm(Delta_AUC ~ conc_num_factor, data = .x)
 
-        # Filter and calculate each lm() separately with individual checks for NAs
-        lm_L <- if (!all(is.na(.x$Delta_L))) tryCatch(lm(Delta_L ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-        lm_K <- if (!all(is.na(.x$Delta_K))) tryCatch(lm(Delta_K ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-        lm_r <- if (!all(is.na(.x$Delta_r))) tryCatch(lm(Delta_r ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-        lm_AUC <- if (!all(is.na(.x$Delta_AUC))) tryCatch(lm(Delta_AUC ~ conc_num_factor, data = .x), error = function(e) NULL) else NULL
-
-        # Mutate results for each lm if it was successfully calculated, suppress warnings for perfect fits
+        # If the model fails, set model-related values to NA
         .x %>%
           mutate(
-            lm_intercept_L = if (!is.null(lm_L)) coef(lm_L)[1] else NA,
-            lm_slope_L = if (!is.null(lm_L)) coef(lm_L)[2] else NA,
-            R_Squared_L = if (!is.null(lm_L)) suppressWarnings(summary(lm_L)$r.squared) else NA,
-            lm_Score_L = if (!is.null(lm_L)) max_conc * coef(lm_L)[2] + coef(lm_L)[1] else NA,
-
-            lm_intercept_K = if (!is.null(lm_K)) coef(lm_K)[1] else NA,
-            lm_slope_K = if (!is.null(lm_K)) coef(lm_K)[2] else NA,
-            R_Squared_K = if (!is.null(lm_K)) suppressWarnings(summary(lm_K)$r.squared) else NA,
-            lm_Score_K = if (!is.null(lm_K)) max_conc * coef(lm_K)[2] + coef(lm_K)[1] else NA,
-
-            lm_intercept_r = if (!is.null(lm_r)) coef(lm_r)[1] else NA,
-            lm_slope_r = if (!is.null(lm_r)) coef(lm_r)[2] else NA,
-            R_Squared_r = if (!is.null(lm_r)) suppressWarnings(summary(lm_r)$r.squared) else NA,
-            lm_Score_r = if (!is.null(lm_r)) max_conc * coef(lm_r)[2] + coef(lm_r)[1] else NA,
-
-            lm_intercept_AUC = if (!is.null(lm_AUC)) coef(lm_AUC)[1] else NA,
-            lm_slope_AUC = if (!is.null(lm_AUC)) coef(lm_AUC)[2] else NA,
-            R_Squared_AUC = if (!is.null(lm_AUC)) suppressWarnings(summary(lm_AUC)$r.squared) else NA,
-            lm_Score_AUC = if (!is.null(lm_AUC)) max_conc * coef(lm_AUC)[2] + coef(lm_AUC)[1] else NA
+            lm_intercept_L = ifelse(!is.null(lm_L), coef(lm_L)[1], NA),
+            lm_slope_L = ifelse(!is.null(lm_L), coef(lm_L)[2], NA),
+            R_Squared_L = ifelse(!is.null(lm_L), summary(lm_L)$r.squared, NA),
+            lm_Score_L = ifelse(!is.null(lm_L), max_conc * coef(lm_L)[2] + coef(lm_L)[1], NA),
+            
+            lm_intercept_K = ifelse(!is.null(lm_K), coef(lm_K)[1], NA),
+            lm_slope_K = ifelse(!is.null(lm_K), coef(lm_K)[2], NA),
+            R_Squared_K = ifelse(!is.null(lm_K), summary(lm_K)$r.squared, NA),
+            lm_Score_K = ifelse(!is.null(lm_K), max_conc * coef(lm_K)[2] + coef(lm_K)[1], NA),
+            
+            lm_intercept_r = ifelse(!is.null(lm_r), coef(lm_r)[1], NA),
+            lm_slope_r = ifelse(!is.null(lm_r), coef(lm_r)[2], NA),
+            R_Squared_r = ifelse(!is.null(lm_r), summary(lm_r)$r.squared, NA),
+            lm_Score_r = ifelse(!is.null(lm_r), max_conc * coef(lm_r)[2] + coef(lm_r)[1], NA),
+            
+            lm_intercept_AUC = ifelse(!is.null(lm_AUC), coef(lm_AUC)[1], NA),
+            lm_slope_AUC = ifelse(!is.null(lm_AUC), coef(lm_AUC)[2], NA),
+            R_Squared_AUC = ifelse(!is.null(lm_AUC), summary(lm_AUC)$r.squared, NA),
+            lm_Score_AUC = ifelse(!is.null(lm_AUC), max_conc * coef(lm_AUC)[2] + coef(lm_AUC)[1], NA)
           )
       } else {
         # If not enough conc_num_factor levels, set lm-related values to NA
@@ -345,6 +345,7 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
     }) %>%
     ungroup()
 
+
   # For interaction plot error bars
   delta_means_sds <- calculations %>%
     group_by(across(all_of(group_vars))) %>%
@@ -452,14 +453,14 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
       ),
       
       # For correlation plots
-      lm_R_squared_L = if (!all(is.na(Z_lm_L)) && !all(is.na(Avg_Zscore_L))) summary(lm(Z_lm_L ~ Avg_Zscore_L))$r.squared else NA,
-      lm_R_squared_K = if (!all(is.na(Z_lm_K)) && !all(is.na(Avg_Zscore_K))) summary(lm(Z_lm_K ~ Avg_Zscore_K))$r.squared else NA,
-      lm_R_squared_r = if (!all(is.na(Z_lm_r)) && !all(is.na(Avg_Zscore_r))) summary(lm(Z_lm_r ~ Avg_Zscore_r))$r.squared else NA,
-      lm_R_squared_AUC = if (!all(is.na(Z_lm_AUC)) && !all(is.na(Avg_Zscore_AUC))) summary(lm(Z_lm_AUC ~ Avg_Zscore_AUC))$r.squared else NA
+      lm_R_squared_L = summary(lm(Z_lm_L ~ Avg_Zscore_L))$r.squared,
+      lm_R_squared_K = summary(lm(Z_lm_K ~ Avg_Zscore_K))$r.squared,
+      lm_R_squared_r = summary(lm(Z_lm_r ~ Avg_Zscore_r))$r.squared,
+      lm_R_squared_AUC = summary(lm(Z_lm_AUC ~ Avg_Zscore_AUC))$r.squared
     )
 
   # Creating the final calculations and interactions dataframes with only required columns for csv output
-  calculations_df <- calculations %>%
+  df_calculations <- calculations %>%
     select(
       all_of(group_vars),
       conc_num, conc_num_factor, conc_num_factor_factor, N,
@@ -477,7 +478,7 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
       Zscore_L, Zscore_K, Zscore_r, Zscore_AUC
     )
 
-  interactions_df <- interactions %>%
+  df_interactions <- interactions %>%
     select(
       all_of(group_vars),
       NG, DB, SM,
@@ -486,7 +487,8 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
       Raw_Shift_L, Raw_Shift_K, Raw_Shift_r, Raw_Shift_AUC,
       Z_Shift_L, Z_Shift_K, Z_Shift_r, Z_Shift_AUC,
       lm_R_squared_L, lm_R_squared_K, lm_R_squared_r, lm_R_squared_AUC,
-      Overlap
+      lm_intercept_L, lm_intercept_K, lm_intercept_r, lm_intercept_AUC,
+      lm_slope_L, lm_slope_K, lm_slope_r, lm_slope_AUC, Overlap
     )
 
   # Join calculations and interactions to avoid dimension mismatch
@@ -494,15 +496,19 @@ calculate_interaction_scores <- function(df, df_bg, type, overlap_threshold = 2)
     select(-any_of(c("DB", "NG", "SM",
       "Raw_Shift_L", "Raw_Shift_K", "Raw_Shift_r", "Raw_Shift_AUC",
       "Z_Shift_L", "Z_Shift_K", "Z_Shift_r", "Z_Shift_AUC",
-      "Z_lm_L", "Z_lm_K", "Z_lm_r", "Z_lm_AUC")))
+      "Z_lm_L", "Z_lm_K", "Z_lm_r", "Z_lm_AUC",
+      "lm_R_squared_L", "lm_R_squared_K", "lm_R_squared_r", "lm_R_squared_AUC",
+      "lm_intercept_L", "lm_intercept_K", "lm_intercept_r", "lm_intercept_AUC",
+      "lm_slope_L", "lm_slope_K", "lm_slope_r", "lm_slope_AUC"
+    )))
 
   full_data <- calculations_no_overlap %>%
-    left_join(interactions_df, by = group_vars)
+    left_join(df_interactions, by = group_vars)
 
   # Return final dataframes
   return(list(
-    calculations = calculations_df,
-    interactions = interactions_df,
+    calculations = df_calculations,
+    interactions = df_interactions,
     full_data = full_data
   ))
 }
@@ -535,7 +541,7 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
 
       # Filter points outside of y-limits if specified
       if (!is.null(config$ylim_vals)) {
-        out_of_bounds_df <- df %>%
+        out_of_bounds <- df %>%
           filter(
             is.na(.data[[config$y_var]]) |
             .data[[config$y_var]] < config$ylim_vals[1] |
@@ -543,10 +549,10 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
           )
         
         # Print rows being filtered out
-        if (nrow(out_of_bounds_df) > 0) {
-          message("Filtered: ", config$title, " using y-limits: [", config$ylim_vals[1], ", ", config$ylim_vals[2], "]")
-          message("# of filtered rows outside y-limits (for plotting): ", nrow(out_of_bounds_df))
-          print(out_of_bounds_df)
+        if (nrow(out_of_bounds) > 0) {
+          message("Filtered ", nrow(out_of_bounds), " row(s) from '", config$title, "' because ", config$y_var,
+            " is outside of y-limits: [", config$ylim_vals[1], ", ", config$ylim_vals[2], "]:")
+          print(out_of_bounds %>% select(OrfRep, Gene, num, Drug, scan, Plate, Row, Col, conc_num, all_of(config$y_var)), width = 1000)
         }
 
         df <- df %>%
@@ -558,9 +564,8 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
       }
 
       # Filter NAs if specified
-      if (!is.null(config$na_rm) && config$na_rm) {
+      if (!is.null(config$filter_na) && config$filter_na) {
         df <- df %>%
-          filter(!is.na(.data[[config$x_var]])) %>%
           filter(!is.na(.data[[config$y_var]]))
       }
 
@@ -648,20 +653,18 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
 
           plot <- plot + geom_errorbar(
             aes(
-              x = .data[[config$x_var]],
               ymin = !!custom_ymin_expr,
               ymax = !!custom_ymax_expr
             ),
             color = config$error_bar_params$color,
             linewidth = ifelse(is.null(config$error_bar_params$linewidth), 0.1, config$error_bar_params$linewidth)
           )
-
         } else {
           # If no custom error bar formula, use the default or dynamic ones
-          if (!is.null(config$color_var) && is.null(config$error_bar_params$color)) {
+          if (!is.null(config$color_var) && config$color_var %in% colnames(config$df)) {
+            # Only use color_var if it's present in the dataframe
             plot <- plot + geom_errorbar(
               aes(
-                x = .data[[config$x_var]],
                 ymin = .data[[y_mean_col]] - .data[[y_sd_col]],
                 ymax = .data[[y_mean_col]] + .data[[y_sd_col]],
                 color = .data[[config$color_var]]
@@ -669,13 +672,13 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
               linewidth = 0.1
             )
           } else {
+            # If color_var is missing, fall back to a default color or none
             plot <- plot + geom_errorbar(
               aes(
-                x = .data[[config$x_var]],
                 ymin = .data[[y_mean_col]] - .data[[y_sd_col]],
                 ymax = .data[[y_mean_col]] + .data[[y_sd_col]]
               ),
-              color = config$error_bar_params$color,
+              color = config$error_bar_params$color, # use the provided color or default
               linewidth = ifelse(is.null(config$error_bar_params$linewidth), 0.1, config$error_bar_params$linewidth)
             )
           }
@@ -683,23 +686,18 @@ generate_and_save_plots <- function(out_dir, filename, plot_configs, page_width
 
         # Add the center point if the option is provided
         if (!is.null(config$error_bar_params$mean_point) && config$error_bar_params$mean_point) {
-          if (!is.null(config$color_var) && is.null(config$error_bar_params$color)) {
+          if (!is.null(config$error_bar_params$color)) {
             plot <- plot + geom_point(
-              aes(
-                x = .data[[config$x_var]],
-                y = .data[[y_mean_col]],
-                color = .data[[config$color_var]]
-              ),
-              shape = 16
+              mapping = aes(x = .data[[config$x_var]], y = .data[[y_mean_col]]),  # Include both x and y mappings
+              color = config$error_bar_params$color,
+              shape = 16,
+              inherit.aes = FALSE  # Prevent overriding global aesthetics
             )
           } else {
             plot <- plot + geom_point(
-              aes(
-                x = .data[[config$x_var]],
-                y = .data[[y_mean_col]]
-              ),
-              color = config$error_bar_params$color,
-              shape = 16
+              mapping = aes(x = .data[[config$x_var]], y = .data[[y_mean_col]]),  # Include both x and y mappings
+              shape = 16,
+              inherit.aes = FALSE  # Prevent overriding global aesthetics
             )
           }
         }
@@ -779,33 +777,72 @@ generate_scatter_plot <- function(plot, config) {
     position = position
   )
 
+  # Add a cyan point for the reference data for correlation plots
   if (!is.null(config$cyan_points) && config$cyan_points) {
     plot <- plot + geom_point(
-      aes(x = .data[[config$x_var]], y = .data[[config$y_var]]),
+      data = config$df_reference,
+      mapping = aes(x = .data[[config$x_var]], y = .data[[config$y_var]]),
       color = "cyan",
       shape = 3,
-      size = 0.5
+      size = 0.5,
+      inherit.aes = FALSE
     )
   }
-
-  if (!is.null(config$gray_points) && config$gray_points) {
-    plot <- plot + geom_point(shape = 3, color = "gray70", size = 1)
-  }
   
   # Add linear regression line if specified
   if (!is.null(config$lm_line)) {
-    plot <- plot +
-      annotate(
-        "segment",
-        x = config$lm_line$x_min,
-        xend = config$lm_line$x_max,
-        y = config$lm_line$intercept + config$lm_line$slope * config$lm_line$x_min,  # Calculate y for x_min
-        yend = config$lm_line$intercept + config$lm_line$slope * config$lm_line$x_max, # Calculate y for x_max
-        color = ifelse(!is.null(config$lm_line$color), config$lm_line$color, "blue"),
-        linewidth = ifelse(!is.null(config$lm_line$linewidth), config$lm_line$linewidth, 1)
-      )
+    # Extract necessary values
+    x_min <- config$lm_line$x_min
+    x_max <- config$lm_line$x_max
+    intercept <- config$lm_line$intercept
+    slope <- config$lm_line$slope
+    color <- ifelse(!is.null(config$lm_line$color), config$lm_line$color, "blue")
+    linewidth <- ifelse(!is.null(config$lm_line$linewidth), config$lm_line$linewidth, 1)
+
+    # Ensure none of the values are NA and calculate y-values
+    if (!is.na(x_min) && !is.na(x_max) && !is.na(intercept) && !is.na(slope)) {
+      y_min <- intercept + slope * x_min
+      y_max <- intercept + slope * x_max
+      
+      # Ensure y-values are within y-limits (if any)
+      if (!is.null(config$ylim_vals)) {
+        y_min_within_limits <- y_min >= config$ylim_vals[1] && y_min <= config$ylim_vals[2]
+        y_max_within_limits <- y_max >= config$ylim_vals[1] && y_max <= config$ylim_vals[2]
+
+        # Adjust or skip based on whether the values fall within limits
+        if (y_min_within_limits && y_max_within_limits) {
+          # Ensure x-values are also valid
+          if (!is.na(x_min) && !is.na(x_max)) {
+            plot <- plot + annotate(
+              "segment",
+              x = x_min,
+              xend = x_max,
+              y = y_min,
+              yend = y_max,
+              color = color,
+              linewidth = linewidth
+            )
+          }
+        } else {
+          message("Skipping linear modeling line due to y-values outside of limits.")
+        }
+      } else {
+        # If no y-limits are provided, proceed with the annotation
+        plot <- plot + annotate(
+          "segment",
+          x = x_min,
+          xend = x_max,
+          y = y_min,
+          yend = y_max,
+          color = color,
+          linewidth = linewidth
+        )
+      }
+    } else {
+      message("Skipping linear modeling line due to missing or invalid values.")
+    }
   }
-  
+
   # Add SD Bands if specified
   if (!is.null(config$sd_band)) {
     plot <- plot +
@@ -829,7 +866,7 @@ generate_scatter_plot <- function(plot, config) {
       )
   }
 
-  # Add Rectangles if specified
+  # Add rectangles if specified
   if (!is.null(config$rectangles)) {
     for (rect in config$rectangles) {
       plot <- plot + annotate(
@@ -909,11 +946,11 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
       df_plot <- if (stage == "before") df_before else df_after
 
       # Check for non-finite values in the y-variable
-      df_plot_filtered <- df_plot %>% filter(is.finite(!!sym(var)))
+      # df_plot_filtered <- df_plot %>% filter(is.finite(.data[[var]]))
 
       # Adjust settings based on plot_type
       plot_config <- list(
-        df = df_plot_filtered,
+        df = df_plot,
         x_var = "scan",
         y_var = var,
         plot_type = plot_type,
@@ -921,7 +958,8 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
         color_var = "conc_num_factor_factor",
         size = 0.2,
         error_bar = (plot_type == "scatter"),
-        legend_position = "bottom"
+        legend_position = "bottom",
+        filter_na = TRUE
       )
 
       # Add config to plots list
@@ -1086,7 +1124,7 @@ generate_interaction_plot_configs <- function(df_summary, df_interactions, type)
         x_breaks = unique(group_data$conc_num_factor_factor),
         x_labels = as.character(unique(group_data$conc_num)),
         ylim_vals = y_limits,
-        y_filter = FALSE,
+        filter_na = TRUE,
         lm_line = list(
           intercept = lm_intercept_value,
           slope = lm_slope_value,
@@ -1111,7 +1149,7 @@ generate_interaction_plot_configs <- function(df_summary, df_interactions, type)
   ))
 }
 
-generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm = FALSE, overlap_color = FALSE) {
+generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, filter_na = FALSE, overlap_color = FALSE) {
   sd_bands <- c(1, 2, 3)
   plot_configs <- list()
   
@@ -1128,7 +1166,7 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
   }
 
   # Helper function to create a rank plot configuration
-  create_plot_config <- function(variable, rank_var, zscore_var, y_label, sd_band, na_rm, with_annotations = TRUE) {
+  create_plot_config <- function(variable, rank_var, zscore_var, y_label, sd_band, filter_na, with_annotations = TRUE) {
     num_enhancers <- sum(df[[zscore_var]] >= sd_band, na.rm = TRUE)
     num_suppressors <- sum(df[[zscore_var]] <= -sd_band, na.rm = TRUE)
 
@@ -1148,7 +1186,7 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
       alpha_negative = 0.3,
       shape = 3,
       size = 0.1,
-      na_rm = na_rm,
+      filter_na = filter_na,
       legend_position = "none"
     )
     
@@ -1181,11 +1219,11 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
     for (sd_band in sd_bands) {
       # Create plot with annotations
       plot_configs[[length(plot_configs) + 1]] <-
-        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, na_rm, with_annotations = TRUE)
+        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, filter_na, with_annotations = TRUE)
       
       # Create plot without annotations
       plot_configs[[length(plot_configs) + 1]] <-
-        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, na_rm, with_annotations = FALSE)
+        create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, filter_na, with_annotations = FALSE)
     }
   }
 
@@ -1198,7 +1236,7 @@ generate_rank_plot_configs <- function(df, is_lm = FALSE, adjust = FALSE, na_rm
   ))
 }
 
-generate_correlation_plot_configs <- function(df) {
+generate_correlation_plot_configs <- function(df, df_reference) {
   # Define relationships for different-variable correlations
   relationships <- list(
     list(x = "L", y = "K"),
@@ -1209,6 +1247,10 @@ generate_correlation_plot_configs <- function(df) {
     list(x = "r", y = "AUC")
   )
 
+  # This filtering was in the original script
+  # df_reference <- df_reference %>%
+  #   filter(!is.na(Z_lm_L))
+
   plot_configs <- list()
 
   # Iterate over the option to highlight cyan points (TRUE/FALSE)
@@ -1221,10 +1263,10 @@ generate_correlation_plot_configs <- function(df) {
       y_var <- paste0("Z_lm_", rel$y)
 
       # Extract the R-squared, intercept, and slope from the df
-      relationship_name <- paste0(rel$x, "_vs_", rel$y)  # Example: L_vs_K
-      intercept <- mean(df[[paste0("lm_intercept_", rel$x)]], na.rm = TRUE)
-      slope <- mean(df[[paste0("lm_slope_", rel$x)]], na.rm = TRUE)
-      r_squared <- mean(df[[paste0("lm_R_squared_", rel$x)]], na.rm = TRUE)
+      relationship_name <- paste0(rel$x, "_vs_", rel$y)
+      intercept <- df[[paste0("lm_intercept_", rel$x)]]
+      slope <- df[[paste0("lm_slope_", rel$x)]]
+      r_squared <- df[[paste0("lm_R_squared_", rel$x)]]
 
       # Generate the label for the plot
       plot_label <- paste("Interaction", rel$x, "vs.", rel$y)
@@ -1232,6 +1274,7 @@ generate_correlation_plot_configs <- function(df) {
       # Construct plot config
       plot_config <- list(
         df = df,
+        df_reference = df_reference,
         x_var = x_var,
         y_var = y_var,
         plot_type = "scatter",
@@ -1248,11 +1291,9 @@ generate_correlation_plot_configs <- function(df) {
           slope = slope,
           color = "tomato3"
         ),
-        shape = 3,
-        size = 0.5,
-        color_var = "Overlap",
-        cyan_points = highlight_cyan, # include cyan points or not based on the loop
-        gray_points = TRUE
+        color = "gray70",
+        filter_na = TRUE,
+        cyan_points = highlight_cyan # include cyan points or not based on the loop
       )
 
       plot_configs <- append(plot_configs, list(plot_config))
@@ -1434,7 +1475,7 @@ main <- function() {
           x_var = "L",
           y_var = "K",
           plot_type = "scatter",
-          title = "Raw L vs K for strains falling outside 2SD of the K mean at each Conc",
+          title = "Raw L vs K for strains falling outside 2 SD of the K mean at each Conc",
           color_var = "conc_num_factor_factor",
           position = "jitter",
           tooltip_vars = c("OrfRep", "Gene", "delta_bg"),
@@ -1459,7 +1500,7 @@ main <- function() {
           x_label = "Delta Background",
           y_var = "K",
           plot_type = "scatter",
-          title = "Delta Background vs K for strains falling outside 2SD of the K mean at each Conc",
+          title = "Delta Background vs K for strains falling outside 2 SD of K",
           color_var = "conc_num_factor_factor",
           position = "jitter",
           tooltip_vars = c("OrfRep", "Gene", "delta_bg"),
@@ -1573,22 +1614,23 @@ main <- function() {
           .groups = "drop"
         )
 
-      # message("Calculating reference strain interaction summary statistics") # formerly X_stats_interaction
-      # df_reference_interaction_stats <- calculate_summary_stats(
-      #   df = df_reference,
-      #   variables = c("L", "K", "r", "AUC"),
-      #   group_vars = c("OrfRep", "Gene", "num", "Drug", "conc_num", "conc_num_factor_factor")
-      #   )$df_with_stats
+      message("Calculating reference strain interaction summary statistics") # formerly X_stats_interaction
+      df_reference_interaction_stats <- calculate_summary_stats(
+        df = df_reference,
+        variables = c("L", "K", "r", "AUC"),
+        group_vars = c("OrfRep", "Gene", "num", "Drug", "conc_num", "conc_num_factor_factor")
+        )$df_with_stats
       
-      # # message("Calculating reference strain interaction scores")
-      # reference_results <- calculate_interaction_scores(df_reference_interaction_stats, df_bg_stats, "reference")
-      # df_reference_interactions_joined <- reference_results$full_data
-      # write.csv(reference_results$calculations, file = file.path(out_dir, "zscore_calculations_reference.csv"), row.names = FALSE)
-      # write.csv(reference_results$interactions, file = file.path(out_dir, "zscore_interactions_reference.csv"), row.names = FALSE)
+      message("Calculating reference strain interaction scores")
+      reference_results <- calculate_interaction_scores(df_reference_interaction_stats, df_bg_stats, "reference")
+      df_reference_interactions_joined <- reference_results$full_data
+      df_reference_interactions <- reference_results$interactions
+      write.csv(reference_results$calculations, file = file.path(out_dir, "zscore_calculations_reference.csv"), row.names = FALSE)
+      write.csv(df_reference_interactions, file = file.path(out_dir, "zscore_interactions_reference.csv"), row.names = FALSE)
 
-      # # message("Generating reference interaction plots")
-      # reference_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_reference_interactions_joined, "reference")
-      # generate_and_save_plots(out_dir, "interaction_plots_reference", reference_plot_configs, page_width = 16, page_height = 16)
+      message("Generating reference interaction plots")
+      reference_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_reference_interactions_joined, "reference")
+      generate_and_save_plots(out_dir, "interaction_plots_reference", reference_plot_configs, page_width = 16, page_height = 16)
 
       message("Setting missing deletion values to the highest theoretical value at each drug conc for L")
       df_deletion <- df_na_stats %>% # formerly X2
@@ -1616,9 +1658,9 @@ main <- function() {
       write.csv(deletion_results$calculations, file = file.path(out_dir, "zscore_calculations.csv"), row.names = FALSE)
       write.csv(df_interactions, file = file.path(out_dir, "zscore_interactions.csv"), row.names = FALSE)
 
-      # message("Generating deletion interaction plots")
-      # deletion_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_interactions_joined, "deletion")
-      # generate_and_save_plots(out_dir, "interaction_plots", deletion_plot_configs, page_width = 16, page_height = 16)
+      message("Generating deletion interaction plots")
+      deletion_plot_configs <- generate_interaction_plot_configs(df_reference_summary_stats, df_interactions_joined, "deletion")
+      generate_and_save_plots(out_dir, "interaction_plots", deletion_plot_configs, page_width = 16, page_height = 16)
 
       message("Writing enhancer/suppressor csv files")
       interaction_threshold <- 2  # TODO add to study config?
@@ -1675,7 +1717,7 @@ main <- function() {
         df_interactions,
         is_lm = FALSE,
         adjust = FALSE,
-        na_rm = TRUE,
+        filter_na = TRUE,
         overlap_color = TRUE
       )
       generate_and_save_plots(out_dir, "rank_plots_na_rm", rank_plot_filtered_configs,
@@ -1686,7 +1728,7 @@ main <- function() {
         df_interactions,
         is_lm = TRUE,
         adjust = FALSE,
-        na_rm = TRUE,
+        filter_na = TRUE,
         overlap_color = TRUE
       )
       generate_and_save_plots(out_dir, "rank_plots_lm_na_rm", rank_plot_lm_filtered_configs,
@@ -1694,7 +1736,8 @@ main <- function() {
 
       message("Generating correlation curve parameter pair plots")
       correlation_plot_configs <- generate_correlation_plot_configs(
-        df_interactions
+        df_interactions,
+        df_reference_interactions
       )
       generate_and_save_plots(out_dir, "correlation_cpps", correlation_plot_configs,
         page_width = 10, page_height = 7)