# Utilities and base class for PyTorch-based models

from os import PathLike
import copy
from abc import ABC, abstractmethod
from typing import TypeVar
import torch


def get_activation(name: str) -> torch.nn.Module:
    activations = {
        "relu": torch.nn.ReLU(),
        "tanh": torch.nn.Tanh(),
        "sigmoid": torch.nn.Sigmoid(),
        "leakyrelu": torch.nn.LeakyReLU(),
        "elu": torch.nn.ELU(),
        "selu": torch.nn.SELU(),
        "gelu": torch.nn.GELU(),
        "softmax": torch.nn.Softmax(dim=-1),
        "linear": torch.nn.Identity(),
    }
    return activations[name.lower()]


def get_optimizer_class(name: str) -> type[torch.optim.Optimizer]:
    optimizers = {
        "adam": torch.optim.Adam,
        "adamw": torch.optim.AdamW,
        "sgd": torch.optim.SGD,
        "rmsprop": torch.optim.RMSprop,
        "adagrad": torch.optim.Adagrad,
        "adamax": torch.optim.Adamax,
    }
    return optimizers[name.lower()]


TorchBasedModel = TypeVar("TorchBasedModel", bound="BaseTorchModel")


class BaseTorchModel(ABC, torch.nn.Module):
    """
    Abstract base class for PyTorch models (for actors, critics, etc.).
    """
    model: torch.nn.Module # The PyTorch model providing the main calculation the module is based on
    rng: torch.Generator | None = None # Random number generator for the cases when the model uses some stochasticity (not initialization or training, just behavioral stochasticity, like actor exploration)

    @property
    def device(self) -> torch.device:
        """
        Return the device on which the model is located.
        """
        param = next(self.parameters(), None)
        if param is None:
            if hasattr(self, "_device"):
                if isinstance(self._device, (str, int, torch.device)):
                    return torch.device(self._device)
                
            return torch.device('cpu')

        return param.device

    @property
    def dtype(self) -> torch.dtype:
        """
        Return the dtype used by the model's tensors.
        Falls back to any buffer's dtype, then to the global default dtype.
        """
        param = next(self.parameters(), None)
        if param is not None:
            return param.dtype

        buffer = next(self.buffers(), None)
        if buffer is not None:
            return buffer.dtype

        return torch.get_default_dtype()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the model.
        """
        return self.model(x.to(device=self.device, dtype=self.dtype))

    def set_seed(self, seed: int | None) -> None:
        """
        Set the random seed for the model.
        """
        if seed is not None:
            torch.manual_seed(seed)
            if self.rng is not None:
                self.rng.manual_seed(seed)
    
    @abstractmethod
    def create_sibling(self: TorchBasedModel, seed: int | None = None) -> TorchBasedModel:
        """
        Create a sibling model based on the current model (possibly based on the given seed).

        The sibling model has the same structure but different parameters, possibly sharing some parts with the current model.
        """
        pass
    
    def copy(self: TorchBasedModel) -> TorchBasedModel:
        """
        Create a copy of the current model with the same parameters.
        """
        memo = {}
        
        # Specifically deal with the torch.Generator possibly saved in self.rng and possibly referenced
        # in submodules, because torch.Generator cannot be pickled and therefore cannot be deepcopied.
        if self.rng is not None:
            new_rng = torch.Generator(device=self.rng.device)
            new_rng.manual_seed(self.rng.initial_seed())
            new_rng.set_state(self.rng.get_state())
            memo[id(self.rng)] = new_rng

        return copy.deepcopy(self, memo)

    @abstractmethod
    def soft_update_towards_sibling(self: TorchBasedModel, sibling: TorchBasedModel, strength_of_update: float) -> None:
        """
        Soft update the parameters towards another model (which should be a sibling).

        The strength_of_update parameter controls how much the parameters are updated towards the sibling's parameters.
        A value of 1.0 means a full update (copy of the sibling), while a value of 0.0 means no update at all.
        """
        pass
    
    def copy_parameters_from_sibling(self: TorchBasedModel, sibling: TorchBasedModel) -> None:
        """
        Copy the parameters from another model (which should be a sibling).
        """
        self.soft_update_towards_sibling(sibling, 1.0)

    def save(self, path: str | PathLike) -> None:
        """
        Save the model parameters and all attributes to a file.
        """
        model_to_save = self
        
        # NOTE: The rng attribute should be deleted from what is to be saved, since torch.Generator cannot
        # be pickled, hence it cannot be saved directly - and it doesn't even need to be saved. Hovewer,
        # if the rng is also referenced somewhere else deeper in the model (e.g. inside its attributes),
        # then the loading needs to be adjusted as well, because all the references to the rng will be None.
        if self.rng is not None:
            model_to_save = copy.deepcopy(model_to_save, {id(self.rng): None})

        # The torch module related stuff needs to be excluded, since those will be saved via the state_dict method
        exclude = set(model_to_save._parameters.keys()) | set(model_to_save._buffers.keys()) | set(model_to_save._modules.keys())

        custom_attributes = {k: v for k, v in model_to_save.__dict__.items() if k not in exclude}
        data = {
            "state_dict": model_to_save.state_dict(),
            "attributes": custom_attributes,
        }
        torch.save(data, path)

    def load(self, path: str | PathLike, device: torch.device | str = "cpu") -> None:
        """
        Load model parameters from a file. The model must have the same architecture
        as the one used to save the parameters and it should be on the same device.
        """
        
        # NOTE: In case the model has any attributes that depend on the self.rng,
        # those will need to be recreated and linked manually, or this function
        # needs to be overridden, because torch.Generator is not serializable /
        # picklable, and thus cannot be saved or loaded.

        # Carry over the RNG, if it exists, since in the saved data, only None is present for this attribute
        rng = self.rng
        
        data = torch.load(path, map_location=device, weights_only=False)
        self.__dict__.update(data["attributes"])
        self.load_state_dict(data["state_dict"])
        
        self.rng = rng
