import itertools
import matplotlib.pyplot as plt
import numpy as np
from bsix.models import BaseSurvival
from bsix.utils import from_results_to_metrics
from collections import defaultdict
from remayn.report import create_excel_summary_report
from remayn.result_set import ResultFolder
from types import SimpleNamespace
def _filter_search(result, estimator_name, dataset, seed):
"""
Filter function to find the result with the given estimator name, dataset and seed.
"""
if estimator_name is not None and result.config.get('estimator_name') != estimator_name:
return False
if dataset is not None and result.config.get('dataset') != dataset:
return False
if seed is not None and result.config.get('random_state') != seed:
return False
return True
def _sort_results(results, estimator_name, dataset, seed):
"""
Sort the results by the given estimator name, dataset and seed.
"""
sort_fields = []
if estimator_name is None:
sort_fields.append('estimator_name')
if dataset is None:
sort_fields.append('dataset')
if seed is None:
sort_fields.append('random_state')
if sort_fields:
results = sorted(
results,
key=lambda result: tuple(result.config.get(field) for field in sort_fields)
)
return results
[docs]
def get_results(result_folder="./results", estimator_name=None, dataset=None, seed=None):
"""
Get the results for the given estimator name, dataset and seed.
"""
rf = ResultFolder(result_folder)
filtered_results = rf.filter(lambda result: _filter_search(result, estimator_name, dataset, seed))
filtered_results = _sort_results(filtered_results, estimator_name, dataset, seed)
results = []
for result in filtered_results:
result.get_data()
results.append(result)
return results
def _sort_dict(data_dict):
"""
Rearrange the arrays so that their columns match a unified reference order of feature_names across all items.
"""
def _align_to_reference(item_list, feature_lists, reference_order):
"""
Aligns features of each seed to the unified reference order.
"""
aligned_list = []
for item, current_features in zip(item_list, feature_lists):
current_features_list = list(current_features)
# Ensure item is at least 1D to safely check its dimensions
item_array = np.atleast_1d(item)
is_2d = item_array.ndim == 2
_list = []
for fn in reference_order:
if fn in current_features_list:
idx = current_features_list.index(fn)
# Extract the column for 2D arrays, or the scalar for 1D arrays
_list.append(item_array[:, idx] if is_2d else item_array[idx])
else:
# Pad with NaNs of the correct shape if the feature is missing
if is_2d:
_list.append(np.full(item_array.shape[0], np.nan))
else:
_list.append(np.nan)
# Reconstruct the array with the aligned features
if is_2d:
aligned_list.append(np.column_stack(_list))
else:
aligned_list.append(np.array(_list))
return np.squeeze(np.array(aligned_list, dtype=float))
# Dictionary processing
for identifier_name, data in data_dict.items():
# Create a unified, duplicate-free list of all features preserving order
reference_order = list(dict.fromkeys(itertools.chain.from_iterable(data["feature_names"])))
# Align values_list and safely stack into a NumPy array
data["values_list"] = _align_to_reference(data["values_list"], data["feature_names"], reference_order)
# Align data_list if it exists
if "data_list" in data:
data["data_list"] = _align_to_reference(data["data_list"], data["feature_names"], reference_order)
# Update feature names to the reference order as a safe 1D array
data["feature_names"] = np.atleast_1d(reference_order)
return data_dict
[docs]
def get_xai_from_filter(result_folder="./results", estimator_name=None, dataset=None, seed=None, identifier_index=None):
"""
Get the xai for the given estimator name, dataset and seed.
"""
rf = ResultFolder(result_folder)
filtered_results = rf.filter(lambda result: _filter_search(result, estimator_name, dataset, seed))
model_list = []
for result in filtered_results:
result.get_data()
# Create an identifier name based on the estimator name and dataset
model_list.append((result.config["estimator_name"], result.config["dataset"], result.data_.best_model))
get_xai_from_model_list(model_list, seed, identifier_index)
[docs]
def get_xai_from_model_list(model_list, seed=None, identifier_index=None):
"""
Get the xai for the given model_list (estimator_name, dataset, model)
"""
dictionary_coefficients = defaultdict(lambda: {'values_list': [], 'feature_names': []})
dictionary_miscellany = defaultdict(lambda: {'scaler': [], 'train_idx': [], 'val_idx': [], 'test_idx': []})
dictionary_shap = defaultdict(lambda: {'data_list': [], 'values_list': [], 'feature_names': []})
for estimator_name, dataset, model in model_list:
# Create an identifier name based on the estimator name and dataset
identifier_name = f"{estimator_name}_{dataset}"
# Accumulate data in the relevant dictionary (coefficients)
if hasattr(model, "coefficients"):
dictionary_coefficients[identifier_name]['values_list'].append(list(model.coefficients.values()))
dictionary_coefficients[identifier_name]['feature_names'].append(list(model.coefficients.keys()))
# Store data in the relevant dictionary (shap)
if hasattr(model, "shap_explainer"):
dictionary_shap[identifier_name]['data_list'].append(model.shap_explainer.data)
dictionary_shap[identifier_name]['values_list'].append(model.shap_explainer.values)
dictionary_shap[identifier_name]['feature_names'].append(model.shap_explainer.feature_names)
dictionary_shap[identifier_name]['data_list'] = list(dictionary_shap[identifier_name]['data_list'])
dictionary_shap[identifier_name]['values_list'] = list(dictionary_shap[identifier_name]['values_list'])
# Store miscellany in the relevant dictionary (miscellany)
if hasattr(model, "scaler_"):
dictionary_miscellany[identifier_name]['scaler'].append(model.scaler_)
if hasattr(model, "train_idx_"):
dictionary_miscellany[identifier_name]['train_idx'].append(model.train_idx_)
if hasattr(model, "val_idx_"):
dictionary_miscellany[identifier_name]['val_idx'].append(model.val_idx_)
if hasattr(model, "test_idx_"):
dictionary_miscellany[identifier_name]['test_idx'].append(model.test_idx_)
if dictionary_coefficients == {}:
dictionary_coefficients = None
else:
dictionary_coefficients = _sort_dict(dictionary_coefficients)
if dictionary_shap == {}:
dictionary_shap = None
else:
dictionary_shap = _sort_dict(dictionary_shap)
_from_dictionaries_to_xai(dictionary_coefficients, dictionary_shap, dictionary_miscellany, seed, identifier_index)
def _from_dictionaries_to_xai(dictionary_coefficients, dictionary_shap, dictionary_miscellany, seed, identifier_index):
"""
Get the xai for the given dictionaries.
"""
# Calculate average coefficients by dataset_estimator
if dictionary_coefficients is not None:
average_coefficients = {}
for identifier_name, data in dictionary_coefficients.items():
if (data['values_list']).ndim > 1:
mean_coefficients = np.nanmean(data['values_list'], axis=1)
else:
mean_coefficients = data['values_list']
# Ensure at least 1D in arrays
data['feature_names'] = np.atleast_1d(data['feature_names'])
mean_coefficients = np.atleast_1d(mean_coefficients)
average_coefficients[identifier_name] = dict(zip(data['feature_names'], mean_coefficients))
# Draw coefficients values means of all seeds by dataset_estimator
for identifier_name, coefficients in average_coefficients.items():
estimator_name, dataset_name = identifier_name.split('_')
figure, ax = BaseSurvival.plot_coefficients(coefficients, estimator_name, dataset_name, seed)
# Create separate shap_explainer objects for each dataset_estimator
if dictionary_shap is not None:
shap_explainers = {}
for identifier_name, data in dictionary_shap.items():
shap_explainers[identifier_name] = SimpleNamespace(data=data['data_list'], values=data['values_list'], feature_names=data['feature_names'])
# Draw shap values of all seeds by dataset_estimator
for identifier_name, shap_explainer in shap_explainers.items():
estimator_name, dataset_name = identifier_name.split('_')
# Select index and scaler
selected_index = np.concatenate([dictionary_miscellany[identifier_name]['train_idx'], dictionary_miscellany[identifier_name]['val_idx']], axis=1)
selected_scaler = dictionary_miscellany[identifier_name]['scaler']
figure, ax = BaseSurvival.plot_shap(shap_explainer, selected_index, selected_scaler, estimator_name, dataset_name, seed)
if identifier_index is not None:
figure, ax = BaseSurvival.plot_individual_shap(shap_explainer, identifier_index, selected_index, selected_scaler, estimator_name, dataset_name, seed)
plt.show()
[docs]
def save_results(result_folder="./results", estimator_name=None, dataset=None, seed=None):
"""
Save the results for the given estimator name, dataset and seed.
"""
rf = ResultFolder(result_folder)
filtered_results = rf.filter(lambda result: _filter_search(result, estimator_name, dataset, seed))
# Define the columns from the config that we want to include in the dataframe
config_colums = [
"dataset",
"estimator_name",
"random_state"
]
df = filtered_results.create_dataframe(
config_columns=config_colums,
metrics_fn=from_results_to_metrics,
include_train=True,
include_val=False,
config_columns_prefix=""
)
# Columns that will be used to group the results and compute means
groups_columns = ["dataset", "estimator_name"]
create_excel_summary_report(df, 'report.xlsx', group_columns=groups_columns)