Skip to content

xgbse._kaplan_neighbors.XGBSEKaplanNeighbors

Convert xgboost into a nearest neighbor model, where we use hamming distance to define similar elements as the ones that co-ocurred the most at the ensemble terminal nodes.

Then, at each neighbor-set compute survival estimates with the Kaplan-Meier estimator.

Note

  • We recommend using dart as the booster to prevent any tree to dominate variance in the ensemble and break the leaf co-ocurrence similarity logic.

  • This method can be very expensive at scales of hundreds of thousands of samples, due to the nearest neighbor search, both on training (construction of search index) and scoring (actual search).

Read more in How XGBSE works.

Source code in xgbse/_kaplan_neighbors.py
class XGBSEKaplanNeighbors(XGBSEBaseEstimator):
    """
    Convert xgboost into a nearest neighbor model, where we use hamming distance to define
    similar elements as the ones that co-ocurred the most at the ensemble terminal nodes.

    Then, at each neighbor-set compute survival estimates with the Kaplan-Meier estimator.

    !!! Note
        * We recommend using dart as the booster to prevent any tree
        to dominate variance in the ensemble and break the leaf co-ocurrence similarity logic.

        * This method can be very expensive at scales of hundreds of thousands of samples,
        due to the nearest neighbor search, both on training (construction of search index) and scoring (actual search).

    Read more in [How XGBSE works](https://loft-br.github.io/xgboost-survival-embeddings/how_xgbse_works.html).
    """

    def __init__(
        self,
        xgb_params: Optional[Dict[str, Any]] = None,
        n_neighbors: int = 30,
        radius: Optional[float] = None,
        enable_categorical: bool = False,
    ):
        """
        Args:
            xgb_params (Dict, None): Parameters for XGBoost model.
                If None, will use XGBoost defaults and set objective as `survival:aft`.
                Check <https://xgboost.readthedocs.io/en/latest/parameter.html> for options.

            n_neighbors (Int): Number of neighbors for computing KM estimates

            radius (Float): If set, uses a radius around the point for neighbors search

            enable_categorical (bool): Enable categorical feature support on xgboost model
        """

        super().__init__(xgb_params=xgb_params, enable_categorical=enable_categorical)
        self.n_neighbors = n_neighbors
        self.radius = radius
        self.index_id = None

    def fit(
        self,
        X,
        y,
        time_bins: Optional[Sequence] = None,
        validation_data: Optional[List[Tuple[Any, Any]]] = None,
        num_boost_round: int = 10,
        early_stopping_rounds: Optional[int] = None,
        verbose_eval: int = 0,
        persist_train: bool = False,
        index_id=None,
    ):
        """
        Transform feature space by fitting a XGBoost model and outputting its leaf indices.
        Build search index in the new space to allow nearest neighbor queries at scoring time.

        Args:
            X ([pd.DataFrame, np.array]): Features to be used while fitting XGBoost model

            y (structured array(numpy.bool_, numpy.number)): Binary event indicator as first field,
                and time of event or time of censoring as second field.

            time_bins (np.array): Specified time windows to use when making survival predictions

            validation_data (List[Tuple]): Validation data in the format of a list of tuples [(X, y)]
                if user desires to use early stopping

            num_boost_round (Int): Number of boosting iterations.

            early_stopping_rounds (Int): Activates early stopping.
                Validation metric needs to improve at least once
                in every **early_stopping_rounds** round(s) to continue training.
                See xgboost.train documentation.

            verbose_eval ([Bool, Int]): Level of verbosity. See xgboost.train documentation.

            persist_train (Bool): Whether or not to persist training data to use explainability
                through prototypes

            index_id (pd.Index): User defined index if intended to use explainability
                through prototypes


        Returns:
            XGBSEKaplanNeighbors: Fitted instance of XGBSEKaplanNeighbors
        """

        self.fit_feature_extractor(
            X,
            y,
            time_bins=time_bins,
            validation_data=validation_data,
            num_boost_round=num_boost_round,
            early_stopping_rounds=early_stopping_rounds,
            verbose_eval=verbose_eval,
        )

        self.E_train, self.T_train = convert_y(y)

        # creating nearest neighbor index
        leaves = self.feature_extractor.predict_leaves(X)

        self.tree = BallTree(leaves, metric="hamming", leaf_size=40)

        if persist_train:
            self.persist_train = True
            if index_id is None:
                index_id = X.index.copy()
        self.index_id = index_id

        return self

    def predict(
        self,
        X,
        time_bins=None,
        return_ci=False,
        ci_width=0.683,
        return_interval_probs=False,
    ):
        """
        Make queries to nearest neighbor search index build on the transformed XGBoost space.
        Compute a Kaplan-Meier estimator for each neighbor-set. Predict the KM estimators.

        Args:
            X (pd.DataFrame): Dataframe with samples to generate predictions

            time_bins (np.array): Specified time windows to use when making survival predictions

            return_ci (Bool): Whether to return confidence intervals via the Exponential Greenwood formula

            ci_width (Float): Width of confidence interval

            return_interval_probs (Bool): Boolean indicating if interval probabilities are
                supposed to be returned. If False the cumulative survival is returned.


        Returns:
            (pd.DataFrame): A dataframe of survival probabilities
            for all times (columns), from a time_bins array, for all samples of X
            (rows). If return_interval_probs is True, the interval probabilities are returned
            instead of the cumulative survival probabilities.

            upper_ci (np.array): Upper confidence interval for the survival
            probability values

            lower_ci (np.array): Lower confidence interval for the survival
            probability values
        """

        leaves = self.feature_extractor.predict_leaves(X)

        if self.radius:
            assert self.radius >= 0, "Radius must be greater than 0"

            neighs, _ = self.tree.query_radius(
                leaves, r=self.radius, return_distance=True
            )

            number_of_neighbors = np.array([len(neigh) for neigh in neighs])

            if np.argwhere(number_of_neighbors == 1).shape[0] > 0:
                # If there is at least one sample without neighbors apart from itself
                # a warning is raised suggesting a radius increase
                warnings.warn(
                    "Warning: Some samples don't have neighbors apart from itself. Increase the radius",
                    RuntimeWarning,
                )
        else:
            _, neighs = self.tree.query(leaves, k=self.n_neighbors)

        # gathering times and events/censors for neighbor sets
        T_neighs = self.T_train[neighs]
        E_neighs = self.E_train[neighs]

        # vectorized (very fast!) implementation of Kaplan Meier curves
        if time_bins is None:
            time_bins = self.time_bins

        # calculating z-score from width
        z = st.norm.ppf(0.5 + ci_width / 2)

        preds_df, upper_ci, lower_ci = calculate_kaplan_vectorized(
            T_neighs, E_neighs, time_bins, z
        )

        if return_ci and return_interval_probs:
            raise ValueError(
                "Confidence intervals for interval probabilities is not supported. Choose between return_ci and return_interval_probs."
            )

        if return_interval_probs:
            preds_df = calculate_interval_failures(preds_df)
            return preds_df

        if return_ci:
            return preds_df, upper_ci, lower_ci

        return preds_df

__init__(self, xgb_params=None, n_neighbors=30, radius=None, enable_categorical=False) special

Parameters:

Name Type Description Default
xgb_params Dict, None

Parameters for XGBoost model. If None, will use XGBoost defaults and set objective as survival:aft. Check https://xgboost.readthedocs.io/en/latest/parameter.html for options.

None
n_neighbors Int

Number of neighbors for computing KM estimates

30
radius Float

If set, uses a radius around the point for neighbors search

None
enable_categorical bool

Enable categorical feature support on xgboost model

False
Source code in xgbse/_kaplan_neighbors.py
def __init__(
    self,
    xgb_params: Optional[Dict[str, Any]] = None,
    n_neighbors: int = 30,
    radius: Optional[float] = None,
    enable_categorical: bool = False,
):
    """
    Args:
        xgb_params (Dict, None): Parameters for XGBoost model.
            If None, will use XGBoost defaults and set objective as `survival:aft`.
            Check <https://xgboost.readthedocs.io/en/latest/parameter.html> for options.

        n_neighbors (Int): Number of neighbors for computing KM estimates

        radius (Float): If set, uses a radius around the point for neighbors search

        enable_categorical (bool): Enable categorical feature support on xgboost model
    """

    super().__init__(xgb_params=xgb_params, enable_categorical=enable_categorical)
    self.n_neighbors = n_neighbors
    self.radius = radius
    self.index_id = None

fit(self, X, y, time_bins=None, validation_data=None, num_boost_round=10, early_stopping_rounds=None, verbose_eval=0, persist_train=False, index_id=None)

Transform feature space by fitting a XGBoost model and outputting its leaf indices. Build search index in the new space to allow nearest neighbor queries at scoring time.

Parameters:

Name Type Description Default
X [pd.DataFrame, np.array]

Features to be used while fitting XGBoost model

required
y structured array(numpy.bool_, numpy.number

Binary event indicator as first field, and time of event or time of censoring as second field.

required
time_bins np.array

Specified time windows to use when making survival predictions

None
validation_data List[Tuple]

Validation data in the format of a list of tuples [(X, y)] if user desires to use early stopping

None
num_boost_round Int

Number of boosting iterations.

10
early_stopping_rounds Int

Activates early stopping. Validation metric needs to improve at least once in every early_stopping_rounds round(s) to continue training. See xgboost.train documentation.

None
verbose_eval [Bool, Int]

Level of verbosity. See xgboost.train documentation.

0
persist_train Bool

Whether or not to persist training data to use explainability through prototypes

False
index_id pd.Index

User defined index if intended to use explainability through prototypes

None

Returns:

Type Description
XGBSEKaplanNeighbors

Fitted instance of XGBSEKaplanNeighbors

Source code in xgbse/_kaplan_neighbors.py
def fit(
    self,
    X,
    y,
    time_bins: Optional[Sequence] = None,
    validation_data: Optional[List[Tuple[Any, Any]]] = None,
    num_boost_round: int = 10,
    early_stopping_rounds: Optional[int] = None,
    verbose_eval: int = 0,
    persist_train: bool = False,
    index_id=None,
):
    """
    Transform feature space by fitting a XGBoost model and outputting its leaf indices.
    Build search index in the new space to allow nearest neighbor queries at scoring time.

    Args:
        X ([pd.DataFrame, np.array]): Features to be used while fitting XGBoost model

        y (structured array(numpy.bool_, numpy.number)): Binary event indicator as first field,
            and time of event or time of censoring as second field.

        time_bins (np.array): Specified time windows to use when making survival predictions

        validation_data (List[Tuple]): Validation data in the format of a list of tuples [(X, y)]
            if user desires to use early stopping

        num_boost_round (Int): Number of boosting iterations.

        early_stopping_rounds (Int): Activates early stopping.
            Validation metric needs to improve at least once
            in every **early_stopping_rounds** round(s) to continue training.
            See xgboost.train documentation.

        verbose_eval ([Bool, Int]): Level of verbosity. See xgboost.train documentation.

        persist_train (Bool): Whether or not to persist training data to use explainability
            through prototypes

        index_id (pd.Index): User defined index if intended to use explainability
            through prototypes


    Returns:
        XGBSEKaplanNeighbors: Fitted instance of XGBSEKaplanNeighbors
    """

    self.fit_feature_extractor(
        X,
        y,
        time_bins=time_bins,
        validation_data=validation_data,
        num_boost_round=num_boost_round,
        early_stopping_rounds=early_stopping_rounds,
        verbose_eval=verbose_eval,
    )

    self.E_train, self.T_train = convert_y(y)

    # creating nearest neighbor index
    leaves = self.feature_extractor.predict_leaves(X)

    self.tree = BallTree(leaves, metric="hamming", leaf_size=40)

    if persist_train:
        self.persist_train = True
        if index_id is None:
            index_id = X.index.copy()
    self.index_id = index_id

    return self

predict(self, X, time_bins=None, return_ci=False, ci_width=0.683, return_interval_probs=False)

Make queries to nearest neighbor search index build on the transformed XGBoost space. Compute a Kaplan-Meier estimator for each neighbor-set. Predict the KM estimators.

Parameters:

Name Type Description Default
X pd.DataFrame

Dataframe with samples to generate predictions

required
time_bins np.array

Specified time windows to use when making survival predictions

None
return_ci Bool

Whether to return confidence intervals via the Exponential Greenwood formula

False
ci_width Float

Width of confidence interval

0.683
return_interval_probs Bool

Boolean indicating if interval probabilities are supposed to be returned. If False the cumulative survival is returned.

False

Returns:

Type Description
(pd.DataFrame)

A dataframe of survival probabilities for all times (columns), from a time_bins array, for all samples of X (rows). If return_interval_probs is True, the interval probabilities are returned instead of the cumulative survival probabilities.

upper_ci (np.array): Upper confidence interval for the survival probability values

lower_ci (np.array): Lower confidence interval for the survival probability values

Source code in xgbse/_kaplan_neighbors.py
def predict(
    self,
    X,
    time_bins=None,
    return_ci=False,
    ci_width=0.683,
    return_interval_probs=False,
):
    """
    Make queries to nearest neighbor search index build on the transformed XGBoost space.
    Compute a Kaplan-Meier estimator for each neighbor-set. Predict the KM estimators.

    Args:
        X (pd.DataFrame): Dataframe with samples to generate predictions

        time_bins (np.array): Specified time windows to use when making survival predictions

        return_ci (Bool): Whether to return confidence intervals via the Exponential Greenwood formula

        ci_width (Float): Width of confidence interval

        return_interval_probs (Bool): Boolean indicating if interval probabilities are
            supposed to be returned. If False the cumulative survival is returned.


    Returns:
        (pd.DataFrame): A dataframe of survival probabilities
        for all times (columns), from a time_bins array, for all samples of X
        (rows). If return_interval_probs is True, the interval probabilities are returned
        instead of the cumulative survival probabilities.

        upper_ci (np.array): Upper confidence interval for the survival
        probability values

        lower_ci (np.array): Lower confidence interval for the survival
        probability values
    """

    leaves = self.feature_extractor.predict_leaves(X)

    if self.radius:
        assert self.radius >= 0, "Radius must be greater than 0"

        neighs, _ = self.tree.query_radius(
            leaves, r=self.radius, return_distance=True
        )

        number_of_neighbors = np.array([len(neigh) for neigh in neighs])

        if np.argwhere(number_of_neighbors == 1).shape[0] > 0:
            # If there is at least one sample without neighbors apart from itself
            # a warning is raised suggesting a radius increase
            warnings.warn(
                "Warning: Some samples don't have neighbors apart from itself. Increase the radius",
                RuntimeWarning,
            )
    else:
        _, neighs = self.tree.query(leaves, k=self.n_neighbors)

    # gathering times and events/censors for neighbor sets
    T_neighs = self.T_train[neighs]
    E_neighs = self.E_train[neighs]

    # vectorized (very fast!) implementation of Kaplan Meier curves
    if time_bins is None:
        time_bins = self.time_bins

    # calculating z-score from width
    z = st.norm.ppf(0.5 + ci_width / 2)

    preds_df, upper_ci, lower_ci = calculate_kaplan_vectorized(
        T_neighs, E_neighs, time_bins, z
    )

    if return_ci and return_interval_probs:
        raise ValueError(
            "Confidence intervals for interval probabilities is not supported. Choose between return_ci and return_interval_probs."
        )

    if return_interval_probs:
        preds_df = calculate_interval_failures(preds_df)
        return preds_df

    if return_ci:
        return preds_df, upper_ci, lower_ci

    return preds_df

set_fit_request(self, *, early_stopping_rounds='$UNCHANGED$', index_id='$UNCHANGED$', num_boost_round='$UNCHANGED$', persist_train='$UNCHANGED$', time_bins='$UNCHANGED$', validation_data='$UNCHANGED$', verbose_eval='$UNCHANGED$')

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

.. versionadded:: 1.3

.. note:: This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

early_stopping_rounds : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for early_stopping_rounds parameter in fit.

index_id : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for index_id parameter in fit.

num_boost_round : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for num_boost_round parameter in fit.

persist_train : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for persist_train parameter in fit.

time_bins : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for time_bins parameter in fit.

validation_data : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for validation_data parameter in fit.

verbose_eval : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for verbose_eval parameter in fit.

Returns

self : object The updated object.

Source code in xgbse/_kaplan_neighbors.py
def func(*args, **kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. "
            f"Accepted arguments are: {set(self.keys)}"
        )

    # This makes it possible to use the decorated method as an unbound method,
    # for instance when monkeypatching.
    # https://github.com/scikit-learn/scikit-learn/issues/28632
    if instance is None:
        _instance = args[0]
        args = args[1:]
    else:
        _instance = instance

    # Replicating python's behavior when positional args are given other than
    # `self`, and `self` is only allowed if this method is unbound.
    if args:
        raise TypeError(
            f"set_{self.name}_request() takes 0 positional argument but"
            f" {len(args)} were given"
        )

    requests = _instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    _instance._metadata_request = requests

    return _instance

set_predict_request(self, *, ci_width='$UNCHANGED$', return_ci='$UNCHANGED$', return_interval_probs='$UNCHANGED$', time_bins='$UNCHANGED$')

Request metadata passed to the predict method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to predict if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to predict.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

.. versionadded:: 1.3

.. note:: This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

ci_width : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for ci_width parameter in predict.

return_ci : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for return_ci parameter in predict.

return_interval_probs : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for return_interval_probs parameter in predict.

time_bins : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for time_bins parameter in predict.

Returns

self : object The updated object.

Source code in xgbse/_kaplan_neighbors.py
def func(*args, **kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. "
            f"Accepted arguments are: {set(self.keys)}"
        )

    # This makes it possible to use the decorated method as an unbound method,
    # for instance when monkeypatching.
    # https://github.com/scikit-learn/scikit-learn/issues/28632
    if instance is None:
        _instance = args[0]
        args = args[1:]
    else:
        _instance = instance

    # Replicating python's behavior when positional args are given other than
    # `self`, and `self` is only allowed if this method is unbound.
    if args:
        raise TypeError(
            f"set_{self.name}_request() takes 0 positional argument but"
            f" {len(args)} were given"
        )

    requests = _instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    _instance._metadata_request = requests

    return _instance