"""Base class for dynamic discrete choice environments.
This module defines DDCEnvironment, a Gymnasium-compatible base class
for economic decision environments. It extends gym.Env with properties
and methods specific to structural estimation.
Key additions over standard Gym environments:
- Explicit transition matrices (for model-based estimation)
- Feature matrix for utility computation
- Problem specification (DDCProblem)
- True parameter access (for simulation studies)
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, SupportsFloat, Union
import gymnasium as gym
import jax.numpy as jnp
import numpy as np
import pandas as pd
from gymnasium import spaces
from econirl.core.types import DDCProblem, Panel
class DDCEnvironment(gym.Env, ABC):
"""Base class for dynamic discrete choice environments.
This abstract class defines the interface for environments used in
structural estimation. Subclasses must implement the abstract properties
and methods to define specific economic models.
DDCEnvironment extends gym.Env with:
- `problem_spec`: The DDCProblem specification
- `transition_matrices`: Explicit P(s'|s,a) for model-based methods
- `feature_matrix`: Features φ(s,a) for utility computation
- `true_parameters`: Ground truth (for simulation studies)
The environment can be used both for:
1. Simulation: Generate data from known parameters
2. Estimation: Provide structure needed by estimators
Example:
>>> env = MyDDCEnvironment(params)
>>> obs, info = env.reset()
>>> for _ in range(100):
... action = env.action_space.sample()
... obs, reward, terminated, truncated, info = env.step(action)
"""
metadata = {"render_modes": ["human"]}
def __init__(
self,
discount_factor: float = 0.9999,
scale_parameter: float = 1.0,
seed: int | None = None,
):
"""Initialize the environment.
Args:
discount_factor: Time discount factor β ∈ [0, 1)
scale_parameter: Logit scale parameter σ > 0
seed: Random seed for reproducibility
"""
super().__init__()
self._discount_factor = discount_factor
self._scale_parameter = scale_parameter
self._np_random = np.random.default_rng(seed)
# Subclasses should set these
self._state: int | None = None
self._current_period: int = 0
@property
@abstractmethod
def num_states(self) -> int:
"""Number of discrete states in the environment."""
...
@property
@abstractmethod
def num_actions(self) -> int:
"""Number of discrete actions available."""
...
@property
def problem_spec(self) -> DDCProblem:
"""Return the DDCProblem specification for this environment."""
return DDCProblem(
num_states=self.num_states,
num_actions=self.num_actions,
discount_factor=self._discount_factor,
scale_parameter=self._scale_parameter,
state_dim=self.state_dim,
state_encoder=self.encode_states,
)
@property
def state_dim(self) -> int:
"""Dimensionality of the continuous state representation."""
return 1
def encode_states(self, states: jnp.ndarray) -> jnp.ndarray:
"""Encode flat state indices to continuous features.
Default: normalized scalar s/(S-1) with shape (batch, 1).
Override for multi-dimensional environments.
"""
denom = max(self.num_states - 1, 1)
return jnp.expand_dims(states.astype(jnp.float32) / denom, axis=-1)
@property
@abstractmethod
def transition_matrices(self) -> jnp.ndarray:
"""Return transition probability matrices.
Returns:
Tensor of shape (num_actions, num_states, num_states)
where result[a, s, s'] = P(s' | s, a)
"""
...
@property
@abstractmethod
def feature_matrix(self) -> jnp.ndarray:
"""Return feature matrix for utility computation.
Features are the observable characteristics that enter the
utility function: U(s,a;θ) = θ · φ(s,a)
Returns:
Tensor of shape (num_states, num_actions, num_features)
"""
...
@property
@abstractmethod
def true_parameters(self) -> dict[str, float]:
"""Return the true utility parameters (for simulation studies).
Returns:
Dictionary mapping parameter names to values
"""
...
@property
@abstractmethod
def parameter_names(self) -> list[str]:
"""Return names of utility parameters in order."""
...
def get_true_parameter_vector(self) -> jnp.ndarray:
"""Return true parameters as a tensor in canonical order."""
params = self.true_parameters
return jnp.array(
[params[name] for name in self.parameter_names], dtype=jnp.float32
)
@property
def current_state(self) -> int | None:
"""Return the current state index."""
return self._state
@abstractmethod
def _get_initial_state_distribution(self) -> np.ndarray:
"""Return the initial state distribution for reset().
Returns:
Array of shape (num_states,) with probabilities
"""
...
@abstractmethod
def _compute_flow_utility(self, state: int, action: int) -> float:
"""Compute the flow utility for a state-action pair.
This should use the true_parameters to compute utility.
Args:
state: Current state index
action: Chosen action index
Returns:
Flow utility value (before adding preference shock)
"""
...
@abstractmethod
def _sample_next_state(self, state: int, action: int) -> int:
"""Sample the next state given current state and action.
Args:
state: Current state index
action: Chosen action index
Returns:
Next state index
"""
...
def reset(
self,
*,
seed: int | None = None,
options: dict[str, Any] | None = None,
) -> tuple[int, dict[str, Any]]:
"""Reset the environment to an initial state.
Args:
seed: Optional seed for random number generator
options: Optional configuration options
Returns:
Tuple of (initial_observation, info_dict)
"""
if seed is not None:
self._np_random = np.random.default_rng(seed)
# Sample initial state
init_dist = self._get_initial_state_distribution()
self._state = int(self._np_random.choice(self.num_states, p=init_dist))
self._current_period = 0
return self._state, {"period": self._current_period}
def step(
self, action: int
) -> tuple[int, SupportsFloat, bool, bool, dict[str, Any]]:
"""Take an action and observe the result.
Args:
action: The action to take
Returns:
Tuple of (next_state, reward, terminated, truncated, info)
- next_state: The new state after transition
- reward: The flow utility (reward in RL terms)
- terminated: Always False (infinite horizon)
- truncated: Always False (no time limit by default)
- info: Additional information
"""
if self._state is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
if not 0 <= action < self.num_actions:
raise ValueError(f"Invalid action {action}. Must be in [0, {self.num_actions})")
# Compute flow utility (this is the "reward" in RL terms)
utility = self._compute_flow_utility(self._state, action)
# Sample next state
prev_state = self._state
self._state = self._sample_next_state(self._state, action)
self._current_period += 1
info = {
"period": self._current_period,
"prev_state": prev_state,
"action": action,
"flow_utility": utility,
}
# DDC models are infinite horizon, so never terminated
return self._state, utility, False, False, info
def compute_utility_matrix(self, parameters: jnp.ndarray | None = None) -> jnp.ndarray:
"""Compute the full utility matrix for all state-action pairs.
Args:
parameters: Optional parameter vector. If None, uses true_parameters.
Returns:
Tensor of shape (num_states, num_actions) with flow utilities
"""
if parameters is None:
parameters = self.get_true_parameter_vector()
features = self.feature_matrix
return jnp.einsum("sak,k->sa", features, parameters)
def render(self) -> None:
"""Render the current state (optional)."""
if self._state is not None:
print(f"Period {self._current_period}: State = {self._state}")
# --- Synthetic data generation ---
def generate_panel(
self,
n_individuals: int = 1000,
n_periods: int = 100,
seed: int = 42,
as_dataframe: bool = False,
) -> Union[Panel, pd.DataFrame]:
"""Generate synthetic panel data from this environment.
Computes the optimal policy from the true parameters and
simulates trajectories for multiple individuals.
Args:
n_individuals: Number of individuals to simulate.
n_periods: Number of time periods per individual.
seed: Random seed for reproducibility.
as_dataframe: If True, return a DataFrame with
human-readable columns via _state_to_record().
Returns:
Panel object, or DataFrame if as_dataframe=True.
"""
from econirl.simulation.synthetic import simulate_panel
panel = simulate_panel(
self,
n_individuals=n_individuals,
n_periods=n_periods,
seed=seed,
)
if not as_dataframe:
return panel
records = []
for traj in panel.trajectories:
tid = traj.individual_id
for t in range(len(traj.states)):
s = int(traj.states[t])
a = int(traj.actions[t])
ns = int(traj.next_states[t])
record = {
"individual_id": tid,
"period": t,
"state": s,
"action": a,
"next_state": ns,
}
record.update(self._state_to_record(s, a))
records.append(record)
return pd.DataFrame(records)
def _state_to_record(self, state: int, action: int) -> dict[str, Any]:
"""Convert a state-action pair to human-readable record fields.
Subclasses override this to add domain-specific columns like
profit_bin, incumbent_status, etc. The base implementation
returns an empty dict.
"""
return {}
@classmethod
def info(cls) -> dict:
"""Return metadata about this environment.
Subclasses should override this to provide name, description,
source, n_states, n_actions, parameter details, etc.
"""
return {
"name": cls.__name__,
"description": cls.__doc__.split("\n")[0] if cls.__doc__ else "",
}