# Abstract base classes for classical Reinforcement Learning (RL) algorithms.

from abc import ABC, abstractmethod
from utils.types import *
from models.actors.base_actor import BaseActor
from replay_buffers.base_replay_buffer import BaseReplayBuffer
from envs.base_env import BaseEnvironment


class BaseRL(ABC):
    """
    Abstract base class for classical Reinforcement Learning (RL) algorithms.
    """
    def __init__(
        self,
        environment: BaseEnvironment,
        actor: BaseActor,
        replay_buffer: BaseReplayBuffer,
        *,
        maximum_number_of_rollouts_per_training_iteration: int | None = 1,
        maximum_number_of_steps_per_training_iteration: int | None = None,
        seed: int | None = None,
        **hyperparameters: Any
    ) -> None:
        """
        Abstract base class for classical Reinforcement Learning (RL) algorithms.

        The provided actor is the one that will be trained by the RL algorithm.

        The replay buffer is used to store transitions encountered during the rollouts for usage in further training.
        (Can be shared with an EA and thus receive additional transitions from outer sources.)

        Hyperparameters:
        - `maximum_number_of_rollouts_per_training_iteration` (int | None): The maximum number of rollouts
        to be performed in each training iteration. (default: 1)
        - `maximum_number_of_steps_per_training_iteration` (int | None): The maximum number of steps to be taken
        in the environment in each training iteration. (default: None)
        - `seed` (int | None): The random seed for reproducibility. (default: None)
        """
        self._environment = environment
        self._actor = actor
        self._replay_buffer = replay_buffer
        
        assert maximum_number_of_rollouts_per_training_iteration is not None or \
            maximum_number_of_steps_per_training_iteration is not None, \
                "At least one of the two constraints (`maximum_number_of_rollouts_per_training_iteration` or " \
                "`maximum_number_of_steps_per_training_iteration`) on the length of the training iteration must be specified!"
                
        assert maximum_number_of_rollouts_per_training_iteration is None or \
            maximum_number_of_rollouts_per_training_iteration > 0, \
                "The maximum number of rollouts per training iteration must be positive or None!"
                
        assert maximum_number_of_steps_per_training_iteration is None or \
            maximum_number_of_steps_per_training_iteration > 0, \
                "The maximum number of steps per training iteration must be positive or None!"

        self.maximum_number_of_rollouts_per_training_iteration = maximum_number_of_rollouts_per_training_iteration
        self.maximum_number_of_steps_per_training_iteration = maximum_number_of_steps_per_training_iteration
        self._seed = seed

        hyperparameters.update({
            "maximum_number_of_rollouts_per_training_iteration": maximum_number_of_rollouts_per_training_iteration,
            "maximum_number_of_steps_per_training_iteration": maximum_number_of_steps_per_training_iteration,
            "seed": seed
        })
        self._hyperparameters = hyperparameters

    @property
    def hyperparameters(self) -> dict[str, Any]:
        """
        Return the hyperparameters of the RL algorithm.
        """
        return self._hyperparameters

    @property
    def environment(self) -> BaseEnvironment:
        """
        Return the environment used by the RL algorithm.
        """
        return self._environment

    @property
    def actor(self) -> BaseActor:
        """
        Return the RL actor.
        """
        return self._actor

    @property
    def replay_buffer(self) -> BaseReplayBuffer:
        """
        Return the replay buffer used by the RL algorithm.
        """
        return self._replay_buffer

    @property
    def seed(self) -> int | None:
        """
        Return the random seed for the RL algorithm.
        """
        return self._seed
    
    @abstractmethod
    def set_seed(self, seed: int | None) -> None:
        """
        Set the random seed for the RL algorithm and its every component that uses randomness.
        """
        self._seed = seed
        ...  # Use the seed whereever necessary.

    @abstractmethod
    def update_actor_based_on_ea(self, evolutionary_actor: BaseActor) -> None:
        """
        Update the RL actor based on the actor obtained from the evolutionary algorithm.
        """
        pass

    @abstractmethod
    def training_step(self) -> None:
        """
        One training (meta)step for the RL agent using the internal replay buffer. Can consist
        of multiple updates to the actor and critic networks, based on the current content of
        the associated replay buffer.
        
        Here also the exploration hyperparameters of the actor should be updated, if so required.
        """
        pass

    def training_iteration(self) -> int:
        """
        Carry out one iteration of training the RL agent using the associated replay buffer.
        One iteration consists of multiple rollouts in the environment (based on
        `maximum_number_of_rollouts_per_training_iteration` and `maximum_number_of_steps_per_training_iteration`
        hyperparameter values), collecting data into the replay buffer, and training the agent on the
        data contained in the replay buffer using training_step() function.

        Returns
        -------
            int
                The number of interactions with the environment during the training iteration.
        """
        
        interactions = 0
        remaining_steps_budget = self.maximum_number_of_steps_per_training_iteration

        rollouts_processed = 0
        while self.maximum_number_of_rollouts_per_training_iteration is None or \
            rollouts_processed < self.maximum_number_of_rollouts_per_training_iteration:
            
            _, rollout_length, _ = self.environment.evaluate(
                self.actor,
                self.replay_buffer,
                exploratory_actions_required=True,
                max_steps=remaining_steps_budget
            )
            
            interactions += rollout_length
            rollouts_processed += 1
            
            if self.maximum_number_of_steps_per_training_iteration is not None:
                remaining_steps_budget = self.maximum_number_of_steps_per_training_iteration - interactions
                
                if remaining_steps_budget <= 0:
                    break

        self.training_step()
        return interactions


class BaseOnpolicyRL(BaseRL):
    """
    Abstract base class for classical on-policy Reinforcement Learning (RL) algorithms.
    """
    def __init__(
        self,
        environment: BaseEnvironment,
        actor: BaseActor,
        replay_buffer: BaseReplayBuffer,
        *,
        maximum_number_of_rollouts_per_training_iteration: int | None = 1,
        maximum_number_of_steps_per_training_iteration: int | None = None,
        seed: int | None = None,
        **hyperparameters: Any
    ) -> None:
        """
        Abstract base class for classical on-policy Reinforcement Learning (RL) algorithms.
    
        Hyperparameters:
        - `maximum_number_of_rollouts_per_training_iteration` (int | None): The maximum number of rollouts
        to be performed in each training iteration. (default: 1)
        - `maximum_number_of_steps_per_training_iteration` (int | None): The maximum number of steps to be taken
        in the environment in each training iteration. (default: None)
        - `seed` (int | None): Random seed for reproducibility. (default: None)
        """
        super().__init__(
            environment,
            actor,
            replay_buffer,
            maximum_number_of_rollouts_per_training_iteration=maximum_number_of_rollouts_per_training_iteration,
            maximum_number_of_steps_per_training_iteration=maximum_number_of_steps_per_training_iteration,
            seed=seed,
            **hyperparameters
        )

    def training_iteration(self) -> int:
        """
        Carry out one iteration of on-policy training the RL agent using the associated replay buffer.
        One iteration consists of clearing the associated replay buffer, multiple rollouts
        in the environment (based on `number_of_rollouts_per_training_iteration` hyperparameter value),
        collecting data into the initially empty replay buffer, and training the agent on the data
        contained in the replay buffer using training_step() function.

        Returns
        -------
            int
                The number of interactions with the environment during the training iteration.
        """
        self.replay_buffer.clear()
        return super().training_iteration()


class BaseOffpolicyRL(BaseRL):
    """
    Abstract base class for classical off-policy Reinforcement Learning (RL) algorithms.
    """
    def __init__(
        self,
        environment: BaseEnvironment,
        actor: BaseActor,
        replay_buffer: BaseReplayBuffer,
        *,
        maximum_number_of_rollouts_per_training_iteration: int | None = 1,
        maximum_number_of_steps_per_training_iteration: int | None = None,
        rollout_steps_per_training_step: int = 1,
        number_of_collected_transitions_before_training: int | None = None,
        seed: int | None = None,
        **hyperparameters: Any
    ) -> None:
        """
        Abstract base class for classical off-policy Reinforcement Learning (RL) algorithms.
        
        Hyperparameters:
        - `maximum_number_of_rollouts_per_training_iteration` (int | None): The maximum number of rollouts
        to be performed in each training iteration. (default: 1)
        - `maximum_number_of_steps_per_training_iteration` (int | None): The maximum number of steps to be taken
        in the environment in each training iteration. (default: None)
        - `rollout_steps_per_training_step` (int): The number of rollout steps taken in the environment
        before each training step. (default: 1)
        - `number_of_collected_transitions_before_training` (int | None): The minimal number of transitions
        that has to be gathered in the replay buffer before any training occurs. (default: None)
        - `seed` (int | None): Random seed for reproducibility. (default: None)
        """
        super().__init__(
            environment,
            actor,
            replay_buffer,
            maximum_number_of_rollouts_per_training_iteration=maximum_number_of_rollouts_per_training_iteration,
            maximum_number_of_steps_per_training_iteration=maximum_number_of_steps_per_training_iteration,
            rollout_steps_per_training_step=rollout_steps_per_training_step,
            number_of_collected_transitions_before_training=number_of_collected_transitions_before_training,
            seed=seed,
            **hyperparameters
        )

        self.rollout_steps_per_training_step = rollout_steps_per_training_step
        self.number_of_collected_transitions_before_training = number_of_collected_transitions_before_training
        
    @property
    def has_enough_data_for_training(self) -> bool:
        """
        Return whether the replay buffer has enough data for training (based on the
        `number_of_collected_transitions_before_training` hyperparameter value).
        """
        return self.number_of_collected_transitions_before_training is None or \
            len(self.replay_buffer) >= self.number_of_collected_transitions_before_training

    def training_iteration(self) -> int:
        """
        Carry out one iteration of training the RL agent using the associated replay buffer.
        One iteration consists of multiple rollouts in the environment (based on
        `maximum_number_of_rollouts_per_training_iteration` and `maximum_number_of_steps_per_training_iteration`
        hyperparameter values), collecting data into the replay buffer, and after each predefined
        number of steps in the environment (based on `rollout_steps_per_training_step` hyperparameter value)
        and at the end of the iteration training the agent on the data contained in the replay buffer
        using training_step() function (if there are enough samples in the replay buffer, as given by the
        `number_of_collected_transitions_before_training` hyperparameter).

        Returns
        -------
            int
                The number of interactions with the environment during the training iteration.
        """
        interactions = 0
        remaining_steps_budget = self.maximum_number_of_steps_per_training_iteration

        rollouts_processed = 0
        while self.maximum_number_of_rollouts_per_training_iteration is None or \
            rollouts_processed < self.maximum_number_of_rollouts_per_training_iteration:
            
            done = False
            cumulative_reward, episode_length = 0.0, 0

            observation, _ = self.environment.reset()
            self.actor.reset()

            while not done and \
                (self.environment.timestep_limit is None or episode_length < self.environment.timestep_limit) and \
                    (remaining_steps_budget is None or remaining_steps_budget > 0):
                        
                action = self.actor.act(observation, explore=True)
                next_observation, reward, terminated, truncated, _ = self.environment.step(action)
                cumulative_reward += reward
                episode_length += 1
                if remaining_steps_budget is not None:
                    remaining_steps_budget -= 1

                transition = Transition(
                    observation=observation,
                    action=action,
                    reward=reward,
                    next_observation=next_observation,
                    terminated=terminated
                )
                self.replay_buffer.add(transition)

                observation = next_observation
                done = terminated or truncated

                if (interactions + episode_length) % self.rollout_steps_per_training_step == 0 and \
                    self.has_enough_data_for_training:
                    self.training_step()
            
            interactions += episode_length
            rollouts_processed += 1
            
            if remaining_steps_budget is not None and remaining_steps_budget <= 0: break

        if self.has_enough_data_for_training:
            self.training_step()
            
        return interactions
