"""NeuralAIRL: Context-aware Adversarial Inverse Reinforcement Learning.
Learns a disentangled reward r(s,a,ctx) and shaping potential h(s,ctx) via
adversarial training against a learned policy network, then extracts
structural parameters by projecting implied rewards onto features.
No transition matrix is needed. Supports context conditioning through
pluggable state and context encoders.
Reference:
Fu, J., Luo, K., & Levine, S. (2018). Learning robust rewards with
adversarial inverse reinforcement learning. ICLR.
"""
from __future__ import annotations
from typing import Callable
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import optax
import pandas as pd
from scipy.stats import norm as scipy_norm
from econirl.core.reward_spec import RewardSpec
from econirl.core.types import Panel, TrajectoryPanel
from econirl.estimators.neural_base import NeuralEstimatorMixin
def _to_numpy(values: object) -> np.ndarray:
return np.asarray(values)
def _to_jax_float(values: object) -> jax.Array:
return jnp.asarray(values, dtype=jnp.float32)
def _to_jax_int(values: object) -> jax.Array:
return jnp.asarray(values, dtype=jnp.int32)
def _bce_with_logits(logits: jax.Array, targets: jax.Array) -> jax.Array:
return jnp.maximum(logits, 0.0) - logits * targets + jax.nn.softplus(-jnp.abs(logits))
def _sample_actions(policy_probs: jax.Array, key: jax.Array) -> jax.Array:
probs = jnp.clip(policy_probs, 1e-8, 1.0)
probs = probs / probs.sum(axis=-1, keepdims=True)
keys = jr.split(key, probs.shape[0])
return jax.vmap(lambda k, p: jr.categorical(k, jnp.log(p)))(keys, probs).astype(jnp.int32)
class _MLP(eqx.Module):
layers: tuple[eqx.nn.Linear, ...]
output_layer: eqx.nn.Linear
def __init__(
self,
in_dim: int,
out_dim: int,
hidden_dim: int,
num_layers: int,
*,
key: jax.Array,
):
n_hidden = max(num_layers, 0)
keys = jr.split(key, n_hidden + 1)
layers: list[eqx.nn.Linear] = []
current_dim = in_dim
for idx in range(n_hidden):
layers.append(eqx.nn.Linear(current_dim, hidden_dim, key=keys[idx]))
current_dim = hidden_dim
self.layers = tuple(layers)
self.output_layer = eqx.nn.Linear(current_dim, out_dim, key=keys[-1])
def _forward_single(self, x: jax.Array) -> jax.Array:
h = x
for layer in self.layers:
h = jax.nn.relu(layer(h))
return self.output_layer(h)
def __call__(self, x: jax.Array) -> jax.Array:
x = jnp.asarray(x, dtype=jnp.float32)
if x.ndim == 1:
return self._forward_single(x)
return jax.vmap(self._forward_single)(x)
def eval(self) -> _MLP:
return self
class _RewardNetwork(eqx.Module):
n_actions: int = eqx.field(static=True)
net: _MLP
def __init__(
self,
state_dim: int,
context_dim: int,
n_actions: int,
hidden_dim: int,
num_layers: int,
*,
key: jax.Array,
):
self.n_actions = n_actions
self.net = _MLP(
state_dim + context_dim + n_actions,
1,
hidden_dim,
num_layers,
key=key,
)
def __call__(
self,
state_feat: object,
ctx_feat: object,
action_onehot: object,
) -> object:
sf = _to_jax_float(state_feat)
cf = _to_jax_float(ctx_feat)
ao = _to_jax_float(action_onehot)
x = jnp.concatenate([sf, cf, ao], axis=-1)
out = jnp.squeeze(self.net(x), axis=-1)
return out
def all_actions(
self,
state_feat: object,
ctx_feat: object,
n_actions: int,
) -> object:
sf = _to_jax_float(state_feat)
cf = _to_jax_float(ctx_feat)
actions = jnp.eye(n_actions, dtype=jnp.float32)
sf_exp = jnp.repeat(sf[:, None, :], n_actions, axis=1)
cf_exp = jnp.repeat(cf[:, None, :], n_actions, axis=1)
a_exp = jnp.repeat(actions[None, :, :], sf.shape[0], axis=0)
x = jnp.concatenate([sf_exp, cf_exp, a_exp], axis=-1)
out = jnp.squeeze(jax.vmap(self.net)(x), axis=-1)
return out
def eval(self) -> _RewardNetwork:
return self
class _ShapingNetwork(eqx.Module):
net: _MLP
def __init__(
self,
state_dim: int,
context_dim: int,
hidden_dim: int,
num_layers: int,
*,
key: jax.Array,
):
self.net = _MLP(
state_dim + context_dim,
1,
hidden_dim,
num_layers,
key=key,
)
def __call__(self, state_feat: object, ctx_feat: object) -> object:
sf = _to_jax_float(state_feat)
cf = _to_jax_float(ctx_feat)
x = jnp.concatenate([sf, cf], axis=-1)
out = jnp.squeeze(self.net(x), axis=-1)
return out
def eval(self) -> _ShapingNetwork:
return self
class _PolicyNetwork(eqx.Module):
n_actions: int = eqx.field(static=True)
net: _MLP
def __init__(
self,
state_dim: int,
context_dim: int,
n_actions: int,
hidden_dim: int,
num_layers: int,
*,
key: jax.Array,
):
self.n_actions = n_actions
self.net = _MLP(
state_dim + context_dim,
n_actions,
hidden_dim,
num_layers,
key=key,
)
def logits(self, state_feat: object, ctx_feat: object) -> jax.Array:
sf = _to_jax_float(state_feat)
cf = _to_jax_float(ctx_feat)
x = jnp.concatenate([sf, cf], axis=-1)
return jnp.asarray(self.net(x), dtype=jnp.float32)
def __call__(self, state_feat: object, ctx_feat: object) -> object:
logits = self.logits(state_feat, ctx_feat)
probs = jax.nn.softmax(logits, axis=-1)
return probs
def log_prob(
self,
state_feat: object,
ctx_feat: object,
actions: object,
) -> object:
logits = self.logits(state_feat, ctx_feat)
actions_j = _to_jax_int(actions)
log_probs = jax.nn.log_softmax(logits, axis=-1)
out = log_probs[jnp.arange(actions_j.shape[0]), actions_j]
return out
def eval(self) -> _PolicyNetwork:
return self
class _DiscriminatorBundle(eqx.Module):
reward_net: _RewardNetwork
shaping_net: _ShapingNetwork
class NeuralAIRL(NeuralEstimatorMixin):
"""Context-aware AIRL estimator with sklearn-style API."""
def __init__(
self,
n_actions: int = 8,
discount: float = 0.95,
reward_hidden_dim: int = 128,
reward_num_layers: int = 3,
shaping_hidden_dim: int = 128,
shaping_num_layers: int = 3,
policy_hidden_dim: int = 128,
policy_num_layers: int = 3,
batch_size: int = 512,
max_epochs: int = 500,
disc_lr: float = 1e-3,
policy_lr: float = 1e-3,
disc_steps: int = 5,
gradient_clip: float = 1.0,
patience: int = 50,
label_smoothing: float = 0.0,
state_encoder: Callable[[object], object] | None = None,
context_encoder: Callable[[object], object] | None = None,
state_dim: int | None = None,
context_dim: int = 0,
feature_names: list[str] | None = None,
verbose: bool = False,
):
self.n_actions = n_actions
self.discount = discount
self.reward_hidden_dim = reward_hidden_dim
self.reward_num_layers = reward_num_layers
self.shaping_hidden_dim = shaping_hidden_dim
self.shaping_num_layers = shaping_num_layers
self.policy_hidden_dim = policy_hidden_dim
self.policy_num_layers = policy_num_layers
self.batch_size = batch_size
self.max_epochs = max_epochs
self.disc_lr = disc_lr
self.policy_lr = policy_lr
self.disc_steps = disc_steps
self.gradient_clip = gradient_clip
self.patience = patience
self.label_smoothing = label_smoothing
self.state_encoder = state_encoder
self.context_encoder = context_encoder
self.state_dim = state_dim
self.context_dim = context_dim
self.feature_names = feature_names
self.verbose = verbose
self.params_: dict[str, float] | None = None
self.se_: dict[str, float] | None = None
self.pvalues_: dict[str, float] | None = None
self.coef_: np.ndarray | None = None
self.policy_: np.ndarray | None = None
self.value_: np.ndarray | None = None
self.projection_r2_: float | None = None
self.converged_: bool | None = None
self.n_epochs_: int | None = None
self._reward_net: _RewardNetwork | None = None
self._shaping_net: _ShapingNetwork | None = None
self._policy_net: _PolicyNetwork | None = None
self._state_encoder: Callable[[object], jax.Array] | None = None
self._context_encoder: Callable[[object], jax.Array] | None = None
self._state_dim: int | None = None
self._context_dim: int | None = None
self._n_states: int | None = None
def fit(
self,
data: pd.DataFrame | Panel | TrajectoryPanel,
state: str | None = None,
action: str | None = None,
id: str | None = None,
context: str | object | None = None,
features: RewardSpec | object | None = None,
transitions: object = None,
) -> NeuralAIRL:
all_states, all_actions, all_next, all_contexts = self._extract_data(
data, state, action, id, context
)
n_states = int(np.asarray(all_states).max()) + 1
self._n_states = n_states
self._build_encoders(all_states, all_contexts, n_states)
key = jr.PRNGKey(np.random.randint(0, 2**31 - 1))
reward_key, shaping_key, policy_key = jr.split(key, 3)
self._reward_net = _RewardNetwork(
self._state_dim,
self._context_dim,
self.n_actions,
self.reward_hidden_dim,
self.reward_num_layers,
key=reward_key,
)
self._shaping_net = _ShapingNetwork(
self._state_dim,
self._context_dim,
self.shaping_hidden_dim,
self.shaping_num_layers,
key=shaping_key,
)
self._policy_net = _PolicyNetwork(
self._state_dim,
self._context_dim,
self.n_actions,
self.policy_hidden_dim,
self.policy_num_layers,
key=policy_key,
)
self._train(all_states, all_actions, all_next, all_contexts)
self._extract_policy_and_value(all_states, all_contexts, n_states)
if features is not None:
self._project_onto_features(
features, all_states, all_actions, all_contexts
)
else:
self.params_ = None
self.se_ = None
self.pvalues_ = None
self.projection_r2_ = None
self.coef_ = None
return self
def _extract_data(
self,
data: pd.DataFrame | Panel | TrajectoryPanel,
state: str | None,
action: str | None,
id: str | None,
context: str | object | None,
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
if isinstance(data, pd.DataFrame):
if state is None or action is None or id is None:
raise ValueError(
"state, action, and id column names are required "
"when data is a DataFrame"
)
panel = TrajectoryPanel.from_dataframe(
data, state=state, action=action, id=id
)
all_states = jnp.asarray(panel.all_states, dtype=jnp.int32)
all_actions = jnp.asarray(panel.all_actions, dtype=jnp.int32)
all_next = jnp.asarray(panel.all_next_states, dtype=jnp.int32)
if isinstance(context, str):
all_contexts = self._extract_context_from_df(data, id, context, panel)
elif context is not None:
all_contexts = _to_jax_int(context)
else:
all_contexts = jnp.zeros(len(all_states), dtype=jnp.int32)
elif isinstance(data, (Panel, TrajectoryPanel)):
all_states = jnp.asarray(data.get_all_states(), dtype=jnp.int32)
all_actions = jnp.asarray(data.get_all_actions(), dtype=jnp.int32)
all_next = jnp.asarray(data.get_all_next_states(), dtype=jnp.int32)
if context is not None:
all_contexts = _to_jax_int(context)
else:
all_contexts = jnp.zeros(len(all_states), dtype=jnp.int32)
else:
raise TypeError(
f"data must be a DataFrame, Panel, or TrajectoryPanel, got {type(data)}"
)
return all_states, all_actions, all_next, all_contexts
def _extract_context_from_df(
self,
df: pd.DataFrame,
id_col: str,
context_col: str,
panel: TrajectoryPanel,
) -> jax.Array:
contexts: list[int] = []
for _, group in df.groupby(id_col, sort=True):
group = group.sort_index()
contexts.extend(group[context_col].values.tolist())
return jnp.asarray(contexts, dtype=jnp.int32)
def _call_encoder(self, encoder: Callable[[object], object], values: object) -> jax.Array:
encoded = encoder(values)
return _to_jax_float(encoded)
def _build_encoders(
self,
all_states: jax.Array,
all_contexts: jax.Array,
n_states: int,
) -> None:
if self.state_encoder is not None:
self._state_encoder = lambda s: self._call_encoder(self.state_encoder, s)
self._state_dim = self.state_dim or 1
else:
max_s = max(n_states - 1, 1)
self._state_encoder = lambda s, _ms=max_s: (
_to_jax_float(s) / float(_ms)
).reshape(-1, 1)
self._state_dim = 1
if self.context_encoder is not None:
self._context_encoder = lambda c: self._call_encoder(self.context_encoder, c)
self._context_dim = self.context_dim or 1
else:
n_ctx = max(int(np.asarray(all_contexts).max()), 1) if len(all_contexts) else 1
self._context_encoder = lambda c, _mc=n_ctx: (
_to_jax_float(c) / float(_mc)
).reshape(-1, 1)
self._context_dim = 1
def _train(
self,
states: jax.Array,
actions: jax.Array,
next_states: jax.Array,
contexts: jax.Array,
) -> None:
disc_model = _DiscriminatorBundle(self._reward_net, self._shaping_net)
policy_net = self._policy_net
disc_transforms = []
policy_transforms = []
if self.gradient_clip > 0:
disc_transforms.append(optax.clip_by_global_norm(self.gradient_clip))
policy_transforms.append(optax.clip_by_global_norm(self.gradient_clip))
disc_transforms.append(optax.adam(self.disc_lr))
policy_transforms.append(optax.adam(self.policy_lr))
disc_optimizer = optax.chain(*disc_transforms)
policy_optimizer = optax.chain(*policy_transforms)
disc_opt_state = disc_optimizer.init(eqx.filter(disc_model, eqx.is_inexact_array))
policy_opt_state = policy_optimizer.init(eqx.filter(policy_net, eqx.is_inexact_array))
N = len(states)
best_loss = float("inf")
patience_counter = 0
expert_label = 1.0 - self.label_smoothing
policy_label = 0.0 + self.label_smoothing
def compute_disc_logits(
disc: _DiscriminatorBundle,
policy: _PolicyNetwork,
s_feat: jax.Array,
ctx_feat: jax.Array,
action_idx: jax.Array,
ns_feat: jax.Array,
) -> jax.Array:
action_oh = jax.nn.one_hot(action_idx, self.n_actions, dtype=jnp.float32)
g = disc.reward_net(s_feat, ctx_feat, action_oh)
h_s = disc.shaping_net(s_feat, ctx_feat)
h_ns = disc.shaping_net(ns_feat, ctx_feat)
log_pi = policy.log_prob(s_feat, ctx_feat, action_idx)
return g + self.discount * h_ns - h_s - log_pi
@eqx.filter_value_and_grad
def disc_loss_fn(
disc: _DiscriminatorBundle,
policy: _PolicyNetwork,
s_feat: jax.Array,
ctx_feat: jax.Array,
action_idx: jax.Array,
ns_feat: jax.Array,
key: jax.Array,
) -> jax.Array:
expert_logits = compute_disc_logits(disc, policy, s_feat, ctx_feat, action_idx, ns_feat)
policy_probs = policy(s_feat, ctx_feat)
policy_actions = _sample_actions(policy_probs, key)
policy_logits = compute_disc_logits(disc, policy, s_feat, ctx_feat, policy_actions, ns_feat)
expert_targets = jnp.full_like(expert_logits, expert_label)
policy_targets = jnp.full_like(policy_logits, policy_label)
loss = _bce_with_logits(expert_logits, expert_targets).mean()
loss = loss + _bce_with_logits(policy_logits, policy_targets).mean()
return loss
@eqx.filter_value_and_grad
def policy_loss_fn(
policy: _PolicyNetwork,
disc: _DiscriminatorBundle,
s_feat: jax.Array,
ctx_feat: jax.Array,
ns_feat: jax.Array,
key: jax.Array,
) -> jax.Array:
policy_probs = policy(s_feat, ctx_feat)
policy_actions = _sample_actions(policy_probs, key)
log_pi = policy.log_prob(s_feat, ctx_feat, policy_actions)
disc_logits = compute_disc_logits(disc, policy, s_feat, ctx_feat, policy_actions, ns_feat)
disc_reward = -jax.nn.softplus(-disc_logits)
entropy = -(policy_probs * jnp.log(policy_probs + 1e-10)).sum(axis=-1)
return -(log_pi * jax.lax.stop_gradient(disc_reward)).mean() - 0.01 * entropy.mean()
@eqx.filter_jit
def disc_step(
disc: _DiscriminatorBundle,
disc_state: optax.OptState,
policy: _PolicyNetwork,
s_feat: jax.Array,
ctx_feat: jax.Array,
action_idx: jax.Array,
ns_feat: jax.Array,
key: jax.Array,
) -> tuple[_DiscriminatorBundle, optax.OptState, jax.Array]:
loss, grads = disc_loss_fn(disc, policy, s_feat, ctx_feat, action_idx, ns_feat, key)
updates, disc_state = disc_optimizer.update(grads, disc_state, disc)
disc = eqx.apply_updates(disc, updates)
return disc, disc_state, loss
@eqx.filter_jit
def policy_step(
policy: _PolicyNetwork,
policy_state: optax.OptState,
disc: _DiscriminatorBundle,
s_feat: jax.Array,
ctx_feat: jax.Array,
ns_feat: jax.Array,
key: jax.Array,
) -> tuple[_PolicyNetwork, optax.OptState, jax.Array]:
loss, grads = policy_loss_fn(policy, disc, s_feat, ctx_feat, ns_feat, key)
updates, policy_state = policy_optimizer.update(grads, policy_state, policy)
policy = eqx.apply_updates(policy, updates)
return policy, policy_state, loss
best_disc_model = disc_model
best_policy_net = policy_net
for epoch in range(self.max_epochs):
perm = np.random.permutation(N)
epoch_disc_loss = 0.0
epoch_policy_loss = 0.0
n_batches = 0
for start in range(0, N, self.batch_size):
idx = perm[start : start + self.batch_size]
s = states[idx]
a = actions[idx]
ns = next_states[idx]
ctx = contexts[idx]
s_feat = self._state_encoder(s)
ns_feat = self._state_encoder(ns)
ctx_feat = self._context_encoder(ctx)
last_disc_loss = 0.0
for _ in range(self.disc_steps):
disc_key = jr.PRNGKey(np.random.randint(0, 2**31 - 1))
disc_model, disc_opt_state, disc_loss = disc_step(
disc_model,
disc_opt_state,
policy_net,
s_feat,
ctx_feat,
a,
ns_feat,
disc_key,
)
last_disc_loss = float(disc_loss)
policy_key = jr.PRNGKey(np.random.randint(0, 2**31 - 1))
policy_net, policy_opt_state, policy_loss = policy_step(
policy_net,
policy_opt_state,
disc_model,
s_feat,
ctx_feat,
ns_feat,
policy_key,
)
epoch_disc_loss += last_disc_loss
epoch_policy_loss += float(policy_loss)
n_batches += 1
avg_disc_loss = epoch_disc_loss / max(n_batches, 1)
avg_policy_loss = epoch_policy_loss / max(n_batches, 1)
if self.verbose and (epoch + 1) % 50 == 0:
print(
f" Epoch {epoch + 1}: disc_loss={avg_disc_loss:.4f} "
f"policy_loss={avg_policy_loss:.4f}"
)
if avg_disc_loss < best_loss - 1e-4:
best_loss = avg_disc_loss
patience_counter = 0
best_disc_model = disc_model
best_policy_net = policy_net
else:
patience_counter += 1
if patience_counter >= self.patience:
if self.verbose:
print(f" Early stopping at epoch {epoch + 1}")
break
self._reward_net = best_disc_model.reward_net
self._shaping_net = best_disc_model.shaping_net
self._policy_net = best_policy_net
self.converged_ = patience_counter >= self.patience or epoch == self.max_epochs - 1
self.n_epochs_ = epoch + 1
def _extract_policy_and_value(
self,
all_states: jax.Array,
all_contexts: jax.Array,
n_states: int,
) -> None:
unique_states = jnp.arange(n_states, dtype=jnp.int32)
ctx_default = jnp.zeros(n_states, dtype=jnp.int32)
s_feat = self._state_encoder(unique_states)
ctx_feat = self._context_encoder(ctx_default)
policy = self._policy_net(s_feat, ctx_feat)
value = self._shaping_net(s_feat, ctx_feat)
self.policy_ = np.asarray(policy)
self.value_ = np.asarray(value)
def _project_onto_features(
self,
features: RewardSpec | object,
states: jax.Array,
actions: jax.Array,
contexts: jax.Array,
) -> None:
if isinstance(features, RewardSpec):
feat_matrix = features.feature_matrix
names = features.parameter_names
else:
feat_matrix = features
names = self.feature_names or [f"f{i}" for i in range(np.asarray(features).shape[-1])]
states_j = _to_jax_int(states)
actions_j = _to_jax_int(actions)
contexts_j = _to_jax_int(contexts)
s_feat = self._state_encoder(states_j)
ctx_feat = self._context_encoder(contexts_j)
a_oh = jax.nn.one_hot(actions_j, self.n_actions, dtype=jnp.float32)
rewards = jnp.asarray(self._reward_net(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
feat_np = _to_numpy(feat_matrix)
phi = feat_np[np.asarray(states_j), np.asarray(actions_j), :]
theta, se, r2 = self._project_parameters(phi, rewards)
self.params_ = {n: float(v) for n, v in zip(names, theta)}
self.se_ = {n: float(v) for n, v in zip(names, se)}
self.pvalues_ = self._compute_pvalues(self.params_, self.se_)
self.projection_r2_ = r2
self.coef_ = np.asarray(theta)
@property
def reward_matrix_(self) -> np.ndarray | None:
if self._reward_net is None or self._n_states is None:
return None
unique_states = jnp.arange(self._n_states, dtype=jnp.int32)
ctx_default = jnp.zeros(self._n_states, dtype=jnp.int32)
s_feat = self._state_encoder(unique_states)
ctx_feat = self._context_encoder(ctx_default)
r_all = self._reward_net.all_actions(s_feat, ctx_feat, self.n_actions)
return np.asarray(r_all)
def predict_proba(self, states: np.ndarray) -> np.ndarray:
if self.policy_ is None:
raise RuntimeError("Model not fitted. Call fit() first.")
states = np.asarray(states, dtype=np.int64)
return self.policy_[states]
def predict_proba_from_features(
self,
state_features: object,
contexts: object | None = None,
) -> np.ndarray:
if self._policy_net is None:
raise RuntimeError("Model not fitted. Call fit() first.")
s_feat = _to_jax_float(state_features)
if s_feat.ndim == 1:
s_feat = s_feat[None, :]
if contexts is None:
contexts = jnp.zeros(s_feat.shape[0], dtype=jnp.int32)
ctx_feat = self._context_encoder(contexts)
probs = self._policy_net(s_feat, ctx_feat)
return np.asarray(probs)
def predict_reward_from_features(
self,
state_features: object,
actions: object,
contexts: object | None = None,
) -> np.ndarray:
if self._reward_net is None:
raise RuntimeError("Model not fitted. Call fit() first.")
s_feat = _to_jax_float(state_features)
if s_feat.ndim == 1:
s_feat = s_feat[None, :]
actions_j = _to_jax_int(actions)
if actions_j.ndim == 0:
actions_j = actions_j[None]
if contexts is None:
contexts = jnp.zeros(s_feat.shape[0], dtype=jnp.int32)
ctx_feat = self._context_encoder(contexts)
a_oh = jax.nn.one_hot(actions_j, self.n_actions, dtype=jnp.float32)
rewards = self._reward_net(s_feat, ctx_feat, a_oh)
return np.asarray(rewards)
def predict_reward(
self,
states: object,
actions: object,
contexts: object | None = None,
) -> object:
if self._reward_net is None:
raise RuntimeError("Model not fitted. Call fit() first.")
states_j = _to_jax_int(states)
actions_j = _to_jax_int(actions)
if contexts is None:
contexts_j = jnp.zeros(states_j.shape[0], dtype=jnp.int32)
else:
contexts_j = _to_jax_int(contexts)
s_feat = self._state_encoder(states_j)
ctx_feat = self._context_encoder(contexts_j)
a_oh = jax.nn.one_hot(actions_j, self.n_actions, dtype=jnp.float32)
rewards = jnp.asarray(self._reward_net(s_feat, ctx_feat, a_oh), dtype=jnp.float32)
return rewards
def conf_int(self, alpha: float = 0.05) -> dict[str, tuple[float, float]]:
if self.params_ is None or self.se_ is None:
raise RuntimeError(
"No projected parameters available. "
"Call fit() with features= to extract structural parameters."
)
z = scipy_norm.ppf(1 - alpha / 2)
intervals: dict[str, tuple[float, float]] = {}
for name in self.params_:
est = self.params_[name]
se = self.se_[name]
if np.isfinite(se):
intervals[name] = (est - z * se, est + z * se)
else:
intervals[name] = (float("nan"), float("nan"))
return intervals
def summary(self) -> str:
if self.policy_ is None:
return "NeuralAIRL: Not fitted yet. Call fit() first."
n_obs = None
if self._n_states is not None and self.policy_ is not None:
n_obs = self._n_states
return self._format_neural_summary(
method_name="NeuralAIRL",
params=self.params_,
se=self.se_,
pvalues=self.pvalues_,
projection_r2=self.projection_r2_,
n_observations=n_obs,
n_epochs=self.n_epochs_,
converged=self.converged_,
discount=self.discount,
context_dim=self._context_dim,
extra_lines=[
f"Reward network: {self.reward_num_layers} layers x {self.reward_hidden_dim} hidden",
f"Shaping network: {self.shaping_num_layers} layers x {self.shaping_hidden_dim} hidden",
f"Policy network: {self.policy_num_layers} layers x {self.policy_hidden_dim} hidden",
],
)
def __repr__(self) -> str:
fitted = self.policy_ is not None
return (
f"NeuralAIRL(n_actions={self.n_actions}, "
f"discount={self.discount}, "
f"fitted={fitted})"
)