Browse Source

Improve interactions grids with gridExtra

Bryan Roessler 6 months ago
parent
commit
f005155d08
2 changed files with 179 additions and 199 deletions
  1. 178 198
      qhtcp-workflow/apps/r/calculate_interaction_zscores.R
  2. 1 1
      qhtcp-workflow/qhtcp-workflow

+ 178 - 198
qhtcp-workflow/apps/r/calculate_interaction_zscores.R

@@ -6,6 +6,7 @@ suppressMessages({
   library("rlang")
   library("ggthemes")
   library("data.table")
+  library("gridExtra")
   library("future")
   library("furrr")
   library("purrr")
@@ -371,124 +372,128 @@ calculate_interaction_scores <- function(df, max_conc, bg_stats,
     interactions_joined = interactions_joined))
 }
 
-generate_and_save_plots <- function(out_dir, filename, plot_configs, grid_layout = NULL) {
+generate_and_save_plots <- function(out_dir, filename, plot_configs) {
   message("Generating ", filename, ".pdf and ", filename, ".html")
 
-  # Prepare lists to collect plots
-  static_plots <- list()
-  plotly_plots <- list()
-
-  for (i in seq_along(plot_configs)) {
-    config <- plot_configs[[i]]
-    df <- config$df
-
-    # Create the base plot
-    aes_mapping <- if (config$plot_type == "bar") {
-      if (!is.null(config$color_var)) {
-        aes(x = .data[[config$x_var]], fill = as.factor(.data[[config$color_var]]), color = as.factor(.data[[config$color_var]]))
-      } else {
-        aes(x = .data[[config$x_var]])
-      }
-    } else if (config$plot_type == "density") {
-      if (!is.null(config$color_var)) {
-        aes(x = .data[[config$x_var]], color = as.factor(.data[[config$color_var]]))
-      } else {
-        aes(x = .data[[config$x_var]])
-      }
-    } else {
-      if (!is.null(config$color_var)) {
-        aes(x = .data[[config$x_var]], y = .data[[config$y_var]], color = as.factor(.data[[config$color_var]]))
+  # Iterate through the plot_configs (which contain both plots and grid_layout)
+  for (config_group in plot_configs) {
+    plot_list <- config_group$plots
+    grid_nrow <- config_group$grid_layout$nrow
+    grid_ncol <- config_group$grid_layout$ncol
+
+    # Prepare lists to collect static and interactive plots
+    static_plots <- list()
+    plotly_plots <- list()
+
+    # Generate each individual plot based on the configuration
+    for (i in seq_along(plot_list)) {
+      config <- plot_list[[i]]
+      df <- config$df
+
+      # Create the base plot
+      aes_mapping <- if (config$plot_type == "bar") {
+        if (!is.null(config$color_var)) {
+          aes(x = .data[[config$x_var]], fill = as.factor(.data[[config$color_var]]), color = as.factor(.data[[config$color_var]]))
+        } else {
+          aes(x = .data[[config$x_var]])
+        }
+      } else if (config$plot_type == "density") {
+        if (!is.null(config$color_var)) {
+          aes(x = .data[[config$x_var]], color = as.factor(.data[[config$color_var]]))
+        } else {
+          aes(x = .data[[config$x_var]])
+        }
       } else {
-        aes(x = .data[[config$x_var]], y = .data[[config$y_var]])
+        if (!is.null(config$color_var)) {
+          aes(x = .data[[config$x_var]], y = .data[[config$y_var]], color = as.factor(.data[[config$color_var]]))
+        } else {
+          aes(x = .data[[config$x_var]], y = .data[[config$y_var]])
+        }
       }
-    }
 
-    plot <- ggplot(df, aes_mapping)
+      plot <- ggplot(df, aes_mapping)
 
-    # Apply theme_publication with legend_position from config
-    legend_position <- if (!is.null(config$legend_position)) config$legend_position else "bottom"
-    plot <- plot + theme_publication(legend_position = legend_position)
+      # Apply theme_publication with legend_position from config
+      legend_position <- if (!is.null(config$legend_position)) config$legend_position else "bottom"
+      plot <- plot + theme_publication(legend_position = legend_position)
 
-    # Use appropriate helper function based on plot type
-    plot <- switch(config$plot_type,
-      "scatter" = generate_scatter_plot(plot, config),
-      "box" = generate_box_plot(plot, config),
-      "density" = plot + geom_density(),
-      "bar" = plot + geom_bar(),
-      plot  # default case if no type matches
-    )
+      # Use appropriate helper function based on plot type
+      plot <- switch(config$plot_type,
+                     "scatter" = generate_scatter_plot(plot, config),
+                     "box" = generate_box_plot(plot, config),
+                     "density" = plot + geom_density(),
+                     "bar" = plot + geom_bar(),
+                     plot  # default case if no type matches
+      )
 
-    # Add title and labels
-    if (!is.null(config$title)) {
-      plot <- plot + ggtitle(config$title)
-    }
-    if (!is.null(config$x_label)) {
-      plot <- plot + xlab(config$x_label)
-    }
-    if (!is.null(config$y_label)) {
-      plot <- plot + ylab(config$y_label)
-    }
+      # Add title and labels
+      if (!is.null(config$title)) {
+        plot <- plot + ggtitle(config$title)
+      }
+      if (!is.null(config$x_label)) {
+        plot <- plot + xlab(config$x_label)
+      }
+      if (!is.null(config$y_label)) {
+        plot <- plot + ylab(config$y_label)
+      }
 
-    # Add cartesian coordinates if specified
-    if (!is.null(config$coord_cartesian)) {
-      plot <- plot + coord_cartesian(ylim = config$coord_cartesian)
-    }
+      # Add cartesian coordinates if specified
+      if (!is.null(config$coord_cartesian)) {
+        plot <- plot + coord_cartesian(ylim = config$coord_cartesian)
+      }
 
-    # Apply scale_color_discrete(guide = FALSE) when color_var is NULL
-    if (is.null(config$color_var)) {
-      plot <- plot + scale_color_discrete(guide = "none")
-    }
+      # Apply scale_color_discrete(guide = FALSE) when color_var is NULL
+      if (is.null(config$color_var)) {
+        plot <- plot + scale_color_discrete(guide = "none")
+      }
 
-    # Add interactive tooltips for plotly
-    tooltip_vars <- c()
-    if (config$plot_type == "scatter") {
-      if (!is.null(config$delta_bg_point) && config$delta_bg_point) {
-        tooltip_vars <- c(tooltip_vars, "OrfRep", "Gene", "delta_bg")
-      } else if (!is.null(config$gene_point) && config$gene_point) {
-        tooltip_vars <- c(tooltip_vars, "OrfRep", "Gene")
-      } else if (!is.null(config$y_var) && !is.null(config$x_var)) {
-        tooltip_vars <- c(config$x_var, config$y_var)
+      # Add interactive tooltips for plotly
+      tooltip_vars <- c()
+      if (config$plot_type == "scatter") {
+        if (!is.null(config$delta_bg_point) && config$delta_bg_point) {
+          tooltip_vars <- c(tooltip_vars, "OrfRep", "Gene", "delta_bg")
+        } else if (!is.null(config$gene_point) && config$gene_point) {
+          tooltip_vars <- c(tooltip_vars, "OrfRep", "Gene")
+        } else if (!is.null(config$y_var) && !is.null(config$x_var)) {
+          tooltip_vars <- c(config$x_var, config$y_var)
+        }
       }
-    }
 
-    # Convert to plotly object and suppress warnings here
-    plotly_plot <- suppressWarnings({
-      if (length(tooltip_vars) > 0) {
-        ggplotly(plot, tooltip = tooltip_vars)
-      } else {
-        ggplotly(plot, tooltip = "none")
+      # Convert to plotly object and suppress warnings here
+      plotly_plot <- suppressWarnings({
+        if (length(tooltip_vars) > 0) {
+          ggplotly(plot, tooltip = tooltip_vars)
+        } else {
+          ggplotly(plot, tooltip = "none")
+        }
+      })
+
+      # Adjust legend position if specified
+      if (!is.null(config$legend_position) && config$legend_position == "bottom") {
+        plotly_plot <- plotly_plot %>% layout(legend = list(orientation = "h"))
       }
-    })
 
-    # Adjust legend position if specified
-    if (!is.null(config$legend_position) && config$legend_position == "bottom") {
-      plotly_plot <- plotly_plot %>% layout(legend = list(orientation = "h"))
+      # Add plots to lists
+      static_plots[[i]] <- plot
+      plotly_plots[[i]] <- plotly_plot
     }
 
-    # Add plots to lists
-    static_plots[[i]] <- plot
-    plotly_plots[[i]] <- plotly_plot
-  }
-
-  # Save static PDF plot(s)
-  pdf(file.path(out_dir, paste0(filename, ".pdf")), width = 14, height = 9)
-  lapply(static_plots, print)
-  dev.off()
-
-  # Combine and save interactive HTML plot(s)
-  combined_plot <- subplot(
-    plotly_plots,
-    nrows = if (!is.null(grid_layout) && !is.null(grid_layout$nrow)) {
-      grid_layout$nrow
-    } else {
-      # Calculate nrow based on the length of plotly_plots
-      ceiling(length(plotly_plots) / ifelse(!is.null(grid_layout) && !is.null(grid_layout$ncol), grid_layout$ncol, 1))
-    },
-    margin = 0.05
-  )
+    # Save static PDF plot(s) for the current grid
+    pdf(file.path(out_dir, paste0(filename, ".pdf")), width = 16, height = 9)
+    grid.arrange(grobs = static_plots, ncol = grid_ncol, nrow = grid_nrow)
+    dev.off()
+
+    # Combine and save interactive HTML plot(s)
+    combined_plot <- subplot(
+      plotly_plots,
+      nrows = grid_nrow,
+      ncols = grid_ncol,
+      margin = 0.05
+    )
 
-  # Save combined html plot(s)
-  saveWidget(combined_plot, file = file.path(out_dir, paste0(filename, ".html")), selfcontained = TRUE)
+    # Save combined HTML plot(s)
+    saveWidget(combined_plot, file = file.path(out_dir, paste0(filename, ".html")), selfcontained = TRUE)
+  }
 }
 
 generate_scatter_plot <- function(plot, config) {
@@ -686,102 +691,76 @@ generate_plate_analysis_plot_configs <- function(variables, stages = c("before",
   return(plots)
 }
 
-generate_interaction_plot_configs <- function(df, limits_map = NULL) {
-  # Default limits_map if not provided
+generate_interaction_plot_configs <- function(df, limits_map = NULL, stats_df = NULL) {
   if (is.null(limits_map)) {
     limits_map <- list(
-      L = c(-65, 65),
-      K = c(-65, 65),
-      r = c(-0.65, 0.65),
-      AUC = c(-6500, 6500)
+      L = c(0, 130),
+      K = c(-20, 160),
+      r = c(0, 1),
+      AUC = c(0, 12500)
     )
   }
 
-  # Filter data
-  df_filtered <- df
-  for (var in names(limits_map)) {
-    df_filtered <- df_filtered %>%
-      filter(!is.na(!!sym(var)) &
-        !!sym(var) >= limits_map[[var]][1] &
-        !!sym(var) <= limits_map[[var]][2])
-  }
+  # Ensure proper grouping by OrfRep, Gene, and num
+  df_filtered <- df %>%
+    filter(
+      !is.na(L) & L >= limits_map$L[1] & L <= limits_map$L[2],
+      !is.na(K) & K >= limits_map$K[1] & K <= limits_map$K[2],
+      !is.na(r) & r >= limits_map$r[1] & r <= limits_map$r[2],
+      !is.na(AUC) & AUC >= limits_map$AUC[1] & AUC <= limits_map$AUC[2]
+    ) %>%
+    group_by(OrfRep, Gene, num)  # Group by OrfRep, Gene, and num
 
-  configs <- list()
+  scatter_configs <- list()
+  box_configs <- list()
 
+  # Generate scatter and box plots for each variable (L, K, r, AUC)
   for (var in names(limits_map)) {
-    y_range <- limits_map[[var]]
-    
-    # Calculate annotation positions
-    y_min <- min(y_range)
-    y_max <- max(y_range)
-    y_span <- y_max - y_min
-    annotation_positions <- list(
-      ZShift = y_max - 0.1 * y_span,
-      lm_ZScore = y_max - 0.2 * y_span,
-      NG = y_min + 0.2 * y_span,
-      DB = y_min + 0.1 * y_span,
-      SM = y_min + 0.05 * y_span
-    )
-
-    # Prepare linear model line
-    lm_line <- list(
-      intercept = df_filtered[[paste0("lm_intercept_", var)]],
-      slope = df_filtered[[paste0("lm_slope_", var)]]
-    )
-
-    # Calculate x-axis position for annotations
-    num_levels <- length(levels(df_filtered$conc_num_factor))
-    x_pos <- (1 + num_levels) / 2
-
-    # Generate annotations
-    annotations <- lapply(names(annotation_positions), function(annotation_name) {
-      label <- switch(annotation_name,
-        ZShift = paste("ZShift =", round(df_filtered[[paste0("Z_Shift_", var)]], 2)),
-        lm_ZScore = paste("lm ZScore =", round(df_filtered[[paste0("Z_lm_", var)]], 2)),
-        NG = paste("NG =", df_filtered$NG),
-        DB = paste("DB =", df_filtered$DB),
-        SM = paste("SM =", df_filtered$SM),
-        NULL
-      )
-      if (!is.null(label)) {
-        list(x = x_pos, y = annotation_positions[[annotation_name]], label = label)
-      } else {
-        NULL
-      }
-    })
-    annotations <- Filter(Negate(is.null), annotations)
-
-    # Shared plot settings
-    plot_settings <- list(
+    scatter_configs[[length(scatter_configs) + 1]] <- list(
       df = df_filtered,
-      x_var = "conc_num_factor",
-      y_var = var,
-      ylim_vals = y_range,
-      annotations = annotations,
-      lm_line = lm_line,
-      x_breaks = levels(df_filtered$conc_num_factor),
-      x_labels = levels(df_filtered$conc_num_factor),
-      x_label = unique(df_filtered$Drug[1]),
-      coord_cartesian = y_range,
-    )
-
-    # Scatter plot config
-    configs[[length(configs) + 1]] <- modifyList(plot_settings, list(
+      x_var = "conc_num",  # X-axis variable
+      y_var = var,         # Y-axis variable (Delta_L, Delta_K, Delta_r, Delta_AUC)
       plot_type = "scatter",
-      title = sprintf("%s      %s", df_filtered$OrfRep[1], df_filtered$Gene[1]),
-      error_bar = TRUE,
-      position = "jitter",
-      size = 1
-    ))
-
-    # Box plot config
-    configs[[length(configs) + 1]] <- modifyList(plot_settings, list(
+      title = sprintf("Scatter RF for %s with SD", var),
+      coord_cartesian = limits_map[[var]],  # Set limits for Y-axis
+      annotations = list(
+        list(x = -0.25, y = 10, label = "NG"),
+        list(x = -0.25, y = 5, label = "DB"),
+        list(x = -0.25, y = 0, label = "SM")
+      ),
+      grid_layout = list(ncol = 4, nrow = 3)
+    )
+    box_configs[[length(box_configs) + 1]] <- list(
+      df = df_filtered,
+      x_var = "conc_num",  # X-axis variable
+      y_var = var,         # Y-axis variable (Delta_L, Delta_K, Delta_r, Delta_AUC)
       plot_type = "box",
-      title = sprintf("%s      %s (box plot)", df_filtered$OrfRep[1], df_filtered$Gene[1]),
-      error_bar = FALSE
-    ))
+      title = sprintf("Boxplot RF for %s with SD", var),
+      coord_cartesian = limits_map[[var]],
+      grid_layout = list(ncol = 4, nrow = 3)
+    )
   }
 
+  # Combine scatter and box plots into grids
+  configs <- list(
+    list(
+      grid_layout = list(nrow = 2, ncol = 2),  # Scatter plots in a 2x2 grid (for the 8 plots)
+      plots = scatter_configs[1:4]
+    ),
+    list(
+      grid_layout = list(nrow = 2, ncol = 2),  # Box plots in a 2x2 grid (for the 8 plots)
+      plots = box_configs
+    ),
+    list(
+      grid_layout = list(nrow = 3, ncol = 4),  # Delta_ plots in a 3x4 grid
+      plots = scatter_configs
+    ),
+    list(
+      grid_layout = list(nrow = 3, ncol = 4),  # Delta_ box plots in a 3x4 grid
+      plots = box_configs
+    )
+  )
+
   return(configs)
 }
 
@@ -864,7 +843,8 @@ generate_rank_plot_configs <- function(df, variables, is_lm = FALSE, adjust = FA
         size = 0.1,
         y_label = y_label,
         x_label = "Rank",
-        legend_position = "none"
+        legend_position = "none",
+        grid_layout = list(ncol = 3, nrow = 2)
       )
       
       # Non-Annotated Plot Configuration
@@ -884,7 +864,8 @@ generate_rank_plot_configs <- function(df, variables, is_lm = FALSE, adjust = FA
         size = 0.1,
         y_label = y_label,
         x_label = "Rank",
-        legend_position = "none"
+        legend_position = "none",
+        grid_layout = list(ncol = 3, nrow = 2)
       )
     }
   }
@@ -1006,7 +987,8 @@ generate_correlation_plot_configs <- function(df) {
           fill = NA, color = "grey20", alpha = 0.1
         )
       ),
-      cyan_points = TRUE
+      cyan_points = TRUE,
+      grid_layout = list(ncol = 2, nrow = 2)
     )
 
     configs[[length(configs) + 1]] <- config
@@ -1258,9 +1240,9 @@ main <- function() {
     )
 
     # Generating quality control plots in parallel
-    furrr::future_map(plot_configs, function(config) {
-      generate_and_save_plots(config$out_dir, config$filename, config$plot_configs)
-    }, .options = furrr_options(seed = TRUE))
+    # furrr::future_map(plot_configs, function(config) {
+    #   generate_and_save_plots(config$out_dir, config$filename, config$plot_configs)
+    # }, .options = furrr_options(seed = TRUE))
 
     # Process background strains
     bg_strains <- c("YDL227C")
@@ -1345,11 +1327,11 @@ main <- function() {
       # Create interaction plots
       message("Generating reference interaction plots")
       reference_plot_configs <- generate_interaction_plot_configs(zscore_interactions_reference_joined)
-      generate_and_save_plots(out_dir, "interaction_plots_reference", reference_plot_configs, grid_layout = list(ncol = 4, nrow = 3))
+      generate_and_save_plots(out_dir, "interaction_plots_reference", reference_plot_configs)
 
       message("Generating deletion interaction plots")
       deletion_plot_configs <- generate_interaction_plot_configs(zscore_interactions_joined)
-      generate_and_save_plots(out_dir, "interaction_plots", deletion_plot_configs, grid_layout = list(ncol = 4, nrow = 3))
+      generate_and_save_plots(out_dir, "interaction_plots", deletion_plot_configs)
 
       # Define conditions for enhancers and suppressors
       # TODO Add to study config?
@@ -1408,7 +1390,7 @@ main <- function() {
         adjust = TRUE
       )
       generate_and_save_plots(out_dir = out_dir, filename = "rank_plots",
-        plot_configs = rank_plot_configs, grid_layout = list(ncol = 3, nrow = 2))
+        plot_configs = rank_plot_configs)
 
       message("Generating ranked linear model plots")
       rank_lm_plot_configs <- generate_rank_plot_configs(
@@ -1418,7 +1400,7 @@ main <- function() {
         adjust = TRUE
       )
       generate_and_save_plots(out_dir = out_dir, filename = "rank_plots_lm",
-        plot_configs = rank_lm_plot_configs, grid_layout = list(ncol = 3, nrow = 2))
+        plot_configs = rank_lm_plot_configs)
 
       message("Filtering and reranking plots")
       interaction_threshold <- 2 # TODO add to study config?
@@ -1454,8 +1436,7 @@ main <- function() {
       generate_and_save_plots(
         out_dir = out_dir,
         filename = "RankPlots_na_rm",
-        plot_configs = rank_plot_filtered_configs,
-        grid_layout = list(ncol = 3, nrow = 2))
+        plot_configs = rank_plot_filtered_configs)
 
       message("Generating filtered ranked linear model plots")
       rank_plot_lm_filtered_configs <- generate_rank_plot_configs(
@@ -1468,8 +1449,7 @@ main <- function() {
       generate_and_save_plots(
         out_dir = out_dir,
         filename = "rank_plots_lm_na_rm",
-        plot_configs = rank_plot_lm_filtered_configs,
-        grid_layout = list(ncol = 3, nrow = 2))
+        plot_configs = rank_plot_lm_filtered_configs)
 
       message("Generating correlation curve parameter pair plots")
       correlation_plot_configs <- generate_correlation_plot_configs(zscore_interactions_filtered)
@@ -1477,7 +1457,7 @@ main <- function() {
         out_dir = out_dir,
         filename = "correlation_cpps",
         plot_configs = correlation_plot_configs,
-        grid_layout = list(ncol = 2, nrow = 2))
+      )
     })
   })
 }

+ 1 - 1
qhtcp-workflow/qhtcp-workflow

@@ -1260,7 +1260,7 @@ qhtcp() {
   # done
 
   # Run R interactions script on all studies
-  calculate_interaction_zscores \
+  calculate_interaction_zscores; exit \
   && join_interaction_zscores \
   && remc \
   && gtf \