# Cross-Entropy Method classes.

from .base_ea import BaseEA, BaseActor, BaseReplayBuffer, BaseEnvironment
from utils.types import *

import numpy as np


class CEM(BaseEA):
    """
    Basic Cross-Entropy Method class with sequential population evaluation in a given environment
    and a soft update towards the RL actor.
    """
    def __init__(
        self,
        environment: BaseEnvironment,
        template_actor: BaseActor,
        replay_buffer: BaseReplayBuffer | None = None,
        population_size: int = 10,
        antithetic_sampling: bool = False,
        elite_fraction: float = 0.5,
        use_diagonal_covariance_matrix: bool = False,
        initial_standard_deviation: float = 1.0,
        update_weights_type: str = "uniform",
        mean_to_use: str = "old",
        prevent_premature_convergence: bool = True,
        added_noise_decay: float = 0.99,
        minimal_standard_deviation: float = 0.01,
        reinforcement_soft_update_strength: float = 0.1,
        seed: int | None = None
    ) -> None:
        """
        Basic Cross-Entropy Method class with sequential population evaluation in a given environment
        and a soft update towards the RL actor.

        The provided actor is copied and the copy is used as a template for the actors for evolution
        via the create_sibling method.
        
        The replay buffer is used to store transitions encountered during the evolution process
        for usage in the classical RL algorithm, if provided.
        
        Hyperparameters:
        - `population_size` (int): Number of tested perturbations from the distribution each generation. (default: 10)
        - `antithetic_sampling` (bool): If True, only half the population is generated and the rest of the population
        is filled with antithetic counterparts of the first half. (default: False)
        - `elite_fraction` (float): Fraction of the best members of the population that will form the new distribution
        each generation. (default: 0.5)
        - `use_diagonal_covariance_matrix` (bool): If True, will use just the variances of the parameters (diagonal
        covariance matrix), otherwise uses the full covariance matrix. (default: False)
        - `initial_standard_deviation` (float): Initial standard deviation for the distribution. (default: 1.0)
        - `update_weights_type` (str): Type of update weights for the mean and covariance matrix.
        ("uniform" / "quality_based", default: "uniform")
        - `mean_to_use` (str): Which mean to use when computing the new covariance matrix. ("new" / "old", default: "old")
        - `prevent_premature_convergence` (bool): If True, will add exponentially decreasing noise to the variances
        in each covariance matrix to prevent premature convergence. (default: True)
        - `added_noise_decay` (float): Rate of decay of the added noise preventing the premature convergence;
        used when prevent_premature_convergence is True. (default: 0.99)
        - `minimal_standard_deviation` (float): Minimal variance of the added noise preventing the premature convergence;
        used when prevent_premature_convergence is True. (default: 0.01)
        - `reinforcement_soft_update_strength` (float): Strength of the soft update towards the reinforcement learning
        actor. (default: 0.01)
        - `seed` (int | None): Random seed for reproducibility. (default: None)
        """
        super().__init__(
            environment,
            template_actor,
            replay_buffer,
            population_size=population_size,
            elite_fraction=elite_fraction,
            use_diagonal_covariance_matrix=use_diagonal_covariance_matrix,
            initial_standard_deviation=initial_standard_deviation,
            update_weights_type=update_weights_type,
            mean_to_use=mean_to_use,
            prevent_premature_convergence=prevent_premature_convergence,
            added_noise_decay=added_noise_decay,
            minimal_standard_deviation=minimal_standard_deviation,
            reinforcement_soft_update_strength=reinforcement_soft_update_strength,
            seed=seed
        )
        
        # Random number generator for the EA
        self.rng = np.random.default_rng(self.seed)

        # Distribution mean
        self.mean_actor = self._template_actor.create_sibling(self.seed)
        
        # Distribution covariance matrix setting
        # If True, we will keep just the variances of the parameters (in a vector for higher efficiency), otherwise we will keep the full covariance matrix
        self.diagonal_covariance = use_diagonal_covariance_matrix
        if self.diagonal_covariance:
            self.variance_vector = np.ones(len(self.mean_actor.vectorize())) * (initial_standard_deviation ** 2)
        else:
            self.covariance_matrix = np.eye(len(self.mean_actor.vectorize())) * (initial_standard_deviation ** 2)

        # Number of individuals generated from the distribution each generation
        self.population_size = population_size
        # Use antithetic sampling to reduce variance of the estimates
        self.antithetic_sampling = antithetic_sampling
        if self.antithetic_sampling and self.population_size % 2 != 0:
            raise ValueError("Population size must be even when using antithetic sampling.")
        # Number of elite individuals to select from the population to form the new distribution
        self.elite_size = int(self.population_size * elite_fraction)

        # Update weights setting (The weights are used to compute the new mean and covariance matrix from the elite individuals)
        # Type of update weights for the mean and covariance matrix ("uniform" / "quality_based")
        update_weights_type = update_weights_type.lower()
        if update_weights_type == "uniform":
            self.update_weights = np.ones(self.elite_size) / self.elite_size
        elif update_weights_type == "quality_based":
            self.update_weights = np.log(1 + self.elite_size) / np.arange(1, self.elite_size + 1)
            self.update_weights = self.update_weights / np.sum(self.update_weights)
        else:
            raise ValueError(f"Unknown update weights type: {update_weights_type}")
        
        # Which mean to use when computing the new covariance matrix
        # Either the new mean computed from the elite individuals (which is the classical CEM),
        # or the old mean from the previous iteration (which is more like CMA, should be more efficient and is the one used in CEM-RL)
        self.mean_to_use = mean_to_use.lower()
        if self.mean_to_use not in ["new", "old"]:
            raise ValueError(f"Unknown mean to use: {self.mean_to_use}. Use 'new' or 'old'.")
        
        # Premature convergence prevention setting
        # If True, will be adding exponentially decreasing noise to the variances in each covariance matrix to prevent premature convergence
        self.prevent_premature_convergence = prevent_premature_convergence
        if self.prevent_premature_convergence:
            # Added noise variance
            self.added_noise = initial_standard_deviation ** 2
            # Rate of decay of the added noise preventing the premature convergence
            self.added_noise_decay = added_noise_decay
            # Minimal variance of the added noise preventing the premature convergence (the added noise will decay towards this value)
            self.added_noise_lower_bound = minimal_standard_deviation ** 2

        # Strength of the soft update towards the reinforcement learning actor
        self.reinforcement_soft_update_strength = reinforcement_soft_update_strength

        # Actor used for evaluations in the environment
        self.evaluation_actor = self.mean_actor.copy()
        
        # Synchronize the randomness of the EA with the given seed
        self.set_seed(self.seed)


    @property
    def evolutionary_actor(self) -> BaseActor:
        return self.mean_actor
    
    
    def set_seed(self, seed: int | None) -> None:
        self._seed = seed
        self.rng = np.random.default_rng(self.seed)
        self.evaluation_actor.set_seed(self.seed)
        self.mean_actor.set_seed(self.seed)
        self.environment.set_seed(self.seed)


    def update_evolution_state_based_on_rl(self, reinforcement_actor: BaseActor) -> None:
        self.mean_actor.soft_update_towards_sibling(reinforcement_actor, self.reinforcement_soft_update_strength)


    def evolve(self) -> int:
        interactions = 0
        old_mean = self.mean_actor.vectorize()
        
        # Population generation
        if self.diagonal_covariance:
            population_deviations = self.rng.normal(
                np.zeros_like(old_mean),
                np.sqrt(self.variance_vector),
                size=(
                    (self.population_size if not self.antithetic_sampling else self.population_size // 2),
                    len(old_mean)
                )
            )
        else:
            population_deviations = self.rng.multivariate_normal(
                np.zeros_like(old_mean),
                self.covariance_matrix,
                size=(self.population_size if not self.antithetic_sampling else self.population_size // 2)
            )
            
        if self.antithetic_sampling:
            # Fill the second half of the population with antithetic counterparts
            population_deviations = np.concatenate((population_deviations, -population_deviations), axis=0)
            
        population = old_mean + population_deviations

        # Population evaluation
        fitnesses = list()
        for individual in population:
            self.evaluation_actor.set_parameters(individual)
            fitness, episode_length, _ = self.environment.evaluate(self.evaluation_actor, self.replay_buffer)
            fitnesses.append(fitness)
            interactions += episode_length

        # Update of the parameters of CEM (mean and covariance matrix)
        # Get indices of the elite individuals in descending order of fitness
        elite_indices = np.argsort(fitnesses)[-self.elite_size:][::-1]
        new_mean = self.update_weights @ population[elite_indices]
        self.mean_actor.set_parameters(new_mean)

        if self.mean_to_use == "new":
            mean = new_mean
        elif self.mean_to_use == "old":
            mean = old_mean
        
        if self.diagonal_covariance:
            self.variance_vector = np.sum([weight * (population[i] - mean) ** 2 for i, weight in zip(elite_indices, self.update_weights)], axis=0) \
                + (self.added_noise if self.prevent_premature_convergence else 0)
        else:
            self.covariance_matrix = np.sum([weight * (population[i] - mean).reshape(-1, 1) @ (population[i] - mean).reshape(1, -1) for i, weight in zip(elite_indices, self.update_weights)], axis=0) \
                + (self.added_noise * np.eye(len(mean)) if self.prevent_premature_convergence else 0)
                
        if self.prevent_premature_convergence:
            self.added_noise = self.added_noise_decay * self.added_noise + (1 - self.added_noise_decay) * self.added_noise_lower_bound
        
        return interactions
        