# Basic sequential loop for Evolutionary Reinforcement Learning (ERL) algorithms.

from typing import Type, Any
from models.actors.base_actor import BaseActor
from replay_buffers.base_replay_buffer import BaseReplayBuffer
from envs.base_env import BaseEnvironment
from rl.base_rl import BaseRL
from ea.base_ea import BaseEA


class BasicErlLoop:
    """
    Basic sequential loop for Evolutionary Reinforcement Learning (ERL) algorithms.
    """
    def __init__(
        self,
        rl_algorithm_class: Type[BaseRL],
        rl_hyperparameters: dict[str, Any],
        ea_algorithm_class: Type[BaseEA],
        ea_hyperparameters: dict[str, Any],
        environment: BaseEnvironment,
        actor: BaseActor,
        replay_buffer: BaseReplayBuffer,
        *, # TODO - verbose argument na vypisování do konzole (byť to by možná mohl být argument `train` metody), log-related argumenty na logování a ukládání modelů
        num_of_test_episodes: int = 5,
        number_of_training_interactions: int = int(1e6),
        seed: int | None = None
    ) -> None: # TODO - dokončit dokumentaci
        """
        Basic sequential loop for Evolutionary Reinforcement Learning (ERL) algorithms.

        Hyperparameters:
        - `num_of_test_episodes` (int): Number of episodes for evaluation after each iteration (for logging purposes only). (default: 5)
        - `number_of_training_interactions` (int): Number of interactions with the environment that will
        be performed during the training by both the algorithms altogether. (default: 1e6)
        - `seed` (int | None): Random seed for reproducibility. (default: None)
        """
        rl_hyperparameters["seed"] = seed
        ea_hyperparameters["seed"] = seed

        self.rl_algorithm = rl_algorithm_class(
            environment=environment,
            actor=actor,
            replay_buffer=replay_buffer,
            **rl_hyperparameters
        )
        
        self.ea_algorithm = ea_algorithm_class(
            environment=environment,
            template_actor=actor,
            replay_buffer=replay_buffer,
            **ea_hyperparameters
        )
        
        self.num_of_test_episodes = num_of_test_episodes
        
        self.number_of_training_interactions = number_of_training_interactions
        self._seed = seed
        
        self._hyperparameters = dict(
            number_of_training_interactions=number_of_training_interactions,
            seed=seed,
            rl_hyperparameters=rl_hyperparameters,
            ea_hyperparameters=ea_hyperparameters
        )
        
    @property
    def rl_actor(self) -> BaseActor:
        """
        Return the RL actor.
        """
        return self.rl_algorithm.actor
    
    @property
    def ea_actor(self) -> BaseActor:
        """
        Return the EA actor.
        """
        return self.ea_algorithm.evolutionary_actor
    
    @property
    def hyperparameters(self) -> dict[str, Any]:
        """
        Return the hyperparameters of the ERL algorithm.
        """
        return self._hyperparameters
    
    @property
    def rl_hyperparameters(self) -> dict[str, Any]:
        """
        Return the hyperparameters of the used gradient algorithm.
        """
        return self.hyperparameters["rl_hyperparameters"]
    
    @property
    def ea_hyperparameters(self) -> dict[str, Any]:
        """
        Return the hyperparameters of the used evolutionary algorithm.
        """
        return self.hyperparameters["ea_hyperparameters"]

    @property
    def test_environment(self) -> BaseEnvironment:
        """
        Return the test environment.
        """
        return self.ea_algorithm.environment

    def set_seed(self, seed: int | None):
        self._seed = seed
        self.rl_algorithm.set_seed(seed)
        self.ea_algorithm.set_seed(seed)
        
    
    def train(self) -> None:
        """
        Train the actor using both the evolutionary and reinforcement learning algorithms for a number of interactions specified by the hyperparameters.
        """
        interactions = 0
        iteration = 0
        while interactions < self.number_of_training_interactions:
            iteration += 1
            
            ea_interactions = self.ea_algorithm.evolve()
            interactions += ea_interactions
            
            rl_interactions = self.rl_algorithm.training_iteration()
            interactions += rl_interactions

            self.rl_algorithm.update_actor_based_on_ea(self.ea_algorithm.evolutionary_actor)

            print(f"Iteration {iteration}; Cumulative interactions: {interactions}; EA interactions this iteration: {ea_interactions}; RL interactions this iteration: {rl_interactions}")
            
            rl_mean_return, rl_mean_length = self.test_environment.multiepisode_evaluation(self.num_of_test_episodes, self.rl_actor, verbose=False)
            print(f"Iteration {iteration} - RL; Average test return: {rl_mean_return}, Average test runtime: {rl_mean_length}")
            
            ea_mean_return, ea_mean_length = self.test_environment.multiepisode_evaluation(self.num_of_test_episodes, self.ea_actor, verbose=False)
            print(f"Iteration {iteration} - EA; Average test return: {ea_mean_return}, Average test runtime: {ea_mean_length}")
            