浏览代码

Standardize plot groups

Bryan Roessler 6 月之前
父节点
当前提交
1548dbf9c1
共有 1 个文件被更改,包括 158 次插入140 次删除
  1. 158 140
      qhtcp-workflow/apps/r/calculate_interaction_zscores.R

+ 158 - 140
qhtcp-workflow/apps/r/calculate_interaction_zscores.R

@@ -399,99 +399,94 @@ calculate_interaction_scores <- function(df, max_conc, bg_stats,
 generate_and_save_plots <- function(out_dir, filename, plot_configs) {
   message("Generating ", filename, ".pdf and ", filename, ".html")
 
-  static_plots <- list()
-  plotly_plots <- list()
+  # Check if we're dealing with multiple plot groups
+  plot_groups <- if ("plots" %in% names(plot_configs)) {
+    list(plot_configs)  # Single group
+  } else {
+    plot_configs  # Multiple groups
+  }
 
-  for (i in seq_along(plot_configs$plots)) {
-    config <- plot_configs$plots[[i]]
-    df <- config$df
+  for (group in plot_groups) {
 
-    # Define aes_mapping, ensuring y_var is only used when it's not NULL
-    aes_mapping <- switch(config$plot_type,
-      "bar" = if (!is.null(config$color_var)) {
-        aes(x = .data[[config$x_var]], fill = .data[[config$color_var]], color = .data[[config$color_var]])
-      } else {
-        aes(x = .data[[config$x_var]])
-      },
-      "density" = if (!is.null(config$color_var)) {
-        aes(x = .data[[config$x_var]], color = .data[[config$color_var]])
-      } else {
-        aes(x = .data[[config$x_var]])
-      },
-      # For other plot types, only include y_var if it's not NULL
-      if (!is.null(config$y_var) && !is.null(config$color_var)) {
-        aes(x = .data[[config$x_var]], y = .data[[config$y_var]], color = .data[[config$color_var]])
-      } else if (!is.null(config$y_var)) {
-        aes(x = .data[[config$x_var]], y = .data[[config$y_var]])
+    static_plots <- list()
+    plotly_plots <- list()
+
+    grid_layout <- group$grid_layout
+    plots <- group$plots
+
+    for (i in seq_along(plots)) {
+      config <- plots[[i]]
+      df <- config$df
+
+      if (config$plot_type == "bar") {
+        if (!is.null(config$color_var)) {
+          aes_mapping <- aes(x = .data[[config$x_var]], fill = .data[[config$color_var]], color = .data[[config$color_var]])
+        } else {
+          aes_mapping <- aes(x = .data[[config$x_var]])
+        }
+      } else if (config$plot_type == "density") {
+        if (!is.null(config$color_var)) {
+          aes_mapping <- aes(x = .data[[config$x_var]], color = .data[[config$color_var]])
+        } else {
+          aes_mapping <- aes(x = .data[[config$x_var]])
+        }
       } else {
-        aes(x = .data[[config$x_var]])  # no y_var needed for density and bar plots
+        # For other plot types
+        if (!is.null(config$y_var) && !is.null(config$color_var)) {
+          aes_mapping <- aes(x = .data[[config$x_var]], y = .data[[config$y_var]], color = .data[[config$color_var]])
+        } else if (!is.null(config$y_var)) {
+          aes_mapping <- aes(x = .data[[config$x_var]], y = .data[[config$y_var]])
+        } else {
+          aes_mapping <- aes(x = .data[[config$x_var]])
+        }
       }
-    )
 
-    plot <- ggplot(df, aes_mapping) + theme_publication(legend_position = config$legend_position)
+      plot <- ggplot(df, aes_mapping) + theme_publication(legend_position = config$legend_position)
 
-    # Apply appropriate plot function
-    plot <- switch(config$plot_type,
-      "scatter" = generate_scatter_plot(plot, config),
-      "box" = generate_box_plot(plot, config),
-      "density" = plot + geom_density(),
-      "bar" = plot + geom_bar(),
-      plot  # default (unused)
-    )
+      plot <- switch(config$plot_type,
+        "scatter" = generate_scatter_plot(plot, config),
+        "box" = generate_boxplot(plot, config),
+        "density" = plot + geom_density(),
+        "bar" = plot + geom_bar(),
+        plot  # default (unused)
+      )
 
-    # Add titles and labels
-    if (!is.null(config$title)) plot <- plot + ggtitle(config$title)
-    if (!is.null(config$x_label)) plot <- plot + xlab(config$x_label)
-    if (!is.null(config$y_label)) plot <- plot + ylab(config$y_label)
-    if (!is.null(config$coord_cartesian)) plot <- plot + coord_cartesian(ylim = config$coord_cartesian)
+      if (!is.null(config$title)) plot <- plot + ggtitle(config$title)
+      if (!is.null(config$x_label)) plot <- plot + xlab(config$x_label)
+      if (!is.null(config$y_label)) plot <- plot + ylab(config$y_label)
+      if (!is.null(config$coord_cartesian)) plot <- plot + coord_cartesian(ylim = config$coord_cartesian)
 
-    # Convert ggplot to plotly, skipping subplot
-    if (!is.null(config$tooltip_vars)) {
-      plotly_plot <- suppressWarnings(plotly::ggplotly(plot, tooltip = config$tooltip_vars))
-    } else {
       plotly_plot <- suppressWarnings(plotly::ggplotly(plot))
-    }
 
-    if (!is.null(plotly_plot[["frames"]])) {
-      plotly_plot[["frames"]] <- NULL
+      static_plots[[i]] <- plot
+      plotly_plots[[i]] <- plotly_plot
     }
 
-    # Adjust legend position in plotly
-    if (!is.null(config$legend_position) && config$legend_position == "bottom") {
-      plotly_plot <- plotly_plot %>% layout(legend = list(orientation = "h"))
-    }
-
-    static_plots[[i]] <- plot
-    plotly_plots[[i]] <- plotly_plot
-  }
-
-  # Save static PDF plots
-  pdf(file.path(out_dir, paste0(filename, ".pdf")), width = 16, height = 9)
+    pdf(file.path(out_dir, paste0(filename, ".pdf")), width = 16, height = 9)
 
-  if (is.null(plot_configs$grid_layout)) {
-    # Print each plot on a new page if grid_layout is not set
-    for (plot in static_plots) {
-      print(plot)
+    if (is.null(grid_layout)) {
+      for (plot in static_plots) {
+        print(plot)
+      }
+    } else {
+      grid.arrange(
+        grobs = static_plots,
+        ncol = grid_layout$ncol,
+        nrow = grid_layout$nrow
+      )
     }
-  } else {
-    # Use grid.arrange if grid_layout is set
-    grid_nrow <- ifelse(is.null(plot_configs$grid_layout$nrow), length(plot_configs$plots), plot_configs$grid_layout$nrow)
-    grid_ncol <- ifelse(is.null(plot_configs$grid_layout$ncol), 1, plot_configs$grid_layout$ncol)
-    grid.arrange(grobs = static_plots, ncol = grid_ncol, nrow = grid_nrow)
-  }
 
-  dev.off()
+    dev.off()
 
-  # Save combined HTML plot
-  out_html_file <- file.path(out_dir, paste0(filename, ".html"))
-  message("Saving combined HTML file: ", out_html_file)
-  htmltools::save_html(
-    htmltools::tagList(plotly_plots),
-    file = out_html_file
-  )
+    out_html_file <- file.path(out_dir, paste0(filename, ".html"))
+    message("Saving combined HTML file: ", out_html_file)
+    htmltools::save_html(
+      htmltools::tagList(plotly_plots),
+      file = out_html_file
+    )
+  }
 }
 
-
 generate_scatter_plot <- function(plot, config) {
 
   # Define the points
@@ -582,39 +577,52 @@ generate_scatter_plot <- function(plot, config) {
   # Add error bars if specified
   if (!is.null(config$error_bar) && config$error_bar && !is.null(config$y_var)) {
     if (!is.null(config$error_bar_params)) {
+      # Error bar params are constants, so set them outside aes
       plot <- plot +
         geom_errorbar(
           aes(
-            ymin = config$error_bar_params$ymin,
-            ymax = config$error_bar_params$ymax
+            ymin = !!sym(config$y_var),   # y_var mapped to y-axis
+            ymax = !!sym(config$y_var)
           ),
+          ymin = config$error_bar_params$ymin,  # Constant values
+          ymax = config$error_bar_params$ymax,  # Constant values
           alpha = 0.3,
           linewidth = 0.5
         )
     } else {
+      # Dynamically generate ymin and ymax based on column names
       y_mean_col <- paste0("mean_", config$y_var)
       y_sd_col <- paste0("sd_", config$y_var)
       
       plot <- plot +
         geom_errorbar(
           aes(
-            ymin = !!sym(y_mean_col) - !!sym(y_sd_col),
-            ymax = !!sym(y_mean_col) + !!sym(y_sd_col)
+            ymin = !!sym(y_mean_col) - !!sym(y_sd_col),  # Calculating ymin in aes
+            ymax = !!sym(y_mean_col) + !!sym(y_sd_col)   # Calculating ymax in aes
           ),
           alpha = 0.3,
           linewidth = 0.5
         )
     }
   }
-  
+
   # Customize X-axis if specified
   if (!is.null(config$x_breaks) && !is.null(config$x_labels) && !is.null(config$x_label)) {
-    plot <- plot +
-      scale_x_discrete(
-        name = config$x_label,
-        breaks = config$x_breaks,
-        labels = config$x_labels
-      )
+    if (is.factor(df[[config$x_var]]) || is.character(df[[config$x_var]])) {
+      plot <- plot +
+        scale_x_discrete(
+          name = config$x_label,
+          breaks = config$x_breaks,
+          labels = config$x_labels
+        )
+    } else {
+      plot <- plot +
+        scale_x_continuous(
+          name = config$x_label,
+          breaks = config$x_breaks,
+          labels = config$x_labels
+        )
+    }
   }
   
   # Set Y-axis limits if specified
@@ -642,17 +650,27 @@ generate_scatter_plot <- function(plot, config) {
   return(plot)
 }
 
-generate_box_plot <- function(plot, config) {
+generate_boxplot <- function(plot, config) {
   # Convert x_var to a factor within aes mapping
   plot <- plot + geom_boxplot(aes(x = factor(.data[[config$x_var]])))
 
-  # Apply scale_x_discrete for breaks, labels, and axis label if provided
+  # Customize X-axis if specified
   if (!is.null(config$x_breaks) && !is.null(config$x_labels) && !is.null(config$x_label)) {
-    plot <- plot + scale_x_discrete(
-      name = config$x_label,
-      breaks = config$x_breaks,
-      labels = config$x_labels
-    )
+    if (is.factor(df[[config$x_var]]) || is.character(df[[config$x_var]])) {
+      plot <- plot +
+        scale_x_discrete(
+          name = config$x_label,
+          breaks = config$x_breaks,
+          labels = config$x_labels
+        )
+    } else {
+      plot <- plot +
+        scale_x_continuous(
+          name = config$x_label,
+          breaks = config$x_breaks,
+          labels = config$x_labels
+        )
+    }
   }
 
   return(plot)
@@ -660,7 +678,7 @@ generate_box_plot <- function(plot, config) {
 
 generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df_after = NULL,
   plot_type = "scatter", stages = c("before", "after")) {
-  plots <- list()
+  plot_configs <- list()
   
   for (var in variables) {
     for (stage in stages) {
@@ -670,7 +688,7 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
       df_plot_filtered <- df_plot %>% filter(is.finite(!!sym(var)))
 
       # Adjust settings based on plot_type
-      config <- list(
+      plot_config <- list(
         df = df_plot_filtered,
         x_var = "scan",
         y_var = var,
@@ -683,39 +701,41 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
       )
 
       # Add config to plots list
-      plots <- append(plots, list(config))
+      plot_configs <- append(plot_configs, list(plot_config))
     }
   }
 
-  return(list(grid_layout = list(ncol = 1, nrow = length(plots)), plots = plots))
+  return(list(plots = plot_configs))
 }
 
-generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type = "reference") {
-  if (is.null(limits_map)) {
-    limits_map <- list(
-      L = c(0, 130),
-      K = c(-20, 160),
-      r = c(0, 1),
-      AUC = c(0, 12500),
-      Delta_L = c(-60, 60),
-      Delta_K = c(-60, 60),
-      Delta_r = c(-0.6, 0.6),
-      Delta_AUC = c(-6000, 6000)
-    )
-  }
+generate_interaction_plot_configs <- function(df, plot_type = "reference") {
+  limits_map <- list(
+    L = c(0, 130),
+    K = c(-20, 160),
+    r = c(0, 1),
+    AUC = c(0, 12500)
+  )
+
+  delta_limits_map <- list(
+    Delta_L = c(-60, 60),
+    Delta_K = c(-60, 60),
+    Delta_r = c(-0.6, 0.6),
+    Delta_AUC = c(-6000, 6000)
+  )
 
   group_vars <- if (plot_type == "reference") c("OrfRep", "Gene", "num") else c("OrfRep", "Gene")
+  
   df_filtered <- df %>%
     mutate(OrfRepCombined = if (plot_type == "reference") paste(OrfRep, Gene, num, sep = "_") else paste(OrfRep, Gene, sep = "_"))
 
-  # Separate the plots into two groups: overall variables and delta comparisons
-  overall_plots <- list()
-  delta_plots <- list()
+  overall_plot_configs <- list()
+  delta_plot_configs <- list()
 
-  for (var in c("L", "K", "r", "AUC")) {
+  # Overall plots
+  for (var in names(limits_map)) {
     y_limits <- limits_map[[var]]
 
-    config <- list(
+    plot_config <- list(
       df = df_filtered,
       plot_type = "scatter",
       x_var = "conc_num_factor_factor",
@@ -729,9 +749,10 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
       position = "jitter",
       smooth = TRUE
     )
-    overall_plots <- append(overall_plots, list(config))
+    overall_plot_configs <- append(overall_plot_configs, list(plot_config))
   }
 
+  # Delta plots
   unique_groups <- df_filtered %>% select(all_of(group_vars)) %>% distinct()
 
   for (i in seq_len(nrow(unique_groups))) {
@@ -742,17 +763,15 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
     Gene <- if ("Gene" %in% names(group)) as.character(group$Gene) else ""
     num <- if ("num" %in% names(group)) as.character(group$num) else ""
 
-    for (var in c("Delta_L", "Delta_K", "Delta_r", "Delta_AUC")) {
-      y_limits <- limits_map[[var]]
+    for (var in names(delta_limits_map)) {
+      y_limits <- delta_limits_map[[var]]
       y_span <- y_limits[2] - y_limits[1]
 
-      # Error bars
       WT_sd_var <- paste0("WT_sd_", sub("Delta_", "", var))
       WT_sd_value <- group_data[[WT_sd_var]][1]
       error_bar_ymin <- 0 - (2 * WT_sd_value)
       error_bar_ymax <- 0 + (2 * WT_sd_value)
 
-      # Annotations
       Z_Shift_value <- round(group_data[[paste0("Z_Shift_", sub("Delta_", "", var))]][1], 2)
       Z_lm_value <- round(group_data[[paste0("Z_lm_", sub("Delta_", "", var))]][1], 2)
       NG_value <- group_data$NG[1]
@@ -767,7 +786,7 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
         list(x = 1, y = y_limits[1], label = paste("SM =", SM_value))
       )
 
-      config <- list(
+      plot_config <- list(
         df = group_data,
         plot_type = "scatter",
         x_var = "conc_num_factor_factor",
@@ -786,20 +805,22 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
         x_labels = as.character(unique(group_data$conc_num)),
         ylim_vals = y_limits
       )
-      delta_plots <- append(delta_plots, list(config))
+      delta_plot_configs <- append(delta_plot_configs, list(plot_config))
     }
   }
 
   return(list(
-    overall_plots = list(grid_layout = list(ncol = 2, nrow = 2), plots = overall_plots),
-    delta_plots = list(grid_layout = list(ncol = 4, nrow = 3), plots = delta_plots)
+    list(grid_layout = list(ncol = 2, nrow = 2), plots = overall_plot_configs),
+    list(grid_layout = list(ncol = 4, nrow = 3), plots = delta_plot_configs)
   ))
 }
 
 generate_rank_plot_configs <- function(df, variables, is_lm = FALSE, adjust = FALSE, overlap_color = FALSE) {
   sd_bands <- c(1, 2, 3)
-  configs <- list()
+  plot_configs <- list()
   
+  variables <- c("L", "K")
+
   # Adjust (if necessary) and rank columns
   for (variable in variables) {
     if (adjust) {
@@ -863,19 +884,19 @@ generate_rank_plot_configs <- function(df, variables, is_lm = FALSE, adjust = FA
     # Loop through SD bands
     for (sd_band in sd_bands) {
       # Create plot with annotations
-      configs[[length(configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = TRUE)
+      plot_configs[[length(plot_configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = TRUE)
       
       # Create plot without annotations
-      configs[[length(configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = FALSE)
+      plot_configs[[length(plot_configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = FALSE)
     }
   }
 
   # Calculate dynamic grid layout based on the number of plots
-  num_plots <- length(configs)
   grid_ncol <- 3
+  num_plots <- length(plot_configs)
   grid_nrow <- ceiling(num_plots / grid_ncol)  # Automatically calculate the number of rows
 
-  return(list(grid_layout = list(ncol = grid_ncol, nrow = grid_nrow), plots = configs))
+  return(list(grid_layout = list(ncol = grid_ncol, nrow = grid_nrow), plots = plot_configs))
 }
 
 generate_correlation_plot_configs <- function(df, highlight_cyan = FALSE) {
@@ -888,22 +909,23 @@ generate_correlation_plot_configs <- function(df, highlight_cyan = FALSE) {
     list(x = "Z_lm_r", y = "Z_lm_AUC", label = "Interaction r vs. Interaction AUC")
   )
 
-  plots <- list()
+  plot_configs <- list()
 
   for (rel in relationships) {
     lm_model <- lm(as.formula(paste(rel$y, "~", rel$x)), data = df)
     r_squared <- summary(lm_model)$r.squared
 
-    config <- list(
+    plot_config <- list(
       df = df,
       x_var = rel$x,
       y_var = rel$y,
       plot_type = "scatter",
       title = rel$label,
       annotations = list(
-        list(x = mean(df[[rel$x]], na.rm = TRUE),
-             y = mean(df[[rel$y]], na.rm = TRUE),
-             label = paste("R-squared =", round(r_squared, 3)))
+        list(
+          x = mean(df[[rel$x]], na.rm = TRUE),
+          y = mean(df[[rel$y]], na.rm = TRUE),
+          label = paste("R-squared =", round(r_squared, 3)))
       ),
       smooth = TRUE,
       smooth_color = "tomato3",
@@ -914,10 +936,10 @@ generate_correlation_plot_configs <- function(df, highlight_cyan = FALSE) {
       cyan_points = highlight_cyan
     )
 
-    plots <- append(plots, list(config))
+    plot_configs <- append(plot_configs, list(plot_config))
   }
 
-  return(list(grid_layout = list(ncol = 3, nrow = 2), plots = plots))
+  return(list(plots = plot_configs))
 }
 
 main <- function() {
@@ -1332,7 +1354,6 @@ main <- function() {
       message("Generating rank plots")
       rank_plot_configs <- generate_rank_plot_configs(
         df = zscore_interactions_joined,
-        variables = interaction_vars,
         is_lm = FALSE,
         adjust = TRUE
       )
@@ -1342,7 +1363,6 @@ main <- function() {
       message("Generating ranked linear model plots")
       rank_lm_plot_configs <- generate_rank_plot_configs(
         df = zscore_interactions_joined,
-        variables = interaction_vars,
         is_lm = TRUE,
         adjust = TRUE
       )
@@ -1375,7 +1395,6 @@ main <- function() {
       message("Generating filtered ranked plots")
       rank_plot_filtered_configs <- generate_rank_plot_configs(
         df = zscore_interactions_filtered,
-        variables = interaction_vars,
         is_lm = FALSE,
         adjust = FALSE,
         overlap_color = TRUE
@@ -1388,7 +1407,6 @@ main <- function() {
       message("Generating filtered ranked linear model plots")
       rank_plot_lm_filtered_configs <- generate_rank_plot_configs(
         df = zscore_interactions_filtered,
-        variables = interaction_vars,
         is_lm = TRUE,
         adjust = FALSE,
         overlap_color = TRUE