# Class defining classical replay buffer for storing and sampling transitions based on deque.

from collections import deque
import random
from .base_replay_buffer import BaseReplayBuffer
from utils.types import *


class ReplayBuffer(BaseReplayBuffer):
    """
    Classical replay buffer for storing and sampling transitions based on deque.
    """
    def __init__(self, capacity: int | None = None) -> None:
        """
        Initialize the replay buffer with a predefined fixed capacity.
        If capacity is None, the buffer will grow indefinitely.
        """
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)
        self.rng = random.Random()


    def add(self, experience: Transition) -> None:
        self.buffer.append(experience)


    def sample(self, batch_size: int) -> list[Transition]:
        if len(self.buffer) < batch_size:
            batch_size = len(self.buffer)

        return self.rng.sample(self.buffer, batch_size)
    
    
    def set_seed(self, seed: int | None) -> None:
        self.rng.seed(seed)


    def __len__(self) -> int:
        return len(self.buffer)
    
    
    def reset(self) -> None:
        """
        Reset the buffer, clearing all stored experiences.
        """
        self.buffer.clear()
