Source code for bsix.models.base

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import mplcursors
import numpy as np
import pandas as pd
import torch

from abc import ABC, abstractmethod
from lifelines import CoxPHFitter, KaplanMeierFitter, PiecewiseExponentialFitter, statistics
from scipy import stats
from sklearn.base import BaseEstimator

def _tool_setTimeTicksAxisX(ax):

    """
    Tool for setting time ticks on X-axis.
    """

    max_time = max(abs(ax.get_xlim()[0]), abs(ax.get_xlim()[1])) # (min, max)

    # Time (days)
    if max_time > 3650: # More than 10 years
        major, minor = 1825, 365 # 5 years - 1 year
    elif max_time > 365: # Between 1 year and 10 years
        major, minor = 365, 73 # 1 year - 5 splits per year
    elif max_time > 90: # Less than 1 year
        major, minor = 30, 6 # 1 month - 5 splits per month

    # Time (years)
    elif max_time > 10: # More than 10 years
        major, minor = 5, 1 # 5 years - 1 year
    elif max_time > 1: # Between 1 year and 10 years
        major, minor = 1, 0.2 # 1 year - 5 splits per year
    else: # Less than 1 year
        major, minor = 0.5, 0.1 # 1 month - 5 splits per month

    return major, minor

def _tool_setXaiTicksAxisX(ax):

    """
    Tool for setting XAI ticks on X-axis.
    """

    max_shap = max(abs(ax.get_xlim()[0]), abs(ax.get_xlim()[1])) # (min, max)

    if max_shap > 50: # More than 30 (xai)
        major, minor = 10, 2 # 10 - 2 (xai)
    elif max_shap > 30: # More than 30 (xai)
        major, minor = 5, 1 # 5 - 1 (xai)
    elif max_shap > 10: # More than 10 (xai)
        major, minor = 1, 0.2 # 1 - 0.2 (xai)
    elif max_shap > 1: # More than 1 (xai)
        major, minor = 0.5, 0.1 # 0.5 - 0.1 (xai)
    elif max_shap > 0.1: # More than 0.1 (xai)
        major, minor = 0.1, 0.02 # 0.1 - 0.02 (xai)
    else:
        major, minor = 0.05, 0.01 # 0.05 - 0.01 (xai)

    return major, minor

def _tool_setRiskTicksAxisY(ax):

    """
    Tool for setting risk ticks on Y-axis.
    """

    max_risk = max(abs(ax.get_ylim()[0]), abs(ax.get_ylim()[1])) # (min, max)

    if max_risk > 5: # More than 5 (risk)
        major, minor = 1, 0.25 # 1 - 0.5 (risk)
    else: # Less than 5 (risk)
        major, minor = 0.5, 0.1 # 0.2 - 0.1 (risk)

    return major, minor

def _tool_toDataframe(data, columns=None):

    """
    Tool for converting X, y to a DataFrame.
    """

    if columns == None: # Without columns names
        dataframe = pd.DataFrame(data, columns=[str(l) for l in range(data.shape[1])])
    else: # With columns names
        dataframe = pd.DataFrame(data, columns=columns)

    return dataframe

def _tool_extractSHAP(shap_explainer, seed):

    """
    Tool for extracting data from a SHAP explainer and standardising dimensions.
    """

    values = shap_explainer.values.copy()
    data = shap_explainer.data.copy()
    names = np.array(shap_explainer.feature_names.copy(), str)
    
    # Standardize dimensions: ensure values and data are 3D (num_seeds, num_samples, num_features)
    if shap_explainer.values.ndim == 1: # (num_samples,)
        _values = values[np.newaxis, :, np.newaxis]
        _data = data[np.newaxis, :, np.newaxis]
    elif shap_explainer.values.ndim == 2:
        if seed == None: # (num_seeds, num_samples,)
            _values = values[..., np.newaxis]
            _data = data[..., np.newaxis]
        else: # (, num_samples, num_features)
            _values = values[np.newaxis, ...]
            _data = data[np.newaxis, ...]
    else: # (num_seeds, num_samples, num_features)
        _values = values
        _data = data

    return _values, _data, names

    
[docs] class BaseSurvival(BaseEstimator, ABC): """ Abstract Class for Survival Analysis models. """
[docs] @abstractmethod def calculate_xai(self, X, **kwargs): """ Calculate XAI values. """ raise NotImplementedError
[docs] @abstractmethod def fit(self, X, y, **kwargs): """ Fit the model. """ raise NotImplementedError
[docs] @abstractmethod def predict(self, X, **kwargs): """ Predict on X. """ raise NotImplementedError
[docs] @abstractmethod def predict_cumulative_hazard_function(self, X, **kwargs): """ H(x,t) = H0(t) * exp(g(x)). """ raise NotImplementedError
[docs] @abstractmethod def predict_survival_function(self, X, **kwargs): """ S(x, t) = exp(-H(x, t)). """ raise NotImplementedError
#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
[docs] @staticmethod def dinamic_discretise(y, dataset, seed=0, plot=False): """ Discretise data by piecewise exponential and show in kaplan meier. """ rng = np.random.default_rng(seed=seed) # Piecewise Exponential (constant risk at intervals) # # Risk groups = num splits + 1 num_splits = [3, 4, 5] # Define the search range # n_iter = 100 is the ideal number of iterations (tested) n_iter = 5 # Iterations in search (random search) min_time = np.quantile(y["time"], 0.05) #y["time"].min() max_time = np.quantile(y["time"], 0.95) #y["time"].max() possible_splits = np.linspace(min_time, max_time, num=15).tolist() # 15 splits # Search (random search) best_splits = None results = [] for n_splits in num_splits: for _ in range(n_iter): # Randomized split = (n1, n2, n3, ...) split = sorted(rng.choice(possible_splits, n_splits, replace=False)) # Instantiate the Piecewise Exponential model with breakpoints pf = PiecewiseExponentialFitter(breakpoints=list(split)) # Fit pf.fit(y["time"], y["event"]) # Calculate the size of each split (only events) splits = [0] + list(split) + [np.inf] lengths = pd.cut(y["time"][y["event"] == 1], bins=splits).value_counts() # Save the results results.append({ "n_splits": n_splits, "split": split, "aic": pf.AIC_, "entropy": (1.0 - (stats.entropy(lengths) / np.log(n_splits + 1))), }) # Convert to dataframe df_results = pd.DataFrame(results) # Apply Min-Max to normalize # AIC min_aic = df_results["aic"].min() max_aic = df_results["aic"].max() df_results["aic_norm"] = (df_results["aic"] - min_aic) / (max_aic - min_aic) # Entropy min_entropy = df_results["entropy"].min() max_entropy = df_results["entropy"].max() df_results["entropy_norm"] = (df_results["entropy"] - min_entropy) / (max_entropy - min_entropy) # Calculate the score (weighted) alpha = 0.5 df_results["score"] = (alpha * df_results["aic_norm"]) + ((1 - alpha) * df_results["entropy_norm"]) # Save the best split best_splits = df_results["split"].loc[df_results["score"].idxmin()] # Kaplan Meier # # Instantiate the Kapaln Meier estimator kmf = KaplanMeierFitter() # Fit kmf.fit(durations=y["time"], event_observed=y["event"]) # Plot # if plot: # Configure style plt.figure(figsize=(10, 6)) # Personalise curve ax = kmf.plot( color="#C1502E", label=f"KM estimate" ) # Splits for split in best_splits: plt.axvline(x=split, color="#2EC192", linestyle="-.", alpha=0.5) # Title and axis labels ax.set_title(f"Discretised Kaplan-Meier\n{dataset} - seed {seed}", fontsize=12) ax.set_xlabel("Time (days)", fontsize=10) ax.set_ylabel("Survival Probability", fontsize=10) # Axis ticks majorX, minorX = _tool_setTimeTicksAxisX(ax) ax.xaxis.set_major_locator(ticker.MultipleLocator(majorX)) ax.xaxis.set_minor_locator(ticker.MultipleLocator(minorX)) ax.yaxis.set_major_locator(ticker.MultipleLocator(0.5)) ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1)) plt.xticks(rotation=45, ha="right") # Axis limits ax.set_xlim(left=0) ax.set_ylim(bottom=0, top=1.05) ax.spines["left"].set_position(("outward", 5)) ax.spines["bottom"].set_position(("outward", 5)) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) # Grid plt.grid(True, which="major", linestyle="-", alpha=0.7) plt.grid(True, which="minor", linestyle="--", alpha=0.7, linewidth=0.5) # Legend plt.legend(frameon=True, facecolor="white", edgecolor="0.8") # Save figure plt.tight_layout() plt.show() splits = [0] + best_splits + [np.inf] return splits
[docs] @staticmethod def feature_selection(X, y): """ Calculate the best features based on p-value. """ # Cox model (lifelines) model = CoxPHFitter(penalizer=0.01) labels_covariables = ["event", "time"] # Transform data to dataframe dataframe = pd.concat([_tool_toDataframe(X), _tool_toDataframe(y, labels_covariables)], axis=1) # List of significance covariables (p-value) significance_covariables = [c for c in dataframe.columns.to_list() if c not in labels_covariables] # Fit with all the covariables model.fit(dataframe[significance_covariables + labels_covariables], duration_col=labels_covariables[1], event_col=labels_covariables[0], show_progress=False) # Compute p-values (sinificance) significance = model._compute_p_values() # Obtain the significance covariables (p-value < 0.05) while max(significance) >= 0.05 and len(significance_covariables) > 0: # Delete the covariable with the maximum p-value covariables_to_delete = significance_covariables[np.where(significance == max(significance))[0][0]] significance_covariables.remove(covariables_to_delete) # Fit with the significance covariables model.fit(dataframe[significance_covariables + labels_covariables], duration_col=labels_covariables[1], event_col=labels_covariables[0], show_progress=False) # Compute p-values (sinificance) significance = model._compute_p_values() # Obtain the index of the significance covariables significance_covariables = [int(i) for i in significance_covariables] return significance_covariables
[docs] @staticmethod def generate_simulated_survival_data(number_rows=1000, number_columns=10, censored=0.75, relation=None, seed=0): """ Generate simulated survival data based. """ # Fix the seed np.random.seed(seed) # Generate covariates (normal distribution [nature]) X = np.random.normal(0, 1, size=(number_rows, number_columns)) names_columns = [f"feature_{i}" for i in range(number_columns)] # Generate coeffs (uniform distribution [same probability]) coeffs = np.random.uniform(-1, 1, size=number_columns) # Calculate log_risk (lineal or non-lineal) # Lineal: H(x) = beta * x log_risk = np.dot(X, coeffs) if relation == "cuadratic": # Cuadratic: H(x) = beta * x^2 log_risk = np.dot(X**2, coeffs) elif relation == "sin": # Sin: H(x) = beta * sin(x * pi) log_risk = np.dot(np.sin(X * np.pi), coeffs) # Calculate hazard risk risk = np.exp(log_risk) # Calculate survival times (S(t) = e^(−λt) where λ = baseline × risk) baseline = 0.15 S = np.random.uniform(0, 1, size=number_rows) time_survival = -np.log(S) / (baseline * risk) # Calculate censored individuals number_censored = int(round(number_rows * censored)) idx_censored = np.random.choice(number_rows, size=number_censored, replace=False) # Event event = np.ones(number_rows, dtype=int) # Time time_observed = time_survival.copy() time_censored = np.random.uniform(0, time_survival[idx_censored]) time_observed[idx_censored] = time_censored event[idx_censored] = 0 # Dataframe dataframe = pd.DataFrame(X, columns=names_columns) dataframe["event"] = event dataframe["time"] = np.round(time_observed, 2) return dataframe
[docs] @staticmethod def logrank_test(y, groups, weights=None): """ Calculate the log-rank test for n groups. """ result = statistics.multivariate_logrank_test(y["time"], groups, y["event"], weights) return result
[docs] @staticmethod def to_time_dependent(dataframe, splits, identifier="identifier", time="time", event="event"): """ Transform a DataFrame with a per-subject measurement into a time-dependent format. """ # Sort dataframe by identifier dataframe_transformed = dataframe.sort_values(by=[identifier]).copy() # Rename columns dataframe_transformed = dataframe_transformed.rename(columns={identifier: "identifier", event: "event", time: "time"}) # Aply discretisation dataframe_transformed["time_frame"] = pd.cut(dataframe_transformed["time"], bins=splits, labels=False) # Repeat each row N times according to the time_frame column (time discretised) dataframe_transformed = dataframe_transformed.loc[dataframe_transformed.index.repeat(dataframe_transformed["time_frame"] + 1)] # Accumulate form 0 to time_discretised value dataframe_transformed["time_frame"] = dataframe_transformed.groupby("identifier").cumcount() # Reset index dataframe_transformed = dataframe_transformed.reset_index(drop=True) # Last index of the row for each patient last_row_index = dataframe_transformed.groupby("identifier").tail(1).index # Indicate whether (or not) the event occurred in the last row for each patient dataframe_transformed.loc[~dataframe_transformed.index.isin(last_row_index), "event"] = 0.0 # Calculate the split associated with the time_frame value (not 0) dataframe_transformed["days_risk"] = dataframe_transformed["time_frame"].map(dict(enumerate(splits[1:]))) # Indicate the real value of time in the last row for each patient dataframe_transformed.loc[last_row_index, "days_risk"] = dataframe_transformed.loc[last_row_index, "time"] # Move data from days_risk to time and remove days_risk column. dataframe_transformed["time"] = dataframe_transformed["days_risk"] dataframe_transformed = dataframe_transformed.drop(columns=["days_risk", "time_frame"]) # Reset index dataframe_transformed = dataframe_transformed.reset_index(drop=True) return dataframe_transformed
[docs] @staticmethod def to_time_varying(dataframe, identifier="identifier", time="time", event="event"): """ Transform a DataFrame with a multiple-subject measurements into a start-stop format. """ # Sort dataframe by identifier and date dataframe_transformed = dataframe.sort_values(by=[identifier, time]).copy() # Rename columns dataframe_transformed = dataframe_transformed.rename(columns={identifier: "identifier", event: "event", time: "time_stop"}) # Move the new time_start column (time) down by inserting 0.0 as the first value dataframe_transformed["time_start"] = dataframe_transformed.groupby("identifier")["time_stop"].shift(1).fillna(1e-15) dataframe_transformed = dataframe_transformed.astype({"time_start": float, "time_stop": float}) # Move the event column down by inserting 0.0 as the first value (do not remove the row with the event) shift_event = dataframe_transformed.groupby("identifier")["event"].shift(1).fillna(0).astype(int) # Pass the event down the chain forward_fill_event = shift_event.groupby(dataframe_transformed["identifier"]).cummax() # Remove events with a value of 1.0 dataframe_transformed = dataframe_transformed[forward_fill_event == 0] # Reorder dataframe cols = [col for col in dataframe_transformed.columns if col not in ["identifier", "time_start", "time_stop", "event"]] dataframe_transformed = dataframe_transformed[["identifier"] + cols + ["event", "time_start", "time_stop"]] # Ensure that time_stop is greater than time_start invalid_mask = dataframe_transformed["time_stop"] <= dataframe_transformed["time_start"] dataframe_transformed.loc[invalid_mask, "time_stop"] = dataframe_transformed.loc[invalid_mask, "time_start"] + 1e-15 # Rest index dataframe_transformed = dataframe_transformed.reset_index(drop=True) return dataframe_transformed
#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#
[docs] @staticmethod def plot_coefficients(coefficients, estimator_name, dataset, seed=None, progression=None): """ Plot XAI coefficients for the data (lollipop plot). """ # Extract data coefficients_values = np.array(list(coefficients.values()), dtype=np.float32) names = np.array(list(coefficients.keys()), dtype=str) # Handle multiple seeds if coefficients_values.ndim == 2: # Shape is (num_features, num_seeds). Calculate the mean across seeds (axis=1) coefficients_mean = np.mean(coefficients_values, axis=1) else: # Shape is (num_features,) - Single seed case coefficients_mean = coefficients_values # Sort features by importance (absolute mean coefficient value) # Ascending order so the most important feature appears at the top of the y-axis importance = np.abs(coefficients_mean) sort_idx = np.argsort(importance) names = names[sort_idx] coefficients_mean = coefficients_mean[sort_idx] # Configure plot style figure, ax = plt.subplots(figsize=(10, 6)) cmap = plt.get_cmap("coolwarm") # Normalise the color based on the maximum absolute mean coefficient max_abs = np.max(np.abs(coefficients_mean)) + 1e-15 normalise = plt.Normalize(vmin=-max_abs, vmax=max_abs) # Obtain color map color = cmap(normalise(coefficients_mean)) # Plot the lollipop chart (mean value) ax.hlines(y=names, xmin=0, xmax=coefficients_mean, colors=color, linewidth=3, alpha=1.0, zorder=2) ax.scatter(x=coefficients_mean, y=names, marker="o", c=color, s=200, alpha=1.0, zorder=3) # Add text labels for the exact mean coefficient values near the zero-axis for name, val in zip(names, coefficients_mean): x_pos = (-(max_abs * 0.01) if val >= 0.0 else (max_abs * 0.01)) ha_pos = ("right" if val >= 0.0 else "left") ax.text( x=x_pos, y=name, s=f"$\\bf{{{val:.2f}}}$", color="#353535", fontsize=6, ha=ha_pos, va="center", bbox=dict(facecolor='white', edgecolor='none', pad=1, alpha=0.5), zorder=4 ) # Draw vertical reference line at 0 ax.axvline(x=0, color="#000000", linewidth=0.75, zorder=1) # Title and axis labels title_parts = [f"{estimator_name} - {dataset}"] if seed is not None: title_parts.append(f"seed {seed}") if progression is not None: title_parts.append(f"progression {progression}") plt.title(f"XAI\n{' - '.join(title_parts)}", fontsize=12) plt.xlabel("$\\bf{Mean}$ $\\bf{Coefficient}$ $\\bf{values}$", fontsize=10) plt.ylabel("$\\bf{Features}$", fontsize=10) # Format axes ax = plt.gca() majorX, minorX = _tool_setXaiTicksAxisX(ax) ax.xaxis.set_major_locator(ticker.MultipleLocator(majorX)) ax.xaxis.set_minor_locator(ticker.MultipleLocator(minorX)) plt.xticks(rotation=45, ha="right") # Grid styling plt.grid(True, which="major", linestyle="-", alpha=0.7, zorder=0) plt.grid(True, which="minor", linestyle="--", alpha=0.7, linewidth=0.5, zorder=0) plt.tight_layout() return figure, ax
[docs] @staticmethod def plot_individual_shap(shap_explainer, identifier_index, index, scaler, estimator_name, dataset, seed=None, progression=None): """ Plot SHAP values for an individual instance (horizontal bar plot). """ if not isinstance(index, np.ndarray): index = np.array(index) if not isinstance(scaler, list): scaler = [scaler] # Target indexes for all seeds (positions of the same individual across seeds) target_idxs = np.where(index == identifier_index) _values, _data, names = _tool_extractSHAP(shap_explainer, seed) _values = _values[target_idxs] _data = _data[target_idxs] for d, s in enumerate(target_idxs[0]): # Index for valid columns (features) in the current seed sort_idx = [list(names).index(c) for c in list(scaler[s].feature_names_in_)] # Scaler inverse transform only the valid columns for the current seed transformed_valid_data = scaler[s].inverse_transform(_data[d][sort_idx].reshape(1, -1)) # Rewrite data only with changes (nan values remain unchanged) _data[d][sort_idx] = transformed_valid_data # Clean floating-point inconsistencies _data = np.round(_data, decimals=5) # Calulate means _values_mean = np.nanmean(_values, axis=0) _data_mean = np.nanmean(_data, axis=0) # Sort features by importance (absolute SHAP value) importance = np.abs(_values_mean) sort_idx = np.argsort(importance) sorted_values = _values_mean[sort_idx] sorted_data = _data_mean[sort_idx] sorted_names = names[sort_idx] # Configure plot style figure, ax = plt.subplots(figsize=(10, 6)) cmap = plt.get_cmap("coolwarm") # Normalise the color based on the maximum absolute SHAP value max_abs = np.nanmax(np.abs(sorted_values)) + 1e-15 normalise = plt.Normalize(vmin=-max_abs, vmax=max_abs) # Map SHAP values to colors colors = cmap(normalise(sorted_values)) # Plot the horizontal bars ax.barh(sorted_names, sorted_values, color=colors, edgecolor="#000000", linewidth=0.75, alpha=0.8, zorder=3) # Add the color bar (legend) sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1)) color_bar = figure.colorbar(sm, ax=ax) color_bar.set_label("Feature value", labelpad=-15, fontsize=10) color_bar.set_ticks([0, 1]) color_bar.set_ticklabels(["Low", "High"]) # Draw vertical reference line at SHAP value = 0 ax.axvline(x=0, color="#000000", linewidth=0.75, zorder=2) # Title and axis labels title_parts = [f"{estimator_name} - {dataset}"] if seed is not None: title_parts.append(f"seed {seed}") if progression is not None: title_parts.append(f"progression {progression}") plt.title(f"XAI\n{' - '.join(title_parts)}\nIndividual {identifier_index}", fontsize=12) plt.xlabel("$\\bf{SHAP}$ $\\bf{values}$", fontsize=10) plt.ylabel("$\\bf{Features}$", fontsize=10) # Format axes majorX, minorX = _tool_setXaiTicksAxisX(ax) ax.xaxis.set_major_locator(ticker.MultipleLocator(majorX)) ax.xaxis.set_minor_locator(ticker.MultipleLocator(minorX)) plt.xticks(rotation=45, ha="right") # Format y-axis labels to include both the feature name and its real-world value y_labels = [f"{n} (= {d:.2f})" for n, d in zip(sorted_names, sorted_data)] plt.yticks(ticks=np.arange(len(sort_idx)), labels=y_labels) # Grid styling plt.grid(True, which="major", linestyle="-", alpha=0.7, zorder=0) plt.grid(True, which="minor", linestyle="--", alpha=0.7, linewidth=0.5, zorder=0) plt.tight_layout() return figure, ax
[docs] @staticmethod def plot_shap(shap_explainer, index, scaler, estimator_name, dataset, seed=None, progression=None): """ Plot SHAP values for the data (beeswarm plot). """ if not isinstance(index, np.ndarray): index = np.array(index) if not isinstance(scaler, list): scaler = [scaler] _values, _data, names = _tool_extractSHAP(shap_explainer, seed) for s in range(_data.shape[0]): # Index for valid columns (features) in the current seed sort_idx = [list(names).index(c) for c in list(scaler[s].feature_names_in_)] # Scaler inverse transform only the valid columns for the current seed transformed_valid_data = scaler[s].inverse_transform(_data[s][:, sort_idx]) # Rewrite data only with changes (nan values remain unchanged) _data[s][:, sort_idx] = transformed_valid_data # Clean floating-point inconsistencies _data = np.round(_data, decimals=5) # Calculate global feature importance and consensus sorting # Average absolute SHAP values across all seeds (axis=0) and samples (axis=1) global_importance = np.abs(_values).mean(axis=(0, 1)) # Sort indices descending (most important feature first) global_sort_idx = np.argsort(global_importance) # Configure plot style figure, ax = plt.subplots(figsize=(10, 6)) cmap = plt.get_cmap("coolwarm") # Plot points # Iterate over the globally sorted features (row by row) dots = [] for y_pos, feature_idx in enumerate(global_sort_idx, start=1): # Flatten it to draw all the seeds. x = _values[:, :, feature_idx].flatten() x_original = _data[:, :, feature_idx].flatten() # Normalise the color for the scatter points min_val = np.nanmin(x_original) max_val = np.nanmax(x_original) + 1e-15 # Add jitter to the y-axis to spread out the points (beeswarm effect) y = y_pos + np.random.normal(0, 0.075, size=len(x)) # Scatter plot by feature dot = ax.scatter(x, y, s=10, c=x_original, cmap=cmap, vmin=min_val, vmax=max_val, alpha=0.8, edgecolors="none", zorder=3) dot.data = _data dot.feature_name = names[feature_idx] dot.identifier_indexes = index.flatten() dots.append(dot) # Add the color bar (legend) sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=0, vmax=1)) color_bar = figure.colorbar(sm, ax=ax) color_bar.set_label("Feature value", labelpad=-15, fontsize=10) color_bar.set_ticks([0, 1]) color_bar.set_ticklabels(["Low", "High"]) # Draw vertical line at SHAP value = 0 ax.axvline(x=0, color="#000000", linewidth=0.75, zorder=2) # Title and axis labels title_parts = [f"{estimator_name} - {dataset}"] if seed is not None: title_parts.append(f"seed {seed}") if progression is not None: title_parts.append(f"progression {progression}") plt.title(f"XAI\n{' - '.join(title_parts)}", fontsize=12) plt.xlabel("$\\bf{Shap}$ $\\bf{values}$", fontsize=10) plt.ylabel("$\\bf{Features}$", fontsize=10) # Format axes majorX, minorX = _tool_setXaiTicksAxisX(ax) ax.xaxis.set_major_locator(ticker.MultipleLocator(majorX)) ax.xaxis.set_minor_locator(ticker.MultipleLocator(minorX)) plt.xticks(rotation=45, ha="right") # Set y-axis labels using the globally sorted feature names plt.yticks(ticks=np.arange(1, len(global_sort_idx) + 1), labels=names[global_sort_idx]) # Grid styling plt.grid(True, which="major", linestyle="-", alpha=0.7, zorder=0) plt.grid(True, which="minor", linestyle="--", alpha=0.7, linewidth=0.5, zorder=0) plt.tight_layout() # Highlight layer (initially empty, will be updated interactively) # Outside the loop to avoid multiple layers highlight_scatter = ax.scatter([], [], s=20, c='black', zorder=5) # Personalise cursor cursor = mplcursors.cursor(dots) @cursor.connect("add") def _tool_setInteractivePlot(cursor): cursor_idx = cursor.index # Indexes seed_idx = cursor_idx // _values.shape[1] identifier_idxs = cursor.artist.identifier_indexes clicked_individual_idx = cursor_idx % _values.shape[1] clicked_feature_name = cursor.artist.feature_name # Target indexes for all seeds (positions of the same individual across seeds) target_idxs = np.where(identifier_idxs == identifier_idxs[cursor_idx])[0] # Obtain target coordinates (x, y) for all seeds hx, hy = [], [] for dot in dots: offsets = dot.get_offsets() for t_idx in target_idxs: hx.append(offsets[t_idx, 0]) hy.append(offsets[t_idx, 1]) # Update the highlight layer highlight_scatter.set_offsets(np.c_[hx, hy]) # Text text_lines = [ f"$\\bf{{Seed:}}$ {seed_idx}", f"$\\bf{{Identifier\\ Index:}}$ {identifier_idxs[cursor_idx]}", f"$\\bf{{Individual\\ Index:}}$ {clicked_individual_idx}", f"$\\bf{{SHAP\\ values:}}$" ] for f_idx in global_sort_idx: f_name = names[f_idx] s_val = _values[seed_idx, clicked_individual_idx, f_idx] prefix = "•" if f_name == clicked_feature_name else " " text_lines.append(f"{prefix}{f_name}: {s_val:.4f}") cursor.annotation.set_text("\n".join(text_lines)) cursor.annotation.set_ha("left") cursor.annotation.set_multialignment("left") # Bounding Box idx_color_value = cursor.artist.get_array()[cursor_idx] color = cursor.artist.cmap(cursor.artist.norm(idx_color_value)) bbox = cursor.annotation.get_bbox_patch() bbox.set_facecolor(color) bbox.set_edgecolor("black") bbox.set_alpha(0.7) return figure, ax
@staticmethod def _plot_survival_hazard_functions(X, index, estimator_name, dataset, function_type="Survival", seed=0, progression=None): """ Plot survival and cumulative hazard functions for the data. """ # Configure style figure, ax = plt.subplots(figsize=(10, 6)) # Plot curve lines = [] for i, step_function in enumerate(X): times = step_function.X probabilities = step_function(times) line, = ax.step(times, probabilities, where="post", alpha=0.6) line.identifier_index = index[i] lines.append(line) # Title and axis labels title_parts = [f"{estimator_name} - {dataset}"] if seed is not None: title_parts.append(f"seed {seed}") if progression is not None: title_parts.append(f"progression {progression}") plt.title(f"{function_type}\n{' - '.join(title_parts)}", fontsize=12) plt.xlabel("Time (days)", fontsize=10) plt.ylabel(f"{function_type} probability", fontsize=10) # Axis ticks majorX, minorX = _tool_setTimeTicksAxisX(ax) ax.xaxis.set_major_locator(ticker.MultipleLocator(majorX)) ax.xaxis.set_minor_locator(ticker.MultipleLocator(minorX)) if function_type == "Survival": ax.yaxis.set_major_locator(ticker.MultipleLocator(0.5)) ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.1)) if function_type == "CumulativeRisk": majorY, minorY = _tool_setRiskTicksAxisY(ax) ax.yaxis.set_major_locator(ticker.MultipleLocator(majorY)) ax.yaxis.set_minor_locator(ticker.MultipleLocator(minorY)) plt.xticks(rotation=45, ha="right") # Axis limits ax.set_xlim(left=0) if function_type == "Survival": ax.set_ylim(bottom=0, top=1.05) if function_type == "CumulativeRisk": ax.set_ylim(bottom=0) ax.spines["left"].set_position(("outward", 5)) ax.spines["bottom"].set_position(("outward", 5)) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) # Grid styling plt.grid(True, which="major", linestyle="-", alpha=0.7) plt.grid(True, which="minor", linestyle="--", alpha=0.7, linewidth=0.5) figure.tight_layout() # Personalise cursor cursor = mplcursors.cursor(lines) @cursor.connect("add") def _tool_setInteractivePlot(cursor): # Values idx = cursor.artist.identifier_index # Text cursor.annotation.set_text( f"$\\bf{{Time:}}$ {cursor.target[0]:.2f}\n" f"$\\bf{{{function_type}:}}$ {cursor.target[1]:.2f}\n" f"$\\bf{{Individual:}}$ {idx}" ) cursor.annotation.set_ha("left") cursor.annotation.set_multialignment("left") # Bounding Box bbox = cursor.annotation.get_bbox_patch() bbox.set_facecolor(cursor.artist.get_color()) bbox.set_edgecolor("black") bbox.set_alpha(0.7) return figure, ax def _sort(self, X, y, time="time"): """ Sort data by descending time. """ sort_idx = np.argsort(y[time])[::-1] X = X[sort_idx] y = y[sort_idx] return X, y def _sort_multitask(self, risk, t, e): """ Sort data by descending time (multitask). """ _, idx = torch.sort(t, descending=True) risk = risk[idx] t = t[idx] e = e[idx] return risk, t, e