Skip to content

Extrapolation

In this notebook you will find: - How to get a survival curve using xgbse - How to extrapolate your predicted survival curve using the xgbse.extrapolation module

Metrabic

We will be using the Molecular Taxonomy of Breast Cancer International Consortium (METABRIC) dataset from pycox as base for this example.

from xgbse.converters import convert_to_structured
from pycox.datasets import metabric
import numpy as np

# getting data
df = metabric.read_df()

df.head()
x0 x1 x2 x3 x4 x5 x6 x7 x8 duration event
0 5.603834 7.811392 10.797988 5.967607 1.0 1.0 0.0 1.0 56.840000 99.333336 0
1 5.284882 9.581043 10.204620 5.664970 1.0 0.0 0.0 1.0 85.940002 95.733330 1
2 5.920251 6.776564 12.431715 5.873857 0.0 1.0 0.0 1.0 48.439999 140.233337 0
3 6.654017 5.341846 8.646379 5.655888 0.0 0.0 0.0 0.0 66.910004 239.300003 0
4 5.456747 5.339741 10.555724 6.008429 1.0 0.0 0.0 1.0 67.849998 56.933334 1

Split and create Time Bins

Split the data in train and test, using sklearn API. We also setup the TIME_BINS arange, which will be used to fit the survival curve

from xgbse.converters import convert_to_structured
from sklearn.model_selection import train_test_split

# splitting to X, T, E format
X = df.drop(['duration', 'event'], axis=1)
T = df['duration']
E = df['event']
y = convert_to_structured(T, E)

# splitting between train, and validation
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1/3, random_state = 0)
TIME_BINS = np.arange(15, 315, 15)
TIME_BINS
array([ 15,  30,  45,  60,  75,  90, 105, 120, 135, 150, 165, 180, 195,
       210, 225, 240, 255, 270, 285, 300])

Fit model and predict survival curves

The package follows scikit-learn API, with a minor adaptation to work with time and event data. The model outputs the probability of survival, in a pd.Dataframe where columns represent different times.

from xgbse import XGBSEDebiasedBCE

# fitting xgbse model
xgbse_model = XGBSEDebiasedBCE()
xgbse_model.fit(X_train, y_train, time_bins=TIME_BINS)

# predicting
survival = xgbse_model.predict(X_test)
survival.head()
15 30 45 60 75 90 105 120 135 150 165 180 195 210 225 240 255 270 285 300
0 0.983502 0.951852 0.923277 0.900028 0.862270 0.799324 0.715860 0.687257 0.651314 0.610916 0.568001 0.513172 0.493194 0.430701 0.377675 0.310496 0.272169 0.225599 0.184878 0.144089
1 0.973506 0.917739 0.839154 0.710431 0.663119 0.558886 0.495204 0.364995 0.311628 0.299939 0.226226 0.191373 0.171697 0.144864 0.112447 0.089558 0.081137 0.057679 0.048563 0.035985
2 0.986894 0.959209 0.919768 0.889910 0.853239 0.777208 0.725381 0.649177 0.582569 0.531787 0.485275 0.451667 0.428899 0.386413 0.344369 0.279685 0.242064 0.187967 0.158121 0.118562
3 0.986753 0.955210 0.910354 0.857684 0.824301 0.769262 0.665805 0.624934 0.583592 0.537261 0.493957 0.443193 0.416702 0.376552 0.308947 0.237033 0.177140 0.141838 0.117917 0.088937
4 0.977348 0.940368 0.873695 0.804796 0.742655 0.632426 0.556008 0.521490 0.493577 0.458477 0.416363 0.391099 0.364431 0.291472 0.223758 0.190398 0.165911 0.120061 0.095512 0.069566

Survival curves visualization

import matplotlib.pyplot as plt

plt.figure(figsize=(12,4), dpi=120)

plt.plot(
    survival.columns,
    survival.iloc[42],
    'k--',
    label='Survival'
)

plt.title('Sample of predicted survival curves - $P(T>t)$')
plt.legend()
<matplotlib.legend.Legend at 0x7fc38026cef0>

svg

Notice that this predicted survival curve does not end at zero (cure fraction due to censored data). In some cases it might be useful to extrapolate our survival curves using specific strategies. xgbse.extrapolation implements a constant risk extrapolation strategy.

Extrapolation

from xgbse.extrapolation import extrapolate_constant_risk

# extrapolating predicted survival
survival_ext = extrapolate_constant_risk(survival, 450, 15)
survival_ext.head()
15.0 30.0 45.0 60.0 75.0 90.0 105.0 120.0 135.0 150.0 ... 315.0 330.0 345.0 360.0 375.0 390.0 405.0 420.0 435.0 450.0
0 0.983502 0.951852 0.923277 0.900028 0.862270 0.799324 0.715860 0.687257 0.651314 0.610916 ... 0.112299 0.068213 0.032292 0.011915 0.003426 0.000768 0.000134 1.825794e-05 1.937124e-06 1.601799e-07
1 0.973506 0.917739 0.839154 0.710431 0.663119 0.558886 0.495204 0.364995 0.311628 0.299939 ... 0.026665 0.014641 0.005957 0.001796 0.000401 0.000066 0.000008 7.404100e-07 4.986652e-08 2.488634e-09
2 0.986894 0.959209 0.919768 0.889910 0.853239 0.777208 0.725381 0.649177 0.582569 0.531787 ... 0.088900 0.049982 0.021071 0.006660 0.001579 0.000281 0.000037 3.735612e-06 2.798762e-07 1.572266e-08
3 0.986753 0.955210 0.910354 0.857684 0.824301 0.769262 0.665805 0.624934 0.583592 0.537261 ... 0.067080 0.038160 0.016373 0.005299 0.001293 0.000238 0.000033 3.462388e-06 2.734946e-07 1.629408e-08
4 0.977348 0.940368 0.873695 0.804796 0.742655 0.632426 0.556008 0.521490 0.493577 0.458477 ... 0.050668 0.026879 0.010385 0.002923 0.000599 0.000089 0.000010 7.701555e-07 4.442463e-08 1.866412e-09

5 rows × 31 columns

# plotting extrapolation #

plt.figure(figsize=(12,4), dpi=120)

plt.plot(
    survival.columns,
    survival.iloc[42],
    'k--',
    label='Survival'
)

plt.plot(
    survival_ext.columns,
    survival_ext.iloc[42],
    'tomato',
    alpha=0.5,
    label='Extrapolated Survival'
)

plt.title('Extrapolation of survival curves')
plt.legend()
<matplotlib.legend.Legend at 0x7fc3801842b0>

svg