import gymnasium as gym
from ea.cem import CEM
from models.actors.mlp_actor import DiscreteMlpActor
from envs.gymnasium_env import GymnasiumEnvironment


def main(env_name, seed):
    env = GymnasiumEnvironment(gym.make(env_name))
    observation_size = int(env.env.observation_space.shape[0]) # type: ignore
    action_size = int(env.env.action_space.n) # type: ignore
    
    actor = DiscreteMlpActor(observation_size, action_size, hidden_layer_sizes=[4,2])
    
    cem = CEM(
        environment=env,
        template_actor=actor,
        population_size=10,
        antithetic_sampling=False,
        elite_fraction=0.5,
        use_diagonal_covariance_matrix=False,
        initial_standard_deviation=1.0,
        update_weights_type="quality_based",
        mean_to_use="old",
        prevent_premature_convergence=True,
        added_noise_decay=0.99,
        minimal_standard_deviation=0.01,
        reinforcement_soft_update_strength=0.0,
        seed=seed
    )
    
    cumulative_interactions = 0
    for i in range(50):
        # Evolve
        interactions = cem.evolve()
        cumulative_interactions += interactions
        
        # Test the resulting actor
        cumulative_reward, _, _ = env.evaluate(cem.evolutionary_actor)
        print(f"Iteration {i+1}; Training interactions in this iteration: {interactions} (Cumulative: {cumulative_interactions}); Test return: {cumulative_reward}")
        
    # Save the actor
    cem.evolutionary_actor.save(f"actor_CEM_{env_name}.pt")
    
    num_of_test_episodes = 5
    env.set_seed(seed)
    print()
    mean_return, mean_length = env.multiepisode_evaluation(num_of_test_episodes, cem.evolutionary_actor, verbose=True)
    print(f"Final evaluation summary - Mean return: {mean_return}, Mean episode length: {mean_length}")

    # Load the actor (to check consistency with the actor before saving)
    test_actor = actor.create_sibling()
    test_actor.load(f"actor_CEM_{env_name}.pt")

    env = GymnasiumEnvironment(gym.make(env_name, render_mode="human"))
    env.set_seed(seed)
    print()
    mean_return, mean_length = env.multiepisode_evaluation(num_of_test_episodes, test_actor, verbose=True)
    print(f"Final evaluation summary - Mean return: {mean_return}, Mean episode length: {mean_length}")


if __name__ == "__main__":
    main(env_name="CartPole-v1", seed=42)
