Skip to content


Train a set of logistic regressions on top of the leaf embedding produced by XGBoost, each predicting survival at different user-defined discrete time windows. The classifiers remove individuals as they are censored, with targets that are indicators of surviving at each window.


  • Training and scoring of logistic regression models is efficient, being performed in parallel through joblib, so the model can scale to hundreds of thousands or millions of samples.
  • However, if many windows are used and data is large, training of logistic regression models may become a bottleneck, taking more time than training of the underlying XGBoost model.

Read more in How XGBSE works.

__init__(self, xgb_params=None, lr_params=None, n_jobs=-1) special


Name Type Description Default
xgb_params Dict, None

Parameters for XGBoost model. If not passed, the following default parameters will be used:

    "objective": "survival:aft",
    "eval_metric": "aft-nloglik",
    "aft_loss_distribution": "normal",
    "aft_loss_distribution_scale": 1,
    "tree_method": "hist",
    "learning_rate": 5e-2,
    "max_depth": 8,
    "booster": "dart",
    "subsample": 0.5,
    "min_child_weight": 50,
    "colsample_bynode": 0.5,

Check for more options.

lr_params Dict, None

Parameters for Logistic Regression models. If not passed, the following default parameters will be used:

DEFAULT_PARAMS_LR = {"C": 1e-3, "max_iter": 500}

Check for more options.

n_jobs Int

Number of CPU cores used to fit logistic regressions via joblib.


fit(self, X, y, num_boost_round=1000, validation_data=None, early_stopping_rounds=None, verbose_eval=0, persist_train=False, index_id=None, time_bins=None)

Transform feature space by fitting a XGBoost model and returning its leaf indices. Leaves are transformed and considered as dummy variables to fit multiple logistic regression models to each evaluated time bin.


Name Type Description Default
X [pd.DataFrame, np.array]

Features to be used while fitting XGBoost model

y structured array(numpy.bool_, numpy.number

Binary event indicator as first field, and time of event or time of censoring as second field.

num_boost_round Int

Number of boosting iterations.

validation_data Tuple

Validation data in the format of a list of tuples [(X, y)] if user desires to use early stopping

early_stopping_rounds Int

Activates early stopping. Validation metric needs to improve at least once in every early_stopping_rounds round(s) to continue training. See xgboost.train documentation.

verbose_eval [Bool, Int]

Level of verbosity. See xgboost.train documentation.

persist_train Bool

Whether or not to persist training data to use explainability through prototypes

index_id pd.Index

User defined index if intended to use explainability through prototypes

time_bins np.array

Specified time windows to use when making survival predictions



Type Description

Trained XGBSEDebiasedBCE instance

predict(self, X, return_interval_probs=False)

Predicts survival probabilities using the XGBoost + Logistic Regression pipeline.


Name Type Description Default
X pd.DataFrame

Dataframe of features to be used as input for the XGBoost model.

return_interval_probs Bool

Boolean indicating if interval probabilities are supposed to be returned. If False the cumulative survival is returned. Default is False.



Type Description

A dataframe of survival probabilities for all times (columns), from a time_bins array, for all samples of X (rows). If return_interval_probs is True, the interval probabilities are returned instead of the cumulative survival probabilities.