# Various noise processes to be used for exploration

import torch


class OrnsteinUhlenbeckNoise:
    """Ornstein-Uhlenbeck process."""

    def __init__(
        self,
        shape: int | tuple[int, ...],
        mu: float = 0.,
        sigma: float = 0.2,
        theta: float = 0.15,
        device: torch.device | str = "cpu",
        rng: torch.Generator | None = None
    ) -> None:
        self.device = torch.device(device) if isinstance(device, str) else device

        self.mu = mu * torch.ones(shape, device=self.device)
        self.sigma = sigma
        self.theta = theta
        
        self.rng = rng
        if self.rng is not None:
            assert self.rng.device == self.device, "Random generator must be on the same device as the noise process."

        self.reset()


    def reset(self) -> None:
        self.state = self.mu.clone()


    def sample(self) -> torch.Tensor:
        self.state += torch.normal(mean=(self.theta * (self.mu - self.state)), std=self.sigma, generator=self.rng)
        return self.state
