Standardize plot groups

This commit is contained in:
2024-10-01 18:44:21 -04:00
parent 1e816e1a71
commit 1548dbf9c1

View File

@@ -399,99 +399,94 @@ calculate_interaction_scores <- function(df, max_conc, bg_stats,
generate_and_save_plots <- function(out_dir, filename, plot_configs) { generate_and_save_plots <- function(out_dir, filename, plot_configs) {
message("Generating ", filename, ".pdf and ", filename, ".html") message("Generating ", filename, ".pdf and ", filename, ".html")
static_plots <- list() # Check if we're dealing with multiple plot groups
plotly_plots <- list() plot_groups <- if ("plots" %in% names(plot_configs)) {
list(plot_configs) # Single group
for (i in seq_along(plot_configs$plots)) {
config <- plot_configs$plots[[i]]
df <- config$df
# Define aes_mapping, ensuring y_var is only used when it's not NULL
aes_mapping <- switch(config$plot_type,
"bar" = if (!is.null(config$color_var)) {
aes(x = .data[[config$x_var]], fill = .data[[config$color_var]], color = .data[[config$color_var]])
} else {
aes(x = .data[[config$x_var]])
},
"density" = if (!is.null(config$color_var)) {
aes(x = .data[[config$x_var]], color = .data[[config$color_var]])
} else {
aes(x = .data[[config$x_var]])
},
# For other plot types, only include y_var if it's not NULL
if (!is.null(config$y_var) && !is.null(config$color_var)) {
aes(x = .data[[config$x_var]], y = .data[[config$y_var]], color = .data[[config$color_var]])
} else if (!is.null(config$y_var)) {
aes(x = .data[[config$x_var]], y = .data[[config$y_var]])
} else {
aes(x = .data[[config$x_var]]) # no y_var needed for density and bar plots
}
)
plot <- ggplot(df, aes_mapping) + theme_publication(legend_position = config$legend_position)
# Apply appropriate plot function
plot <- switch(config$plot_type,
"scatter" = generate_scatter_plot(plot, config),
"box" = generate_box_plot(plot, config),
"density" = plot + geom_density(),
"bar" = plot + geom_bar(),
plot # default (unused)
)
# Add titles and labels
if (!is.null(config$title)) plot <- plot + ggtitle(config$title)
if (!is.null(config$x_label)) plot <- plot + xlab(config$x_label)
if (!is.null(config$y_label)) plot <- plot + ylab(config$y_label)
if (!is.null(config$coord_cartesian)) plot <- plot + coord_cartesian(ylim = config$coord_cartesian)
# Convert ggplot to plotly, skipping subplot
if (!is.null(config$tooltip_vars)) {
plotly_plot <- suppressWarnings(plotly::ggplotly(plot, tooltip = config$tooltip_vars))
} else {
plotly_plot <- suppressWarnings(plotly::ggplotly(plot))
}
if (!is.null(plotly_plot[["frames"]])) {
plotly_plot[["frames"]] <- NULL
}
# Adjust legend position in plotly
if (!is.null(config$legend_position) && config$legend_position == "bottom") {
plotly_plot <- plotly_plot %>% layout(legend = list(orientation = "h"))
}
static_plots[[i]] <- plot
plotly_plots[[i]] <- plotly_plot
}
# Save static PDF plots
pdf(file.path(out_dir, paste0(filename, ".pdf")), width = 16, height = 9)
if (is.null(plot_configs$grid_layout)) {
# Print each plot on a new page if grid_layout is not set
for (plot in static_plots) {
print(plot)
}
} else { } else {
# Use grid.arrange if grid_layout is set plot_configs # Multiple groups
grid_nrow <- ifelse(is.null(plot_configs$grid_layout$nrow), length(plot_configs$plots), plot_configs$grid_layout$nrow)
grid_ncol <- ifelse(is.null(plot_configs$grid_layout$ncol), 1, plot_configs$grid_layout$ncol)
grid.arrange(grobs = static_plots, ncol = grid_ncol, nrow = grid_nrow)
} }
dev.off() for (group in plot_groups) {
# Save combined HTML plot static_plots <- list()
out_html_file <- file.path(out_dir, paste0(filename, ".html")) plotly_plots <- list()
message("Saving combined HTML file: ", out_html_file)
htmltools::save_html( grid_layout <- group$grid_layout
htmltools::tagList(plotly_plots), plots <- group$plots
file = out_html_file
) for (i in seq_along(plots)) {
config <- plots[[i]]
df <- config$df
if (config$plot_type == "bar") {
if (!is.null(config$color_var)) {
aes_mapping <- aes(x = .data[[config$x_var]], fill = .data[[config$color_var]], color = .data[[config$color_var]])
} else {
aes_mapping <- aes(x = .data[[config$x_var]])
}
} else if (config$plot_type == "density") {
if (!is.null(config$color_var)) {
aes_mapping <- aes(x = .data[[config$x_var]], color = .data[[config$color_var]])
} else {
aes_mapping <- aes(x = .data[[config$x_var]])
}
} else {
# For other plot types
if (!is.null(config$y_var) && !is.null(config$color_var)) {
aes_mapping <- aes(x = .data[[config$x_var]], y = .data[[config$y_var]], color = .data[[config$color_var]])
} else if (!is.null(config$y_var)) {
aes_mapping <- aes(x = .data[[config$x_var]], y = .data[[config$y_var]])
} else {
aes_mapping <- aes(x = .data[[config$x_var]])
}
}
plot <- ggplot(df, aes_mapping) + theme_publication(legend_position = config$legend_position)
plot <- switch(config$plot_type,
"scatter" = generate_scatter_plot(plot, config),
"box" = generate_boxplot(plot, config),
"density" = plot + geom_density(),
"bar" = plot + geom_bar(),
plot # default (unused)
)
if (!is.null(config$title)) plot <- plot + ggtitle(config$title)
if (!is.null(config$x_label)) plot <- plot + xlab(config$x_label)
if (!is.null(config$y_label)) plot <- plot + ylab(config$y_label)
if (!is.null(config$coord_cartesian)) plot <- plot + coord_cartesian(ylim = config$coord_cartesian)
plotly_plot <- suppressWarnings(plotly::ggplotly(plot))
static_plots[[i]] <- plot
plotly_plots[[i]] <- plotly_plot
}
pdf(file.path(out_dir, paste0(filename, ".pdf")), width = 16, height = 9)
if (is.null(grid_layout)) {
for (plot in static_plots) {
print(plot)
}
} else {
grid.arrange(
grobs = static_plots,
ncol = grid_layout$ncol,
nrow = grid_layout$nrow
)
}
dev.off()
out_html_file <- file.path(out_dir, paste0(filename, ".html"))
message("Saving combined HTML file: ", out_html_file)
htmltools::save_html(
htmltools::tagList(plotly_plots),
file = out_html_file
)
}
} }
generate_scatter_plot <- function(plot, config) { generate_scatter_plot <- function(plot, config) {
# Define the points # Define the points
@@ -582,24 +577,28 @@ generate_scatter_plot <- function(plot, config) {
# Add error bars if specified # Add error bars if specified
if (!is.null(config$error_bar) && config$error_bar && !is.null(config$y_var)) { if (!is.null(config$error_bar) && config$error_bar && !is.null(config$y_var)) {
if (!is.null(config$error_bar_params)) { if (!is.null(config$error_bar_params)) {
# Error bar params are constants, so set them outside aes
plot <- plot + plot <- plot +
geom_errorbar( geom_errorbar(
aes( aes(
ymin = config$error_bar_params$ymin, ymin = !!sym(config$y_var), # y_var mapped to y-axis
ymax = config$error_bar_params$ymax ymax = !!sym(config$y_var)
), ),
ymin = config$error_bar_params$ymin, # Constant values
ymax = config$error_bar_params$ymax, # Constant values
alpha = 0.3, alpha = 0.3,
linewidth = 0.5 linewidth = 0.5
) )
} else { } else {
# Dynamically generate ymin and ymax based on column names
y_mean_col <- paste0("mean_", config$y_var) y_mean_col <- paste0("mean_", config$y_var)
y_sd_col <- paste0("sd_", config$y_var) y_sd_col <- paste0("sd_", config$y_var)
plot <- plot + plot <- plot +
geom_errorbar( geom_errorbar(
aes( aes(
ymin = !!sym(y_mean_col) - !!sym(y_sd_col), ymin = !!sym(y_mean_col) - !!sym(y_sd_col), # Calculating ymin in aes
ymax = !!sym(y_mean_col) + !!sym(y_sd_col) ymax = !!sym(y_mean_col) + !!sym(y_sd_col) # Calculating ymax in aes
), ),
alpha = 0.3, alpha = 0.3,
linewidth = 0.5 linewidth = 0.5
@@ -609,12 +608,21 @@ generate_scatter_plot <- function(plot, config) {
# Customize X-axis if specified # Customize X-axis if specified
if (!is.null(config$x_breaks) && !is.null(config$x_labels) && !is.null(config$x_label)) { if (!is.null(config$x_breaks) && !is.null(config$x_labels) && !is.null(config$x_label)) {
plot <- plot + if (is.factor(df[[config$x_var]]) || is.character(df[[config$x_var]])) {
scale_x_discrete( plot <- plot +
name = config$x_label, scale_x_discrete(
breaks = config$x_breaks, name = config$x_label,
labels = config$x_labels breaks = config$x_breaks,
) labels = config$x_labels
)
} else {
plot <- plot +
scale_x_continuous(
name = config$x_label,
breaks = config$x_breaks,
labels = config$x_labels
)
}
} }
# Set Y-axis limits if specified # Set Y-axis limits if specified
@@ -642,17 +650,27 @@ generate_scatter_plot <- function(plot, config) {
return(plot) return(plot)
} }
generate_box_plot <- function(plot, config) { generate_boxplot <- function(plot, config) {
# Convert x_var to a factor within aes mapping # Convert x_var to a factor within aes mapping
plot <- plot + geom_boxplot(aes(x = factor(.data[[config$x_var]]))) plot <- plot + geom_boxplot(aes(x = factor(.data[[config$x_var]])))
# Apply scale_x_discrete for breaks, labels, and axis label if provided # Customize X-axis if specified
if (!is.null(config$x_breaks) && !is.null(config$x_labels) && !is.null(config$x_label)) { if (!is.null(config$x_breaks) && !is.null(config$x_labels) && !is.null(config$x_label)) {
plot <- plot + scale_x_discrete( if (is.factor(df[[config$x_var]]) || is.character(df[[config$x_var]])) {
name = config$x_label, plot <- plot +
breaks = config$x_breaks, scale_x_discrete(
labels = config$x_labels name = config$x_label,
) breaks = config$x_breaks,
labels = config$x_labels
)
} else {
plot <- plot +
scale_x_continuous(
name = config$x_label,
breaks = config$x_breaks,
labels = config$x_labels
)
}
} }
return(plot) return(plot)
@@ -660,7 +678,7 @@ generate_box_plot <- function(plot, config) {
generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df_after = NULL, generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df_after = NULL,
plot_type = "scatter", stages = c("before", "after")) { plot_type = "scatter", stages = c("before", "after")) {
plots <- list() plot_configs <- list()
for (var in variables) { for (var in variables) {
for (stage in stages) { for (stage in stages) {
@@ -670,7 +688,7 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
df_plot_filtered <- df_plot %>% filter(is.finite(!!sym(var))) df_plot_filtered <- df_plot %>% filter(is.finite(!!sym(var)))
# Adjust settings based on plot_type # Adjust settings based on plot_type
config <- list( plot_config <- list(
df = df_plot_filtered, df = df_plot_filtered,
x_var = "scan", x_var = "scan",
y_var = var, y_var = var,
@@ -683,39 +701,41 @@ generate_plate_analysis_plot_configs <- function(variables, df_before = NULL, df
) )
# Add config to plots list # Add config to plots list
plots <- append(plots, list(config)) plot_configs <- append(plot_configs, list(plot_config))
} }
} }
return(list(grid_layout = list(ncol = 1, nrow = length(plots)), plots = plots)) return(list(plots = plot_configs))
} }
generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type = "reference") { generate_interaction_plot_configs <- function(df, plot_type = "reference") {
if (is.null(limits_map)) { limits_map <- list(
limits_map <- list( L = c(0, 130),
L = c(0, 130), K = c(-20, 160),
K = c(-20, 160), r = c(0, 1),
r = c(0, 1), AUC = c(0, 12500)
AUC = c(0, 12500), )
Delta_L = c(-60, 60),
Delta_K = c(-60, 60), delta_limits_map <- list(
Delta_r = c(-0.6, 0.6), Delta_L = c(-60, 60),
Delta_AUC = c(-6000, 6000) Delta_K = c(-60, 60),
) Delta_r = c(-0.6, 0.6),
} Delta_AUC = c(-6000, 6000)
)
group_vars <- if (plot_type == "reference") c("OrfRep", "Gene", "num") else c("OrfRep", "Gene") group_vars <- if (plot_type == "reference") c("OrfRep", "Gene", "num") else c("OrfRep", "Gene")
df_filtered <- df %>% df_filtered <- df %>%
mutate(OrfRepCombined = if (plot_type == "reference") paste(OrfRep, Gene, num, sep = "_") else paste(OrfRep, Gene, sep = "_")) mutate(OrfRepCombined = if (plot_type == "reference") paste(OrfRep, Gene, num, sep = "_") else paste(OrfRep, Gene, sep = "_"))
# Separate the plots into two groups: overall variables and delta comparisons overall_plot_configs <- list()
overall_plots <- list() delta_plot_configs <- list()
delta_plots <- list()
for (var in c("L", "K", "r", "AUC")) { # Overall plots
for (var in names(limits_map)) {
y_limits <- limits_map[[var]] y_limits <- limits_map[[var]]
config <- list( plot_config <- list(
df = df_filtered, df = df_filtered,
plot_type = "scatter", plot_type = "scatter",
x_var = "conc_num_factor_factor", x_var = "conc_num_factor_factor",
@@ -729,9 +749,10 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
position = "jitter", position = "jitter",
smooth = TRUE smooth = TRUE
) )
overall_plots <- append(overall_plots, list(config)) overall_plot_configs <- append(overall_plot_configs, list(plot_config))
} }
# Delta plots
unique_groups <- df_filtered %>% select(all_of(group_vars)) %>% distinct() unique_groups <- df_filtered %>% select(all_of(group_vars)) %>% distinct()
for (i in seq_len(nrow(unique_groups))) { for (i in seq_len(nrow(unique_groups))) {
@@ -742,17 +763,15 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
Gene <- if ("Gene" %in% names(group)) as.character(group$Gene) else "" Gene <- if ("Gene" %in% names(group)) as.character(group$Gene) else ""
num <- if ("num" %in% names(group)) as.character(group$num) else "" num <- if ("num" %in% names(group)) as.character(group$num) else ""
for (var in c("Delta_L", "Delta_K", "Delta_r", "Delta_AUC")) { for (var in names(delta_limits_map)) {
y_limits <- limits_map[[var]] y_limits <- delta_limits_map[[var]]
y_span <- y_limits[2] - y_limits[1] y_span <- y_limits[2] - y_limits[1]
# Error bars
WT_sd_var <- paste0("WT_sd_", sub("Delta_", "", var)) WT_sd_var <- paste0("WT_sd_", sub("Delta_", "", var))
WT_sd_value <- group_data[[WT_sd_var]][1] WT_sd_value <- group_data[[WT_sd_var]][1]
error_bar_ymin <- 0 - (2 * WT_sd_value) error_bar_ymin <- 0 - (2 * WT_sd_value)
error_bar_ymax <- 0 + (2 * WT_sd_value) error_bar_ymax <- 0 + (2 * WT_sd_value)
# Annotations
Z_Shift_value <- round(group_data[[paste0("Z_Shift_", sub("Delta_", "", var))]][1], 2) Z_Shift_value <- round(group_data[[paste0("Z_Shift_", sub("Delta_", "", var))]][1], 2)
Z_lm_value <- round(group_data[[paste0("Z_lm_", sub("Delta_", "", var))]][1], 2) Z_lm_value <- round(group_data[[paste0("Z_lm_", sub("Delta_", "", var))]][1], 2)
NG_value <- group_data$NG[1] NG_value <- group_data$NG[1]
@@ -767,7 +786,7 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
list(x = 1, y = y_limits[1], label = paste("SM =", SM_value)) list(x = 1, y = y_limits[1], label = paste("SM =", SM_value))
) )
config <- list( plot_config <- list(
df = group_data, df = group_data,
plot_type = "scatter", plot_type = "scatter",
x_var = "conc_num_factor_factor", x_var = "conc_num_factor_factor",
@@ -786,19 +805,21 @@ generate_interaction_plot_configs <- function(df, limits_map = NULL, plot_type =
x_labels = as.character(unique(group_data$conc_num)), x_labels = as.character(unique(group_data$conc_num)),
ylim_vals = y_limits ylim_vals = y_limits
) )
delta_plots <- append(delta_plots, list(config)) delta_plot_configs <- append(delta_plot_configs, list(plot_config))
} }
} }
return(list( return(list(
overall_plots = list(grid_layout = list(ncol = 2, nrow = 2), plots = overall_plots), list(grid_layout = list(ncol = 2, nrow = 2), plots = overall_plot_configs),
delta_plots = list(grid_layout = list(ncol = 4, nrow = 3), plots = delta_plots) list(grid_layout = list(ncol = 4, nrow = 3), plots = delta_plot_configs)
)) ))
} }
generate_rank_plot_configs <- function(df, variables, is_lm = FALSE, adjust = FALSE, overlap_color = FALSE) { generate_rank_plot_configs <- function(df, variables, is_lm = FALSE, adjust = FALSE, overlap_color = FALSE) {
sd_bands <- c(1, 2, 3) sd_bands <- c(1, 2, 3)
configs <- list() plot_configs <- list()
variables <- c("L", "K")
# Adjust (if necessary) and rank columns # Adjust (if necessary) and rank columns
for (variable in variables) { for (variable in variables) {
@@ -863,19 +884,19 @@ generate_rank_plot_configs <- function(df, variables, is_lm = FALSE, adjust = FA
# Loop through SD bands # Loop through SD bands
for (sd_band in sd_bands) { for (sd_band in sd_bands) {
# Create plot with annotations # Create plot with annotations
configs[[length(configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = TRUE) plot_configs[[length(plot_configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = TRUE)
# Create plot without annotations # Create plot without annotations
configs[[length(configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = FALSE) plot_configs[[length(plot_configs) + 1]] <- create_plot_config(variable, rank_var, zscore_var, y_label, sd_band, with_annotations = FALSE)
} }
} }
# Calculate dynamic grid layout based on the number of plots # Calculate dynamic grid layout based on the number of plots
num_plots <- length(configs)
grid_ncol <- 3 grid_ncol <- 3
num_plots <- length(plot_configs)
grid_nrow <- ceiling(num_plots / grid_ncol) # Automatically calculate the number of rows grid_nrow <- ceiling(num_plots / grid_ncol) # Automatically calculate the number of rows
return(list(grid_layout = list(ncol = grid_ncol, nrow = grid_nrow), plots = configs)) return(list(grid_layout = list(ncol = grid_ncol, nrow = grid_nrow), plots = plot_configs))
} }
generate_correlation_plot_configs <- function(df, highlight_cyan = FALSE) { generate_correlation_plot_configs <- function(df, highlight_cyan = FALSE) {
@@ -888,22 +909,23 @@ generate_correlation_plot_configs <- function(df, highlight_cyan = FALSE) {
list(x = "Z_lm_r", y = "Z_lm_AUC", label = "Interaction r vs. Interaction AUC") list(x = "Z_lm_r", y = "Z_lm_AUC", label = "Interaction r vs. Interaction AUC")
) )
plots <- list() plot_configs <- list()
for (rel in relationships) { for (rel in relationships) {
lm_model <- lm(as.formula(paste(rel$y, "~", rel$x)), data = df) lm_model <- lm(as.formula(paste(rel$y, "~", rel$x)), data = df)
r_squared <- summary(lm_model)$r.squared r_squared <- summary(lm_model)$r.squared
config <- list( plot_config <- list(
df = df, df = df,
x_var = rel$x, x_var = rel$x,
y_var = rel$y, y_var = rel$y,
plot_type = "scatter", plot_type = "scatter",
title = rel$label, title = rel$label,
annotations = list( annotations = list(
list(x = mean(df[[rel$x]], na.rm = TRUE), list(
y = mean(df[[rel$y]], na.rm = TRUE), x = mean(df[[rel$x]], na.rm = TRUE),
label = paste("R-squared =", round(r_squared, 3))) y = mean(df[[rel$y]], na.rm = TRUE),
label = paste("R-squared =", round(r_squared, 3)))
), ),
smooth = TRUE, smooth = TRUE,
smooth_color = "tomato3", smooth_color = "tomato3",
@@ -914,10 +936,10 @@ generate_correlation_plot_configs <- function(df, highlight_cyan = FALSE) {
cyan_points = highlight_cyan cyan_points = highlight_cyan
) )
plots <- append(plots, list(config)) plot_configs <- append(plot_configs, list(plot_config))
} }
return(list(grid_layout = list(ncol = 3, nrow = 2), plots = plots)) return(list(plots = plot_configs))
} }
main <- function() { main <- function() {
@@ -1332,7 +1354,6 @@ main <- function() {
message("Generating rank plots") message("Generating rank plots")
rank_plot_configs <- generate_rank_plot_configs( rank_plot_configs <- generate_rank_plot_configs(
df = zscore_interactions_joined, df = zscore_interactions_joined,
variables = interaction_vars,
is_lm = FALSE, is_lm = FALSE,
adjust = TRUE adjust = TRUE
) )
@@ -1342,7 +1363,6 @@ main <- function() {
message("Generating ranked linear model plots") message("Generating ranked linear model plots")
rank_lm_plot_configs <- generate_rank_plot_configs( rank_lm_plot_configs <- generate_rank_plot_configs(
df = zscore_interactions_joined, df = zscore_interactions_joined,
variables = interaction_vars,
is_lm = TRUE, is_lm = TRUE,
adjust = TRUE adjust = TRUE
) )
@@ -1375,7 +1395,6 @@ main <- function() {
message("Generating filtered ranked plots") message("Generating filtered ranked plots")
rank_plot_filtered_configs <- generate_rank_plot_configs( rank_plot_filtered_configs <- generate_rank_plot_configs(
df = zscore_interactions_filtered, df = zscore_interactions_filtered,
variables = interaction_vars,
is_lm = FALSE, is_lm = FALSE,
adjust = FALSE, adjust = FALSE,
overlap_color = TRUE overlap_color = TRUE
@@ -1388,7 +1407,6 @@ main <- function() {
message("Generating filtered ranked linear model plots") message("Generating filtered ranked linear model plots")
rank_plot_lm_filtered_configs <- generate_rank_plot_configs( rank_plot_lm_filtered_configs <- generate_rank_plot_configs(
df = zscore_interactions_filtered, df = zscore_interactions_filtered,
variables = interaction_vars,
is_lm = TRUE, is_lm = TRUE,
adjust = FALSE, adjust = FALSE,
overlap_color = TRUE overlap_color = TRUE