Vecotrize plot generation

This commit is contained in:
2024-09-10 20:50:56 -04:00
parent 97de737ed5
commit d8eb162fb5

View File

@@ -155,9 +155,8 @@ update_gene_names <- function(df, sgd_gene_list) {
# Calculate summary statistics for all variables # Calculate summary statistics for all variables
calculate_summary_stats <- function(df, variables, group_vars = c("conc_num", "conc_num_factor")) { calculate_summary_stats <- function(df, variables, group_vars = c("conc_num", "conc_num_factor")) {
df <- df %>%
mutate(across(all_of(variables), ~ ifelse(. == 0, NA, .)))
# Summarize the variables within the grouped data
summary_stats <- df %>% summary_stats <- df %>%
group_by(across(all_of(group_vars))) %>% group_by(across(all_of(group_vars))) %>%
summarise( summarise(
@@ -165,16 +164,18 @@ calculate_summary_stats <- function(df, variables, group_vars = c("conc_num", "c
across(all_of(variables), list( across(all_of(variables), list(
mean = ~mean(., na.rm = TRUE), mean = ~mean(., na.rm = TRUE),
median = ~median(., na.rm = TRUE), median = ~median(., na.rm = TRUE),
max = ~max(., na.rm = TRUE), max = ~ ifelse(all(is.na(.)), NA, max(., na.rm = TRUE)),
min = ~min(., na.rm = TRUE), min = ~ ifelse(all(is.na(.)), NA, min(., na.rm = TRUE)),
sd = ~sd(., na.rm = TRUE), sd = ~sd(., na.rm = TRUE),
se = ~sd(., na.rm = TRUE) / sqrt(sum(!is.na(.)) - 1) se = ~ ifelse(all(is.na(.)), NA, sd(., na.rm = TRUE) / sqrt(sum(!is.na(.)) - 1))
), .names = "{.fn}_{.col}") ), .names = "{.fn}_{.col}")
) )
# Prevent .x and .y suffix issues by renaming columns
df_cleaned <- df %>% df_cleaned <- df %>%
select(-any_of(names(summary_stats))) select(-any_of(setdiff(names(summary_stats), group_vars))) # Avoid duplicate columns in the final join
# Join the stats back to the original dataframe
df_with_stats <- left_join(df_cleaned, summary_stats, by = group_vars) df_with_stats <- left_join(df_cleaned, summary_stats, by = group_vars)
return(list(summary_stats = summary_stats, df_with_stats = df_with_stats)) return(list(summary_stats = summary_stats, df_with_stats = df_with_stats))
@@ -336,45 +337,57 @@ calculate_interaction_scores <- function(df, max_conc, variables, group_vars = c
generate_and_save_plots <- function(output_dir, file_name, plot_configs, grid_layout = NULL) { generate_and_save_plots <- function(output_dir, file_name, plot_configs, grid_layout = NULL) {
`%||%` <- function(a, b) if (!is.null(a)) a else b # Helper function for plot type logic
apply_plot_type <- function(plot, config) {
plots <- lapply(plot_configs, function(config) { switch(config$plot_type,
df <- config$df "rank" = {
plot <- ggplot(df, aes(x = !!sym(config$x_var), y = !!sym(config$y_var), color = as.factor(!!sym(config$color_var)))) plot <- plot + geom_point(size = 0.1, shape = 3)
if (!is.null(config$sd_band)) {
# Handle plot types like "rank", "correlation", and default scatter/box/density for (i in seq_len(config$sd_band)) {
if (config$plot_type == "rank") { plot <- plot +
plot <- plot + geom_point(size = 0.1, shape = 3) annotate("rect", xmin = -Inf, xmax = Inf, ymin = i, ymax = Inf, fill = "#542788", alpha = 0.3) +
if (!is.null(config$sd_band)) { annotate("rect", xmin = -Inf, xmax = Inf, ymin = -i, ymax = -Inf, fill = "orange", alpha = 0.3) +
for (i in seq_len(config$sd_band)) { geom_hline(yintercept = c(-i, i), color = "gray")
plot <- plot + }
annotate("rect", xmin = -Inf, xmax = Inf, ymin = i, ymax = Inf, fill = "#542788", alpha = 0.3) +
annotate("rect", xmin = -Inf, xmax = Inf, ymin = -i, ymax = -Inf, fill = "orange", alpha = 0.3) +
geom_hline(yintercept = c(-i, i), color = "gray")
} }
} plot
if (!is.null(config$enhancer_label)) { },
plot <- plot + annotate("text", x = config$enhancer_label$x, y = config$enhancer_label$y, label = config$enhancer_label$label) + "correlation" = {
annotate("text", x = config$suppressor_label$x, y = config$suppressor_label$y, label = config$suppressor_label$label) plot + geom_point(shape = 3, color = "gray70") + geom_smooth(method = "lm", color = "tomato3") +
} annotate("text", x = 0, y = 0, label = config$correlation_text)
} else if (config$plot_type == "correlation") { },
plot <- plot + geom_point(shape = 3, color = "gray70") + geom_smooth(method = "lm", color = "tomato3") + "box" = plot + geom_boxplot(),
annotate("text", x = 0, y = 0, label = config$correlation_text) "density" = plot + geom_density(),
} else { "bar" = plot + geom_bar(stat = "identity"),
plot <- plot + aes(y = !!sym(config$y_var)) + plot + geom_point(shape = 3) + geom_smooth(method = "lm", se = FALSE) # Default scatter plot
if (config$plot_type == "box") geom_boxplot() else )
if (config$plot_type == "density") geom_density() else }
if (config$plot_type == "bar") geom_bar(stat = "identity") else geom_point(shape = 3) + geom_smooth(method = "lm", se = FALSE)
}
# Add error bars for "delta_bg" or general cases # Helper function for error bars
if (config$error_bar %||% FALSE) { apply_error_bars <- function(plot, config) {
if (!is.null(config$error_bar) && config$error_bar) {
y_mean_col <- paste0("mean_", config$y_var) y_mean_col <- paste0("mean_", config$y_var)
y_sd_col <- paste0("sd_", config$y_var) y_sd_col <- paste0("sd_", config$y_var)
plot <- plot + geom_errorbar(aes(ymin = !!sym(y_mean_col) - !!sym(y_sd_col), plot <- plot + geom_errorbar(aes(ymin = !!sym(y_mean_col) - !!sym(y_sd_col),
ymax = !!sym(y_mean_col) + !!sym(y_sd_col)), width = 0.1) + ymax = !!sym(y_mean_col) + !!sym(y_sd_col)), width = 0.1) +
geom_point(aes(y = !!sym(y_mean_col)), size = 0.6) geom_point(aes(y = !!sym(y_mean_col)), size = 0.6)
} }
plot
}
# Helper function for annotations
apply_annotations <- function(plot, config) {
if (!is.null(config$annotations)) {
plot <- plot + geom_text(aes(x = config$annotations$x, y = config$annotations$y, label = config$annotations$label))
}
plot
}
# Generate each plot
plots <- lapply(plot_configs, function(config) {
plot <- ggplot(config$df, aes(x = !!sym(config$x_var), y = !!sym(config$y_var), color = as.factor(!!sym(config$color_var))))
plot <- apply_plot_type(plot, config)
plot <- apply_error_bars(plot, config)
# Apply y-limits if provided # Apply y-limits if provided
if (!is.null(config$ylim_vals)) { if (!is.null(config$ylim_vals)) {
@@ -382,27 +395,25 @@ generate_and_save_plots <- function(output_dir, file_name, plot_configs, grid_la
} }
# Apply labels, titles, and legends # Apply labels, titles, and legends
plot <- plot + ggtitle(config$title) + theme_publication(legend_position = config$legend_position %||% "bottom") + plot <- plot + ggtitle(config$title) +
if (!is.null(config$x_label)) xlab(config$x_label) else NULL + theme_publication(legend_position = if (!is.null(config$legend_position)) config$legend_position else "bottom") +
if (!is.null(config$y_label)) ylab(config$y_label) else NULL xlab(config$x_label %||% "") + ylab(config$y_label %||% "")
# Add annotations if available plot <- apply_annotations(plot, config)
if (!is.null(config$annotations)) {
plot <- plot + geom_text(aes(x = config$annotations$x, y = config$annotations$y, label = config$annotations$label))
}
return(plot) return(plot)
}) })
# Save the plots # Save plots to PDF
pdf(file.path(output_dir, paste0(file_name, ".pdf")), width = 14, height = 9) pdf(file.path(output_dir, paste0(file_name, ".pdf")), width = 14, height = 9)
lapply(plots, print) lapply(plots, print)
dev.off() dev.off()
# Generate Plotly versions for interactive HTML
plotly_plots <- lapply(plots, function(plot) suppressWarnings(ggplotly(plot) %>% layout(legend = list(orientation = "h")))) plotly_plots <- lapply(plots, function(plot) suppressWarnings(ggplotly(plot) %>% layout(legend = list(orientation = "h"))))
# Handle grid layout # Handle grid layout
combined_plot <- subplot(plotly_plots, nrows = grid_layout$nrow %||% length(plots), margin = 0.05) combined_plot <- subplot(plotly_plots, nrows = if (!is.null(grid_layout)) grid_layout$nrow else length(plots), margin = 0.05)
saveWidget(combined_plot, file = file.path(output_dir, paste0(file_name, ".html")), selfcontained = TRUE) saveWidget(combined_plot, file = file.path(output_dir, paste0(file_name, ".html")), selfcontained = TRUE)
} }
@@ -727,7 +738,7 @@ main <- function() {
} }
l_outside_2sd_k_plots <- list( l_outside_2sd_k_plots <- list(
list(df = X_outside_2SD_K, x_var = "l", y_var = "K", plot_type = "scatter", list(df = df_na_l_outside_2sd_k_stats, x_var = "l", y_var = "K", plot_type = "scatter",
title = "Raw L vs K for strains falling outside 2SD of the K mean at each Conc", title = "Raw L vs K for strains falling outside 2SD of the K mean at each Conc",
color_var = "conc_num", color_var = "conc_num",
legend_position = "right" legend_position = "right"
@@ -735,7 +746,7 @@ main <- function() {
) )
delta_bg_outside_2sd_k_plots <- list( delta_bg_outside_2sd_k_plots <- list(
list(df = X_outside_2SD_K, x_var = "delta_bg", y_var = "K", plot_type = "scatter", list(df = df_na_l_outside_2sd_k_stats, x_var = "delta_bg", y_var = "K", plot_type = "scatter",
title = "Delta Background vs K for strains falling outside 2SD of the K mean at each Conc", title = "Delta Background vs K for strains falling outside 2SD of the K mean at each Conc",
color_var = "conc_num", color_var = "conc_num",
legend_position = "right" legend_position = "right"