Skip to content

xgbse._kaplan_neighbors.XGBSEKaplanNeighbors

Convert xgboost into a nearest neighbor model, where we use hamming distance to define similar elements as the ones that co-ocurred the most at the ensemble terminal nodes.

Then, at each neighbor-set compute survival estimates with the Kaplan-Meier estimator.

Note

  • We recommend using dart as the booster to prevent any tree to dominate variance in the ensemble and break the leaf co-ocurrence similarity logic.

  • This method can be very expensive at scales of hundreds of thousands of samples, due to the nearest neighbor search, both on training (construction of search index) and scoring (actual search).

Read more in How XGBSE works.

__init__(self, xgb_params=None, n_neighbors=30, radius=None) special

Parameters:

Name Type Description Default
xgb_params Dict

Parameters for XGBoost model. If not passed, the following default parameters will be used:

DEFAULT_PARAMS = {
    "objective": "survival:aft",
    "eval_metric": "aft-nloglik",
    "aft_loss_distribution": "normal",
    "aft_loss_distribution_scale": 1,
    "tree_method": "hist",
    "learning_rate": 5e-2,
    "max_depth": 8,
    "booster": "dart",
    "subsample": 0.5,
    "min_child_weight": 50,
    "colsample_bynode": 0.5,
}

Check https://xgboost.readthedocs.io/en/latest/parameter.html for more options.

None
n_neighbors Int

Number of neighbors for computing KM estimates

30
radius Float

If set, uses a radius around the point for neighbors search

None

fit(self, X, y, num_boost_round=1000, validation_data=None, early_stopping_rounds=None, verbose_eval=0, persist_train=True, index_id=None, time_bins=None)

Transform feature space by fitting a XGBoost model and outputting its leaf indices. Build search index in the new space to allow nearest neighbor queries at scoring time.

Parameters:

Name Type Description Default
X [pd.DataFrame, np.array]

Design matrix to fit XGBoost model

required
y structured array(numpy.bool_, numpy.number

Binary event indicator as first field, and time of event or time of censoring as second field.

required
num_boost_round Int

Number of boosting iterations.

1000
validation_data Tuple

Validation data in the format of a list of tuples [(X, y)] if user desires to use early stopping

None
early_stopping_rounds Int

Activates early stopping. Validation metric needs to improve at least once in every early_stopping_rounds round(s) to continue training. See xgboost.train documentation.

None
verbose_eval [Bool, Int]

Level of verbosity. See xgboost.train documentation.

0
persist_train Bool

Whether or not to persist training data to use explainability through prototypes

True
index_id pd.Index

User defined index if intended to use explainability through prototypes

None
time_bins np.array

Specified time windows to use when making survival predictions

None

Returns:

Type Description
XGBSEKaplanNeighbors

Fitted instance of XGBSEKaplanNeighbors

predict(self, X, time_bins=None, return_ci=False, ci_width=0.683, return_interval_probs=False)

Make queries to nearest neighbor search index build on the transformed XGBoost space. Compute a Kaplan-Meier estimator for each neighbor-set. Predict the KM estimators.

Parameters:

Name Type Description Default
X pd.DataFrame

Dataframe with samples to generate predictions

required
time_bins np.array

Specified time windows to use when making survival predictions

None
return_ci Bool

Whether to return confidence intervals via the Exponential Greenwood formula

False
ci_width Float

Width of confidence interval

0.683
return_interval_probs Bool

Boolean indicating if interval probabilities are supposed to be returned. If False the cumulative survival is returned.

False

Returns:

Type Description
(pd.DataFrame)

A dataframe of survival probabilities for all times (columns), from a time_bins array, for all samples of X (rows). If return_interval_probs is True, the interval probabilities are returned instead of the cumulative survival probabilities.

upper_ci (np.array): Upper confidence interval for the survival probability values

lower_ci (np.array): Lower confidence interval for the survival probability values