# Abstract base class for environments.

from abc import ABC, abstractmethod
from utils.types import *
from typing import Iterator
from models.actors.base_actor import BaseActor
from replay_buffers.base_replay_buffer import BaseReplayBuffer


class BaseEnvironment(ABC):
    """
    Abstract base class for environments.
    """
    @property
    @abstractmethod
    def timestep_limit(self) -> int | None:
        """
        Return the maximum number of steps allowed in an episode.
        Returns None if there is no limit.
        """
        pass

    @abstractmethod
    def reset(self) -> tuple[ObservationType, dict[str, Any]]:
        """
        Reset the environment to an initial state.

        Returns
        -------
            ObservationType
                The initial observation.
            dict[str, Any]
                An info dictionary containing additional information about the initial state.
        """
        pass

    @abstractmethod
    def step(self, action: ActionType) -> tuple[ObservationType, float, bool, bool, dict[str, Any]]:
        """
        Take a step in the environment given an action.

        Returns
        -------
            ObservationType
                The new observation after taking the action.
            float
                The reward received after taking the action.
            bool
                Whether the episode has terminated.
            bool
                Whether the episode has been truncated.
            dict[str, Any]
                Additional information about the step and the new state.
        """
        pass

    @abstractmethod
    def set_seed(self, seed: int | None) -> None:
        """
        Set the random seed for the environment.
        """
        pass

    def evaluate(
        self,
        actor: BaseActor,
        replay_buffer: BaseReplayBuffer | None = None,
        exploratory_actions_required: bool = False,
        max_steps: int | None = None
    ) -> tuple[float, int, dict[str, Any]]:
        """
        Evaluate the given actor in the environment, saving the encountered transitions in the given replay buffer
        (if provided). If `exploratory_actions_required` is True, then informs the actor that it should use
        exploratory actions via the `explore` argument of the `actor.act(...)` method.
        If `max_steps` is provided, it limits the number of steps taken in the evaluation.

        Returns
        -------
            float
                The cumulative reward (fitness) achieved by the actor.
            int
                The length of the episode (number of steps taken).
            dict[str, Any]
                Additional information about the rollout / the episode.
        """
        done = False
        cumulative_reward, length = 0.0, 0

        observation, _ = self.reset()
        actor.reset()

        while not done and \
            (self.timestep_limit is None or length < self.timestep_limit) and \
                (max_steps is None or length < max_steps):
                    
            action = actor.act(observation, explore=exploratory_actions_required)
            next_observation, reward, terminated, truncated, _ = self.step(action)
            cumulative_reward += reward
            length += 1

            transition = Transition(
                observation=observation,
                action=action,
                reward=reward,
                next_observation=next_observation,
                terminated=terminated
            )
            if replay_buffer is not None:
                replay_buffer.add(transition)

            observation = next_observation
            done = terminated or truncated

        return cumulative_reward, length, {}

    def multiepisode_evaluation(
        self, 
        number_of_episodes: int,
        actor: BaseActor,
        replay_buffer: BaseReplayBuffer | None = None,
        exploratory_actions_required: bool = False,
        verbose: bool = False
    ) -> tuple[float, int]:
        """
        Evaluate the given actor (in exploratory or greedy mode based on `exploratory_actions_required`)
        in the environment for a specified number of episodes, saving the encountered transitions
        in the given replay buffer (if provided).

        If `verbose`, prints the results for each episode as they are evaluated.

        Returns
        -------
            float
                The mean cumulative reward achieved by the actor over the episodes.
            int
                The mean episode length over the episodes.
        """
        total_reward, total_length = 0.0, 0

        for i in range(number_of_episodes):
            if verbose: print(f"Evaluation episode {i+1} / {number_of_episodes}")
            
            episode_return, episode_length, episode_info = self.evaluate(actor, replay_buffer, exploratory_actions_required)
            
            if verbose: print(f"Episode {i+1} - Cumulative reward: {episode_return}, Episode length: {episode_length}, Info: {episode_info}")
            
            total_reward += episode_return
            total_length += episode_length

        return total_reward / number_of_episodes, total_length // number_of_episodes
