# Author: Enoch Yi-Tung Chen 
# Website: enochytchen.com
# Title: Plotting Survival Curves after fitting Cox Proportional Hazards Models

# Load necessary libraries
library(survival)
library(ggplot2)
library(dplyr)

# Load the colon cancer dataset from the survival package
colon <- na.omit(survival::colon)

# Fit the Cox model
coxph_model <- coxph(Surv(time, status) ~ age + sex, data = colon)

# Approach 1: Predict survival curve at the mean values of covariates ---
# This is the default behavior when newdata is not specified.
# survfit() creates a curve for a single hypothetical person with mean age and mean sex.
fit1 <- survfit(coxph_model)

# Extract the data into a data frame for plotting
dffit1 <- data.frame(time = fit1$time,
                     surv_prob = fit1$surv,
                     method = "Default (at mean covariates)")

# Approach 2: Predict for each individual and then average the curves
# By passing the original data to newdata, we get a survival curve for each person.
fit2 <- survfit(coxph_model, newdata = colon)

# The result 'fit_individual$surv' is a matrix where each column is a person's
# predicted survival curve. We calculate the row-wise mean to get the
# average survival probability at each time point.
avg_surv_prob <- rowMeans(fit2$surv)

# Create a data frame for the marginal survival curve
dffit2 <- data.frame(
  time = fit2$time,
  surv_prob = avg_surv_prob,
  method = "Marginal survival curve (average across individuals)"
)

# Plot
# Combine the two data frames into one for easy plotting with ggplot2
df <- bind_rows(dffit1, dffit2)

# Plot both curves on the same plot
ggplot(df, aes(x = time, y = surv_prob, color = method)) +
  geom_step(linewidth = 1) + # Use geom_step for survival curves
  labs(
    x = "Time (in days)",
    y = "Survival Probability",
    color = "Approach"
  ) +
  theme_minimal() +
  scale_color_manual(values = c("Default (at mean covariates)" = "#870052", 
                                "Marginal survival curve (average across individuals)" = "#4DB5BC")) +
  coord_cartesian(ylim = c(0, 1))

############################
# Author: Enoch Yi-Tung Chen

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.