Source code for bsix.utils.estimators

import numpy as np

from .survival_metrics import scorerConcordanceIndex

from sklearn.metrics import make_scorer

CLASSIFIERS = [
    "BaseCoxRegression",
    "BaseCoxRegressionWithTimeVarying",
    "BaseRandomSurvivalForest",
    "BaseSurvivalTree",

    "AcceleratedFailureTime",
    "CoxRegression",
    "CoxRegressionWithTimeVarying",
    "DeepMultiTaskFFNN",
    "DeepMultiTaskMultiLossFFNN",
    "DeepSurvFFNN",
    "DeepTimeVaryingFFNN",
    "RandomSurvForest",
    "SurvTree",
]

[docs] def get_estimator(estimator_name, inputs, labels, valid_data, seed, n_jobs=-1, n_iter=30): """ Get estimator (search cv) based on name. """ if estimator_name in CLASSIFIERS: from sklearn.model_selection import RandomizedSearchCV #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-# if estimator_name == "AcceleratedFailureTime": from ..models import AcceleratedFailureTime param_grid = [ { "type": ["LogLogisticAFT", "WeibullAFT"], "penalizer": np.round(np.logspace(-1, 1, 3), 8), "l1_ratio": np.round(np.linspace(0, 1, 5, endpoint=False), 8), } ] estimator = AcceleratedFailureTime() elif estimator_name == "CoxRegression": from ..models import CoxRegression param_grid = [ { "alpha": np.round(np.logspace(-3, -1, 3), 8), "ties": ["efron", "breslow"], "n_iter": [100, 200, 300, 400, 500], } ] estimator = CoxRegression() elif estimator_name == "DeepSurvFFNN": from ..models import DeepSurv param_grid = [ { "epochs":[250, 500], "hidden_layers": [[4], [8], [16], [32]], "learn_rate": np.round(np.logspace(-5, -3, 3), 8), "lr_decay": np.round(np.logspace(-8, -6, 3), 8), "l1_reg": np.round(np.logspace(-5, -3, 3), 8), "l2_reg": np.round(np.logspace(-4, -2, 3), 8), "dropout": np.round(np.linspace(0.25, 0.75, 3), 8), "activation": ["relu", "selu", "tanh", "sigmoid"], } ] estimator = DeepSurv(inputs.shape[1], seed=seed) elif estimator_name == "RandomSurvForest": from ..models import RandomSurvForest param_grid = [ { "n_estimators": [100, 300, 500], "max_depth": [3, 5, 7], "min_samples_leaf": [2, 3, 5], "min_samples_split": [2, 6, 10], } ] estimator = RandomSurvForest(seed=seed) elif estimator_name == "SurvTree": from ..models import SurvTree param_grid = [ { "max_depth": [3, 5, 7], "min_samples_split": [2, 6, 10], "min_samples_leaf": [2, 3, 5], } ] estimator = SurvTree(seed=seed) #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-# elif estimator_name == "CoxRegressionWithTimeVarying": from ..models import CoxRegressionWithTimeVarying param_grid = [ { "alpha": np.round(np.logspace(-1, 1, 3), 8), "ties": ["efron", "breslow"], "n_iter": [100, 200, 300], } ] estimator = CoxRegressionWithTimeVarying() elif estimator_name == "DeepTimeVaryingFFNN": from ..models import DeepTimeVarying param_grid = [ { "epochs":[500], "hidden_layers": [[4]], "learn_rate": [0.001], "lr_decay": [1e-8], "l1_reg": [0.0001], "l2_reg": [0.0001], "dropout": [0.75], "activation": ["relu"], } ] estimator = DeepTimeVarying(inputs.shape[1], seed=seed) #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-# elif estimator_name == "DeepMultiTaskFFNN": from ..models import DeepMultiTask rng = np.random.default_rng(seed=seed) param_grid = [ { "epochs":[250, 500], "hidden_layers": [[4], [8], [16], [32]], "learn_rate": np.round(np.logspace(-5, -3, 3), 8), "lr_decay": np.round(np.logspace(-8, -6, 3), 8), "l1_reg": np.round(np.logspace(-5, -3, 3), 8), "l2_reg": np.round(np.logspace(-4, -2, 3), 8), "cox_reg": np.round(np.logspace(-3, 3, 7), 8), "dropout": np.round(np.linspace(0.25, 0.75, 3), 8), "activation": ["relu", "selu", "tanh"], "coef_likelihood": np.round(rng.dirichlet(alpha=np.ones(1 if labels.ndim == 1 else labels.shape[1]), size=7), 8).tolist(), } ] estimator = DeepMultiTask(inputs.shape[1], seed=seed) elif estimator_name == "DeepMultiTaskMultiLossFFNN": from ..models import DeepMultiTaskMultiLoss rng = np.random.default_rng(seed=seed) param_grid = [ { "epochs":[250, 500], "hidden_layers": [[4], [8], [16], [32]], "learn_rate": np.round(np.logspace(-5, -3, 3), 8), "lr_decay": np.round(np.logspace(-8, -6, 3), 8), "l1_reg": np.round(np.logspace(-5, -3, 3), 8), "l2_reg": np.round(np.logspace(-4, -2, 3), 8), "cox_reg": np.round(np.logspace(-3, 3, 7), 8), "bin_reg": np.round(np.logspace(-3, 3, 7), 8), "dropout": np.round(np.linspace(0.25, 0.75, 3), 8), "activation": ["relu", "selu", "tanh"], "coef_likelihood": np.round(rng.dirichlet(alpha=np.ones(1 if labels.ndim == 1 else labels.shape[1]), size=7), 8).tolist(), "coef_binary": np.round(rng.dirichlet(alpha=np.ones(1 if labels.ndim == 1 else labels.shape[1]), size=7), 8).tolist(), } ] estimator = DeepMultiTaskMultiLoss(inputs.shape[1], seed=seed) #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-# elif estimator_name == "BaseCoxRegression": from ..models import BaseCoxRegression param_grid = [ { "alpha": np.round(np.logspace(-3, -1, 3), 8), "ties": ["efron", "breslow"], "n_iter": [100, 200, 300, 400, 500], } ] estimator = BaseCoxRegression() elif estimator_name == "BaseCoxRegressionWithTimeVarying": from ..models import BaseCoxRegressionWithTimeVarying param_grid = [ { "penalizer": np.round(np.logspace(-1, 1, 3), 8), "l1_ratio": np.round(np.linspace(0, 1, 5, endpoint=False), 8), } ] estimator = BaseCoxRegressionWithTimeVarying() elif estimator_name == "BaseRandomSurvivalForest": from ..models import BaseRandomSurvivalForest param_grid = [ { "n_estimators": [100, 300, 500], "max_depth": [3, 5, 7], "min_samples_leaf": [2, 3, 5], "min_samples_split": [2, 6, 10], } ] estimator = BaseRandomSurvivalForest(seed=seed) elif estimator_name == "BaseSurvivalTree": from ..models import BaseSurvivalTree param_grid = [ { "max_depth": [3, 5, 7], "min_samples_split": [2, 6, 10], "min_samples_leaf": [2, 3, 5], } ] estimator = BaseSurvivalTree(seed=seed) #-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-#-# else: raise NotImplementedError( f"Estimator {estimator_name} not implemented in set_estimators function." ) if len(param_grid) > 0: return RandomizedSearchCV( estimator=estimator, param_distributions=param_grid, n_iter=n_iter, n_jobs=n_jobs, cv=valid_data, scoring=make_scorer(scorerConcordanceIndex, greater_is_better=True), error_score="raise", random_state=seed, verbose=10 ) else: return estimator else: raise ValueError(f"Estimator {estimator_name} not recognised.")