# Abstract base classes for critics used by classical RL algorithms based on PyTorch.

from abc import abstractmethod
from models.torch_models import BaseTorchModel
from utils.types import *
import torch
    
    
class BaseStateCritic(BaseTorchModel):
    """Base class for the state (observation) critics in RL algorithms."""
    
    @abstractmethod
    def evaluate(self, observation: ObservationType) -> float:
        """
        Evaluate the value of the given observation / state.
        """
        pass
    
    @abstractmethod
    def evaluate_batch(self, observation_batch: torch.Tensor) -> torch.Tensor:
        """
        Evaluate values of all the given observations / states.
        """
        pass
    
    
class BaseStateActionCritic(BaseTorchModel):
    """Base class for the state-action (observation-action) critics in RL algorithms."""

    @abstractmethod
    def evaluate(self, observation: ObservationType, action: ActionType) -> float:
        """
        Evaluate the value of the given observation-action pair.
        """
        pass
    
    @abstractmethod
    def evaluate_batch(self, observation_batch: torch.Tensor, action_batch: torch.Tensor) -> torch.Tensor:
        """
        Evaluate values of all the given observation-action pairs.
        """
        pass
