Source code for econirl.estimators.mceirl_neural

"""MCEIRLNeural: Neural Maximum Causal Entropy IRL.

Supports two reward parameterizations:
- ``reward_type="state_action"`` (default): learns R(s,a) via a neural
  network that takes [state_features, action_onehot] as input.  This is
  more general and correctly handles environments with action-dependent
  rewards (e.g., gridworlds where moving has a cost but staying is free).
- ``reward_type="state"``: learns R(s) only, broadcasting the same reward
  to all actions (original behaviour).

Training loop (MCE-IRL objective, Ziebart 2010):
    for epoch in range(max_epochs):
        1. Compute reward matrix R(s,a) for all (state, action) pairs
        2. Solve soft Bellman with this reward (transitions required)
        3. Compute state visitation frequencies via forward pass
        4. Loss = -E_expert[R] + E_policy[R]  (feature matching)
        5. Backprop through reward network

After training, implied rewards are projected onto features via
least-squares to extract interpretable theta (same as NeuralGLADIUS).

Reference:
    Ziebart, B. D. (2010). Modeling purposeful adaptive behavior with the
        principle of maximum causal entropy. PhD thesis, CMU.
    Wulfmeier, M., Ondruska, P., & Posner, I. (2015). Maximum entropy
        deep inverse reinforcement learning. arXiv:1507.04888.
"""

from __future__ import annotations

from typing import Callable

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd
from scipy.stats import norm as scipy_norm

from econirl.core.bellman import SoftBellmanOperator
from econirl.core.occupancy import compute_state_action_visitation
from econirl.core.reward_spec import RewardSpec
from econirl.core.solvers import hybrid_iteration, value_iteration
from econirl.core.types import DDCProblem, Panel, TrajectoryPanel
from econirl.estimators.neural_base import NeuralEstimatorMixin


# ---------------------------------------------------------------------------
# Internal network modules (Equinox)
# ---------------------------------------------------------------------------


class _StateRewardNetwork(eqx.Module):
    """R(s) reward network.

    Input: state features of shape (state_dim,).
    Output: scalar reward.
    """

    layers: list
    output_layer: eqx.nn.Linear

    def __init__(
        self,
        state_dim: int,
        hidden_dim: int,
        num_layers: int,
        *,
        key: jax.Array,
    ):
        keys = jax.random.split(key, num_layers + 1)
        layers = []
        in_dim = state_dim
        for i in range(num_layers):
            layers.append(eqx.nn.Linear(in_dim, hidden_dim, key=keys[i]))
            in_dim = hidden_dim
        self.layers = layers
        self.output_layer = eqx.nn.Linear(in_dim, 1, key=keys[-1])

    def __call__(self, state_feat: jax.Array) -> jax.Array:
        """Compute R(s) for a single state.

        Parameters
        ----------
        state_feat : jax.Array
            State features of shape (state_dim,).

        Returns
        -------
        jax.Array
            Scalar reward.
        """
        x = state_feat
        for layer in self.layers:
            x = jax.nn.relu(layer(x))
        return self.output_layer(x).squeeze(-1)


class _StateActionRewardNetwork(eqx.Module):
    """R(s,a) reward network.

    Input: concatenation of state features (state_dim,) and action
    one-hot encoding (n_actions,).
    Output: scalar reward.
    """

    layers: list
    output_layer: eqx.nn.Linear
    _n_actions: int = eqx.field(static=True)

    def __init__(
        self,
        state_dim: int,
        n_actions: int,
        hidden_dim: int,
        num_layers: int,
        *,
        key: jax.Array,
    ):
        self._n_actions = n_actions
        input_dim = state_dim + n_actions
        keys = jax.random.split(key, num_layers + 1)
        layers = []
        in_dim = input_dim
        for i in range(num_layers):
            layers.append(eqx.nn.Linear(in_dim, hidden_dim, key=keys[i]))
            in_dim = hidden_dim
        self.layers = layers
        self.output_layer = eqx.nn.Linear(in_dim, 1, key=keys[-1])

    def __call__(
        self, state_feat: jax.Array, action_onehot: jax.Array
    ) -> jax.Array:
        """Compute R(s,a) for a single (state, action) pair.

        Parameters
        ----------
        state_feat : jax.Array
            State features of shape (state_dim,).
        action_onehot : jax.Array
            One-hot action encoding of shape (n_actions,).

        Returns
        -------
        jax.Array
            Scalar reward.
        """
        x = jnp.concatenate([state_feat, action_onehot])
        for layer in self.layers:
            x = jax.nn.relu(layer(x))
        return self.output_layer(x).squeeze(-1)

    def all_actions(self, state_feat: jax.Array) -> jax.Array:
        """Compute R(s,a) for all actions at every state.

        Parameters
        ----------
        state_feat : jax.Array
            State features of shape (S, state_dim).

        Returns
        -------
        jax.Array
            Reward matrix of shape (S, A).
        """
        S = state_feat.shape[0]
        A = self._n_actions
        eye = jnp.eye(A)
        # Expand state features: (S, state_dim) -> (S*A, state_dim)
        sf_expanded = jnp.repeat(state_feat, A, axis=0)
        # Tile action one-hots: (A, A) -> (S*A, A)
        act_expanded = jnp.tile(eye, (S, 1))
        # Apply network to all (state, action) pairs in one vmap call
        rewards = jax.vmap(self)(sf_expanded, act_expanded)
        return rewards.reshape(S, A)


# ---------------------------------------------------------------------------
# MCEIRLNeural estimator
# ---------------------------------------------------------------------------


[docs] class MCEIRLNeural(NeuralEstimatorMixin): """Neural Maximum Causal Entropy IRL. Learns a neural reward function using the MCE-IRL objective: maximize E_expert[R] - log Z(R) where Z(R) is the partition function (soft value at initial state). Supports two reward types: - ``reward_type="state_action"`` (default): R(s,a) via a network that takes [state_features, action_onehot]. This is more general and correctly handles action-dependent rewards. - ``reward_type="state"``: R(s) broadcast to all actions (original). For v1, transitions must be available so that exact soft value iteration and state visitation frequencies can be computed. Parameters ---------- n_states : int, optional Number of discrete states. Inferred from data if None. n_actions : int, optional Number of discrete actions. Inferred from data if None. discount : float, default=0.95 Time discount factor beta. reward_type : str, default="state_action" Type of reward function: ``"state_action"`` for R(s,a) or ``"state"`` for R(s) broadcast to all actions. reward_hidden_dim : int, default=64 Hidden dimension for the reward MLP. reward_num_layers : int, default=2 Number of hidden layers in the reward MLP. max_epochs : int, default=200 Maximum number of training epochs. lr : float, default=1e-3 Learning rate for Adam optimizer. inner_solver : str, default="hybrid" Solver for soft value iteration: "hybrid" or "value". inner_tol : float, default=1e-8 Convergence tolerance for inner solver. inner_max_iter : int, default=5000 Maximum iterations for inner solver. state_encoder : callable, optional Function mapping state indices (int array) to feature vectors. Receives shape (B,) and should return shape (B, state_dim). If None, a default normalizing encoder is created. state_dim : int, optional Dimension of state features. Required if state_encoder is provided. feature_names : list of str, optional Names for features when projecting rewards onto linear features. anchor_action : int, optional Action whose reward is fixed to zero. This is useful for identified action-dependent IRL designs with a normalized outside/exit action. absorbing_state : int, optional State whose reward row is fixed to zero. seed : int, default=0 Random seed for network initialization. verbose : bool, default=False Whether to print progress during training. Attributes ---------- params_ : dict or None Projected structural parameters after fitting. None if no features were provided for projection. se_ : dict or None Pseudo standard errors from the projection regression. pvalues_ : dict or None P-values from Wald t-test on pseudo SEs. coef_ : numpy.ndarray or None Coefficient array (same values as ``params_`` in array form). policy_ : numpy.ndarray or None Estimated choice probabilities P(a|s) of shape (n_states, n_actions). value_ : numpy.ndarray or None Estimated value function V(s) of shape (n_states,). reward_ : numpy.ndarray or None Neural reward. Shape (n_states,) for ``reward_type="state"`` or (n_states, n_actions) for ``reward_type="state_action"``. projection_r2_ : float or None R-squared of the feature projection. converged_ : bool or None Whether training converged. n_epochs_ : int or None Number of training epochs completed. Examples -------- >>> from econirl.estimators import MCEIRLNeural >>> import numpy as np >>> >>> # R(s,a) -- default, more general >>> model = MCEIRLNeural(n_states=25, n_actions=4, discount=0.95) >>> model.fit(data=df, state="state", action="action", id="agent_id", ... transitions=T) >>> print(model.reward_.shape) # (25, 4) >>> print(model.policy_.shape) # (25, 4) >>> >>> # R(s) -- state-only, backward compatible >>> model = MCEIRLNeural(n_states=25, n_actions=4, reward_type="state") >>> model.fit(...) >>> print(model.reward_.shape) # (25,) """
[docs] def __init__( self, n_states: int | None = None, n_actions: int | None = None, discount: float = 0.95, # Reward type reward_type: str = "state_action", # Network reward_hidden_dim: int = 64, reward_num_layers: int = 2, # Training max_epochs: int = 200, lr: float = 1e-3, # Inner solver inner_solver: str = "hybrid", inner_tol: float = 1e-8, inner_max_iter: int = 5000, # Encoders state_encoder: Callable | None = None, state_dim: int | None = None, # Projection feature_names: list[str] | None = None, anchor_action: int | None = None, absorbing_state: int | None = None, seed: int = 0, verbose: bool = False, ): if reward_type not in ("state", "state_action"): raise ValueError( f"reward_type must be 'state' or 'state_action', " f"got '{reward_type}'" ) self.n_states = n_states self.n_actions = n_actions self.discount = discount self.reward_type = reward_type self.reward_hidden_dim = reward_hidden_dim self.reward_num_layers = reward_num_layers self.max_epochs = max_epochs self.lr = lr self.inner_solver = inner_solver self.inner_tol = inner_tol self.inner_max_iter = inner_max_iter self.state_encoder = state_encoder self.state_dim = state_dim self.feature_names = feature_names self.anchor_action = anchor_action self.absorbing_state = absorbing_state self.seed = seed self.verbose = verbose # Fitted attributes (set after fit()) self.params_: dict[str, float] | None = None self.se_: dict[str, float] | None = None self.pvalues_: dict[str, float] | None = None self.coef_: np.ndarray | None = None self.policy_: np.ndarray | None = None self.value_: np.ndarray | None = None self.reward_: np.ndarray | None = None self.projection_r2_: float | None = None self.converged_: bool | None = None self.n_epochs_: int | None = None self.feature_difference_: float | None = None self.occupancy_moment_residual_: float | None = None # Internal state self._reward_net = None self._state_encoder: Callable | None = None self._state_dim: int | None = None self._n_states: int | None = None self._n_actions: int | None = None self._empirical_sa: jnp.ndarray | None = None self._initial_distribution: jnp.ndarray | None = None
[docs] def fit( self, data: pd.DataFrame | Panel | TrajectoryPanel, state: str | None = None, action: str | None = None, id: str | None = None, features: RewardSpec | np.ndarray | None = None, transitions: np.ndarray | None = None, context: object = None, ) -> "MCEIRLNeural": """Fit the MCEIRLNeural estimator to data. Parameters ---------- data : pandas.DataFrame or Panel or TrajectoryPanel Panel data with demonstrations. When a DataFrame is passed, ``state``, ``action``, and ``id`` column names are required. state : str, optional Column name for the state variable (required for DataFrame). action : str, optional Column name for the action variable (required for DataFrame). id : str, optional Column name for the individual identifier (required for DataFrame). features : RewardSpec or numpy.ndarray, optional Feature specification for parameter projection. If provided, the neural reward is projected onto these features to extract interpretable theta. transitions : numpy.ndarray Transition matrices ``P(s'|s,a)``, shape (n_actions, n_states, n_states). Required for v1 (exact soft value iteration). context : ignored Accepted for API compatibility but not used. Returns ------- self : MCEIRLNeural Returns self for method chaining. """ if transitions is None: raise ValueError( "MCEIRLNeural v1 requires transitions. " "Pass transitions as (n_actions, n_states, n_states) array." ) # --- Step 1: Extract arrays from data --- panel, all_states, all_actions, all_next = self._extract_data( data, state, action, id ) n_states = self.n_states or int(all_states.max()) + 1 n_actions = self.n_actions or int(all_actions.max()) + 1 if self.anchor_action is not None and not 0 <= self.anchor_action < n_actions: raise ValueError( f"anchor_action must be in [0, {n_actions}), got {self.anchor_action}" ) if self.absorbing_state is not None and not 0 <= self.absorbing_state < n_states: raise ValueError( f"absorbing_state must be in [0, {n_states}), got {self.absorbing_state}" ) self._n_states = n_states self._n_actions = n_actions # Convert transitions to JAX transitions_jax = jnp.asarray(transitions, dtype=jnp.float32) # --- Step 2: Build encoder --- self._build_encoder(n_states) # --- Step 3: Compute empirical state-action occupancy --- empirical_sa = self._compute_empirical_occupancy( panel, n_states, n_actions, discount=self.discount ) self._empirical_sa = empirical_sa self._initial_distribution = self._compute_initial_distribution( panel, n_states ) # --- Step 4: Build reward network --- key = jax.random.PRNGKey(self.seed) if self.reward_type == "state_action": self._reward_net = _StateActionRewardNetwork( self._state_dim, n_actions, self.reward_hidden_dim, self.reward_num_layers, key=key, ) else: self._reward_net = _StateRewardNetwork( self._state_dim, self.reward_hidden_dim, self.reward_num_layers, key=key, ) # --- Step 5: Training loop --- self._train_mce( transitions_jax, empirical_sa, n_states, n_actions, ) # --- Step 6: Extract policy, value, and reward --- self._extract_final(transitions_jax, n_states, n_actions) # --- Step 7: Feature projection --- if features is not None: self._project_onto_features(features, n_states, n_actions) else: self.params_ = None self.se_ = None self.pvalues_ = None self.projection_r2_ = None self.coef_ = None return self
# ------------------------------------------------------------------ # Data extraction # ------------------------------------------------------------------ def _extract_data( self, data: pd.DataFrame | Panel | TrajectoryPanel, state: str | None, action: str | None, id: str | None, ) -> tuple[TrajectoryPanel, np.ndarray, np.ndarray, np.ndarray]: """Extract state/action/next_state arrays from input data.""" if isinstance(data, pd.DataFrame): if state is None or action is None or id is None: raise ValueError( "state, action, and id column names are required " "when data is a DataFrame" ) panel = TrajectoryPanel.from_dataframe( data, state=state, action=action, id=id ) all_states = np.asarray(panel.all_states, dtype=np.int64) all_actions = np.asarray(panel.all_actions, dtype=np.int64) all_next = np.asarray(panel.all_next_states, dtype=np.int64) elif isinstance(data, (Panel, TrajectoryPanel)): panel = TrajectoryPanel.from_panel(data) all_states = np.asarray(panel.get_all_states(), dtype=np.int64) all_actions = np.asarray(panel.get_all_actions(), dtype=np.int64) all_next = np.asarray(panel.get_all_next_states(), dtype=np.int64) else: raise TypeError( f"data must be a DataFrame, Panel, or TrajectoryPanel, " f"got {type(data)}" ) return panel, all_states, all_actions, all_next # ------------------------------------------------------------------ # Encoder setup # ------------------------------------------------------------------ def _build_encoder(self, n_states: int) -> None: """Build default state encoder if not provided.""" if self.state_encoder is not None: self._state_encoder = self.state_encoder self._state_dim = self.state_dim or 1 else: max_s = max(n_states - 1, 1) def _default_encoder(s, _ms=max_s): s_float = jnp.asarray(s, dtype=jnp.float32) return (s_float / _ms).reshape(-1, 1) self._state_encoder = _default_encoder self._state_dim = 1 # ------------------------------------------------------------------ # Empirical occupancy # ------------------------------------------------------------------ def _compute_empirical_occupancy( self, panel: TrajectoryPanel, n_states: int, n_actions: int, discount: float = 1.0, ) -> jnp.ndarray: """Compute empirical state-action occupancy from demonstrations. Returns ------- jnp.ndarray State-action occupancy of shape (n_states, n_actions). Normalized to sum to 1. """ sa_counts = np.zeros((n_states, n_actions), dtype=np.float32) total = 0.0 for traj in panel.trajectories: states = np.asarray(traj.states, dtype=np.int64) actions = np.asarray(traj.actions, dtype=np.int64) if len(states) == 0: continue if discount == 1.0: weights = np.ones(len(states), dtype=np.float32) else: weights = np.power(float(discount), np.arange(len(states))).astype( np.float32 ) flat_idx = states * n_actions + actions np.add.at(sa_counts.ravel(), flat_idx, weights) total += float(weights.sum()) if total > 0: sa_counts = sa_counts / total return jnp.array(sa_counts) def _compute_initial_distribution( self, panel: TrajectoryPanel, n_states: int, ) -> jnp.ndarray: """Compute the empirical initial-state distribution.""" counts = np.zeros(n_states, dtype=np.float32) for traj in panel.trajectories: if len(traj.states): counts[int(traj.states[0])] += 1.0 total = counts.sum() if total > 0: counts = counts / total else: counts[:] = 1.0 / n_states return jnp.asarray(counts) # ------------------------------------------------------------------ # Training # ------------------------------------------------------------------ def _compute_reward_matrix( self, reward_net, state_feat: jax.Array, n_states: int, n_actions: int, ) -> jax.Array: """Compute R(s,a) for all states and actions.""" if self.reward_type == "state_action": rewards = reward_net.all_actions(state_feat) else: rewards_s = jax.vmap(reward_net)(state_feat) rewards = jnp.broadcast_to( rewards_s[:, None], (n_states, n_actions) ) if self.absorbing_state is not None: rewards = rewards.at[int(self.absorbing_state), :].set(0.0) if self.anchor_action is not None: rewards = rewards.at[:, int(self.anchor_action)].set(0.0) return rewards def _train_mce( self, transitions: jnp.ndarray, empirical_sa: jnp.ndarray, n_states: int, n_actions: int, ) -> None: """Run MCE-IRL training with neural reward network. Training loop: 1. Forward: compute R(s,a) for all states and actions 2. Solve soft Bellman: V, policy = soft_value_iteration(R, transitions) 3. Compute state visitation: D(s) = forward_pass(policy, transitions) 4. Expected occupancy: E_policy[sa] = D(s) * pi(a|s) 5. Gradient: grad_R = policy_sa - empirical_sa 6. Backprop through reward network via surrogate loss """ optimizer = optax.chain( optax.clip_by_global_norm(1.0), optax.adamw(learning_rate=self.lr, weight_decay=1e-5), ) opt_state = optimizer.init(eqx.filter(self._reward_net, eqx.is_array)) problem = DDCProblem( num_states=n_states, num_actions=n_actions, discount_factor=self.discount, scale_parameter=1.0, ) bellman = SoftBellmanOperator(problem=problem, transitions=transitions) best_loss = float("inf") best_net = self._reward_net patience_counter = 0 patience = 100 all_state_indices = jnp.arange(n_states) reward_net = self._reward_net from tqdm import tqdm pbar = tqdm( range(self.max_epochs), desc="MCE-IRL-NN", disable=not self.verbose, leave=True, ) for epoch in pbar: # 1. Compute reward matrix R(s,a) (no gradient tracking needed here) state_feat = self._state_encoder(all_state_indices) reward_matrix = self._compute_reward_matrix( reward_net, state_feat, n_states, n_actions ) # 2. Solve soft Bellman (no gradient through VI) if self.inner_solver == "hybrid": result = hybrid_iteration( bellman, reward_matrix, tol=self.inner_tol, max_iter=self.inner_max_iter, ) else: result = value_iteration( bellman, reward_matrix, tol=self.inner_tol, max_iter=self.inner_max_iter, ) policy = result.policy # 3. Compute state-action occupancy via discounted forward pass policy_sa = self._forward_pass( policy, transitions, n_states, self.discount ) # 4. Feature matching gradient w.r.t. R(s,a) grad_r = policy_sa - empirical_sa # 5. Compute network parameter gradients via surrogate loss. # The surrogate loss L = sum(R * grad_r) has gradient # dL/d_params = sum(grad_r * dR/d_params), which is exactly # the chain rule for the MCE-IRL objective. def surrogate_loss(net): R = self._compute_reward_matrix( net, state_feat, n_states, n_actions ) return jnp.sum(R * grad_r) loss_val_jax, grads = eqx.filter_value_and_grad(surrogate_loss)( reward_net ) updates, opt_state = optimizer.update( grads, opt_state, eqx.filter(reward_net, eqx.is_array) ) reward_net = eqx.apply_updates(reward_net, updates) # Monitor feature matching residual loss_val = float(jnp.sum(grad_r ** 2)) feature_diff = float(jnp.linalg.norm(empirical_sa - policy_sa)) pbar.set_postfix({ "loss": f"{loss_val:.4f}", "fdiff": f"{feature_diff:.4f}", "best": f"{best_loss:.4f}", "no_imp": patience_counter, }) # Early stopping with best model checkpoint if loss_val < best_loss - 1e-5: best_loss = loss_val best_net = reward_net patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: if self.verbose: print(f" Early stopping at epoch {epoch + 1}") break # Restore best model self._reward_net = best_net self.converged_ = patience_counter >= patience or epoch == self.max_epochs - 1 self.n_epochs_ = epoch + 1 self.feature_difference_ = float(np.sqrt(best_loss)) def _forward_pass( self, policy: jnp.ndarray, transitions: jnp.ndarray, n_states: int, discount: float, ) -> jnp.ndarray: """Compute normalized discounted state-action visitation. Parameters ---------- policy : jnp.ndarray Policy pi(a|s), shape (n_states, n_actions). transitions : jnp.ndarray Transition matrices ``P(s'|s,a)``, shape (n_actions, n_states, n_states). n_states : int Number of states. discount : float Discount factor. Returns ------- jnp.ndarray State-action visitation frequencies, shape (n_states, n_actions). """ problem = DDCProblem( num_states=n_states, num_actions=policy.shape[1], discount_factor=discount, scale_parameter=1.0, ) return compute_state_action_visitation( policy, transitions, problem, self._initial_distribution, ) # ------------------------------------------------------------------ # Post-training extraction # ------------------------------------------------------------------ def _extract_final( self, transitions: jnp.ndarray, n_states: int, n_actions: int, ) -> None: """Extract policy, value function, and reward from trained network.""" all_state_indices = jnp.arange(n_states) state_feat = self._state_encoder(all_state_indices) reward_matrix = self._compute_reward_matrix( self._reward_net, state_feat, n_states, n_actions ) problem = DDCProblem( num_states=n_states, num_actions=n_actions, discount_factor=self.discount, scale_parameter=1.0, ) bellman = SoftBellmanOperator( problem=problem, transitions=transitions ) if self.inner_solver == "hybrid": result = hybrid_iteration( bellman, reward_matrix, tol=self.inner_tol, max_iter=self.inner_max_iter, ) else: result = value_iteration( bellman, reward_matrix, tol=self.inner_tol, max_iter=self.inner_max_iter, ) self.policy_ = np.asarray(result.policy) self.value_ = np.asarray(result.V) if self.reward_type == "state_action": self.reward_ = np.asarray(reward_matrix) else: rewards_s = jax.vmap(self._reward_net)(state_feat) self.reward_ = np.asarray(rewards_s) if self._empirical_sa is not None: policy_sa = self._forward_pass( result.policy, transitions, n_states, self.discount, ) residual = self._empirical_sa - policy_sa self.feature_difference_ = float(jnp.linalg.norm(residual)) self.occupancy_moment_residual_ = float(jnp.max(jnp.abs(residual))) def _project_onto_features( self, features: RewardSpec | np.ndarray, n_states: int, n_actions: int, ) -> None: """Project neural rewards onto features for interpretable theta. For ``reward_type="state_action"``, R(s,a) is projected onto (S*A, K) features: theta = argmin ||Phi_flat @ theta - R_flat||^2 For ``reward_type="state"``, R(s) is projected onto (S, K) state features (original behaviour). Parameters ---------- features : RewardSpec or numpy.ndarray Feature specification. RewardSpec provides (S, A, K) matrix. An array of shape (S, K) or (S, A, K) is also accepted. n_states : int Number of states. n_actions : int Number of actions. """ # Extract feature matrix and names if isinstance(features, RewardSpec): feat_3d = np.asarray(features.feature_matrix) names = features.parameter_names else: feat_arr = np.asarray(features) if feat_arr.ndim == 3: feat_3d = feat_arr elif feat_arr.ndim == 2: # (S, K) -> broadcast to (S, A, K) feat_3d = np.broadcast_to( feat_arr[:, None, :], (feat_arr.shape[0], n_actions, feat_arr.shape[1]), ).copy() else: raise ValueError( f"features must be 2D (S, K) or 3D (S, A, K), " f"got {feat_arr.ndim}D" ) names = self.feature_names or [ f"f{i}" for i in range(feat_3d.shape[-1]) ] rewards = self.reward_.astype(np.float32) if self.reward_type == "state_action": phi = feat_3d.reshape(-1, feat_3d.shape[-1]).astype(np.float32) r_flat = rewards.reshape(-1) else: phi = feat_3d[:, 0, :].astype(np.float32) r_flat = rewards theta, se, r2 = self._project_parameters(phi, r_flat) self.params_ = {n: float(v) for n, v in zip(names, theta)} self.se_ = {n: float(v) for n, v in zip(names, se)} self.pvalues_ = self._compute_pvalues(self.params_, self.se_) self.projection_r2_ = r2 self.coef_ = theta # ------------------------------------------------------------------ # Prediction methods # ------------------------------------------------------------------ @property def reward_matrix_(self) -> np.ndarray | None: """Reward matrix R(s,a) of shape (n_states, n_actions). For ``reward_type="state_action"``, ``self.reward_`` already has shape (n_states, n_actions) and is returned directly. For ``reward_type="state"``, the state-only reward is broadcast to all actions. """ if self.reward_ is None: return None if self.reward_.ndim == 2: return self.reward_ # State-only reward: broadcast to all actions n_actions = self._n_actions or self.n_actions return np.tile(self.reward_[:, np.newaxis], (1, n_actions))
[docs] def predict_proba(self, states: np.ndarray) -> np.ndarray: """Predict choice probabilities for given states. Parameters ---------- states : numpy.ndarray Array of state indices. Returns ------- numpy.ndarray Choice probabilities of shape (len(states), n_actions). Raises ------ RuntimeError If the model has not been fitted yet. """ if self.policy_ is None: raise RuntimeError("Model not fitted. Call fit() first.") states = np.asarray(states, dtype=np.int64) return self.policy_[states]
# ------------------------------------------------------------------ # Inference # ------------------------------------------------------------------
[docs] def conf_int(self, alpha: float = 0.05) -> dict[str, tuple[float, float]]: """Compute confidence intervals for projected parameters. Parameters ---------- alpha : float, default=0.05 Significance level. Returns (1 - alpha) confidence intervals. Returns ------- dict ``{param_name: (lower, upper)}`` confidence intervals. Raises ------ RuntimeError If no projected parameters are available. """ if self.params_ is None or self.se_ is None: raise RuntimeError( "No projected parameters available. " "Call fit() with features= to extract structural parameters." ) z = scipy_norm.ppf(1 - alpha / 2) intervals: dict[str, tuple[float, float]] = {} for name in self.params_: est = self.params_[name] se = self.se_[name] if np.isfinite(se): intervals[name] = (est - z * se, est + z * se) else: intervals[name] = (float("nan"), float("nan")) return intervals
# ------------------------------------------------------------------ # Summary # ------------------------------------------------------------------
[docs] def summary(self) -> str: """Generate a formatted summary of estimation results. Returns ------- str Human-readable summary including neural reward info, parameter estimates, and projection R-squared. """ if self.policy_ is None: return "MCEIRLNeural: Not fitted yet. Call fit() first." return self._format_neural_summary( method_name="MCEIRLNeural (Deep MCE-IRL)", params=self.params_, se=self.se_, pvalues=self.pvalues_, projection_r2=self.projection_r2_, n_observations=self._n_states, n_epochs=self.n_epochs_, converged=self.converged_, discount=self.discount, extra_lines=[ f"Reward type: {self.reward_type}", f"Reward network: {self.reward_num_layers} layers x {self.reward_hidden_dim} hidden", f"Inner solver: {self.inner_solver}", f"Anchor action: {self.anchor_action}", f"Absorbing state: {self.absorbing_state}", ], )
# ------------------------------------------------------------------ # Repr # ------------------------------------------------------------------ def __repr__(self) -> str: fitted = self.policy_ is not None return ( f"MCEIRLNeural(n_states={self._n_states or self.n_states}, " f"n_actions={self._n_actions or self.n_actions}, " f"discount={self.discount}, " f"fitted={fitted})" )