Skip to content

IHM Example Using HistGBT

This notebook shows an example of using HistGBT to model In-hospital mortality from MIMIC-III dataset.

Data is presumed to have been already extracted from cohort and defined via a yaml configuration as below:

# USER DEFINED
tgt_col: y_true
idx_cols: stay
time_order_col: 
    - Hours
    - seqnum

feat_cols: null

train:
    tgt_file: '{DATA_DIR}/IHM_V0_COHORT_OUT_EXP-SPLIT0-train.csv'
    feat_file: '{DATA_DIR}/IHM_V0_FEAT_EXP-SPLIT0-train.csv'

val:
    tgt_file: '{DATA_DIR}/IHM_V0_COHORT_OUT_EXP-SPLIT0-val.csv'
    feat_file: '{DATA_DIR}/IHM_V0_FEAT_EXP-SPLIT0-val.csv'

test:
    tgt_file: '{DATA_DIR}/IHM_V0_COHORT_OUT_EXP-SPLIT0-test.csv'
    feat_file: '{DATA_DIR}/IHM_V0_FEAT_EXP-SPLIT0-test.csv'

# DATA DEFINITIONS

## Definitions of categorical data in the dataset
category_map:
  Capillary refill rate: ['0.0', '1.0']
  Glascow coma scale eye opening: ['To Pain', '3 To speech', '1 No Response', '4 Spontaneously',
                                   'To Speech', 'Spontaneously', '2 To pain', 'None'] 
  Glascow coma scale motor response: ['1 No Response' , '3 Abnorm flexion' , 'Abnormal extension' , 'No response',
                                      '4 Flex-withdraws' , 'Localizes Pain' , 'Flex-withdraws' , 'Obeys Commands',
                                      'Abnormal Flexion' , '6 Obeys Commands' , '5 Localizes Pain' , '2 Abnorm extensn']
  Glascow coma scale total: ['11', '10', '13', '12', '15', '14', '3', '5', '4', '7', '6', '9', '8']
  Glascow coma scale verbal response: ['1 No Response', 'No Response', 'Confused', 'Inappropriate Words', 'Oriented', 
                                       'No Response-ETT', '5 Oriented', 'Incomprehensible sounds', '1.0 ET/Trach', 
                                       '4 Confused', '2 Incomp sounds', '3 Inapprop words']

numerical: ['Heart Rate', 'Fraction inspired oxygen', 'Weight', 'Respiratory rate', 
            'pH', 'Diastolic blood pressure', 'Glucose', 'Systolic blood pressure',
            'Height', 'Oxygen saturation', 'Temperature', 'Mean blood pressure']

## Definitions of normal values in the dataset
normal_values:
  Capillary refill rate: 0.0
  Diastolic blood pressure: 59.0
  Fraction inspired oxygen: 0.21
  Glucose: 128.0
  Heart Rate: 86
  Height: 170.0
  Mean blood pressure: 77.0
  Oxygen saturation: 98.0
  Respiratory rate: 19
  Systolic blood pressure: 118.0
  Temperature: 36.6
  Weight: 81.0
  pH: 7.4
  Glascow coma scale eye opening: '4 Spontaneously'
  Glascow coma scale motor response: '6 Obeys Commands'
  Glascow coma scale total:  '15'
  Glascow coma scale verbal response: '5 Oriented'

Preamble

The following code cell imports the required libraries and sets up the notebook

# Jupyter notebook specific imports
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')

# Imports injecting into namespace
from tqdm.auto import tqdm
tqdm.pandas()

# General imports
import os
import json
import pickle
from pathlib import Path

import pandas as pd
import numpy as np
from getpass import getpass
import argparse

from sklearn.preprocessing import StandardScaler
from sklearn.exceptions import NotFittedError

from lightsaber import constants as C
import lightsaber.data_utils.utils as du
from lightsaber.data_utils.pt_dataset import (filter_preprocessor)
from lightsaber.data_utils import sk_dataloader as skd
from lightsaber.trainers import sk_trainer as skr

from sklearn.ensemble import HistGradientBoostingClassifier

import logging
log = logging.getLogger()
data_dir = Path(os.environ.get('LS_DATA_PATH', './data'))
assert data_dir.is_dir()

conf_path = os.environ.get('LS_CONF_PATH', os.path.abspath('./ihm_expt_config.yml')) 
expt_conf = du.yaml.load(open(conf_path).read().format(DATA_DIR=data_dir),
                         Loader=du._Loader)

IHM Model Training

In general, we need to follow the following steps to train a HistGBT for IHM model.

  • Data Ingestion: The first step involves setting up the pre-processors to train an IHM model. In this example, we will use a StandardScaler from scikit-learn using filters defined within lightsaber.

  • We would next read the train, test, and validation dataset. In some cases, users may also want to define a calibration dataset

  • Model Definition: We would next need to define a base model for classification. In this example, we will use a standard scikit-learn::HistGBT model

  • Model Training: Once the models are defined, we can use lightsaber to train the model via the pre-packaged SKModel and the corresponding trainer code. This step will also generate the relevant metrics for this problem.

  • we will also show how to train a single hyper-parameter setting as well as a grid search over a pre-specified hyper-parameter space.

Data Ingestion

We firs start by reading extracted cohort data and use a StandardScaler demonstrating the proper usage of a pre-processor

flatten = 'sum'
preprocessor = StandardScaler()
train_filter = [filter_preprocessor(cols=expt_conf['numerical'], 
                                    preprocessor=preprocessor,
                                    refit=True),
               ]

train_dataloader = skd.SKDataLoader(tgt_file=expt_conf['train']['tgt_file'],
                                    feat_file=expt_conf['train']['feat_file'],
                                    idx_col=expt_conf['idx_cols'],
                                    tgt_col=expt_conf['tgt_col'],
                                    feat_columns=expt_conf['feat_cols'],
                                    time_order_col=expt_conf['time_order_col'],
                                    category_map=expt_conf['category_map'],
                                    filter=train_filter,
                                    fill_value=expt_conf['normal_values'],
                                    flatten=flatten,
                                   )
print(train_dataloader.shape, len(train_dataloader))

# For other datasets use fitted preprocessors
fitted_filter = [filter_preprocessor(cols=expt_conf['numerical'], 
                                     preprocessor=preprocessor, refit=False),
                 ]
val_dataloader = skd.SKDataLoader(tgt_file=expt_conf['val']['tgt_file'],
                                  feat_file=expt_conf['val']['feat_file'],
                                  idx_col=expt_conf['idx_cols'],
                                  tgt_col=expt_conf['tgt_col'],
                                  feat_columns=expt_conf['feat_cols'],
                                  time_order_col=expt_conf['time_order_col'],
                                  category_map=expt_conf['category_map'],
                                  filter=fitted_filter,
                                  fill_value=expt_conf['normal_values'],
                                  flatten=flatten,
                                )

test_dataloader = skd.SKDataLoader(tgt_file=expt_conf['test']['tgt_file'],
                                  feat_file=expt_conf['test']['feat_file'],
                                  idx_col=expt_conf['idx_cols'],
                                  tgt_col=expt_conf['tgt_col'],
                                  feat_columns=expt_conf['feat_cols'],
                                  time_order_col=expt_conf['time_order_col'],
                                  category_map=expt_conf['category_map'],
                                  filter=fitted_filter,
                                  fill_value=expt_conf['normal_values'],
                                  flatten=flatten,
                                )

print(val_dataloader.shape, len(val_dataloader))
print(test_dataloader.shape, len(test_dataloader))

Training a Single Model

Model definition

We can define a base classification model using standard scikit-learn workflow as below:

model_name = 'HistGBT'
hparams = argparse.Namespace(learning_rate=0.01,
                             max_iter=100,
                             l2_regularization=0.01
                             )

base_model = HistGradientBoostingClassifier(learning_rate=hparams.learning_rate, 
                                            l2_regularization=hparams.l2_regularization, 
                                            max_iter=hparams.max_iter)

wrapped_model = skr.SKModel(base_model, hparams, name=model_name)

Model training with in-built model tracking and evaluation

mlflow_conf = dict(experiment_name=f'classifier_ihm')
artifacts = dict(preprocessor=preprocessor)
experiment_tags = dict(model=model_name, 
                       tune=False)

(run_id, metrics, 
 val_y, val_yhat, val_pred_proba, 
 test_y, test_yhat, test_pred_proba) = skr.run_training_with_mlflow(mlflow_conf, 
                                                                    wrapped_model,
                                                                    train_dataloader=train_dataloader,
                                                                    val_dataloader=val_dataloader,
                                                                    test_dataloader=test_dataloader,
                                                                    artifacts=artifacts,
                                                                    **experiment_tags)

print(f"MLFlow Experiment: {mlflow_conf['experiment_name']} \t | Run ID: {run_id}")
print(metrics)

lightsaber also naturally supports hyper-parameter search to find the best model w.r.t.\ a pre-defined metric using the similar trace as above.

To conduct a grid-search we follow two steps:

  • we define a grid h_search over the model parameter space
  • We pass an experiment tag tune set to True along with the grid h_search to the trainer code
model_name = 'HistGBT'
hparams = argparse.Namespace(learning_rate=0.01,
                             max_iter=100,
                             l2_regularization=0.01
                             )
h_search = dict(
    learning_rate=[0.01, 0.1, 0.02],
    max_iter=[50, 100]
)

base_model = HistGradientBoostingClassifier(**vars(hparams))

wrapped_model = skr.SKModel(base_model, hparams, name=model_name)
mlflow_conf = dict(experiment_name=f'classifier_ihm')
artifacts = dict(preprocessor=preprocessor)
experiment_tags = dict(model=model_name, 
                       tune=True)

(run_id, metrics, 
 val_y, val_yhat, val_pred_proba, 
 test_y, test_yhat, test_pred_proba) = skr.run_training_with_mlflow(mlflow_conf, 
                                                                    wrapped_model,
                                                                    train_dataloader=train_dataloader,
                                                                    val_dataloader=val_dataloader,
                                                                    test_dataloader=test_dataloader,
                                                                    artifacts=artifacts,
                                                                    h_search=h_search,
                                                                    **experiment_tags)

print(f"MLFlow Experiment: {mlflow_conf['experiment_name']} \t | Run ID: {run_id}")
print(metrics)

IHM Model Registration

This block shows how to register a model for subsequent steps. Given a run_id this block can be run independtly of other aspects

Internally, the following steps happen:

  • a saved model (along with hyper-params and weights) is retrieved using run_id
  • model is initialized using the weights
  • model is logged to mlflow under registered model name
print(f"Registering model for run: {run_id}")
# Reading from yaml to log other artifacts
data_dir = Path(os.environ.get('LS_DATA_PATH', './data'))
assert data_dir.is_dir()

conf_path = os.environ.get('LS_CONF_PATH', os.path.abspath('./ihm_expt_config.yml')) 
expt_conf = du.yaml.load(open(conf_path).read().format(DATA_DIR=data_dir),
                         Loader=du._Loader)

mlflow_conf = dict(experiment_name=f'classifier_ihm')
registered_model_name = 'classifier_ihm_HistGBT_v0'

print("model ready to be registered") 
# Register model
skr.register_model_with_mlflow(run_id, mlflow_conf, 
                               registered_model_name=registered_model_name,
                               test_feat_file=expt_conf['test']['feat_file'],
                               test_tgt_file=expt_conf['test']['tgt_file'],
                               config=os.path.abspath('./ihm_expt_config.yml')
                              )

IHM Model Inference

Lightsaber also natively supports conducting inferences on new patients using the registered model. The key steps involve:

  • loading the registerd model from mlflow
  • Ingest the new test data using SKDataLoader in inference mode (setting tgt_file to None)
  • Use the SKModel.predict_patient method to generate inference for the patient of interest
print(f"Inference using model for run: {run_id}")
# Reading from yaml to log other artifacts
data_dir = Path(os.environ.get('LS_DATA_PATH', './data'))
assert data_dir.is_dir()

conf_path = os.environ.get('LS_CONF_PATH', os.path.abspath('./ihm_expt_config.yml')) 
expt_conf = du.yaml.load(open(conf_path).read().format(DATA_DIR=data_dir),
                         Loader=du._Loader)

mlflow_conf = dict(experiment_name=f'classifier_ihm')
registered_model_name = 'classifier_ihm_HistGBT_v0'

wrapped_model = skr.load_model_from_mlflow(run_id, mlflow_conf)
print("model ready to be inferred from") 
inference_dataloader = skd.SKDataLoader(tgt_file=None,
                                        feat_file=expt_conf['test']['feat_file'],
                                        idx_col=expt_conf['idx_cols'],
                                        tgt_col=expt_conf['tgt_col'],
                                        feat_columns=expt_conf['feat_cols'],
                                        time_order_col=expt_conf['time_order_col'],
                                        category_map=expt_conf['category_map'],
                                        filter=fitted_filter,
                                        fill_value=expt_conf['normal_values'],
                                        flatten=flatten,
                                        )

print(inference_dataloader.shape, len(inference_dataloader))
patient_id = inference_dataloader.sample_idx.index[0]
print(f"Inference for patient: {patient_id}")

# patient_id = '10011_episode1_timeseries.csv'
wrapped_model.predict_patient(patient_id, inference_dataloader)