Source code for econirl.estimators.neural_gladius

"""NeuralGLADIUS: Context-aware Q-learning with Bellman consistency penalty.

Learns Q(s,a,ctx) and EV(s,a,ctx) via mini-batch training, then extracts
structural parameters by projecting implied rewards onto features.

No transition matrix is needed. Supports context conditioning through
pluggable state and context encoders.

Reference:
    Kang, M., et al. (2025). DDC IRL with neural networks.
"""

from __future__ import annotations

from typing import Callable

import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax
import pandas as pd
from scipy.stats import norm as scipy_norm

from econirl.core.reward_spec import RewardSpec
from econirl.core.types import Panel, TrajectoryPanel
from econirl.estimators.neural_base import NeuralEstimatorMixin

def _to_numpy(values: object) -> np.ndarray:
    return np.asarray(values)


def _to_jax_float(values: object) -> jax.Array:
    return jnp.asarray(values, dtype=jnp.float32)


def _to_jax_int(values: object) -> jax.Array:
    return jnp.asarray(values, dtype=jnp.int32)


class _MLP(eqx.Module):
    layers: tuple[eqx.nn.Linear, ...]
    output_layer: eqx.nn.Linear

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        hidden_dim: int,
        num_layers: int,
        *,
        key: jax.Array,
    ):
        n_hidden = max(num_layers, 0)
        keys = jr.split(key, n_hidden + 1)
        layers: list[eqx.nn.Linear] = []
        current_dim = in_dim
        for idx in range(n_hidden):
            layers.append(eqx.nn.Linear(current_dim, hidden_dim, key=keys[idx]))
            current_dim = hidden_dim
        self.layers = tuple(layers)
        self.output_layer = eqx.nn.Linear(current_dim, out_dim, key=keys[-1])

    def _forward_single(self, x: jax.Array) -> jax.Array:
        h = x
        for layer in self.layers:
            h = jax.nn.relu(layer(h))
        return self.output_layer(h)

    def __call__(self, x: jax.Array) -> jax.Array:
        x = jnp.asarray(x, dtype=jnp.float32)
        if x.ndim == 1:
            return self._forward_single(x)
        return jax.vmap(self._forward_single)(x)

    def eval(self) -> _MLP:
        return self


class _ContextQNetwork(eqx.Module):
    n_actions: int = eqx.field(static=True)
    net: _MLP

    def __init__(
        self,
        state_dim: int,
        context_dim: int,
        n_actions: int,
        hidden_dim: int,
        num_layers: int,
        *,
        key: jax.Array,
    ):
        self.n_actions = n_actions
        self.net = _MLP(
            state_dim + context_dim + n_actions,
            1,
            hidden_dim,
            num_layers,
            key=key,
        )

    def __call__(
        self,
        state_feat: object,
        ctx_feat: object,
        action_onehot: object,
    ) -> object:
        sf = _to_jax_float(state_feat)
        cf = _to_jax_float(ctx_feat)
        ao = _to_jax_float(action_onehot)
        x = jnp.concatenate([sf, cf, ao], axis=-1)
        out = jnp.squeeze(self.net(x), axis=-1)
        return out

    def all_actions(
        self,
        state_feat: object,
        ctx_feat: object,
        n_actions: int,
    ) -> object:
        sf = _to_jax_float(state_feat)
        cf = _to_jax_float(ctx_feat)
        actions = jnp.eye(n_actions, dtype=jnp.float32)
        sf_exp = jnp.repeat(sf[:, None, :], n_actions, axis=1)
        cf_exp = jnp.repeat(cf[:, None, :], n_actions, axis=1)
        a_exp = jnp.repeat(actions[None, :, :], sf.shape[0], axis=0)
        x = jnp.concatenate([sf_exp, cf_exp, a_exp], axis=-1)
        out = jnp.squeeze(jax.vmap(self.net)(x), axis=-1)
        return out

    def eval(self) -> _ContextQNetwork:
        return self


class NeuralGLADIUS(NeuralEstimatorMixin):
    """Context-aware GLADIUS estimator with sklearn-style API."""

    def __init__(
        self,
        n_actions: int = 8,
        discount: float = 0.95,
        scale: float = 1.0,
        q_hidden_dim: int = 128,
        q_num_layers: int = 3,
        ev_hidden_dim: int = 128,
        ev_num_layers: int = 3,
        batch_size: int = 512,
        max_epochs: int = 500,
        lr: float = 1e-3,
        bellman_weight: float = 1.0,
        gradient_clip: float = 1.0,
        patience: int = 50,
        alternating_updates: bool = True,
        lr_decay_rate: float = 0.001,
        tikhonov_annealing: bool = False,
        tikhonov_initial_weight: float = 100.0,
        anchor_action: int | None = None,
        state_encoder: Callable[[object], object] | None = None,
        context_encoder: Callable[[object], object] | None = None,
        state_dim: int | None = None,
        context_dim: int = 0,
        feature_names: list[str] | None = None,
        verbose: bool = False,
    ):
        self.n_actions = n_actions
        self.discount = discount
        self.scale = scale
        self.q_hidden_dim = q_hidden_dim
        self.q_num_layers = q_num_layers
        self.ev_hidden_dim = ev_hidden_dim
        self.ev_num_layers = ev_num_layers
        self.batch_size = batch_size
        self.max_epochs = max_epochs
        self.lr = lr
        self.bellman_weight = bellman_weight
        self.gradient_clip = gradient_clip
        self.patience = patience
        self.alternating_updates = alternating_updates
        self.lr_decay_rate = lr_decay_rate
        self.tikhonov_annealing = tikhonov_annealing
        self.tikhonov_initial_weight = tikhonov_initial_weight
        self.anchor_action = anchor_action
        self.state_encoder = state_encoder
        self.context_encoder = context_encoder
        self.state_dim = state_dim
        self.context_dim = context_dim
        self.feature_names = feature_names
        self.verbose = verbose

        self.params_: dict[str, float] | None = None
        self.se_: dict[str, float] | None = None
        self.pvalues_: dict[str, float] | None = None
        self.coef_: np.ndarray | None = None
        self.policy_: np.ndarray | None = None
        self.value_: np.ndarray | None = None
        self.projection_r2_: float | None = None
        self.converged_: bool | None = None
        self.n_epochs_: int | None = None

        self._q_net: _ContextQNetwork | None = None
        self._ev_net: _ContextQNetwork | None = None
        self._state_encoder: Callable[[object], jax.Array] | None = None
        self._context_encoder: Callable[[object], jax.Array] | None = None
        self._state_dim: int | None = None
        self._context_dim: int | None = None
        self._n_states: int | None = None

    def fit(
        self,
        data: pd.DataFrame | Panel | TrajectoryPanel,
        state: str | None = None,
        action: str | None = None,
        id: str | None = None,
        context: str | object | None = None,
        features: RewardSpec | object | None = None,
        transitions: object = None,
    ) -> NeuralGLADIUS:
        all_states, all_actions, all_next, all_contexts = self._extract_data(
            data, state, action, id, context
        )

        n_states = int(np.asarray(all_states).max()) + 1
        self._n_states = n_states
        self._build_encoders(all_states, all_contexts, n_states)

        key = jr.PRNGKey(np.random.randint(0, 2**31 - 1))
        q_key, ev_key = jr.split(key, 2)
        self._q_net = _ContextQNetwork(
            self._state_dim,
            self._context_dim,
            self.n_actions,
            self.q_hidden_dim,
            self.q_num_layers,
            key=q_key,
        )
        self._ev_net = _ContextQNetwork(
            self._state_dim,
            self._context_dim,
            self.n_actions,
            self.ev_hidden_dim,
            self.ev_num_layers,
            key=ev_key,
        )

        self._train(all_states, all_actions, all_next, all_contexts)
        self._extract_policy_and_value(all_states, all_contexts, n_states)

        if features is not None:
            self._project_onto_features(
                features, all_states, all_actions, all_contexts
            )
        else:
            self.params_ = None
            self.se_ = None
            self.pvalues_ = None
            self.projection_r2_ = None
            self.coef_ = None

        return self

    def _extract_data(
        self,
        data: pd.DataFrame | Panel | TrajectoryPanel,
        state: str | None,
        action: str | None,
        id: str | None,
        context: str | object | None,
    ) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
        if isinstance(data, pd.DataFrame):
            if state is None or action is None or id is None:
                raise ValueError(
                    "state, action, and id column names are required "
                    "when data is a DataFrame"
                )
            panel = TrajectoryPanel.from_dataframe(
                data, state=state, action=action, id=id
            )
            all_states = jnp.asarray(panel.all_states, dtype=jnp.int32)
            all_actions = jnp.asarray(panel.all_actions, dtype=jnp.int32)
            all_next = jnp.asarray(panel.all_next_states, dtype=jnp.int32)

            if isinstance(context, str):
                all_contexts = self._extract_context_from_df(data, id, context, panel)
            elif context is not None:
                all_contexts = _to_jax_int(context)
            else:
                all_contexts = jnp.zeros(len(all_states), dtype=jnp.int32)
        elif isinstance(data, (Panel, TrajectoryPanel)):
            all_states = jnp.asarray(data.get_all_states(), dtype=jnp.int32)
            all_actions = jnp.asarray(data.get_all_actions(), dtype=jnp.int32)
            all_next = jnp.asarray(data.get_all_next_states(), dtype=jnp.int32)
            if context is not None:
                all_contexts = _to_jax_int(context)
            else:
                all_contexts = jnp.zeros(len(all_states), dtype=jnp.int32)
        else:
            raise TypeError(
                f"data must be a DataFrame, Panel, or TrajectoryPanel, got {type(data)}"
            )

        return all_states, all_actions, all_next, all_contexts

    def _extract_context_from_df(
        self,
        df: pd.DataFrame,
        id_col: str,
        context_col: str,
        panel: TrajectoryPanel,
    ) -> jax.Array:
        contexts: list[int] = []
        for _, group in df.groupby(id_col, sort=True):
            group = group.sort_index()
            contexts.extend(group[context_col].values.tolist())
        return jnp.asarray(contexts, dtype=jnp.int32)

    def _call_encoder(self, encoder: Callable[[object], object], values: object) -> jax.Array:
        encoded = encoder(values)
        return _to_jax_float(encoded)

    def _build_encoders(
        self,
        all_states: jax.Array,
        all_contexts: jax.Array,
        n_states: int,
    ) -> None:
        if self.state_encoder is not None:
            self._state_encoder = lambda s: self._call_encoder(self.state_encoder, s)
            self._state_dim = self.state_dim or 1
        else:
            max_s = max(n_states - 1, 1)
            self._state_encoder = lambda s, _ms=max_s: (
                _to_jax_float(s) / float(_ms)
            ).reshape(-1, 1)
            self._state_dim = 1

        if self.context_encoder is not None:
            self._context_encoder = lambda c: self._call_encoder(self.context_encoder, c)
            self._context_dim = self.context_dim or 1
        else:
            n_ctx = max(int(np.asarray(all_contexts).max()), 1) if len(all_contexts) else 1
            self._context_encoder = lambda c, _mc=n_ctx: (
                _to_jax_float(c) / float(_mc)
            ).reshape(-1, 1)
            self._context_dim = 1

    def _train(
        self,
        states: jax.Array,
        actions: jax.Array,
        next_states: jax.Array,
        contexts: jax.Array,
    ) -> None:
        def lr_schedule(step: jax.Array) -> jax.Array:
            return self.lr / (1.0 + self.lr_decay_rate * step)

        q_transforms = []
        ev_transforms = []
        if self.gradient_clip > 0:
            q_transforms.append(optax.clip_by_global_norm(self.gradient_clip))
            ev_transforms.append(optax.clip_by_global_norm(self.gradient_clip))
        q_transforms.append(optax.adam(lr_schedule))
        ev_transforms.append(optax.adam(lr_schedule))

        q_optimizer = optax.chain(*q_transforms)
        ev_optimizer = optax.chain(*ev_transforms)
        q_net = self._q_net
        ev_net = self._ev_net
        q_opt_state = q_optimizer.init(eqx.filter(q_net, eqx.is_inexact_array))
        ev_opt_state = ev_optimizer.init(eqx.filter(ev_net, eqx.is_inexact_array))

        N = len(states)
        best_loss = float("inf")
        patience_counter = 0

        action_counts = np.bincount(np.asarray(actions), minlength=self.n_actions).astype(np.float32)
        action_counts = np.clip(action_counts, a_min=1.0, a_max=None)
        class_weights = jnp.asarray(N / (self.n_actions * action_counts), dtype=jnp.float32)

        def q_all(net: _ContextQNetwork, s_feat: jax.Array, ctx_feat: jax.Array) -> jax.Array:
            return jnp.asarray(net.all_actions(s_feat, ctx_feat, self.n_actions), dtype=jnp.float32)

        @eqx.filter_value_and_grad
        def ev_loss_fn(
            ev_model: _ContextQNetwork,
            q_model: _ContextQNetwork,
            s_feat: jax.Array,
            ctx_feat: jax.Array,
            actions_j: jax.Array,
            ns_feat: jax.Array,
        ) -> jax.Array:
            a_oh = jax.nn.one_hot(actions_j, self.n_actions, dtype=jnp.float32)
            zeta_sa = jnp.asarray(ev_model(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
            q_next_all = q_all(q_model, ns_feat, ctx_feat)
            v_next = self.scale * jax.nn.logsumexp(q_next_all / self.scale, axis=1)
            return jnp.mean((zeta_sa - jax.lax.stop_gradient(v_next)) ** 2)

        @eqx.filter_value_and_grad
        def q_nll_loss_fn(
            q_model: _ContextQNetwork,
            s_feat: jax.Array,
            ctx_feat: jax.Array,
            actions_j: jax.Array,
            ce_weight: float,
        ) -> jax.Array:
            qvals = q_all(q_model, s_feat, ctx_feat)
            log_probs = jax.nn.log_softmax(qvals / self.scale, axis=1)
            per_obs_nll = -log_probs[jnp.arange(actions_j.shape[0]), actions_j]
            weights = class_weights[actions_j]
            nll = jnp.mean(per_obs_nll * weights)
            return ce_weight * nll

        @eqx.filter_value_and_grad
        def joint_loss_fn(
            q_model: _ContextQNetwork,
            ev_model: _ContextQNetwork,
            s_feat: jax.Array,
            ctx_feat: jax.Array,
            actions_j: jax.Array,
            ns_feat: jax.Array,
            ce_weight: float,
        ) -> jax.Array:
            qvals = q_all(q_model, s_feat, ctx_feat)
            log_probs = jax.nn.log_softmax(qvals / self.scale, axis=1)
            per_obs_nll = -log_probs[jnp.arange(actions_j.shape[0]), actions_j]
            weights = class_weights[actions_j]
            nll = jnp.mean(per_obs_nll * weights)
            a_oh = jax.nn.one_hot(actions_j, self.n_actions, dtype=jnp.float32)
            ev_sa = jnp.asarray(ev_model(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
            q_next_all = q_all(q_model, ns_feat, ctx_feat)
            v_next = self.scale * jax.nn.logsumexp(q_next_all / self.scale, axis=1)
            bellman = jnp.mean((ev_sa - jax.lax.stop_gradient(v_next)) ** 2)
            return ce_weight * nll + self.bellman_weight * bellman

        @eqx.filter_jit
        def ev_step(
            ev_model: _ContextQNetwork,
            ev_state: optax.OptState,
            q_model: _ContextQNetwork,
            s_feat: jax.Array,
            ctx_feat: jax.Array,
            actions_j: jax.Array,
            ns_feat: jax.Array,
        ) -> tuple[_ContextQNetwork, optax.OptState, jax.Array]:
            loss, grads = ev_loss_fn(ev_model, q_model, s_feat, ctx_feat, actions_j, ns_feat)
            updates, ev_state = ev_optimizer.update(grads, ev_state, ev_model)
            ev_model = eqx.apply_updates(ev_model, updates)
            return ev_model, ev_state, loss

        @eqx.filter_jit
        def q_step(
            q_model: _ContextQNetwork,
            q_state: optax.OptState,
            s_feat: jax.Array,
            ctx_feat: jax.Array,
            actions_j: jax.Array,
            ce_weight: float,
        ) -> tuple[_ContextQNetwork, optax.OptState, jax.Array]:
            loss, grads = q_nll_loss_fn(q_model, s_feat, ctx_feat, actions_j, ce_weight)
            updates, q_state = q_optimizer.update(grads, q_state, q_model)
            q_model = eqx.apply_updates(q_model, updates)
            return q_model, q_state, loss

        @eqx.filter_jit
        def joint_step(
            q_model: _ContextQNetwork,
            q_state: optax.OptState,
            ev_model: _ContextQNetwork,
            ev_state: optax.OptState,
            s_feat: jax.Array,
            ctx_feat: jax.Array,
            actions_j: jax.Array,
            ns_feat: jax.Array,
            ce_weight: float,
        ) -> tuple[_ContextQNetwork, optax.OptState, _ContextQNetwork, optax.OptState, jax.Array]:
            loss, (q_grads, ev_grads) = eqx.filter_value_and_grad(joint_loss_fn, arg=(0, 1))(
                q_model, ev_model, s_feat, ctx_feat, actions_j, ns_feat, ce_weight
            )
            q_updates, q_state = q_optimizer.update(q_grads, q_state, q_model)
            ev_updates, ev_state = ev_optimizer.update(ev_grads, ev_state, ev_model)
            q_model = eqx.apply_updates(q_model, q_updates)
            ev_model = eqx.apply_updates(ev_model, ev_updates)
            return q_model, q_state, ev_model, ev_state, loss

        best_q = q_net
        best_ev = ev_net

        for epoch in range(self.max_epochs):
            perm = np.random.permutation(N)
            epoch_loss = 0.0
            n_batches = 0
            batch_idx = 0
            ce_weight = (
                self.tikhonov_initial_weight / (1.0 + epoch)
                if self.tikhonov_annealing
                else 1.0
            )

            for start in range(0, N, self.batch_size):
                idx = perm[start : start + self.batch_size]
                s = states[idx]
                a = actions[idx]
                ns = next_states[idx]
                ctx = contexts[idx]

                s_feat = self._state_encoder(s)
                ns_feat = self._state_encoder(ns)
                ctx_feat = self._context_encoder(ctx)

                if self.alternating_updates and batch_idx % 2 == 0:
                    ev_net, ev_opt_state, loss = ev_step(
                        ev_net, ev_opt_state, q_net, s_feat, ctx_feat, a, ns_feat
                    )
                elif self.alternating_updates and batch_idx % 2 == 1:
                    q_net, q_opt_state, loss = q_step(
                        q_net, q_opt_state, s_feat, ctx_feat, a, ce_weight
                    )
                else:
                    q_net, q_opt_state, ev_net, ev_opt_state, loss = joint_step(
                        q_net,
                        q_opt_state,
                        ev_net,
                        ev_opt_state,
                        s_feat,
                        ctx_feat,
                        a,
                        ns_feat,
                        ce_weight,
                    )

                epoch_loss += float(loss)
                n_batches += 1
                batch_idx += 1

            avg_loss = epoch_loss / max(n_batches, 1)

            if self.verbose and (epoch + 1) % 50 == 0:
                print(f"  Epoch {epoch + 1}: loss={avg_loss:.4f}")

            if avg_loss < best_loss - 1e-4:
                best_loss = avg_loss
                patience_counter = 0
                best_q = q_net
                best_ev = ev_net
            else:
                patience_counter += 1
                if patience_counter >= self.patience:
                    if self.verbose:
                        print(f"  Early stopping at epoch {epoch + 1}")
                    break

        self._q_net = best_q
        self._ev_net = best_ev
        self.converged_ = patience_counter >= self.patience or epoch == self.max_epochs - 1
        self.n_epochs_ = epoch + 1

    def _extract_policy_and_value(
        self,
        all_states: jax.Array,
        all_contexts: jax.Array,
        n_states: int,
    ) -> None:
        unique_states = jnp.arange(n_states, dtype=jnp.int32)
        ctx_default = jnp.zeros(n_states, dtype=jnp.int32)
        s_feat = self._state_encoder(unique_states)
        ctx_feat = self._context_encoder(ctx_default)
        qvals = jnp.asarray(self._q_net.all_actions(s_feat, ctx_feat, self.n_actions), dtype=jnp.float32)
        policy = jax.nn.softmax(qvals / self.scale, axis=1)
        value = self.scale * jax.nn.logsumexp(qvals / self.scale, axis=1)
        self.policy_ = np.asarray(policy)
        self.value_ = np.asarray(value)

    def _project_onto_features(
        self,
        features: RewardSpec | object,
        states: jax.Array,
        actions: jax.Array,
        contexts: jax.Array,
    ) -> None:
        if isinstance(features, RewardSpec):
            feat_matrix = features.feature_matrix
            names = features.parameter_names
        else:
            feat_matrix = features
            names = self.feature_names or [f"f{i}" for i in range(np.asarray(features).shape[-1])]

        n_s = self._n_states
        unique_states = jnp.arange(n_s, dtype=jnp.int32)
        unique_ctx = jnp.zeros(n_s, dtype=jnp.int32)
        s_feat = self._state_encoder(unique_states)
        ctx_feat = self._context_encoder(unique_ctx)
        q_all = jnp.asarray(self._q_net.all_actions(s_feat, ctx_feat, self.n_actions), dtype=jnp.float32)
        action_ids = jnp.arange(self.n_actions, dtype=jnp.int32)
        action_oh = jax.nn.one_hot(action_ids, self.n_actions, dtype=jnp.float32)

        def reward_for_action(a_oh_single: jax.Array) -> jax.Array:
            tiled = jnp.repeat(a_oh_single[None, :], n_s, axis=0)
            ev_a = jnp.asarray(self._ev_net(s_feat, ctx_feat, tiled), dtype=jnp.float32)
            return ev_a

        ev_all = jax.vmap(reward_for_action)(action_oh).T
        r_all = q_all - self.discount * ev_all

        feat_np = _to_numpy(feat_matrix)
        dr_list = []
        dphi_list = []
        for a_idx in range(1, self.n_actions):
            dr_list.append(np.asarray(r_all[:, a_idx] - r_all[:, 0]))
            dphi_list.append(feat_np[:n_s, a_idx, :] - feat_np[:n_s, 0, :])

        rewards = np.concatenate(dr_list, axis=0).astype(np.float32)
        phi = np.concatenate(dphi_list, axis=0).astype(np.float32)

        theta, se, r2 = self._project_parameters(phi, rewards)
        self.params_ = {n: float(v) for n, v in zip(names, theta)}
        self.se_ = {n: float(v) for n, v in zip(names, se)}
        self.pvalues_ = self._compute_pvalues(self.params_, self.se_)
        self.projection_r2_ = r2
        self.coef_ = np.asarray(theta)

    @property
    def reward_matrix_(self) -> np.ndarray | None:
        if self._q_net is None or self._ev_net is None or self._n_states is None:
            return None
        n_s = self._n_states
        unique_states = jnp.arange(n_s, dtype=jnp.int32)
        ctx_default = jnp.zeros(n_s, dtype=jnp.int32)
        s_feat = self._state_encoder(unique_states)
        ctx_feat = self._context_encoder(ctx_default)
        q_all = jnp.asarray(self._q_net.all_actions(s_feat, ctx_feat, self.n_actions), dtype=jnp.float32)
        action_ids = jnp.arange(self.n_actions, dtype=jnp.int32)
        action_oh = jax.nn.one_hot(action_ids, self.n_actions, dtype=jnp.float32)

        def ev_for_action(a_oh_single: jax.Array) -> jax.Array:
            tiled = jnp.repeat(a_oh_single[None, :], n_s, axis=0)
            return jnp.asarray(self._ev_net(s_feat, ctx_feat, tiled), dtype=jnp.float32)

        ev_all = jax.vmap(ev_for_action)(action_oh).T
        return np.asarray(q_all - self.discount * ev_all)

    def predict_proba(self, states: np.ndarray) -> np.ndarray:
        if self.policy_ is None:
            raise RuntimeError("Model not fitted. Call fit() first.")
        states = np.asarray(states, dtype=np.int64)
        return self.policy_[states]

    def predict_q_from_features(
        self,
        state_features: object,
        contexts: object | None = None,
    ) -> np.ndarray:
        if self._q_net is None:
            raise RuntimeError("Model not fitted. Call fit() first.")
        s_feat = _to_jax_float(state_features)
        if s_feat.ndim == 1:
            s_feat = s_feat[None, :]
        if contexts is None:
            contexts = jnp.zeros(s_feat.shape[0], dtype=jnp.int32)
        ctx_feat = self._context_encoder(contexts)
        qvals = self._q_net.all_actions(s_feat, ctx_feat, self.n_actions)
        return np.asarray(qvals)

    def predict_reward_from_features(
        self,
        state_features: object,
        actions: object,
        contexts: object | None = None,
    ) -> np.ndarray:
        if self._q_net is None:
            raise RuntimeError("Model not fitted. Call fit() first.")
        s_feat = _to_jax_float(state_features)
        if s_feat.ndim == 1:
            s_feat = s_feat[None, :]
        actions_j = _to_jax_int(actions)
        if actions_j.ndim == 0:
            actions_j = actions_j[None]
        if contexts is None:
            contexts = jnp.zeros(s_feat.shape[0], dtype=jnp.int32)
        ctx_feat = self._context_encoder(contexts)
        a_oh = jax.nn.one_hot(actions_j, self.n_actions, dtype=jnp.float32)
        q_vals = jnp.asarray(self._q_net(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
        ev_vals = jnp.asarray(self._ev_net(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
        return np.asarray(q_vals - self.discount * ev_vals)

    def predict_reward(
        self,
        states: object,
        actions: object,
        contexts: object | None = None,
    ) -> object:
        if self._q_net is None:
            raise RuntimeError("Model not fitted. Call fit() first.")
        states_j = _to_jax_int(states)
        actions_j = _to_jax_int(actions)
        if contexts is None:
            contexts_j = jnp.zeros(states_j.shape[0], dtype=jnp.int32)
        else:
            contexts_j = _to_jax_int(contexts)
        s_feat = self._state_encoder(states_j)
        ctx_feat = self._context_encoder(contexts_j)
        a_oh = jax.nn.one_hot(actions_j, self.n_actions, dtype=jnp.float32)
        q_vals = jnp.asarray(self._q_net(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
        ev_vals = jnp.asarray(self._ev_net(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
        rewards = q_vals - self.discount * ev_vals
        return rewards

    def conf_int(self, alpha: float = 0.05) -> dict[str, tuple[float, float]]:
        if self.params_ is None or self.se_ is None:
            raise RuntimeError(
                "No projected parameters available. "
                "Call fit() with features= to extract structural parameters."
            )
        z = scipy_norm.ppf(1 - alpha / 2)
        intervals: dict[str, tuple[float, float]] = {}
        for name in self.params_:
            est = self.params_[name]
            se = self.se_[name]
            if np.isfinite(se):
                intervals[name] = (est - z * se, est + z * se)
            else:
                intervals[name] = (float("nan"), float("nan"))
        return intervals

    def summary(self) -> str:
        if self.policy_ is None:
            return "NeuralGLADIUS: Not fitted yet. Call fit() first."
        n_obs = self._n_states if self._n_states is not None else None
        return self._format_neural_summary(
            method_name="NeuralGLADIUS",
            params=self.params_,
            se=self.se_,
            pvalues=self.pvalues_,
            projection_r2=self.projection_r2_,
            n_observations=n_obs,
            n_epochs=self.n_epochs_,
            converged=self.converged_,
            discount=self.discount,
            scale=self.scale,
            context_dim=self._context_dim,
            extra_lines=[
                f"Q-network: {self.q_num_layers} layers x {self.q_hidden_dim} hidden",
                f"EV-network: {self.ev_num_layers} layers x {self.ev_hidden_dim} hidden",
            ],
        )

    def __repr__(self) -> str:
        fitted = self.policy_ is not None
        return (
            f"NeuralGLADIUS(n_actions={self.n_actions}, "
            f"discount={self.discount}, "
            f"fitted={fitted})"
        )