Source code for econirl.estimators.ufxp_neural

"""Neural-utility UFXP (Oguz and Bray 2026).

The unnested fixed point trained with a neural utility. The linear UFXP scores
random projections of the Bellman first-order conditions and, because the value
function is removed by a precomputed dual (Proposition 2), solves a single
least-squares problem. This estimator keeps the same precompute and replaces the
closed-form solve with a gradient loop over a neural utility ``u_w(s, a)``.

The point of the method: the dual ``lam = (I - beta F_P')^{-1} W`` is the only
linear solve, computed once before training. Every gradient step is one network
forward pass through the objective

    Q(w) = sum_i ( b_i - <Z_i, du_w> + lam_i' u_P^w )^2,
    du_w[s, a] = u_w(s, a) - u_w(s, ref),
    u_P^w[s]   = sum_a P_hat[s, a] u_w(s, a),

with ``b_i`` precomputed and utility-free. No Bellman equation is solved inside
the loop, which is what lets a neural utility train at scale.

Scope (Phase 1): the neural-utility training and behavior recovery, with the
learned utility projected onto features for an interpretable coefficient vector
and pseudo standard errors. The paper's optimal-weight second step (neural
OUFXP) and its efficient standard errors are not implemented here; the standard
errors reported are projection pseudo-SEs, not the efficient UFXP variance.
"""

from __future__ import annotations

import time

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import pandas as pd

from econirl.core.types import DDCProblem, Panel, TrajectoryPanel
from econirl.estimation.ufxp import _ufxp_precompute, _ufxp_random_duals
from econirl.estimators.neural_base import NeuralEstimatorMixin
from econirl.preferences.linear import LinearUtility


class _UtilityNet(eqx.Module):
    """MLP utility ``u_w(phi)`` mapping a feature vector in R^K to a scalar."""

    layers: list
    output_layer: eqx.nn.Linear

    def __init__(self, n_features: int, hidden_dim: int, num_layers: int, *, key):
        keys = jax.random.split(key, num_layers + 1)
        layers = []
        in_dim = n_features
        for i in range(num_layers):
            layers.append(eqx.nn.Linear(in_dim, hidden_dim, key=keys[i]))
            in_dim = hidden_dim
        self.layers = layers
        self.output_layer = eqx.nn.Linear(in_dim, 1, key=keys[-1])

    def __call__(self, phi_sa: jax.Array) -> jax.Array:
        x = phi_sa
        for layer in self.layers:
            x = jax.nn.relu(layer(x))
        return self.output_layer(x).squeeze(-1)

    def utility_matrix(self, phi: jax.Array) -> jax.Array:
        """Utility ``u_w(s, a)`` for every state-action pair; ``phi`` is (S, A, K)."""
        S, A, K = phi.shape
        u = jax.vmap(self)(phi.reshape(S * A, K))
        return u.reshape(S, A)


[docs] class NeuralUFXP(NeuralEstimatorMixin): """Neural-utility UFXP estimator (Oguz and Bray 2026). Trains a neural utility ``u_w(s, a)`` by minimizing the UFXP random-projection objective, reusing the linear estimator's precomputed dual so no Bellman equation is solved during training. Parameters ---------- n_states, n_actions : int, optional Sizes of the state and action spaces. Inferred from the data if None. discount : float, default=0.95 Discount factor ``beta``. scale : float, default=1.0 Logit scale ``sigma``. num_projections : int, default=64 Number of random projections ``m``. reward_hidden_dim : int, default=64 Hidden width of the utility network. reward_num_layers : int, default=2 Hidden depth of the utility network. max_epochs : int, default=2000 Adam steps over the projection objective. lr : float, default=1e-2 Adam learning rate. gradient_clip : float, default=10.0 Global-norm gradient clip (<=0 disables). ccp_min_count : int, default=1 Minimum visits for a state's first-order conditions to be scored. ccp_smoothing : float, default=1e-6 Additive smoothing for the frequency CCPs. seed : int, default=0 Seed for the projections and the network initialization. verbose : bool, default=False Print the objective during training. Attributes ---------- policy_ : numpy.ndarray Estimated choice probabilities, shape (n_states, n_actions). value_ : numpy.ndarray Estimated value function, shape (n_states,). reward_ : numpy.ndarray Learned utility ``u_w(s, a)``, shape (n_states, n_actions). params_ : dict The learned utility projected onto the features. The objective constrains the choice-relevant utility, not the utility level, so this is a best-effort linear summary of a partially identified function; a low ``projection_r2_`` flags that the utility is not linear in the features. se_ : dict Projection pseudo standard errors (not the efficient UFXP variance). coef_ : numpy.ndarray Projected coefficients in array form. projection_r2_ : float R-squared of the feature projection. converged_ : bool Whether the objective decreased to a finite value. """
[docs] def __init__( self, n_states: int | None = None, n_actions: int | None = None, discount: float = 0.95, scale: float = 1.0, num_projections: int = 64, reward_hidden_dim: int = 64, reward_num_layers: int = 2, max_epochs: int = 2000, lr: float = 1e-2, gradient_clip: float = 10.0, ccp_min_count: int = 1, ccp_smoothing: float = 1e-6, seed: int = 0, verbose: bool = False, ): self.n_states = n_states self.n_actions = n_actions self.discount = discount self.scale = scale self.num_projections = num_projections self.reward_hidden_dim = reward_hidden_dim self.reward_num_layers = reward_num_layers self.max_epochs = max_epochs self.lr = lr self.gradient_clip = gradient_clip self.ccp_min_count = ccp_min_count self.ccp_smoothing = ccp_smoothing self.seed = seed self.verbose = verbose self.policy_ = None self.value_ = None self.reward_ = None self.params_ = None self.se_ = None self.coef_ = None self.projection_r2_ = None self.converged_ = None self.n_epochs_ = None
def _extract_panel(self, data, state, action, id) -> TrajectoryPanel: if isinstance(data, pd.DataFrame): if state is None or action is None or id is None: raise ValueError("state, action, and id column names are required " "when data is a DataFrame") return TrajectoryPanel.from_dataframe(data, state=state, action=action, id=id) if isinstance(data, (Panel, TrajectoryPanel)): return TrajectoryPanel.from_panel(data) raise TypeError(f"data must be a DataFrame, Panel, or TrajectoryPanel, got {type(data)}")
[docs] def fit( self, data: pd.DataFrame | Panel | TrajectoryPanel, state: str | None = None, action: str | None = None, id: str | None = None, features: np.ndarray | None = None, transitions: np.ndarray | None = None, ) -> "NeuralUFXP": """Fit the neural utility to data. Parameters ---------- data : pandas.DataFrame or Panel or TrajectoryPanel Panel of observed choices. state, action, id : str, optional Column names (required when ``data`` is a DataFrame). features : numpy.ndarray Reward features ``phi(s, a)`` of shape (n_states, n_actions, K). The utility network maps each feature vector to a scalar utility, so the features set the inputs the network can combine. transitions : numpy.ndarray Transition matrices ``P(s'|s,a)`` of shape (n_actions, n_states, n_states). """ if features is None: raise ValueError("NeuralUFXP requires features (n_states, n_actions, K).") if transitions is None: raise ValueError("NeuralUFXP requires transitions (n_actions, n_states, n_states).") t0 = time.time() phi = np.asarray(getattr(features, "features", features), dtype=np.float64) names = list(getattr(features, "names", [])) or [f"theta_{i}" for i in range(phi.shape[2])] panel = self._extract_panel(data, state, action, id) A, S, _ = np.asarray(transitions).shape problem = DDCProblem(num_states=S, num_actions=A, discount_factor=self.discount, scale_parameter=self.scale) utility = LinearUtility(feature_matrix=jnp.asarray(phi), parameter_names=names) # Reuse the linear estimator's theta-independent precompute and duals. pre = _ufxp_precompute(panel, utility, transitions, problem, self.ccp_min_count, self.ccp_smoothing) Z, lam = _ufxp_random_duals(pre, self.num_projections, self.seed) # b_i = <Z_i, logratio> + lam_i' ent -- precomputed, utility-free. b = np.einsum("msa,sa->m", Z, pre.logratio) + lam.T @ pre.ent phi_j = jnp.asarray(pre.phi) Z_j = jnp.asarray(Z) lamT = jnp.asarray(lam.T) # (m, S) b_j = jnp.asarray(b) # (m,) P_j = jnp.asarray(pre.P) # (S, A) ref = pre.ref key = jax.random.PRNGKey(self.seed) net = _UtilityNet(phi.shape[2], self.reward_hidden_dim, self.reward_num_layers, key=key) def objective(model): uall = model.utility_matrix(phi_j) # (S, A) du = uall[:, :ref] - uall[:, ref:ref + 1] # (S, A-1) u_P = jnp.sum(P_j * uall, axis=1) # (S,) resid = b_j - jnp.einsum("msa,sa->m", Z_j, du) + lamT @ u_P return jnp.sum(resid ** 2) transforms = [] if self.gradient_clip > 0: transforms.append(optax.clip_by_global_norm(self.gradient_clip)) transforms.append(optax.adam(self.lr)) opt = optax.chain(*transforms) opt_state = opt.init(eqx.filter(net, eqx.is_inexact_array)) @eqx.filter_jit def step(model, opt_state): loss, grads = eqx.filter_value_and_grad(objective)(model) updates, opt_state = opt.update(grads, opt_state, model) model = eqx.apply_updates(model, updates) return model, opt_state, loss init_loss = float(objective(net)) last = init_loss for epoch in range(self.max_epochs): net, opt_state, loss = step(net, opt_state) last = float(loss) if self.verbose and epoch % 200 == 0: print(f" epoch {epoch:5d} Q={last:.6e}") self.n_epochs_ = self.max_epochs self.converged_ = bool(np.isfinite(last) and last <= init_loss) # One soft-Bellman solve at the learned utility for policy and value. from econirl.core.bellman import SoftBellmanOperator from econirl.core.solvers import hybrid_iteration reward_mat = net.utility_matrix(phi_j) op = SoftBellmanOperator(problem, jnp.asarray(transitions, dtype=jnp.float64)) sol = hybrid_iteration(op, reward_mat, tol=1e-8, max_iter=5000) self.policy_ = np.asarray(sol.policy) self.value_ = np.asarray(sol.V) self.reward_ = np.asarray(reward_mat) # Project the learned utility onto the features for interpretable coefs. theta, se, r2 = self._project_parameters( phi.reshape(S * A, phi.shape[2]), np.asarray(reward_mat).reshape(S * A)) self.coef_ = np.asarray(theta) self.params_ = dict(zip(names, [float(v) for v in theta])) self.se_ = dict(zip(names, [float(v) for v in se])) self.projection_r2_ = float(r2) self._fit_time = time.time() - t0 return self