# TD3 algorithm class

from .base_rl import BaseOffpolicyRL, BaseEnvironment, BaseReplayBuffer
from models.critics.base_critic import BaseStateActionCritic
from models.actors.base_actor import BaseContinuousDeterministicActor
from utils.types import *
from models.torch_models import get_optimizer_class
import numpy as np
import torch


class TD3(BaseOffpolicyRL):
    """
    TD3 algorithm class with soft update towards the EA actor.
    """
    # The actor used by the TD3 algorithm is a continuous deterministic actor,
    # hence an instance of a subclass of BaseContinuousDeterministicActor.
    actor: BaseContinuousDeterministicActor
    
    def __init__(
        self,
        environment: BaseEnvironment,
        actor: BaseContinuousDeterministicActor,
        replay_buffer: BaseReplayBuffer,
        template_critic: BaseStateActionCritic,
        maximum_number_of_rollouts_per_training_iteration: int | None = 1,
        maximum_number_of_steps_per_training_iteration: int | None = None,
        rollout_steps_per_training_step: int = 1,
        number_of_collected_transitions_before_training: int | None = None,
        discount_factor: float = 0.99,
        target_soft_update_strength: float = 0.005,
        critics_updates_per_training_step: int = 2,
        policy_smoothing_noise_std: float = 0.2,
        policy_smoothing_noise_maximum: float = 0.5,
        batch_size: int = 256,
        actor_exploration_noise_std_initial: float = 0.2,
        actor_exploration_noise_std_decay: float = 0.9999,
        actor_exploration_noise_std_minimum: float = 0.0,
        optimizer: str = "adam",
        actor_learning_rate: float = 3e-4,
        critic_learning_rate: float = 3e-4,
        actor_maximum_gradient_norm_for_clipping: float | None = None,
        critic_maximum_gradient_norm_for_clipping: float | None = None,
        other_optimizer_hyperparameters: dict[str, Any] | None = None,
        evolutionary_soft_update_strength: float = 0.1,
        seed: int | None = None
    ) -> None:
        """
        TD3 algorithm class with soft update towards the EA actor.

        Hyperparameters:
        - `maximum_number_of_rollouts_per_training_iteration` (int | None): The maximum number of rollouts
        to be performed in each training iteration. (default: 1)
        - `maximum_number_of_steps_per_training_iteration` (int | None): The maximum number of steps to be taken
        in the environment in each training iteration. (default: None)
        - `rollout_steps_per_training_step` (int): The number of rollout steps taken in the environment
        before each training step. (default: 1)
        - `number_of_collected_transitions_before_training` (int | None): The minimal number of transitions
        that has to be gathered in the replay buffer before any training occurs. (default: None)
        - `discount_factor` (float): The discount factor for future rewards - the gamma value for the RL algorithm.
        (default: 0.99)
        - `target_soft_update_strength` (float): The soft update strength for the target networks.
        (They are updated each time the actor is updated, so at the end of each training step.) (default: 0.005)
        - `critics_updates_per_training_step` (int): The number of critic updates to be performed in each
        training step, so per one actor update, basically a policy delay. (default: 2)
        - `policy_smoothing_noise_std` (float): The standard deviation of the Gaussian noise added to the
        predicted target actions during training of the critics for target policy smoothing. (default: 0.2)
        - `policy_smoothing_noise_maximum` (float): The maximum value of the Gaussian noise added to the
        predicted target actions during training of the critics for target policy smoothing. (default: 0.5)
        - `batch_size` (int): The size of transition batches used for training the critics and the actor
        each training step. (The critics sample multiple such batches each training step, depending on the
        `critics_updates_per_training_step` hyperparameter value.) (default: 256)
        - `actor_exploration_noise_std_initial` (float): The initial standard deviation of the exploration
        noise added to the actor's actions during collecting transitions via rollouts in the environment.
        The type of the noise depends solely on the actor. (It might be a Gaussian noise, Ornstein-Uhlenbeck
        noise, or other similar noise type parametrized by the standard deviation.) (default: 0.2)
        - `actor_exploration_noise_std_decay` (float): The decay rate of the exploration noise standard deviation.
        (The standard deviation undergoes decay after each training step.) (default: 0.9999)
        - `actor_exploration_noise_std_minimum` (float): The minimum value of the exploration noise standard deviation.
        (default: 0.0)
        - `optimizer` (str): The type of the optimizer used for training the actor and the critics. (default: "adam")
        - `actor_learning_rate` (float): The learning rate for the actor's optimizer. (default: 3e-4)
        - `critic_learning_rate` (float): The learning rate for the critics' optimizer. (default: 3e-4)
        - `actor_maximum_gradient_norm_for_clipping` (float | None): The maximum norm for gradient clipping applied during
        the training of the actor. If None, no clipping is applied. (default: None)
        - `critic_maximum_gradient_norm_for_clipping` (float | None): The maximum norm for gradient clipping applied during
        the training of the critics. If None, no clipping is applied. (default: None)
        - `other_optimizer_hyperparameters` (dict[str, Any] | None): Other hyperparameters for the optimizers.
        (Might be momentum for SGD, weight decay, beta and epsilon setting for Adam, etc.) (default: None)
        - `evolutionary_soft_update_strength` (float): The strength of the soft update towards the evolutionary actor
        when used as part of an ERL algorithm. (default: 0.05)
        - `seed` (int | None): Random seed for reproducibility. (default: None)
        """
        super().__init__(
            environment,
            actor,
            replay_buffer,
            maximum_number_of_rollouts_per_training_iteration=maximum_number_of_rollouts_per_training_iteration,
            maximum_number_of_steps_per_training_iteration=maximum_number_of_steps_per_training_iteration,
            rollout_steps_per_training_step=rollout_steps_per_training_step,
            number_of_collected_transitions_before_training=number_of_collected_transitions_before_training,
            discount_factor=discount_factor,
            target_soft_update_strength=target_soft_update_strength,
            critics_updates_per_training_step=critics_updates_per_training_step,
            policy_smoothing_noise_std=policy_smoothing_noise_std,
            policy_smoothing_noise_maximum=policy_smoothing_noise_maximum,
            batch_size=batch_size,
            actor_exploration_noise_std_initial=actor_exploration_noise_std_initial,
            actor_exploration_noise_std_decay=actor_exploration_noise_std_decay,
            actor_exploration_noise_std_minimum=actor_exploration_noise_std_minimum,
            optimizer=optimizer,
            actor_learning_rate=actor_learning_rate,
            critic_learning_rate=critic_learning_rate,
            actor_maximum_gradient_norm_for_clipping=actor_maximum_gradient_norm_for_clipping,
            critic_maximum_gradient_norm_for_clipping=critic_maximum_gradient_norm_for_clipping,
            other_optimizer_hyperparameters=other_optimizer_hyperparameters,
            evolutionary_soft_update_strength=evolutionary_soft_update_strength,
            seed=seed
        )
        
        # Actor exploration schedule hyperparameters (noise std decays per training iteration)
        self.actor.train()
        self.actor.set_exploration_parameters(exploration_noise_std=actor_exploration_noise_std_initial)
        self.actor_exploration_noise_std_minimum = actor_exploration_noise_std_minimum
        self.actor_exploration_noise_std_decay = actor_exploration_noise_std_decay
        
        # Target actor
        self.actor_target = self.actor.copy().to(self.actor.device)
        for p in self.actor_target.parameters():
            p.requires_grad_(False)
        self.actor_target.eval()
        
        # Critic networks and their targets
        self.Q1 = template_critic.create_sibling(seed=self.seed).to(self.actor.device)
        self.Q2 = template_critic.create_sibling(seed=(self.seed + 1 if self.seed is not None else None)).to(self.actor.device)

        self.Q1_target = self.Q1.copy().to(self.Q1.device)
        for p in self.Q1_target.parameters():
            p.requires_grad_(False)
        self.Q1_target.eval()
        self.Q2_target = self.Q2.copy().to(self.Q2.device)
        for p in self.Q2_target.parameters():
            p.requires_grad_(False)
        self.Q2_target.eval()

        # TD3 hyperparameters
        self.gamma = discount_factor
        self.tau = target_soft_update_strength
        self.batch_size = batch_size
        self.policy_smoothing_noise_std = policy_smoothing_noise_std
        self.policy_smoothing_noise_maximum = policy_smoothing_noise_maximum
        self.critics_updates_per_training_step = critics_updates_per_training_step # One training step = one actor and targets updates

        assert self.critics_updates_per_training_step >= 1, "At least one critic update per training step is required."

        # Models optimizers
        self.optimizer_type = optimizer
        self.actor_learning_rate = actor_learning_rate
        self.critic_learning_rate = critic_learning_rate
        self.actor_maximum_gradient_norm_for_clipping = actor_maximum_gradient_norm_for_clipping
        self.critic_maximum_gradient_norm_for_clipping = critic_maximum_gradient_norm_for_clipping
        self.other_optimizer_hyperparameters = other_optimizer_hyperparameters or {}

        self.actor_optimizer = get_optimizer_class(self.optimizer_type)(
            self.actor.parameters(),
            **{"lr": self.actor_learning_rate, **self.other_optimizer_hyperparameters}
        )
        self.Q1_optimizer = get_optimizer_class(self.optimizer_type)(
            self.Q1.parameters(),
            **{"lr": self.critic_learning_rate, **self.other_optimizer_hyperparameters}
        )
        self.Q2_optimizer = get_optimizer_class(self.optimizer_type)(
            self.Q2.parameters(),
            **{"lr": self.critic_learning_rate, **self.other_optimizer_hyperparameters}
        )

        # Evolutionary actor soft update strength
        self.evolutionary_soft_update_strength = evolutionary_soft_update_strength
        
        self.training_rng = torch.Generator(device=self.actor.device)
        
        # Synchronize the randomness of the algorithm with the given seed
        self.set_seed(self.seed)


    def set_seed(self, seed: int | None) -> None:
        self._seed = seed
        self.environment.set_seed(seed)
        self.replay_buffer.set_seed(seed)
        self.actor.set_seed(seed)
        self.actor_target.set_seed(seed)
        self.Q1.set_seed(seed)
        self.Q1_target.set_seed(seed)
        self.Q2.set_seed(seed)
        self.Q2_target.set_seed(seed)
        if seed is not None:
            self.training_rng.manual_seed(seed)


    def update_actor_based_on_ea(self, evolutionary_actor: BaseContinuousDeterministicActor) -> None:
        self.actor.soft_update_towards_sibling(
            evolutionary_actor,
            self.evolutionary_soft_update_strength
        )
        
        
    def _sample_batch_from_replay_buffer(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        batch = self.replay_buffer.sample(self.batch_size)
        states, actions, rewards, next_states, terminateds = \
            Transition.convert_batch_of_transitions_to_batches_of_elements(batch)
        return (
            torch.tensor(np.array(states), device=self.actor.device, dtype=self.actor.dtype),
            torch.tensor(np.array(actions), device=self.actor.device, dtype=self.actor.dtype),
            torch.tensor(np.array(rewards), device=self.actor.device, dtype=self.actor.dtype).view(-1, 1),
            torch.tensor(np.array(next_states), device=self.actor.device, dtype=self.actor.dtype),
            torch.tensor(np.array(terminateds), device=self.actor.device, dtype=torch.bool).view(-1, 1)
        )


    def training_step(self) -> None:
        # Critics training        
        for _ in range(self.critics_updates_per_training_step):
            states, actions, rewards, next_states, terminateds = self._sample_batch_from_replay_buffer()

            # Compute target Q-values
            with torch.no_grad():
                # Target policy actions
                next_actions = self.actor_target.act_batch(next_states)
                
                # Target policy smoothing
                next_actions += torch.clamp(
                    torch.randn(
                        next_actions.size(),
                        dtype=next_actions.dtype,
                        layout=next_actions.layout,
                        device=next_actions.device,
                        generator=self.training_rng
                    ) * self.policy_smoothing_noise_std, # NOTE - Could use torch.randn_like, but for some reason it, at least currently, does not support generator argument
                    -self.policy_smoothing_noise_maximum,
                    self.policy_smoothing_noise_maximum
                )
                next_actions = next_actions.clamp(self.actor.constraints.lower, self.actor.constraints.upper)

                # Target Q-values
                target_q1 = self.Q1_target.evaluate_batch(next_states, next_actions)
                target_q2 = self.Q2_target.evaluate_batch(next_states, next_actions)
                not_terminateds = torch.logical_not(terminateds).to(dtype=target_q1.dtype)
                target_q = rewards + (self.gamma * not_terminateds * torch.minimum(target_q1, target_q2))

            # Update Q1 and Q2 networks
            current_q1 = self.Q1.evaluate_batch(states, actions)
            current_q2 = self.Q2.evaluate_batch(states, actions)

            loss_Q1 = torch.nn.functional.mse_loss(current_q1, target_q)
            loss_Q2 = torch.nn.functional.mse_loss(current_q2, target_q)
            
            self.Q1_optimizer.zero_grad()
            loss_Q1.backward()
            if self.critic_maximum_gradient_norm_for_clipping is not None:
                torch.nn.utils.clip_grad_norm_(
                    self.Q1.parameters(),
                    max_norm=self.critic_maximum_gradient_norm_for_clipping
                )
            self.Q1_optimizer.step()

            self.Q2_optimizer.zero_grad()
            loss_Q2.backward()
            if self.critic_maximum_gradient_norm_for_clipping is not None:
                torch.nn.utils.clip_grad_norm_(
                    self.Q2.parameters(),
                    max_norm=self.critic_maximum_gradient_norm_for_clipping
                )
            self.Q2_optimizer.step()

        # Actor training
        states, actions, rewards, next_states, terminateds = self._sample_batch_from_replay_buffer()
        
        actions = self.actor.act_batch(states)
        actor_loss = -self.Q1.evaluate_batch(states, actions).mean()
        
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        if self.actor_maximum_gradient_norm_for_clipping is not None:
            torch.nn.utils.clip_grad_norm_(
                self.actor.parameters(),
                max_norm=self.actor_maximum_gradient_norm_for_clipping
            )
        self.actor_optimizer.step()

        # Target networks soft update
        with torch.no_grad():
            self.actor_target.soft_update_towards_sibling(self.actor, self.tau)
            self.Q1_target.soft_update_towards_sibling(self.Q1, self.tau)
            self.Q2_target.soft_update_towards_sibling(self.Q2, self.tau)
        
        # Exploration hyperparameters update
        self.actor.set_exploration_parameters(
            exploration_noise_std=max(
                self.actor_exploration_noise_std_minimum,
                self.actor.exploration_noise_std * self.actor_exploration_noise_std_decay
            )
        )
