# Multi-Layer Perceptron (MLP) critics.

from .base_critic import BaseStateCritic, BaseStateActionCritic
from utils.types import *
from models.torch_models import get_activation, BaseTorchModel
import numpy as np
import torch


class BaseMlpCritic(BaseTorchModel):
    """
    Abstract base class for Multi-Layer Perceptron (MLP) critics.
    """
    def __init__(
        self,
        input_size: int,
        hidden_layer_sizes: list[int],
        hidden_activation: str = "ReLU",
        device: torch.device | str = "cpu",
        layer_norm: bool = False,
        seed: int | None = None
    ) -> None:
        """
        Initialize the MLP critic.
        """
        super().__init__()
        
        self.set_seed(seed)
        
        self.input_size = input_size
        self.hidden_layer_sizes = hidden_layer_sizes
        self.hidden_activation = get_activation(hidden_activation)
        self._hidden_activation_name = hidden_activation
        self._device = torch.device(device) if isinstance(device, str) else device
        self.layer_norm = layer_norm
        
        self.model = torch.nn.Sequential()
        
        layer_sizes = [input_size] + hidden_layer_sizes
        
        # Inner layers
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            self.model.add_module(f"layer_{i}", torch.nn.Linear(in_size, out_size, device=self.device))
            if layer_norm:
                self.model.add_module(f"layer_norm_{i}", torch.nn.LayerNorm(out_size, device=self.device))
            self.model.add_module(f"activation_{i}", self.hidden_activation)
            
        # Output layer
        self.model.add_module("output_layer", torch.nn.Linear(layer_sizes[-1], 1, device=self.device))


    def soft_update_towards_sibling(self: TorchBasedModel, sibling: TorchBasedModel, strength_of_update: float) -> None:
        with torch.no_grad():
            for p_self, p_other in zip(self.parameters(), sibling.parameters()):
                p_self.copy_((1 - strength_of_update) * p_self + strength_of_update * p_other)
    
    
class MlpStateCritic(BaseMlpCritic, BaseStateCritic):
    """
    MLP critic for the state (observation) values.
    """
    def __init__(
        self,
        observation_size: int,
        hidden_layer_sizes: list[int],
        hidden_activation: str = "ReLU",
        device: torch.device | str = "cpu",
        layer_norm=False,
        seed: int | None = None
    ) -> None:
        super().__init__(observation_size, hidden_layer_sizes, hidden_activation, device, layer_norm, seed)
        self.observation_size = observation_size
            
            
    def create_sibling(self, seed: int | None = None) -> "MlpStateCritic":
        return MlpStateCritic(
            observation_size=self.observation_size,
            hidden_layer_sizes=self.hidden_layer_sizes,
            hidden_activation=self._hidden_activation_name,
            device=self.device,
            layer_norm=self.layer_norm,
            seed=seed
        )
        
    
    def evaluate(self, observation: np.ndarray) -> float:
        with torch.no_grad():
            observation_torch = torch.tensor(observation, device=self.device).view(1, self.observation_size)
            return float(self.evaluate_batch(observation_torch).cpu().item())
    
    def evaluate_batch(self, observation_batch: torch.Tensor) -> torch.Tensor:
        observation_batch = observation_batch.to(self.device).view(-1, self.observation_size)
        return self.forward(observation_batch)
    
    
class MlpStateActionCritic(BaseMlpCritic, BaseStateActionCritic):
    """
    MLP critic for the state-action (observation-action) values.
    """
    def __init__(
        self,
        observation_size: int,
        action_size: int,
        hidden_layer_sizes: list[int],
        hidden_activation: str = "ReLU",
        device: torch.device | str = "cpu",
        layer_norm=False,
        seed: int | None = None
    ) -> None:
        super().__init__(observation_size + action_size, hidden_layer_sizes, hidden_activation, device, layer_norm, seed)
        self.observation_size = observation_size
        self.action_size = action_size
            
            
    def create_sibling(self, seed: int | None = None) -> "MlpStateActionCritic":
        return MlpStateActionCritic(
            observation_size=self.observation_size,
            action_size=self.action_size,
            hidden_layer_sizes=self.hidden_layer_sizes,
            hidden_activation=self._hidden_activation_name,
            device=self.device,
            layer_norm=self.layer_norm,
            seed=seed
        )
        
    
    def evaluate(self, observation: np.ndarray, action: np.ndarray) -> float:
        with torch.no_grad():
            observation_torch = torch.tensor(observation, device=self.device).view(1, self.observation_size)
            action_torch = torch.tensor(action, device=self.device).view(1, self.action_size)
            return float(self.evaluate_batch(observation_torch, action_torch).cpu().item())
    
    
    def evaluate_batch(self, observation_batch: torch.Tensor, action_batch: torch.Tensor) -> torch.Tensor:
        observation_batch = observation_batch.to(self.device).view(-1, self.observation_size)
        action_batch = action_batch.to(self.device).view(-1, self.action_size)
        return self.forward(torch.cat((observation_batch, action_batch), dim=1))
