# Abstract base class for replay buffers.

from abc import ABC, abstractmethod
from utils.types import *


class BaseReplayBuffer(ABC):
    """
    Abstract base class for replay buffers.
    """
    @abstractmethod
    def add(self, experience: Transition) -> None:
        """
        Add experience to the buffer.
        """
        pass

    @abstractmethod
    def sample(self, batch_size: int) -> list[Transition]:
        """
        Sample a batch of experiences from the buffer.
        """
        pass
    
    @abstractmethod
    def set_seed(self, seed: int | None) -> None:
        """
        Set the random seed for the replay buffer. (This should impact sampling.)
        """
        pass
    
    @abstractmethod
    def __len__(self) -> int:
        """
        Return the current size of the buffer.
        """
        pass
    
    def size(self) -> int:
        """
        Return the current size of the buffer.
        """
        return len(self)
    
    @abstractmethod
    def reset(self) -> None:
        """
        Reset the buffer, removing all stored experiences.
        """
        pass

    def clear(self) -> None:
        """
        Clear the buffer, removing all stored experiences.

        Alias for `reset()` function.
        """
        self.reset()
