# Types and data structures

from typing import Any, NamedTuple
from dataclasses import dataclass, field
from models.torch_models import TorchBasedModel
import torch


ObservationType = Any
ActionType = Any

class Transition(NamedTuple):
    observation: ObservationType
    action: ActionType
    reward: float
    next_observation: ObservationType
    terminated: bool

    @classmethod
    def convert_batch_of_transitions_to_batches_of_elements(cls, batch_of_transitions: list["Transition"]) -> tuple[list[ObservationType], list[ActionType], list[float], list[ObservationType], list[bool]]:
        """
        Convert a batch of transitions into batches for each element.
        """
        observations = []
        actions = []
        rewards = []
        next_observations = []
        terminateds = []

        for transition in batch_of_transitions:
            observations.append(transition.observation)
            actions.append(transition.action)
            rewards.append(transition.reward)
            next_observations.append(transition.next_observation)
            terminateds.append(transition.terminated)

        return (
            observations,
            actions,
            rewards,
            next_observations,
            terminateds
        )

@dataclass(frozen=True)
class Constraints:
    """
    Constraints for the actor's actions.
    """
    lower: torch.Tensor | None = None
    upper: torch.Tensor | None = None
    _lower_cpu: torch.Tensor | None = field(default=None, init=False, repr=False, compare=False)
    _upper_cpu: torch.Tensor | None = field(default=None, init=False, repr=False, compare=False)

    @property
    def lower_cpu(self) -> torch.Tensor | None:
        """
        Return the lower constraints on CPU.
        """
        if self._lower_cpu is None and self.lower is not None:
            object.__setattr__(self, "_lower_cpu", self.lower.cpu())
        return self._lower_cpu

    @property
    def upper_cpu(self) -> torch.Tensor | None:
        """
        Return the upper constraints on CPU.
        """
        if self._upper_cpu is None and self.upper is not None:
            object.__setattr__(self, "_upper_cpu", self.upper.cpu())
        return self._upper_cpu
