"""Unified feature specification for all estimators.
RewardSpec replaces the three separate classes (LinearUtility, LinearReward,
ActionDependentReward) with a single, clean interface. Internally it always
stores features as a (S, A, K) array and exposes the same compute/gradient/
hessian interface that estimators rely on.
Backward-compatibility adapters (.to_linear_utility(), etc.) allow gradual
migration without breaking existing code.
Usage:
>>> features_sak = jnp.zeros((10, 2, 3))
>>> spec = RewardSpec(features_sak, names=["cost", "benefit", "distance"])
>>> R = spec.compute(jnp.array([1.0, -0.5, 0.3])) # shape (10, 2)
>>> features_sk = jnp.zeros((10, 3))
>>> spec = RewardSpec(features_sk, names=["a", "b", "c"], n_actions=4)
>>> spec = RewardSpec.state_dependent(features_sk, names=["a", "b", "c"], n_actions=4)
>>> spec = RewardSpec.state_action_dependent(features_sak, names=["cost", "benefit", "distance"])
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import jax.numpy as jnp
if TYPE_CHECKING:
pass
[docs]
class RewardSpec:
"""Unified feature specification for structural estimation and IRL.
Stores features as a (S, A, K) array and provides compute, gradient,
and hessian methods compatible with the BaseUtilityFunction protocol.
Parameters
----------
features : jnp.ndarray
Either (S, A, K) for action-dependent features, or (S, K) for
state-only features (broadcast to all actions).
names : list[str]
Human-readable name for each feature/parameter dimension.
n_actions : int, optional
Required when ``features`` is (S, K) to specify the number of
actions for broadcasting. Ignored when ``features`` is (S, A, K).
"""
[docs]
def __init__(
self,
features: jnp.ndarray,
names: list[str],
n_actions: int | None = None,
):
if features.ndim == 3:
# (S, A, K) — action-dependent features
S, A, K = features.shape
if n_actions is not None and n_actions != A:
raise ValueError(
f"features already have {A} actions on axis 1 "
f"but n_actions={n_actions} was also provided"
)
self._feature_matrix = jnp.array(features)
self._is_state_only = False
elif features.ndim == 2:
# (S, K) — state-only features, broadcast to (S, A, K)
S, K = features.shape
if n_actions is None:
raise ValueError(
"n_actions is required when features is 2D (S, K)"
)
if n_actions < 1:
raise ValueError(f"n_actions must be >= 1, got {n_actions}")
self._feature_matrix = jnp.broadcast_to(
features[:, None, :], (S, n_actions, K)
).copy()
self._is_state_only = True
else:
raise ValueError(
f"features must be 2D (S, K) or 3D (S, A, K), "
f"got {features.ndim}D with shape {features.shape}"
)
K = self._feature_matrix.shape[2]
if len(names) != K:
raise ValueError(
f"names must have {K} elements to match feature dimension, "
f"got {len(names)}"
)
self._parameter_names = list(names)
# ------------------------------------------------------------------
# Alternative constructors
# ------------------------------------------------------------------
[docs]
@classmethod
def state_dependent(
cls,
state_features: jnp.ndarray,
names: list[str],
n_actions: int,
) -> RewardSpec:
"""Create from state-only features (S, K), broadcast to all actions.
Parameters
----------
state_features : jnp.ndarray
Shape (S, K).
names : list[str]
One name per feature.
n_actions : int
Number of actions to broadcast to.
"""
if state_features.ndim != 2:
raise ValueError(
f"state_features must be 2D (S, K), got shape {state_features.shape}"
)
return cls(features=state_features, names=names, n_actions=n_actions)
[docs]
@classmethod
def state_action_dependent(
cls,
features: jnp.ndarray,
names: list[str],
) -> RewardSpec:
"""Create from action-dependent features (S, A, K).
Parameters
----------
features : jnp.ndarray
Shape (S, A, K).
names : list[str]
One name per feature.
"""
if features.ndim != 3:
raise ValueError(
f"features must be 3D (S, A, K), got shape {features.shape}"
)
return cls(features=features, names=names)
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def feature_matrix(self) -> jnp.ndarray:
"""Feature array of shape (S, A, K)."""
return self._feature_matrix
@property
def parameter_names(self) -> list[str]:
"""Human-readable names for each parameter."""
return self._parameter_names.copy()
@property
def num_parameters(self) -> int:
"""Number of parameters (K)."""
return self._feature_matrix.shape[2]
@property
def num_states(self) -> int:
"""Number of states (S)."""
return self._feature_matrix.shape[0]
@property
def num_actions(self) -> int:
"""Number of actions (A)."""
return self._feature_matrix.shape[1]
@property
def is_state_only(self) -> bool:
"""Whether the spec was constructed from state-only features."""
return self._is_state_only
# ------------------------------------------------------------------
# Compute interface (matches BaseUtilityFunction protocol)
# ------------------------------------------------------------------
[docs]
def compute(self, parameters: jnp.ndarray) -> jnp.ndarray:
"""Compute reward matrix R(s, a) = sum_k params[k] * features[s, a, k].
Parameters
----------
parameters : jnp.ndarray
Shape (K,).
Returns
-------
jnp.ndarray
Shape (S, A).
"""
self.validate_parameters(parameters)
return jnp.einsum("sak,k->sa", self._feature_matrix, parameters)
[docs]
def compute_gradient(self, parameters: jnp.ndarray) -> jnp.ndarray:
"""Gradient of reward w.r.t. parameters.
For linear specification the gradient is the feature matrix itself,
independent of the parameter values.
Parameters
----------
parameters : jnp.ndarray
Shape (K,). Unused but kept for protocol compatibility.
Returns
-------
jnp.ndarray
Shape (S, A, K).
"""
return jnp.array(self._feature_matrix)
[docs]
def compute_hessian(self, parameters: jnp.ndarray) -> jnp.ndarray:
"""Hessian of reward w.r.t. parameters.
For linear specification the Hessian is identically zero.
Parameters
----------
parameters : jnp.ndarray
Shape (K,). Unused.
Returns
-------
jnp.ndarray
Shape (S, A, K, K) of zeros.
"""
K = self.num_parameters
return jnp.zeros(
(self.num_states, self.num_actions, K, K),
dtype=self._feature_matrix.dtype,
)
# ------------------------------------------------------------------
# Parameter helpers
# ------------------------------------------------------------------
[docs]
def get_initial_parameters(self) -> jnp.ndarray:
"""Return zeros of shape (K,) as a starting point."""
return jnp.zeros(self.num_parameters, dtype=jnp.float32)
[docs]
def get_parameter_bounds(
self,
) -> tuple[jnp.ndarray | None, jnp.ndarray | None]:
"""Return (None, None) indicating unbounded parameters."""
return (None, None)
[docs]
def validate_parameters(self, parameters: jnp.ndarray) -> None:
"""Check that parameters have shape (K,).
Raises
------
ValueError
If shape does not match.
"""
if parameters.shape != (self.num_parameters,):
raise ValueError(
f"Expected parameters of shape ({self.num_parameters},), "
f"got {parameters.shape}"
)
# ------------------------------------------------------------------
# Subset utilities
# ------------------------------------------------------------------
[docs]
def subset_states(self, indices: jnp.ndarray) -> RewardSpec:
"""Return a new RewardSpec containing only the specified states.
Parameters
----------
indices : jnp.ndarray
1-D integer array of state indices to keep.
"""
new = RewardSpec.__new__(RewardSpec)
new._feature_matrix = self._feature_matrix[indices, :, :]
new._parameter_names = self._parameter_names.copy()
new._is_state_only = self._is_state_only
return new
# ------------------------------------------------------------------
# Backward-compatibility adapters
# ------------------------------------------------------------------
[docs]
def to_linear_utility(self) -> "LinearUtility":
"""Convert to a LinearUtility with the same (S, A, K) feature matrix.
Returns
-------
LinearUtility
Equivalent LinearUtility instance.
"""
from econirl.preferences.linear import LinearUtility
return LinearUtility(
feature_matrix=jnp.array(self._feature_matrix),
parameter_names=self._parameter_names.copy(),
)
[docs]
def to_action_dependent_reward(self) -> "ActionDependentReward":
"""Convert to an ActionDependentReward with the same (S, A, K) features.
Returns
-------
ActionDependentReward
Equivalent ActionDependentReward instance.
"""
from econirl.preferences.action_reward import ActionDependentReward
return ActionDependentReward(
feature_matrix=jnp.array(self._feature_matrix),
parameter_names=self._parameter_names.copy(),
)
[docs]
def to_linear_reward(self) -> "LinearReward":
"""Convert to a LinearReward with state-only (S, K) features.
This only works when features are truly state-only (identical
across all actions). If features differ across actions, a
ValueError is raised.
Returns
-------
LinearReward
Equivalent LinearReward instance.
Raises
------
ValueError
If features vary across actions.
"""
from econirl.preferences.reward import LinearReward
# Check that all actions have identical features
ref = self._feature_matrix[:, 0:1, :] # (S, 1, K)
if not jnp.allclose(
self._feature_matrix, jnp.broadcast_to(ref, self._feature_matrix.shape)
):
raise ValueError(
"Cannot convert to LinearReward: features differ across "
"actions. LinearReward requires state-only features."
)
state_features = self._feature_matrix[:, 0, :] # (S, K)
return LinearReward(
state_features=jnp.array(state_features),
parameter_names=self._parameter_names.copy(),
n_actions=self.num_actions,
)
# ------------------------------------------------------------------
# Dunder methods
# ------------------------------------------------------------------
def __repr__(self) -> str:
kind = "state_only" if self._is_state_only else "state_action"
return (
f"RewardSpec("
f"num_states={self.num_states}, "
f"num_actions={self.num_actions}, "
f"num_parameters={self.num_parameters}, "
f"kind={kind}, "
f"parameters={self._parameter_names})"
)