import logging
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import shap
import warnings
from ..base import BaseSurvival
from .utils import StepFunction
from numba import njit
from sklearn.utils.validation import check_random_state
warnings.filterwarnings("ignore")
@njit(fastmath=True, cache=True)
def _calculate_log_rank_njit(times_left, events_left, unique_times, n_j, d_j):
"""
Calculate the Log-Rank statistic between two groups defined by left mask.
"""
# At-risk count in the left group (nleft_j) at each unique event time
idx_left = np.searchsorted(times_left, unique_times, side="left")
nleft_j = (len(times_left) - idx_left).astype(np.float64)
# Event count in the left group (dleft_j) at each unique event time
mask = events_left != 0
tleft_events = times_left[mask]
if len(tleft_events) > 0:
dleft_j = np.bincount(np.searchsorted(unique_times, tleft_events), minlength=len(unique_times)).astype(np.float64)
else:
dleft_j = np.zeros(len(unique_times), dtype=np.float64)
safe_n_j = np.where(n_j > 0, n_j, 1.0)
# Calculate U statistic (Observed - Expected events)
U = np.sum(np.where(n_j > 0, dleft_j - d_j * (nleft_j / safe_n_j), 0.0))
# Calculate Variance (V) assuming a hypergeometric distribution
valid = n_j > 1.0
nv = n_j[valid]
V = np.sum(d_j[valid] * nleft_j[valid] * (nv - nleft_j[valid]) * (nv - d_j[valid]) / (nv ** 2 * (nv - 1.0)))
# Return Log-Rank score
return (U ** 2) / V if V > 0.0 else 0.0
@njit(fastmath=True, cache=True)
def _best_split_njit(X, events, times, unique_times, n_j, d_j, features, min_samples_leaf):
"""
Return (best_feature_idx, best_threshold) maximising log-rank score.
"""
best_score = -1.0
best_feature = -1
best_threshold = np.nan
for fi in features:
col = X[:, fi]
uniq = np.unique(col)
# Skip the very last threshold since it would create an empty right node
for i in range(len(uniq) - 1):
thresh = uniq[i]
left_mask = col <= thresh
num_left = int(left_mask.sum())
num_right = len(left_mask) - num_left
# Check constraints
if num_left < min_samples_leaf or num_right < min_samples_leaf:
continue
score = _calculate_log_rank_njit(times[left_mask], events[left_mask], unique_times, n_j, d_j)
# Update best split
if score > best_score:
best_score = score
best_feature = fi
best_threshold = thresh
return best_feature, best_threshold
class LeafEstimator:
"""
Local estimator for leaf nodes in the Survival Tree.
"""
def __init__(self):
"""
Initialise with specified parameters.
"""
# Parameters
self.times = None
self.survival = None
self.cumulative_hazard = None
def fit(self, events, times, global_times):
"""
Fit the estimator to the data.
"""
self.times = global_times
# Sort by time (already sorted?)
sort_idx = np.argsort(times)
t_sorted = times[sort_idx]
e_sorted = events[sort_idx].astype(bool)
# Risk set (n_i) at each global time point
risk_set = len(times) - np.searchsorted(t_sorted, self.times, side="left")
# Count the exact number of events (d_i) at each global time point
d_events = np.zeros(len(self.times), dtype=np.float64)
event_times = t_sorted[e_sorted]
if len(event_times) > 0:
# Map local event times to their corresponding index in the global grid
idx = np.searchsorted(self.times, event_times)
# Counts multiple events happening at the same time
np.add.at(d_events, idx, 1)
safe_risk = np.where(risk_set > 0, risk_set, 1.0)
# Calculate discrete hazards (d_i / n_i)
hazards = np.where(risk_set > 0, d_events / safe_risk, 0.0)
self.cumulative_hazard = np.cumsum(hazards)
self.survival = np.cumprod(1.0 - hazards)
class TreeNode:
"""
Node in the Survival Tree.
"""
def __init__(self, feature=None, threshold=None, left=None, right=None, *, is_leaf=False, risk_value=None, estimator=None):
"""
Initialise with specified parameters.
"""
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.is_leaf = is_leaf
self.risk_value = risk_value
self.estimator = estimator
[docs]
class SurvTree(BaseSurvival):
"""
Survival Tree model.
"""
def __init__(self, max_depth=None, min_samples_split=6, min_samples_leaf=3, seed=0):
"""
Initialise model with specified parameters.
"""
# Parameters
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.min_samples_leaf = min_samples_leaf
self.seed = seed
self.root = None
self.unique_times = None
self.labels_covariables = ["event", "time"]
def _best_split(self, X, events, times, n_features):
"""
Return (best_feature_idx, best_threshold) maximising log-rank score.
"""
# Extract unique event times and counts for the parent node
event_mask = events.astype(bool)
unique_times, d_j_int = np.unique(times[event_mask], return_counts=True)
# If no events occurred in this node, it cannot be split
if len(unique_times) == 0:
return None, None
# Pre-compute parent at-risk counts (n_j) and events (d_j).
n_j = (len(times) - np.searchsorted(times, unique_times, side="left")).astype(np.float64)
d_j = d_j_int.astype(np.float64)
# Shuffle features to ensure random, reproducible tie-breaking
features = np.arange(n_features)
self.rng.shuffle(features)
best_feature, best_threshold = _best_split_njit(X, events, times, unique_times, n_j, d_j, features, self.min_samples_leaf)
if best_feature == -1:
return None, None
return best_feature, best_threshold
def _create_leaf(self, events, times):
"""
Instantiate a terminal node and fit the local survival estimators.
"""
estimator = LeafEstimator()
estimator.fit(events, times, self.unique_times)
# Risk value defined as the area under the cumulative hazard curve
risk_value = float(np.sum(estimator.cumulative_hazard))
return TreeNode(is_leaf=True, risk_value=risk_value, estimator=estimator)
def _build_tree(self, X, events, times, depth):
"""
Recursively build the tree.
"""
n_samples, n_features = X.shape
stop = (
(self.max_depth is not None and depth >= self.max_depth)
or n_samples < self.min_samples_split
or int(events.sum()) == 0
)
# Evaluate stopping criteria
if stop:
return self._create_leaf(events, times)
# Search for the optimal split
best_feature, best_threshold = self._best_split(X, events, times, n_features)
# If no valid split was found convert to leaf
if best_feature is None:
return self._create_leaf(events, times)
# Create boolean mask for the left branch
left_mask = X[:, best_feature] <= best_threshold
# Recursively construct left and right branches
return TreeNode(feature=best_feature, threshold=best_threshold, left=self._build_tree(X[left_mask], events[left_mask], times[left_mask], depth + 1), right=self._build_tree(X[~left_mask], events[~left_mask], times[~left_mask], depth + 1))
[docs]
def fit(self, X, y):
"""
Fit the model to the data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data.
y : structured array-like, shape (n_samples,)
Target training values (events, times).
Returns
-------
self : SurvTree
Fitted estimator.
"""
X, y = self._sort(X, y)
events = y["event"]
times = y["time"]
self.rng = check_random_state(self.seed)
self.unique_times = np.unique(times)
self.root = self._build_tree(X, events, times, depth=0)
return self
[docs]
def predict(self, X):
"""
Predict risk scores for the given data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data.
Returns
-------
risk : array-like, shape (n_samples,)
Predicted risk scores.
"""
leaves = self._get_leaves(X)
risks = np.empty(len(leaves), dtype=np.float64)
for i, node in enumerate(leaves):
risks[i] = node.risk_value
return risks
def score(self, X, y):
return None
def _get_leaves(self, X):
"""
Finds and returns the leaf node for each sample.
"""
leaves = np.empty(X.shape[0], dtype=object)
for i in range(X.shape[0]):
node = self.root
while not node.is_leaf:
if X[i, node.feature] <= node.threshold:
node = node.left
else:
node = node.right
leaves[i] = node
return leaves
# ----------------------
# Base Survival methods
# ----------------------
def _compute_survival_hazard_functions(self, X, survival=True):
"""
Auxiliary method for computing the cumulative hazard function.
"""
if not self.root:
raise ValueError(f"When computing `cumulative_hazard_function` with a model, first fit the model.")
leaves = self._get_leaves(X)
functions = []
for node in leaves:
if survival:
functions.append(StepFunction(node.estimator.times, np.exp(-node.estimator.cumulative_hazard), is_survival=survival))
else:
functions.append(StepFunction(node.estimator.times, node.estimator.cumulative_hazard, is_survival=survival))
return np.array(functions, dtype=object)
[docs]
def predict_survival_function(self, X, index, dataset, seed, plot=False):
"""
Predict the survival function for the given data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data.
index : array-like, shape (n_samples,)
Index for the samples.
dataset : str
Name of the dataset.
seed : int
Random seed for reproducibility.
plot : bool, default = ``False``
Whether to plot the survival function.
Returns
-------
survival_function : array-like, shape (n_samples, n_times)
Predicted survival function.
"""
try:
seed = int(seed)
except (TypeError, ValueError):
raise ValueError(f"When using `predict_survival_function`, the seed must be an integer. Value received: {seed}")
self.survival_function = self._compute_survival_hazard_functions(X, survival=True)
if plot:
figure, ax = self._plot_survival_hazard_functions(self.survival_function, index, "Survival Tree", dataset, "Survival", seed)
plt.show()
return self.survival_function
[docs]
def predict_cumulative_hazard_function(self, X, index, dataset, seed, plot=False):
"""
Predict the cumulative hazard function for the given data.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data.
index : array-like, shape (n_samples,)
Index for the samples.
dataset : str
Name of the dataset.
seed : int
Random seed for reproducibility.
plot : bool, default = ``False``
Whether to plot the cumulative hazard function.
Returns
-------
cumulative_hazard_function : array-like, shape (n_samples, n_times)
Predicted cumulative hazard function.
"""
try:
seed = int(seed)
except (TypeError, ValueError):
raise ValueError(f"When using `predict_cumulative_hazard_function`, the seed must be an integer. Value received: {seed}")
self.cumulative_hazard_function = self._compute_survival_hazard_functions(X, survival=False)
if plot:
figure, ax = self._plot_survival_hazard_functions(self.cumulative_hazard_function, index, "Survival Tree", dataset, "CumulativeRisk", seed)
plt.show()
return self.cumulative_hazard_function
# ----------------------
# XAI
# ----------------------
[docs]
def calculate_xai(self, X, index, scaler, dataset, seed, feature_names, background=False, plot=False):
"""
Calculate XAI values.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Input data.
index : array-like, shape (n_samples,)
Index for the samples.
scaler : object
Scaler used for the data.
dataset : str
Name of the dataset.
seed : int
Random seed for reproducibility.
feature_names : list of str
Names of the features.
background : bool, default = ``False``
Whether to use background data for SHAP.
plot : bool, default = ``False``
Whether to plot the XAI values.
Returns
-------
shap_explainer : shap.Explainer
SHAP explainer for model interpretability.
"""
try:
seed = int(seed)
except (TypeError, ValueError):
raise ValueError(f"When using `calculate_xai` with a model, the seed must be an integer. Value received: {seed}")
logging.getLogger("xai").setLevel(logging.WARNING)
# Applying Explainer (model type)
masker = shap.maskers.Independent(X, max_samples=X.shape[0])
explainer_risk = shap.Explainer(self.predict, masker, feature_names=feature_names, seed=seed)
# Background (faster)
X_background = X.copy()
if background:
X_background = pd.DataFrame(shap.kmeans(X, background).data, columns=feature_names)
self.shap_explainer = explainer_risk(X_background)
if plot:
figure, ax = BaseSurvival.plot_shap(self.shap_explainer, index, scaler, "Survival Tree", dataset, seed)
plt.show()
return self.shap_explainer