Source code for econirl.estimation.iq_learn

"""IQ-Learn: Inverse soft-Q Learning for Imitation.

This module implements IQ-Learn (Garg et al. 2021) adapted for tabular DDC models.
IQ-Learn learns a single soft Q-function that implicitly represents both reward
and policy, avoiding adversarial training entirely.

Algorithm:
    1. Parameterize Q(s,a) as tabular or linear in features
    2. Optimize the IQ-Learn objective (concave in Q):
       - Chi-squared (offline): min_Q -E_rho[Q(s,a)-V*(s)] + (1/4a)E_rho[td^2]
       - Simple (TV distance): max_Q E_rho[td] - (1-gamma)E_p0[V*(s0)]
    3. Extract policy: pi(a|s) = softmax(Q(s,a)/sigma)
    4. Recover reward: r(s,a) = Q(s,a) - gamma * E[V*(s')]

Reference:
    Garg, D., Chakraborty, S., Cundy, C., Song, J., & Ermon, S. (2021).
    "IQ-Learn: Inverse soft-Q Learning for Imitation." NeurIPS.
"""

from __future__ import annotations

import time
from dataclasses import dataclass
from typing import Literal

import jax
import jax.numpy as jnp
import numpy as np

from econirl.core.types import DDCProblem, Panel
from econirl.estimation.base import BaseEstimator, EstimationResult
from econirl.inference.results import EstimationSummary, GoodnessOfFit
from econirl.preferences.action_reward import ActionDependentReward
from econirl.preferences.base import BaseUtilityFunction
from econirl.preferences.reward import LinearReward


@dataclass
class IQLearnConfig:
    """Configuration for IQ-Learn estimation.

    Attributes:
        q_type: Parameterization of Q-function ("tabular" or "linear")
        divergence: Statistical divergence for the objective
        alpha: Regularization strength for chi-squared divergence
        optimizer: Optimization method
        learning_rate: Learning rate for Adam optimizer
        max_iter: Maximum optimization iterations
        convergence_tol: Gradient norm convergence tolerance
        verbose: Whether to print progress
    """

    q_type: Literal["tabular", "linear", "neural"] = "tabular"
    divergence: Literal["chi2", "simple"] = "chi2"
    alpha: float = 1.0
    optimizer: Literal["L-BFGS-B", "adam"] = "L-BFGS-B"
    learning_rate: float = 0.01
    max_iter: int = 500
    convergence_tol: float = 1e-6
    hidden_dim: int = 64
    num_layers: int = 2
    seed: int = 0
    verbose: bool = False


class IQLearnEstimator(BaseEstimator):
    """Inverse soft-Q Learning for tabular MDPs.

    IQ-Learn learns a soft Q-function that implicitly defines both the optimal
    policy and reward function. The key insight is that the IRL min-max problem
    over (reward, policy) collapses to a concave maximization over Q alone,
    since the optimal policy is deterministically given by softmax(Q).

    Parameters
    ----------
    config : IQLearnConfig, optional
        Configuration object with algorithm parameters.
    **kwargs
        Override individual config parameters.

    Examples
    --------
    >>> from econirl.estimation.iq_learn import IQLearnEstimator, IQLearnConfig
    >>> config = IQLearnConfig(q_type="linear", divergence="chi2")
    >>> estimator = IQLearnEstimator(config=config)
    >>> result = estimator.estimate(panel, utility, problem, transitions)
    """

    def __init__(
        self,
        config: IQLearnConfig | None = None,
        **kwargs,
    ):
        if config is None:
            config = IQLearnConfig(**kwargs)
        else:
            for key, value in kwargs.items():
                if hasattr(config, key):
                    setattr(config, key, value)

        super().__init__(
            se_method="asymptotic",
            compute_hessian=False,
            verbose=config.verbose,
        )
        self.config = config

    @property
    def name(self) -> str:
        return "IQ-Learn (Garg et al. 2021)"

    def estimate(
        self,
        panel: Panel,
        utility: BaseUtilityFunction,
        problem: DDCProblem,
        transitions: np.ndarray | jnp.ndarray,
        initial_params: np.ndarray | None = None,
        **kwargs,
    ) -> EstimationSummary:
        """Estimate reward function using IQ-Learn.

        Overrides base class to handle Q-function parameter naming.
        """
        start_time = time.time()

        result = self._optimize(
            panel=panel,
            utility=utility,
            problem=problem,
            transitions=transitions,
            initial_params=initial_params,
            **kwargs,
        )

        if len(result.parameters) == utility.num_parameters:
            param_names = utility.parameter_names
        else:
            param_names = [
                f"R({s},{a})"
                for s in range(problem.num_states)
                for a in range(problem.num_actions)
            ]

        standard_errors = jnp.full_like(result.parameters, float("nan"))

        n_obs = panel.num_observations
        n_params = len(result.parameters)
        ll = result.log_likelihood

        goodness_of_fit = GoodnessOfFit(
            log_likelihood=ll,
            num_parameters=n_params,
            num_observations=n_obs,
            aic=-2 * ll + 2 * n_params,
            bic=-2 * ll + n_params * np.log(n_obs),
            prediction_accuracy=self._compute_prediction_accuracy(
                panel, jnp.asarray(result.policy)
            ),
        )

        total_time = time.time() - start_time

        return EstimationSummary(
            parameters=result.parameters,
            parameter_names=param_names,
            standard_errors=standard_errors,
            hessian=None,
            variance_covariance=None,
            method=self.name,
            num_observations=n_obs,
            num_individuals=panel.num_individuals,
            num_periods=max(panel.num_periods_per_individual),
            discount_factor=problem.discount_factor,
            scale_parameter=problem.scale_parameter,
            log_likelihood=ll,
            goodness_of_fit=goodness_of_fit,
            identification=None,
            converged=result.converged,
            num_iterations=result.num_iterations,
            convergence_message=result.message,
            value_function=result.value_function,
            policy=result.policy,
            estimation_time=total_time,
            metadata=result.metadata,
        )

    def _optimize_neural(
        self,
        panel: Panel,
        utility: BaseUtilityFunction,
        problem: DDCProblem,
        transitions: jnp.ndarray,
        expert_states_jax: jnp.ndarray,
        expert_actions_jax: jnp.ndarray,
        trans_f64: jnp.ndarray,
        initial_dist: jnp.ndarray,
        n_states: int,
        n_actions: int,
        sigma: float,
        gamma: float,
        alpha: float,
        expert_state_coverage: float,
        expert_state_action_coverage: float,
        start_time: float,
    ) -> EstimationResult:
        """Neural Q-head variant.

        Trains a small feedforward Q-network mapping state features to a
        vector of action-conditional Q-values, optimizing the IQ-Learn
        objective with Adam. Provided to fulfill the JSS paper claim of a
        neural Q head alongside the tabular and linear parametrizations.
        """
        import equinox as eqx
        import optax

        all_states = jnp.arange(n_states)
        if problem.state_encoder is not None:
            state_features = problem.state_encoder(all_states)
        else:
            denom = max(n_states - 1, 1)
            state_features = (all_states.astype(jnp.float64) / denom)[:, None]
        state_dim = int(state_features.shape[1])

        _hidden = self.config.hidden_dim
        _layers = self.config.num_layers

        class _QNet(eqx.Module):
            mlp: eqx.nn.MLP

            def __init__(self, *, key):
                self.mlp = eqx.nn.MLP(
                    in_size=state_dim,
                    out_size=n_actions,
                    width_size=_hidden,
                    depth=_layers,
                    activation=jax.nn.relu,
                    key=key,
                )

            def __call__(self, x):
                return self.mlp(x)

        key = jax.random.PRNGKey(self.config.seed)
        q_net = _QNet(key=key)

        optimizer = optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.adam(self.config.learning_rate),
        )
        opt_state = optimizer.init(eqx.filter(q_net, eqx.is_array))

        divergence = self.config.divergence

        @eqx.filter_jit
        def train_step(q_net, opt_state):
            def loss_fn(model):
                Q = jax.vmap(model)(state_features)  # (S, A)
                V_star = sigma * jax.scipy.special.logsumexp(Q / sigma, axis=1)
                EV = jnp.einsum("ast,t->as", trans_f64, V_star).T
                td = Q - gamma * EV
                Q_expert = Q[expert_states_jax, expert_actions_jax]
                V_expert = V_star[expert_states_jax]
                if divergence == "chi2":
                    td_expert = td[expert_states_jax, expert_actions_jax]
                    loss = (
                        -(Q_expert - V_expert).mean()
                        + (1.0 / (4 * alpha)) * (td_expert ** 2).mean()
                    )
                else:
                    td_expert = td[expert_states_jax, expert_actions_jax]
                    loss = -td_expert.mean() + (1 - gamma) * jnp.dot(
                        initial_dist, V_star
                    )
                return loss

            loss, grads = eqx.filter_value_and_grad(loss_fn)(q_net)
            updates, new_opt = optimizer.update(
                grads, opt_state, eqx.filter(q_net, eqx.is_array)
            )
            q_net = eqx.apply_updates(q_net, updates)
            return q_net, new_opt, loss

        loss_history: list[float] = []
        from tqdm import tqdm
        pbar = tqdm(
            range(self.config.max_iter),
            desc="IQ-Learn (neural)",
            disable=not self.config.verbose,
            leave=True,
        )
        for it in pbar:
            q_net, opt_state, loss = train_step(q_net, opt_state)
            loss_history.append(float(loss))
            pbar.set_postfix({"loss": f"{float(loss):.4f}"})

        Q_table = jax.vmap(q_net)(state_features).astype(jnp.float32)
        policy = jax.nn.softmax(Q_table / sigma, axis=1)
        V = sigma * jax.scipy.special.logsumexp(Q_table / sigma, axis=1)
        transitions_f32 = jnp.asarray(transitions, dtype=jnp.float32)
        EV = jnp.einsum("ast,t->as", transitions_f32, V).T
        reward_table = Q_table - gamma * EV
        reward_params, projected_reward_matrix = self._project_reward_to_utility_basis(
            utility,
            n_states,
            n_actions,
            reward_table,
        )

        log_probs = jax.nn.log_softmax(Q_table / sigma, axis=1)
        ll = float(log_probs[expert_states_jax, expert_actions_jax].sum())
        final_objective = loss_history[-1] if loss_history else float("nan")
        parameters = (
            reward_params.astype(jnp.float32)
            if reward_params is not None
            else Q_table.flatten().astype(jnp.float32)
        )

        return EstimationResult(
            parameters=parameters,
            log_likelihood=ll,
            value_function=V,
            policy=policy,
            hessian=None,
            converged=True,
            num_iterations=self.config.max_iter,
            num_function_evals=self.config.max_iter,
            message="Neural Q-head trained",
            optimization_time=time.time() - start_time,
            metadata={
                "q_type": "neural",
                "divergence": self.config.divergence,
                "alpha": self.config.alpha,
                "q_table": np.asarray(Q_table).tolist(),
                "reward_matrix": np.asarray(reward_table).tolist(),
                "reward_table": np.asarray(reward_table).tolist(),
                "raw_bellman_reward_table": np.asarray(reward_table).tolist(),
                "projected_reward_matrix": (
                    np.asarray(projected_reward_matrix).tolist()
                    if projected_reward_matrix is not None
                    else None
                ),
                "reward_params": (
                    np.asarray(reward_params).tolist()
                    if reward_params is not None
                    else None
                ),
                "counterfactual_reward_source": "raw_bellman_reward_table",
                "loss_history": loss_history,
                "final_objective": final_objective,
                "expert_state_coverage": float(expert_state_coverage),
                "expert_state_action_coverage": float(expert_state_action_coverage),
            },
        )

    def _project_reward_to_utility_basis(
        self,
        utility: BaseUtilityFunction,
        n_states: int,
        n_actions: int,
        reward_table: jnp.ndarray,
    ) -> tuple[jnp.ndarray | None, jnp.ndarray | None]:
        """Least-squares projection of a recovered reward table onto truth features."""

        feat = None
        if hasattr(utility, "feature_matrix"):
            feat = jnp.asarray(utility.feature_matrix, dtype=jnp.float32)
        elif isinstance(utility, LinearReward) and hasattr(utility, "state_features"):
            sf = jnp.asarray(utility.state_features, dtype=jnp.float32)
            feat = jnp.broadcast_to(sf[:, None, :], (n_states, n_actions, sf.shape[1]))

        if feat is None:
            return None, None

        phi = feat.reshape(-1, feat.shape[2])
        r_flat = reward_table.flatten()
        phi_aug = jnp.concatenate([phi, jnp.ones((phi.shape[0], 1))], axis=1)
        params_aug = jnp.linalg.lstsq(phi_aug, r_flat, rcond=None)[0]
        reward_params = params_aug[:-1]
        projected_reward_matrix = (phi @ reward_params).reshape(n_states, n_actions)
        return reward_params, projected_reward_matrix

    def _compute_initial_distribution(
        self,
        panel: Panel,
        n_states: int,
    ) -> jnp.ndarray:
        """Compute empirical initial state distribution from data."""
        counts = np.zeros(n_states, dtype=np.float64)
        for traj in panel.trajectories:
            if len(traj) > 0:
                s0 = int(np.asarray(traj.states[0]))
                counts[s0] += 1
        total = counts.sum()
        if total > 0:
            counts = counts / total
        else:
            counts = np.ones(n_states, dtype=np.float64) / n_states
        return jnp.array(counts)

    def _optimize(
        self,
        panel: Panel,
        utility: BaseUtilityFunction,
        problem: DDCProblem,
        transitions: np.ndarray | jnp.ndarray,
        initial_params: np.ndarray | None = None,
        **kwargs,
    ) -> EstimationResult:
        """Run IQ-Learn optimization.

        Learns a soft Q-function by optimizing the IQ-Learn objective,
        then extracts policy and reward via the inverse Bellman operator.
        """
        start_time = time.time()

        n_states = problem.num_states
        n_actions = problem.num_actions
        gamma = problem.discount_factor
        sigma = problem.scale_parameter
        alpha = self.config.alpha

        # Extract expert (s, a, s') from panel
        expert_states = np.asarray(panel.get_all_states(), dtype=np.int64)
        expert_actions = np.asarray(panel.get_all_actions(), dtype=np.int64)

        # Convert to JAX for use inside objective
        expert_states_jax = jnp.array(expert_states)
        expert_actions_jax = jnp.array(expert_actions)

        # Initial state distribution (needed for simple divergence)
        initial_dist = self._compute_initial_distribution(panel, n_states)
        expert_state_coverage = len(np.unique(expert_states)) / max(n_states, 1)
        expert_state_action_coverage = len(
            set(zip(expert_states.tolist(), expert_actions.tolist(), strict=False))
        ) / max(n_states * n_actions, 1)

        # Transitions as float64 for numerical precision
        trans_f64 = jnp.asarray(transitions, dtype=jnp.float64)

        # Setup Q parameterization
        if self.config.q_type == "neural":
            return self._optimize_neural(
                panel, utility, problem, transitions,
                expert_states_jax, expert_actions_jax,
                trans_f64, initial_dist, n_states, n_actions,
                sigma, gamma, alpha,
                expert_state_coverage, expert_state_action_coverage,
                start_time,
            )
        if self.config.q_type == "linear":
            if isinstance(utility, ActionDependentReward):
                fm = np.asarray(utility.feature_matrix, dtype=np.float64)
                feature_matrix = jnp.array(fm)
                n_params = feature_matrix.shape[2]
            elif isinstance(utility, LinearReward):
                sf = np.asarray(utility.state_features, dtype=np.float64)
                sf_jax = jnp.array(sf)
                feature_matrix = jnp.broadcast_to(
                    sf_jax[:, None, :], (n_states, n_actions, sf_jax.shape[1])
                )
                n_params = sf_jax.shape[1]
            else:
                raise TypeError(f"Unsupported utility type for linear q_type: {type(utility)}")

            if initial_params is not None:
                theta_init = np.asarray(initial_params, dtype=np.float64)
            else:
                theta_init = np.zeros(n_params)
        else:
            # Tabular: free Q(s,a) matrix
            feature_matrix = None
            n_params = n_states * n_actions
            if initial_params is not None:
                theta_init = np.asarray(initial_params, dtype=np.float64)
            else:
                theta_init = np.zeros(n_params)

        divergence = self.config.divergence

        def objective(theta):
            """Compute IQ-Learn objective (JAX, differentiable)."""
            # Build Q table
            if feature_matrix is not None:
                Q = jnp.einsum("sak,k->sa", feature_matrix, theta)
            else:
                Q = theta.reshape(n_states, n_actions)

            # V*(s) = sigma * logsumexp(Q(s,:) / sigma)
            V_star = sigma * jax.scipy.special.logsumexp(Q / sigma, axis=1)

            # E_{s'~P(.|s,a)}[V*(s')] = sum_{s'} P(s'|s,a) V*(s')
            # transitions shape: (A, S, S') -> einsum to get (A, S)
            EV = jnp.einsum("ast,t->as", trans_f64, V_star).T  # (S, A)

            # Temporal difference: Q(s,a) - gamma * E[V*(s')]
            td = Q - gamma * EV

            # Expert terms
            Q_expert = Q[expert_states_jax, expert_actions_jax]
            V_expert = V_star[expert_states_jax]

            if divergence == "chi2":
                td_expert = td[expert_states_jax, expert_actions_jax]
                loss = -(Q_expert - V_expert).mean() + (1.0 / (4 * alpha)) * (td_expert**2).mean()
            else:
                td_expert = td[expert_states_jax, expert_actions_jax]
                loss = -td_expert.mean() + (1 - gamma) * jnp.dot(initial_dist, V_star)

            return loss

        obj_and_grad = jax.value_and_grad(objective)

        def objective_and_gradient(theta_np):
            """Wrapper for scipy: numpy in, numpy out."""
            theta_jax = jnp.array(theta_np)
            loss, grad = obj_and_grad(theta_jax)
            return float(loss), np.asarray(grad, dtype=np.float64)

        # Optimize
        if self.config.optimizer == "L-BFGS-B":
            from econirl.core.optimizer import minimize_lbfgsb
            result_opt = minimize_lbfgsb(
                objective,
                jnp.array(theta_init, dtype=jnp.float64),
                maxiter=self.config.max_iter,
                tol=self.config.convergence_tol,
                verbose=self.config.verbose,
                desc="IQ-Learn L-BFGS-B",
            )
            theta_opt = result_opt.x
            converged = result_opt.success
            num_iterations = result_opt.nit
            num_fevals = result_opt.nfev
            message = result_opt.message
            final_obj = result_opt.fun
        else:
            # Adam optimizer (manual implementation)
            theta_np = theta_init.copy()
            m = np.zeros_like(theta_np)
            v = np.zeros_like(theta_np)
            lr = self.config.learning_rate
            beta1, beta2, eps = 0.9, 0.999, 1e-8
            loss_history = []
            grad_norm = float("inf")

            from tqdm import tqdm
            pbar = tqdm(
                range(1, self.config.max_iter + 1),
                desc="IQ-Learn Adam",
                disable=not self.config.verbose,
                leave=True,
            )
            for t in pbar:
                obj, grad_np = objective_and_gradient(theta_np)
                loss_history.append(obj)
                pbar.set_postfix({"loss": f"{obj:.4f}", "|g|": f"{grad_norm:.2e}"})

                m = beta1 * m + (1 - beta1) * grad_np
                v = beta2 * v + (1 - beta2) * grad_np**2
                m_hat = m / (1 - beta1**t)
                v_hat = v / (1 - beta2**t)
                theta_np = theta_np - lr * m_hat / (np.sqrt(v_hat) + eps)

                grad_norm = float(np.linalg.norm(grad_np))
                if grad_norm < self.config.convergence_tol:
                    break

            theta_opt = jnp.array(theta_np)
            converged = grad_norm < self.config.convergence_tol
            num_iterations = t
            num_fevals = t
            message = "Converged" if converged else "Max iterations reached"
            final_obj = loss_history[-1] if loss_history else float("nan")

        # Extract results from optimal Q
        if feature_matrix is not None:
            Q_table = jnp.einsum("sak,k->sa", feature_matrix, theta_opt).astype(jnp.float32)
        else:
            Q_table = theta_opt.reshape(n_states, n_actions).astype(jnp.float32)

        # Policy: softmax(Q/sigma)
        policy = jax.nn.softmax(Q_table / sigma, axis=1)

        # Value function: V*(s) = sigma * logsumexp(Q/sigma)
        V = sigma * jax.scipy.special.logsumexp(Q_table / sigma, axis=1)

        # Reward via inverse Bellman: r(s,a) = Q(s,a) - gamma * E[V*(s')]
        transitions_f32 = jnp.asarray(transitions, dtype=jnp.float32)
        EV = jnp.einsum("ast,t->as", transitions_f32, V).T
        reward_table = Q_table - gamma * EV

        # Log-likelihood
        log_probs = jax.nn.log_softmax(Q_table / sigma, axis=1)
        ll = float(log_probs[expert_states_jax, expert_actions_jax].sum())

        # Project reward onto feature space for structural parameters.
        reward_params, projected_reward_matrix = self._project_reward_to_utility_basis(
            utility,
            n_states,
            n_actions,
            reward_table,
        )

        # Parameters to return
        if self.config.q_type == "linear":
            parameters = theta_opt.astype(jnp.float32)
        elif reward_params is not None:
            parameters = reward_params.astype(jnp.float32)
        else:
            parameters = reward_table.flatten().astype(jnp.float32)

        optimization_time = time.time() - start_time

        return EstimationResult(
            parameters=parameters,
            log_likelihood=ll,
            value_function=V,
            policy=policy,
            hessian=None,
            converged=converged,
            num_iterations=num_iterations,
            num_function_evals=num_fevals,
            message=message,
            optimization_time=optimization_time,
            metadata={
                "q_type": self.config.q_type,
                "divergence": self.config.divergence,
                "alpha": self.config.alpha,
                "q_table": np.asarray(Q_table).tolist(),
                "reward_matrix": np.asarray(reward_table).tolist(),
                "reward_table": np.asarray(reward_table).tolist(),
                "raw_bellman_reward_table": np.asarray(reward_table).tolist(),
                "projected_reward_matrix": (
                    np.asarray(projected_reward_matrix).tolist()
                    if projected_reward_matrix is not None
                    else None
                ),
                "reward_params": (
                    np.asarray(reward_params).tolist()
                    if reward_params is not None
                    else None
                ),
                "counterfactual_reward_source": "raw_bellman_reward_table",
                "final_objective": final_obj,
                "expert_state_coverage": float(expert_state_coverage),
                "expert_state_action_coverage": float(expert_state_action_coverage),
            },
        )