Skip to content

xgbse._kaplan_neighbors.XGBSEKaplanTree

Single tree implementation as a simplification to XGBSEKaplanNeighbors. Instead of doing nearest neighbor searches, fits a single tree via xgboost and calculates KM curves at each of its leaves.

Note

  • It is by far the most efficient implementation, able to scale to millions of examples easily. At fit time, the tree is built and all KM curves are pre-calculated, so that at scoring time a simple query will suffice to get the model's estimates.

Read more in How XGBSE works.

Source code in xgbse/_kaplan_neighbors.py
class XGBSEKaplanTree(XGBSEBaseEstimator):
    """
    Single tree implementation as a simplification to `XGBSEKaplanNeighbors`.
    Instead of doing nearest neighbor searches, fits a single tree via `xgboost`
    and calculates KM curves at each of its leaves.

    !!! Note
        * It is by far the most efficient implementation, able to scale to millions of examples easily.
        At fit time, the tree is built and all KM curves are pre-calculated,
        so that at scoring time a simple query will suffice to get the model's estimates.

    Read more in [How XGBSE works](https://loft-br.github.io/xgboost-survival-embeddings/how_xgbse_works.html).
    """

    def __init__(
        self,
        xgb_params: Optional[Dict[str, Any]] = None,
        enable_categorical: bool = False,
    ):
        """
        Args:
            xgb_params (Dict): Parameters for XGBoost model.
                If not passed, the following default parameters will be used:

                ```
                DEFAULT_PARAMS_TREE = {
                    "objective": "survival:cox",
                    "eval_metric": "cox-nloglik",
                    "tree_method": "hist",
                    "max_depth": 100,
                    "booster": "dart",
                    "subsample": 1.0,
                    "min_child_weight": 30,
                    "colsample_bynode": 1.0,
                }
                ```

                Check <https://xgboost.readthedocs.io/en/latest/parameter.html> for more options.
        """
        if xgb_params is None:
            xgb_params = DEFAULT_PARAMS_TREE

        super().__init__(xgb_params=xgb_params, enable_categorical=enable_categorical)
        self.index_id = None

    def fit(
        self,
        X,
        y,
        persist_train: bool = True,
        index_id=None,
        time_bins: Optional[Sequence] = None,
        ci_width: float = 0.683,
    ):
        """
        Fit a single decision tree using xgboost. For each leaf in the tree,
        build a Kaplan-Meier estimator.

        !!! Note
            * Differently from `XGBSEKaplanNeighbors`, in `XGBSEKaplanTree`,
            the width of the confidence interval (`ci_width`)
            must be specified at fit time.

        Args:

            X ([pd.DataFrame, np.array]): Design matrix to fit XGBoost model

            y (structured array(numpy.bool_, numpy.number)): Binary event indicator as first field,
                and time of event or time of censoring as second field.

            persist_train (Bool): Whether or not to persist training data to use explainability
                through prototypes

            index_id (pd.Index): User defined index if intended to use explainability
                through prototypes

            time_bins (np.array): Specified time windows to use when making survival predictions

            ci_width (Float): Width of confidence interval

        Returns:
            XGBSEKaplanTree: Trained instance of XGBSEKaplanTree
        """

        self.feature_extractor.fit(
            X,
            y,
            time_bins=time_bins,
            num_boost_round=1,
        )
        self.feature_importances_ = self.feature_extractor.feature_importances_

        E_train, T_train = convert_y(y)

        self.time_bins = self.feature_extractor.time_bins
        # getting leaves
        leaves = self.feature_extractor.predict_leaves(X)

        # organizing elements per leaf
        leaf_neighs = (
            pd.DataFrame({"leaf": leaves})
            .groupby("leaf")
            .apply(lambda x: list(x.index))
        )

        # getting T and E for each leaf
        T_leaves = _align_leaf_target(leaf_neighs, T_train)
        E_leaves = _align_leaf_target(leaf_neighs, E_train)

        # calculating z-score from width
        z = st.norm.ppf(0.5 + ci_width / 2)

        # vectorized (very fast!) implementation of Kaplan Meier curves
        (
            self._train_survival,
            self._train_upper_ci,
            self._train_lower_ci,
        ) = calculate_kaplan_vectorized(T_leaves, E_leaves, self.time_bins, z)

        # adding leaf indexes
        self._train_survival = self._train_survival.set_index(leaf_neighs.index)
        self._train_upper_ci = self._train_upper_ci.set_index(leaf_neighs.index)
        self._train_lower_ci = self._train_lower_ci.set_index(leaf_neighs.index)

        if persist_train:
            self.persist_train = True
            if index_id is None:
                index_id = X.index.copy()
            self.tree = BallTree(leaves.reshape(-1, 1), metric="hamming", leaf_size=40)
        self.index_id = index_id

        return self

    def predict(self, X, return_ci=False, return_interval_probs=False):
        """
        Run samples through tree until terminal nodes. Predict the Kaplan-Meier
        estimator associated to the leaf node each sample ended into.

        Args:
            X (pd.DataFrame): Data frame with samples to generate predictions

            return_ci (Bool): Whether to return confidence intervals via the Exponential Greenwood formula

            return_interval_probs (Bool): Boolean indicating if interval probabilities are
                supposed to be returned. If False the cumulative survival is returned.


        Returns:
            preds_df (pd.DataFrame): A dataframe of survival probabilities
                for all times (columns), from a time_bins array, for all samples of X
                (rows). If return_interval_probs is True, the interval probabilities are returned
                instead of the cumulative survival probabilities.

            upper_ci (np.array): Upper confidence interval for the survival
                probability values

            lower_ci (np.array): Lower confidence interval for the survival
                probability values
        """
        # getting leaves and extracting neighbors
        leaves = self.feature_extractor.predict_leaves(X)

        # searching for kaplan meier curves in leaves
        preds_df = self._train_survival.loc[leaves].reset_index(drop=True)
        upper_ci = self._train_upper_ci.loc[leaves].reset_index(drop=True)
        lower_ci = self._train_lower_ci.loc[leaves].reset_index(drop=True)

        if return_ci and return_interval_probs:
            raise ValueError(
                "Confidence intervals for interval probabilities is not supported. Choose between return_ci and return_interval_probs."
            )

        if return_interval_probs:
            preds_df = calculate_interval_failures(preds_df)
            return preds_df

        if return_ci:
            return preds_df, upper_ci, lower_ci
        return preds_df

__init__(self, xgb_params=None, enable_categorical=False) special

Parameters:

Name Type Description Default
xgb_params Dict

Parameters for XGBoost model. If not passed, the following default parameters will be used:

DEFAULT_PARAMS_TREE = {
    "objective": "survival:cox",
    "eval_metric": "cox-nloglik",
    "tree_method": "hist",
    "max_depth": 100,
    "booster": "dart",
    "subsample": 1.0,
    "min_child_weight": 30,
    "colsample_bynode": 1.0,
}

Check https://xgboost.readthedocs.io/en/latest/parameter.html for more options.

None
Source code in xgbse/_kaplan_neighbors.py
def __init__(
    self,
    xgb_params: Optional[Dict[str, Any]] = None,
    enable_categorical: bool = False,
):
    """
    Args:
        xgb_params (Dict): Parameters for XGBoost model.
            If not passed, the following default parameters will be used:

            ```
            DEFAULT_PARAMS_TREE = {
                "objective": "survival:cox",
                "eval_metric": "cox-nloglik",
                "tree_method": "hist",
                "max_depth": 100,
                "booster": "dart",
                "subsample": 1.0,
                "min_child_weight": 30,
                "colsample_bynode": 1.0,
            }
            ```

            Check <https://xgboost.readthedocs.io/en/latest/parameter.html> for more options.
    """
    if xgb_params is None:
        xgb_params = DEFAULT_PARAMS_TREE

    super().__init__(xgb_params=xgb_params, enable_categorical=enable_categorical)
    self.index_id = None

fit(self, X, y, persist_train=True, index_id=None, time_bins=None, ci_width=0.683)

Fit a single decision tree using xgboost. For each leaf in the tree, build a Kaplan-Meier estimator.

Note

  • Differently from XGBSEKaplanNeighbors, in XGBSEKaplanTree, the width of the confidence interval (ci_width) must be specified at fit time.

Parameters:

Name Type Description Default
X [pd.DataFrame, np.array]

Design matrix to fit XGBoost model

required
y structured array(numpy.bool_, numpy.number

Binary event indicator as first field, and time of event or time of censoring as second field.

required
persist_train Bool

Whether or not to persist training data to use explainability through prototypes

True
index_id pd.Index

User defined index if intended to use explainability through prototypes

None
time_bins np.array

Specified time windows to use when making survival predictions

None
ci_width Float

Width of confidence interval

0.683

Returns:

Type Description
XGBSEKaplanTree

Trained instance of XGBSEKaplanTree

Source code in xgbse/_kaplan_neighbors.py
def fit(
    self,
    X,
    y,
    persist_train: bool = True,
    index_id=None,
    time_bins: Optional[Sequence] = None,
    ci_width: float = 0.683,
):
    """
    Fit a single decision tree using xgboost. For each leaf in the tree,
    build a Kaplan-Meier estimator.

    !!! Note
        * Differently from `XGBSEKaplanNeighbors`, in `XGBSEKaplanTree`,
        the width of the confidence interval (`ci_width`)
        must be specified at fit time.

    Args:

        X ([pd.DataFrame, np.array]): Design matrix to fit XGBoost model

        y (structured array(numpy.bool_, numpy.number)): Binary event indicator as first field,
            and time of event or time of censoring as second field.

        persist_train (Bool): Whether or not to persist training data to use explainability
            through prototypes

        index_id (pd.Index): User defined index if intended to use explainability
            through prototypes

        time_bins (np.array): Specified time windows to use when making survival predictions

        ci_width (Float): Width of confidence interval

    Returns:
        XGBSEKaplanTree: Trained instance of XGBSEKaplanTree
    """

    self.feature_extractor.fit(
        X,
        y,
        time_bins=time_bins,
        num_boost_round=1,
    )
    self.feature_importances_ = self.feature_extractor.feature_importances_

    E_train, T_train = convert_y(y)

    self.time_bins = self.feature_extractor.time_bins
    # getting leaves
    leaves = self.feature_extractor.predict_leaves(X)

    # organizing elements per leaf
    leaf_neighs = (
        pd.DataFrame({"leaf": leaves})
        .groupby("leaf")
        .apply(lambda x: list(x.index))
    )

    # getting T and E for each leaf
    T_leaves = _align_leaf_target(leaf_neighs, T_train)
    E_leaves = _align_leaf_target(leaf_neighs, E_train)

    # calculating z-score from width
    z = st.norm.ppf(0.5 + ci_width / 2)

    # vectorized (very fast!) implementation of Kaplan Meier curves
    (
        self._train_survival,
        self._train_upper_ci,
        self._train_lower_ci,
    ) = calculate_kaplan_vectorized(T_leaves, E_leaves, self.time_bins, z)

    # adding leaf indexes
    self._train_survival = self._train_survival.set_index(leaf_neighs.index)
    self._train_upper_ci = self._train_upper_ci.set_index(leaf_neighs.index)
    self._train_lower_ci = self._train_lower_ci.set_index(leaf_neighs.index)

    if persist_train:
        self.persist_train = True
        if index_id is None:
            index_id = X.index.copy()
        self.tree = BallTree(leaves.reshape(-1, 1), metric="hamming", leaf_size=40)
    self.index_id = index_id

    return self

predict(self, X, return_ci=False, return_interval_probs=False)

Run samples through tree until terminal nodes. Predict the Kaplan-Meier estimator associated to the leaf node each sample ended into.

Parameters:

Name Type Description Default
X pd.DataFrame

Data frame with samples to generate predictions

required
return_ci Bool

Whether to return confidence intervals via the Exponential Greenwood formula

False
return_interval_probs Bool

Boolean indicating if interval probabilities are supposed to be returned. If False the cumulative survival is returned.

False

Returns:

Type Description
preds_df (pd.DataFrame)

A dataframe of survival probabilities for all times (columns), from a time_bins array, for all samples of X (rows). If return_interval_probs is True, the interval probabilities are returned instead of the cumulative survival probabilities.

upper_ci (np.array): Upper confidence interval for the survival probability values

lower_ci (np.array): Lower confidence interval for the survival probability values

Source code in xgbse/_kaplan_neighbors.py
def predict(self, X, return_ci=False, return_interval_probs=False):
    """
    Run samples through tree until terminal nodes. Predict the Kaplan-Meier
    estimator associated to the leaf node each sample ended into.

    Args:
        X (pd.DataFrame): Data frame with samples to generate predictions

        return_ci (Bool): Whether to return confidence intervals via the Exponential Greenwood formula

        return_interval_probs (Bool): Boolean indicating if interval probabilities are
            supposed to be returned. If False the cumulative survival is returned.


    Returns:
        preds_df (pd.DataFrame): A dataframe of survival probabilities
            for all times (columns), from a time_bins array, for all samples of X
            (rows). If return_interval_probs is True, the interval probabilities are returned
            instead of the cumulative survival probabilities.

        upper_ci (np.array): Upper confidence interval for the survival
            probability values

        lower_ci (np.array): Lower confidence interval for the survival
            probability values
    """
    # getting leaves and extracting neighbors
    leaves = self.feature_extractor.predict_leaves(X)

    # searching for kaplan meier curves in leaves
    preds_df = self._train_survival.loc[leaves].reset_index(drop=True)
    upper_ci = self._train_upper_ci.loc[leaves].reset_index(drop=True)
    lower_ci = self._train_lower_ci.loc[leaves].reset_index(drop=True)

    if return_ci and return_interval_probs:
        raise ValueError(
            "Confidence intervals for interval probabilities is not supported. Choose between return_ci and return_interval_probs."
        )

    if return_interval_probs:
        preds_df = calculate_interval_failures(preds_df)
        return preds_df

    if return_ci:
        return preds_df, upper_ci, lower_ci
    return preds_df

set_fit_request(self, *, ci_width='$UNCHANGED$', index_id='$UNCHANGED$', persist_train='$UNCHANGED$', time_bins='$UNCHANGED$')

Request metadata passed to the fit method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to fit if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to fit.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

.. versionadded:: 1.3

.. note:: This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

ci_width : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for ci_width parameter in fit.

index_id : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for index_id parameter in fit.

persist_train : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for persist_train parameter in fit.

time_bins : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for time_bins parameter in fit.

Returns

self : object The updated object.

Source code in xgbse/_kaplan_neighbors.py
def func(*args, **kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. "
            f"Accepted arguments are: {set(self.keys)}"
        )

    # This makes it possible to use the decorated method as an unbound method,
    # for instance when monkeypatching.
    # https://github.com/scikit-learn/scikit-learn/issues/28632
    if instance is None:
        _instance = args[0]
        args = args[1:]
    else:
        _instance = instance

    # Replicating python's behavior when positional args are given other than
    # `self`, and `self` is only allowed if this method is unbound.
    if args:
        raise TypeError(
            f"set_{self.name}_request() takes 0 positional argument but"
            f" {len(args)} were given"
        )

    requests = _instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    _instance._metadata_request = requests

    return _instance

set_predict_request(self, *, return_ci='$UNCHANGED$', return_interval_probs='$UNCHANGED$')

Request metadata passed to the predict method.

Note that this method is only relevant if enable_metadata_routing=True (see :func:sklearn.set_config). Please see :ref:User Guide <metadata_routing> on how the routing mechanism works.

The options for each parameter are:

  • True: metadata is requested, and passed to predict if provided. The request is ignored if metadata is not provided.

  • False: metadata is not requested and the meta-estimator will not pass it to predict.

  • None: metadata is not requested, and the meta-estimator will raise an error if the user provides it.

  • str: metadata should be passed to the meta-estimator with this given alias instead of the original name.

The default (sklearn.utils.metadata_routing.UNCHANGED) retains the existing request. This allows you to change the request for some parameters and not others.

.. versionadded:: 1.3

.. note:: This method is only relevant if this estimator is used as a sub-estimator of a meta-estimator, e.g. used inside a :class:~sklearn.pipeline.Pipeline. Otherwise it has no effect.

Parameters

return_ci : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for return_ci parameter in predict.

return_interval_probs : str, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED Metadata routing for return_interval_probs parameter in predict.

Returns

self : object The updated object.

Source code in xgbse/_kaplan_neighbors.py
def func(*args, **kw):
    """Updates the request for provided parameters

    This docstring is overwritten below.
    See REQUESTER_DOC for expected functionality
    """
    if not _routing_enabled():
        raise RuntimeError(
            "This method is only available when metadata routing is enabled."
            " You can enable it using"
            " sklearn.set_config(enable_metadata_routing=True)."
        )

    if self.validate_keys and (set(kw) - set(self.keys)):
        raise TypeError(
            f"Unexpected args: {set(kw) - set(self.keys)} in {self.name}. "
            f"Accepted arguments are: {set(self.keys)}"
        )

    # This makes it possible to use the decorated method as an unbound method,
    # for instance when monkeypatching.
    # https://github.com/scikit-learn/scikit-learn/issues/28632
    if instance is None:
        _instance = args[0]
        args = args[1:]
    else:
        _instance = instance

    # Replicating python's behavior when positional args are given other than
    # `self`, and `self` is only allowed if this method is unbound.
    if args:
        raise TypeError(
            f"set_{self.name}_request() takes 0 positional argument but"
            f" {len(args)} were given"
        )

    requests = _instance._get_metadata_request()
    method_metadata_request = getattr(requests, self.name)

    for prop, alias in kw.items():
        if alias is not UNCHANGED:
            method_metadata_request.add_request(param=prop, alias=alias)
    _instance._metadata_request = requests

    return _instance