# Multi-Layer Perceptron (MLP) actors.

from os import PathLike
from .base_actor import BaseActor, BaseContinuousDeterministicActor
from utils.types import *
from utils.noise import OrnsteinUhlenbeckNoise
from models.torch_models import get_activation
import numpy as np
import torch


class BaseMlpActor(BaseActor):
    """
    Abstract base Multi-Layer Perceptron (MLP) actor.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_layer_sizes: list[int],
        hidden_activation: str,
        output_activation: str,
        device: torch.device | str,
        layer_norm: bool,
        seed: int | None
    ) -> None:
        """
        Initialize the MLP actor.
        """
        super().__init__()
        
        self.rng = torch.Generator(device="cpu")
        self.set_seed(seed)

        self.input_size = input_size
        self.output_size = output_size
        self.hidden_layer_sizes = hidden_layer_sizes
        self.hidden_activation = get_activation(hidden_activation)
        self._hidden_activation_name = hidden_activation
        self.output_activation = get_activation(output_activation)
        self._output_activation_name = output_activation
        self._device = torch.device(device) if isinstance(device, str) else device
        self.layer_norm = layer_norm
        self.seed = seed
        
        self.model = torch.nn.Sequential()
        
        layer_sizes = [input_size] + hidden_layer_sizes
        
        # Inner layers
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            self.model.add_module(f"layer_{i}", torch.nn.Linear(in_size, out_size, device=self.device))
            if layer_norm:
                self.model.add_module(f"layer_norm_{i}", torch.nn.LayerNorm(out_size, device=self.device))
            self.model.add_module(f"activation_{i}", self.hidden_activation)
            
        # Output layer
        self.model.add_module("output_layer", torch.nn.Linear(layer_sizes[-1], output_size, device=self.device))
        self.model.add_module("output_activation", self.output_activation)


    def vectorize(self) -> np.ndarray:
        return torch.cat([p.detach().flatten() for p in self.parameters()]).cpu().numpy()


    def set_parameters(self, parameters: np.ndarray) -> None:
        torch.nn.utils.vector_to_parameters(torch.from_numpy(parameters).to(device=self.device, dtype=self.dtype), self.parameters())


    def soft_update_towards_sibling(self: TorchBasedModel, sibling: TorchBasedModel, strength_of_update: float) -> None:
        with torch.no_grad():
            for p_self, p_other in zip(self.parameters(), sibling.parameters()):
                p_self.copy_((1 - strength_of_update) * p_self + strength_of_update * p_other)


class DiscreteMlpActor(BaseMlpActor):
    """
    Discrete Multi-Layer Perceptron (MLP) actor.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_layer_sizes: list[int],
        hidden_activation: str = "ReLU",
        device: torch.device | str = "cpu",
        layer_norm: bool = False,
        random_action_probability_when_exploring: float = 0.01,
        seed: int | None = None
    ) -> None:
        super().__init__(input_size, output_size, hidden_layer_sizes, hidden_activation, "Softmax", device, layer_norm, seed)
        self.epsilon_greedy = random_action_probability_when_exploring
        self.np_rng = np.random.default_rng(seed)


    def act(self, observation: np.ndarray, explore: bool = False) -> int:
        if explore:
            if self.np_rng.random() < self.epsilon_greedy:
                return self.np_rng.integers(0, self.output_size)

        with torch.no_grad():
            observation_torch = torch.tensor(observation, device=self.device).view(1, self.input_size)
            action = int(self.act_batch(observation_torch).detach().cpu().item())
            return action


    def act_batch(self, observations: torch.Tensor) -> torch.Tensor:
        probabilities = self.forward(observations.to(self.device).view(-1, self.input_size))
        actions = torch.multinomial(probabilities, num_samples=1, generator=self.rng).view(-1)
        return actions


    def set_exploration_parameters(self, *, random_action_probability_when_exploring: float) -> None:
        self.epsilon_greedy = random_action_probability_when_exploring
        
        
    def set_seed(self, seed: int | None) -> None:
        super().set_seed(seed)
        self.np_rng = np.random.default_rng(seed)


    def create_sibling(self, seed: int | None = None) -> "DiscreteMlpActor":
        return DiscreteMlpActor(
            input_size=self.input_size,
            output_size=self.output_size,
            hidden_layer_sizes=self.hidden_layer_sizes,
            hidden_activation=self._hidden_activation_name,
            device=self.device,
            layer_norm=self.layer_norm,
            random_action_probability_when_exploring=self.epsilon_greedy,
            seed=seed
        )


class ContinuousMlpActor(BaseMlpActor, BaseContinuousDeterministicActor):
    """
    Continuous Multi-Layer Perceptron (MLP) actor.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_layer_sizes: list[int],
        hidden_activation: str = "ReLU",
        low_constraints: torch.Tensor | np.ndarray | float | None = None,
        high_constraints: torch.Tensor | np.ndarray | float | None = None,
        device: torch.device | str = "cpu",
        layer_norm: bool = False,
        exploration_noise_std: float = 0.2,
        seed: int | None = None
    ) -> None:
        if low_constraints is None or high_constraints is None:
            super().__init__(input_size, output_size, hidden_layer_sizes, hidden_activation, "Linear", device, layer_norm, seed)
        else:
            super().__init__(input_size, output_size, hidden_layer_sizes, hidden_activation, "Tanh", device, layer_norm, seed)

        if low_constraints is not None:
            if not isinstance(low_constraints, torch.Tensor):
                low_constraints = torch.tensor(low_constraints, device=self.device)
            else:
                low_constraints = low_constraints.to(self.device)

        if high_constraints is not None:
            if not isinstance(high_constraints, torch.Tensor):
                high_constraints = torch.tensor(high_constraints, device=self.device)
            else:
                high_constraints = high_constraints.to(self.device)

        self._constraints = Constraints(lower=low_constraints, upper=high_constraints)

        self.exploration_noise_std = exploration_noise_std


    def _get_exploratory_noise(self) -> torch.Tensor:
        return self.exploration_noise_std * torch.randn(self.output_size, device="cpu", generator=self.rng)


    def act(self, observation: np.ndarray, explore: bool = False) -> np.ndarray:
        with torch.no_grad():
            observation_torch = torch.tensor(observation, device=self.device).view(1, self.input_size)
            action = self.act_batch(observation_torch).detach().view(self.output_size).cpu()

            if explore:
                exploratory_noise = self._get_exploratory_noise().view_as(action)
                action = (action + exploratory_noise)
                action = action.clamp(min=self.constraints.lower_cpu, max=self.constraints.upper_cpu)
            
            return action.numpy()


    def act_batch(self, observations: torch.Tensor) -> torch.Tensor:
        low, high = self.constraints.lower, self.constraints.upper
        observations = observations.view(-1, self.input_size)
        unscaled_action = self.forward(observations)

        if low is not None and high is not None:
            action = (high - low) * (unscaled_action + 1) / 2 + low
        else:
            action = unscaled_action.clamp(min=low, max=high)

        return action


    def create_sibling(self, seed: int | None = None) -> "ContinuousMlpActor":
        return ContinuousMlpActor(
            input_size=self.input_size,
            output_size=self.output_size,
            hidden_layer_sizes=self.hidden_layer_sizes,
            hidden_activation=self._hidden_activation_name,
            low_constraints=self.constraints.lower,
            high_constraints=self.constraints.upper,
            device=self.device,
            layer_norm=self.layer_norm,
            exploration_noise_std=self.exploration_noise_std,
            seed=seed
        )


class ContinuousMlpActorWithOUNoise(ContinuousMlpActor):
    """
    Continuous Multi-Layer Perceptron (MLP) actor with Ornstein-Uhlenbeck noise for exploration.
    """
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_layer_sizes: list[int],
        hidden_activation: str = "ReLU",
        low_constraints: torch.Tensor | np.ndarray | float | None = None,
        high_constraints: torch.Tensor | np.ndarray | float | None = None,
        device: torch.device | str = "cpu",
        layer_norm: bool = False,
        exploration_noise_std: float = 0.2,
        exploration_noise_theta: float = 0.15,
        seed: int | None = None
    ) -> None:
        super().__init__(
            input_size=input_size,
            output_size=output_size,
            hidden_layer_sizes=hidden_layer_sizes,
            hidden_activation=hidden_activation,
            low_constraints=low_constraints,
            high_constraints=high_constraints,
            device=device,
            layer_norm=layer_norm,
            exploration_noise_std=exploration_noise_std,
            seed=seed
        )
        
        self.exploration_noise_theta = exploration_noise_theta
        self.noise_generator = OrnsteinUhlenbeckNoise(
            shape=self.output_size,
            mu=0.,
            sigma=self.exploration_noise_std,
            theta=self.exploration_noise_theta,
            device="cpu",
            rng=self.rng
        )
        
        
    def reset(self) -> None:
        self.noise_generator.reset()
        super().reset()
        

    def set_exploration_parameters(self, *, exploration_noise_std: float) -> None:
        super().set_exploration_parameters(exploration_noise_std=exploration_noise_std)
        self.noise_generator.sigma = exploration_noise_std


    def _get_exploratory_noise(self) -> torch.Tensor:
        return self.noise_generator.sample()


    def create_sibling(self, seed: int | None = None) -> "ContinuousMlpActorWithOUNoise":
        return ContinuousMlpActorWithOUNoise(
            input_size=self.input_size,
            output_size=self.output_size,
            hidden_layer_sizes=self.hidden_layer_sizes,
            hidden_activation=self._hidden_activation_name,
            low_constraints=self.constraints.lower,
            high_constraints=self.constraints.upper,
            device=self.device,
            layer_norm=self.layer_norm,
            exploration_noise_std=self.exploration_noise_std,
            exploration_noise_theta=self.exploration_noise_theta,
            seed=seed
        )

    def load(self, path: str | PathLike, device: torch.device | str = "cpu") -> None:
        """
        Load model parameters from a file. The model must have the same architecture
        as the one used to save the parameters and it should be on the same device.
        """
        super().load(path, device)
        
        noise_state = self.noise_generator.state
        self.noise_generator = OrnsteinUhlenbeckNoise(
            shape=self.output_size,
            mu=0.,
            sigma=self.exploration_noise_std,
            theta=self.exploration_noise_theta,
            device="cpu",
            rng=self.rng
        )
        self.noise_generator.state = noise_state
