Source code for econirl.datasets.icu_sepsis

"""ICU-Sepsis benchmark MDP dataset for clinical treatment IRL.

This module provides access to the ICU-Sepsis benchmark MDP derived from
MIMIC-III patient records by Komorowski et al. (2018). The MDP abstracts
ICU sepsis treatment into 716 discrete states (patient physiology clusters),
25 discrete actions (5 IV fluid levels times 5 vasopressor dose levels),
and transition probabilities estimated from real clinical data.

The dataset supports two use cases. First, load the raw MDP components
(transition matrices, rewards, expert policy) for direct policy evaluation
or custom analysis. Second, generate trajectory data by rolling out the
expert clinician policy through the MDP, producing Panel objects suitable
for IRL estimation with any econirl estimator.

The expert policy represents the aggregate treatment behavior of ICU
clinicians observed in MIMIC-III. Inverse reinforcement learning on this
data recovers the implicit reward function driving clinical decisions.

Reference:
    Komorowski, M., Celi, L.A., Badawi, O., Gordon, A.C., & Faisal, A.A.
    (2018). "The Artificial Intelligence Clinician learns optimal treatment
    strategies for sepsis in intensive care." Nature Medicine, 24, 1716-1720.

    Killian, T.W. et al. (2024). "ICU-Sepsis: A Benchmark MDP Built from
    Real Medical Data." NeurIPS Datasets and Benchmarks.
"""

from __future__ import annotations

from pathlib import Path
from typing import Union

import numpy as np
import pandas as pd

from econirl.core.types import Panel, Trajectory

_DATA_PATH = Path(__file__).parent / "icu_sepsis_mdp.npz"

# Special state indices
DEATH_STATE = 713
SURVIVAL_STATE = 714
ABSORBING_STATE = 715


[docs] def load_icu_sepsis_mdp(data_path: str | Path | None = None) -> dict: """Load the raw ICU-Sepsis MDP components as numpy arrays. Returns a dictionary with the following keys: - transitions: shape (25, 716, 716), ``transitions[a, s, s'] = P(s'|s,a)`` - rewards: shape (716,), state rewards (+1 at survival state 714) - initial_distribution: shape (716,), starting state probabilities - expert_policy: shape (716, 25), clinician behavior policy ``pi(a|s)`` - sofa_scores: shape (716,), mean SOFA score per state cluster Args: data_path: Path to NPZ file. If None, uses bundled data. Returns: Dictionary of numpy arrays. """ path = Path(data_path) if data_path is not None else _DATA_PATH if not path.exists(): raise FileNotFoundError( f"ICU-Sepsis MDP data not found at {path}. " "Ensure the bundled icu_sepsis_mdp.npz file is present." ) data = np.load(path) return { "transitions": np.array(data["transitions"]), "rewards": np.array(data["rewards"]), "initial_distribution": np.array(data["initial_distribution"]), "expert_policy": np.array(data["expert_policy"]), "sofa_scores": np.array(data["sofa_scores"]), }
[docs] def load_icu_sepsis( n_individuals: int = 500, max_steps: int = 20, as_panel: bool = False, seed: int = 42, data_path: str | Path | None = None, ) -> Union[pd.DataFrame, Panel]: """Generate ICU sepsis treatment trajectories from the expert policy. Simulates patient trajectories by rolling out the clinician behavior policy (from MIMIC-III) through the MDP transition dynamics. Each trajectory starts from the empirical initial state distribution and runs until the patient reaches the absorbing terminal state or the maximum number of steps is reached. Args: n_individuals: Number of patient trajectories to generate. max_steps: Maximum steps per trajectory. ICU stays in MIMIC-III are typically 5 to 20 four-hour windows. as_panel: If True, return a Panel object for econirl estimators. seed: Random seed for reproducibility. data_path: Path to NPZ file. If None, uses bundled data. Returns: DataFrame with columns: patient_id, period, state, action, next_state, sofa_score, fluid_level, vaso_level, reward, terminated. If as_panel=True, returns a Panel object. """ mdp = load_icu_sepsis_mdp(data_path) transitions = mdp["transitions"] # (25, 716, 716) expert_policy = mdp["expert_policy"] # (716, 25) initial_dist = mdp["initial_distribution"] # (716,) sofa = mdp["sofa_scores"] # (716,) rewards = mdp["rewards"] # (716,) rng = np.random.default_rng(seed) records = [] trajectories = [] for i in range(n_individuals): state = int(rng.choice(716, p=initial_dist)) states_list = [] actions_list = [] next_states_list = [] for t in range(max_steps): if state == ABSORBING_STATE: break policy = expert_policy[state] policy_sum = policy.sum() if policy_sum > 0: action = int(rng.choice(25, p=policy / policy_sum)) else: action = int(rng.integers(25)) probs = transitions[action, state, :] next_state = int(rng.choice(716, p=probs)) terminated = next_state == ABSORBING_STATE fluid_level = action // 5 vaso_level = action % 5 records.append({ "patient_id": i, "period": t, "state": state, "action": action, "next_state": next_state, "sofa_score": float(sofa[state]), "fluid_level": fluid_level, "vaso_level": vaso_level, "reward": float(rewards[state]), "terminated": terminated, }) states_list.append(state) actions_list.append(action) next_states_list.append(next_state) state = next_state if states_list: trajectories.append(Trajectory( individual_id=i, states=np.array(states_list, dtype=np.int32), actions=np.array(actions_list, dtype=np.int32), next_states=np.array(next_states_list, dtype=np.int32), )) if as_panel: return Panel(trajectories=trajectories) return pd.DataFrame(records)
[docs] def get_icu_sepsis_info() -> dict: """Return metadata about the ICU-Sepsis dataset.""" return { "name": "ICU-Sepsis Benchmark MDP", "description": ( "Sepsis treatment MDP derived from MIMIC-III ICU records. " "716 states (patient physiology clusters), 25 actions " "(5 IV fluid x 5 vasopressor dose levels)." ), "source": "Komorowski et al. (2018) Nature Medicine; Killian et al. (2024) NeurIPS", "license": "MIT", "n_states": 716, "n_actions": 25, "n_features": 4, "state_description": "Patient physiology clusters from k-means on MIMIC-III vitals", "action_description": "5 IV fluid levels x 5 vasopressor dose levels", "special_states": { "death": DEATH_STATE, "survival": SURVIVAL_STATE, "absorbing": ABSORBING_STATE, }, "ground_truth": False, "use_case": "Clinical treatment IRL, offline policy evaluation", }