# Author: Enoch Yi-Tung Chen 
# Website: enochytchen.com
# Title: Estimating survival when individual-participant data can’t be pooled: potential solutions: an example of evaluating all-cause survival for cancer patients with/without fertility preservation
# Purpose: Use Poisson regression with time-splitting to replicate/approximate Cox PH model
# Hypothetical research question:
# "Is there a difference in hazard ratios between undergoing fertility preservation or not 
#  after adjusting for confounders?"

#### Load necessary packages ####

library(survival)
library(dplyr)
library(ggplot2)
library(broom)
library(tidyr)
library(gridExtra)
library(knitr)


#### Simulation Process ####
# Set seed
set.seed(13579)

#### Step 1: Simulate covariates to create confounding ####

# Create the data frame with country assignments
n <- 3000
simdata <- data.frame(country = sample(c("A", "B", "C"), 
                                       n, replace = TRUE, prob = c(1/3, 1/3, 1/3)))

# Add covariates
covariates <- c("agegrp", "fp", "birthplace_europe", "education", "parity",
                "yeardx", "lymphnode", "tumorsize", "chemo")
simdata[, covariates] <- NA

# Get indices for each country to create different covariate distributions
idx_A <- which(simdata$country == "A")
idx_B <- which(simdata$country == "B")
idx_C <- which(simdata$country == "C")
n_A <- length(idx_A) 
n_B <- length(idx_B)
n_C <- length(idx_C)

# Make Country A older than B/younger than C.
# Make Country B younger.
# Make Country C older.

# agegrp (1=young, 2=mid, 3=old)
simdata$agegrp[idx_A] <- sample(1:3, n_A, replace = TRUE, prob = c(0.4, 0.4, 0.2)) # A: Middle population
simdata$agegrp[idx_B] <- sample(1:3, n_B, replace = TRUE, prob = c(0.6, 0.3, 0.1)) # B: Younger population
simdata$agegrp[idx_C] <- sample(1:3, n_C, replace = TRUE, prob = c(0.1, 0.3, 0.6)) # C: Older population

# Make fertility preservation more likely among those getting chemo
# Make fp=1 with more chemo
# Make fp=0 with less chemo

# First assign chemotherapy status
simdata$chemo <- sample(0:1, n, replace = TRUE, prob = c(0.2, 0.8))

# Then assign fertility preservation status based on chemotherapy status
# For patients with chemo, higher probability of fertility preservation
simdata$fp <- rep(NA, n)
# For those not receiving chemo
simdata$fp[simdata$chemo == 0] <- sample(0:1, sum(simdata$chemo == 0), 
                                         replace = TRUE, prob = c(0.8, 0.2))
# For those receiving chemo - higher chance of fertility preservation
simdata$fp[simdata$chemo == 1] <- sample(0:1, sum(simdata$chemo == 1), 
                                         replace = TRUE, prob = c(0.4, 0.6))
# Simply a two-by-two to see whether the results are generated correctly
table(simdata$chemo, simdata$fp)

# Simulate other covariates uniformly
# Just for illustration
simdata$birthplace_europe <- sample(0:1, n, replace = TRUE)
simdata$education <- sample(1:3, n, replace = TRUE)
simdata$parity <- sample(0:2, n, replace = TRUE)
simdata$yeardx <- sample(0:1, n, replace = TRUE)
simdata$lymphnode <- sample(1:3, n, replace = TRUE)
simdata$tumorsize <- sample(0:3, n, replace = TRUE)

# Convert all covariates to factors for modeling
simdata[covariates] <- lapply(simdata[covariates], function(x) as.factor(as.character(x)))

#### Step 2: Generate survival times based on covariates ####

# Define regression coefficients (log-hazard ratios) for the TRUE estimates
beta_agegrp3 <- log(2)    # HR=2.0 for agegrp 3 vs 1
beta_agegrp2 <- log(1.5)  # HR=1.5 for agegrp 2 vs 1
beta_chemo1  <- log(5)    # HR=5.0 for chemo vs no chemo
beta_fp1     <- log(1.2) # HR=1.05 for fp vs no fp

# Create the linear predictor from the covariates
lp <- ifelse(simdata$agegrp == 2, beta_agegrp2, 0) + 
  ifelse(simdata$agegrp == 3, beta_agegrp3, 0) +
  (as.numeric(simdata$chemo) - 1) * beta_chemo1 +
  (as.numeric(simdata$fp) - 1) * beta_fp1

# Use a single baseline hazard for all individuals (Weibull distribution)
# with the proportional hazards assumption.
weibull_shape <- 1.2
weibull_scale <- 80

# Generate true survival times using the inversion method
# Now every individual has survival time according to their linear predictors
U <- runif(n) # Uniform random variables
time_true <- (-log(U) / ( (1/weibull_scale^weibull_shape) * exp(lp) ) )^(1/weibull_shape)

# Apply right-censoring at 10 years
time_censor <- 10
simdata$status <- ifelse(time_true > time_censor, 0, 1)
simdata$time <- pmin(time_true, time_censor)

#### Step 3: Fit Cox Models ####

# Unadjusted Cox model 
coxmod1 <- coxph(Surv(time, status) ~ fp , data = simdata)

# Adjusted Cox model
coxmod2 <- coxph(Surv(time, status) ~  fp + country + agegrp + 
                   birthplace_europe + education + 
                   parity + yeardx + lymphnode + tumorsize + chemo,
                 data = simdata)


#### Step 4: Replicating the results using Poisson GLM ####
# This part is what the site statistician/data handler should do
# to report aggregate-level data

# START HERE
# Create person-time data
max_time <- floor(max(simdata$time))
cut_points <- seq(0, max_time, by = 1)

# Split by country
list_by_country <- split(simdata, simdata$country)

# Process each country's data
df_list <- lapply(list_by_country, function(country_df) {
  df_splityr <- survSplit(Surv(time, status) ~ ., data = country_df,
                          cut = cut_points, end = "time", start = "start",
                          event = "status")
  df_splityr2 <- transform(df_splityr,
                           pt = time - start,
                           fu = as.factor(start))
  df_splityr2 |>
    group_by(fu, country, agegrp, fp, birthplace_europe, education,
             parity, yeardx, lymphnode, tumorsize, chemo) |>
    summarise(pt = sum(pt), status = sum(status), .groups = 'drop')
})

# END HERE

# Then the main analyst combines the datasets from each country
# Combine into one aggregate dataset
simdata_ABC <- dplyr::bind_rows(df_list)

# Unadjusted GLM
glm1 <- glm(status ~ fu + fp + offset(log(pt)),
            family = poisson, data = simdata_ABC)

# Adjusted GLM
glm2 <- glm(status ~ fu + country + agegrp + fp + birthplace_europe + education + 
              parity + yeardx + lymphnode + tumorsize + chemo + offset(log(pt)),
            family = poisson, data = simdata_ABC)


#### Step 5: Plots ####

# Ensure factor levels are consistent
simdata$country <- factor(as.character(simdata$country), levels = c("A", "B", "C"))
simdata_ABC$country <- factor(as.character(simdata_ABC$country), levels = c("A", "B", "C"))
simdata$fp <- factor(as.character(simdata$fp), levels = c("0", "1"))
simdata_ABC$fp <- factor(as.character(simdata_ABC$fp), levels = c("0", "1"))

# Create average profile for predictions
avg_profile <- data.frame(
  country = factor("A", levels = levels(simdata_ABC$country)), # Average profile for country A
  agegrp = factor("1", levels = levels(simdata_ABC$agegrp)),
  birthplace_europe = factor("0", levels = levels(simdata_ABC$birthplace_europe)),
  education = factor("2", levels = levels(simdata_ABC$education)),
  parity = factor("1", levels = levels(simdata_ABC$parity)),
  yeardx = factor("0", levels = levels(simdata_ABC$yeardx)),
  lymphnode = factor("2", levels = levels(simdata_ABC$lymphnode)),
  tumorsize = factor("1", levels = levels(simdata_ABC$tumorsize)),
  chemo = factor("0", levels = levels(simdata_ABC$chemo))
)

# --- Cox Model 1 (Unadjusted) ---
cox1_curves_list <- lapply(c("0", "1"), function(fp_val) {
  newdata_temp <- data.frame(fp = factor(fp_val, levels = c("0", "1")))
  sf <- survfit(coxmod1, newdata = newdata_temp)
  df <- tidy(sf)
  df$fp <- fp_val
  return(df)
})
cox1_curves <- bind_rows(cox1_curves_list)

# --- Cox Model 2 (Adjusted) ---
cox2_curves_list <- lapply(c("0", "1"), function(fp_val) {
  newdata_temp <- avg_profile
  newdata_temp$fp <- factor(fp_val, levels = c("0", "1"))
  sf <- survfit(coxmod2, newdata = newdata_temp)
  df <- tidy(sf)
  df$fp <- fp_val
  return(df)
})
cox2_curves <- bind_rows(cox2_curves_list)

# Combine Cox data for plotting
plot_data_cox1 <- cox1_curves %>% arrange(fp, time)
plot_data_cox2 <- cox2_curves %>% arrange(fp, time)

# --- GLM Models: generate newdata for predictions ---
fu_levels <- sort(unique(as.numeric(as.character(simdata_ABC$fu))))

# Create newdata for GLM predictions
newdata_glm <- expand_grid(
  fp = factor(c("0", "1"), levels = levels(simdata_ABC$fp)),
  fu = fu_levels
) %>%
  mutate(
    country = avg_profile$country[1],
    agegrp = avg_profile$agegrp[1],
    birthplace_europe = avg_profile$birthplace_europe[1],
    education = avg_profile$education[1],
    parity = avg_profile$parity[1],
    yeardx = avg_profile$yeardx[1],
    lymphnode = avg_profile$lymphnode[1],
    tumorsize = avg_profile$tumorsize[1],
    chemo = avg_profile$chemo[1],
    pt = 1,
    fu = factor(fu, levels = as.character(fu_levels))
  )

# Function to calculate survival from GLM predictions
calculate_glm_surv <- function(newdata, predictions, fp_col = "fp") {
  df <- newdata %>%
    mutate(
      pred = predictions,
      hazard_rate = pred / pt,
      time = as.numeric(as.character(fu))
    )
  
  df <- df %>%
    arrange(fp, time) %>%
    group_by(fp) %>%
    mutate(survival = cumprod(exp(-hazard_rate))) %>%
    ungroup() %>%
    select(fp, time, survival)
  
  # Add baseline (time = 0, survival = 1)
  baseline <- df %>%
    distinct(fp) %>%
    mutate(time = 0, survival = 1)
  
  bind_rows(baseline, df) %>% arrange(fp, time)
}

# Generate GLM predictions and survival curves
pred_glm1 <- predict(glm1, newdata = newdata_glm, type = "response")
plot_data_glm1 <- calculate_glm_surv(newdata_glm, pred_glm1)

pred_glm2 <- predict(glm2, newdata = newdata_glm, type = "response")
plot_data_glm2 <- calculate_glm_surv(newdata_glm, pred_glm2)


#### Step 6: Create 2x2 Panel Plot for FP=0 vs FP=1 ####

color_palette_fp <- c("0" = "#0072B2", "1" = "#D55E00")
fp_labels <- c("0" = "No Fertility Preservation", "1" = "Fertility Preservation")

# Plot 1: Unadjusted Cox - FP
p1 <- ggplot(plot_data_cox1, 
             aes(x = time, y = estimate, color = fp)) +
  geom_step(linewidth = 1.1) +
  scale_color_manual(values = color_palette_fp, labels = fp_labels) +
  labs(title = "Unadjusted Cox Model", 
       x = "Time (Years)", 
       y = "Survival Probability",
       color = "Treatment") +
  ylim(0, 1) + 
  theme_minimal() + 
  theme(legend.position = "none") +
  theme(plot.title = element_text(hjust = 0.5))


# Plot 2: Adjusted Cox - FP
p2 <- ggplot(plot_data_cox2, 
             aes(x = time, y = estimate, color = fp)) +
  geom_step(linewidth = 1.1) +
  scale_color_manual(values = color_palette_fp, labels = fp_labels) +
  labs(title = "Adjusted Cox Model", 
       x = "Time (Years)", 
       y = "",
       color = "Treatment") +
  ylim(0, 1) + 
  theme_minimal() + 
  theme(legend.title = element_blank()) +
  guides(color = guide_legend(nrow = 1)) +
  theme(plot.title = element_text(hjust = 0.5))


# Plot 3: Unadjusted GLM - FP
p3 <- ggplot(plot_data_glm1, 
             aes(x = time, y = survival, color = fp)) +
  geom_step(linewidth = 1.1) +
  scale_color_manual(values = color_palette_fp, labels = fp_labels) +
  labs(title = "Unadjusted Poisson GLM", 
       x = "Time (Years)", 
       y = "Survival Probability",
       color = "Treatment") +
  ylim(0, 1) + 
  theme_minimal() + 
  theme(legend.position = "none") +
  theme(plot.title = element_text(hjust = 0.5))

# Plot 4: Adjusted GLM - FP
p4 <- ggplot(plot_data_glm2, 
             aes(x = time, y = survival, color = fp)) +
  geom_step(linewidth = 1.1) +
  scale_color_manual(values = color_palette_fp, labels = fp_labels) +
  labs(title = "Adjusted Poisson GLM", 
       x = "Time (Years)", 
       y = "",
       color = "Treatment") +
  ylim(0, 1) + 
  theme_minimal() + 
  theme(legend.position = "none") +
  theme(plot.title = element_text(hjust = 0.5))


# Extract legend function
get_legend <- function(plot) {
  tmp <- ggplot_gtable(ggplot_build(plot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  legend <- tmp$grobs[[leg]]
  return(legend)
}

# Extract the legend from p2
legend <- get_legend(p2)

# Remove legend from p2
p2 <- p2 + theme(legend.position = "none")

# Arrange 2x2 grid with legend at bottom
grid.arrange(
  arrangeGrob(p1, p2, p3, p4, ncol = 2),
  legend,
  nrow = 2,
  heights = c(10, 1)
)

#### Step 7 : Create coefficient table to compare hazard ratios ####

# Extract coefficients and exponentiate
coef_cox1 <- tidy(coxmod1, exponentiate = TRUE) %>%
  select(term, estimate) %>%
  rename(cox_unadj = estimate)

coef_cox2 <- tidy(coxmod2, exponentiate = TRUE) %>%
  select(term, estimate) %>%
  rename(cox_adj = estimate)

coef_glm1 <- tidy(glm1, exponentiate = TRUE) %>%
  filter(!grepl("^fu", term), term != "(Intercept)") %>%
  select(term, estimate) %>%
  rename(glm_unadj = estimate)

coef_glm2 <- tidy(glm2, exponentiate = TRUE) %>%
  filter(!grepl("^fu", term), term != "(Intercept)") %>%
  select(term, estimate) %>%
  rename(glm_adj = estimate)

# Combine all coefficients
coef_table <- full_join(coef_cox1, coef_cox2, by = "term") %>%
  full_join(coef_glm1, by = "term") %>%
  full_join(coef_glm2, by = "term") %>%
  filter(!is.na(term)) %>%
  select(term, cox_unadj, glm_unadj, cox_adj, glm_adj)

# Print the table
print(kable(coef_table, digits = 4, 
            caption = "Comparing HRs: Cox vs Poisson GLM"))


############################
# Author: Enoch Yi-Tung Chen

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.