Untitled

 avatar
unknown
python
2 months ago
4.8 kB
24
Indexable
import operator
from enum import Enum
from functools import lru_cache
from typing import Any, Dict, Optional

from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
from rest_framework.serializers import BaseSerializer


def get_from_attr_or_serializer(field, attrs, serializer, allow_null=False):
    """Helper to retrieve a field from attrs or serializer instance."""
    # Implementation assumed to exist as per original code.


def get_original_field_name(field, serializer):
    """Helper to get the original field name for error messages."""
    # Implementation assumed to exist as per original code.


class DateComparisonValidator:
    """Validates two dates using a comparison operator with conditional checks.

    Retrieves field values from serializer data or instance, checks optional
    conditions, and validates the operator-based comparison. Uses caching to
    optimize repeated comparisons and structured error reporting.

    Args:
        field1: First date field name.
        field2: Second date field name.
        operator: Comparison operator (gt, lt, etc.).
        message: Optional custom error message.
        condition_field: Optional field to enable validation conditionally.
        condition_value: Required value(s) for condition_field to activate validation.
        allow_null: If True, allows None values without validation.
    """

    requires_context = True

    class Operator(str, Enum):
        """Supported comparison operators with human-readable labels."""
        gt = "gt", _("greater than")
        lt = "lt", _("less than")
        gte = "gte", _("greater than or equal to")
        lte = "lte", _("less than or equal to")
        eq = "eq", _("equal to")
        ne = "ne", _("not equal to")


    _OPERATOR_FUNCS = {
        Operator.gt: operator.gt,
        Operator.lt: operator.lt,
        Operator.gte: operator.ge,
        Operator.lte: operator.le,
        Operator.eq: operator.eq,
        Operator.ne: operator.ne,
    }

    def __init__(
        self,
        field1: str,
        field2: str,
        operator: Operator,
        message: Optional[str] = None,
        condition_field: Optional[str] = None,
        condition_value: Optional[Any] = None,
        allow_null: bool = False,
    ):
        self.field1 = field1
        self.field2 = field2
        self.operator = operator
        self.message = message
        self.condition_field = condition_field
        self.condition_value = condition_value
        self.allow_null = allow_null
        self.operator_func = self._OPERATOR_FUNCS[operator]

    def __call__(self, attrs: Dict[str, Any], serializer: BaseSerializer) -> None:
        """Performs validation if conditions are met and fields are present."""
        if not self._should_validate(attrs, serializer):
            return

        field1_val = self._get_field_value(self.field1, attrs, serializer)
        field2_val = self._get_field_value(self.field2, attrs, serializer)

        if None in (field1_val, field2_val):
            return  # Skip validation if any field is None

        if not self._compare_cached(field1_val, field2_val):
            self._raise_validation_error(serializer)

    def _should_validate(self, attrs: Dict, serializer: BaseSerializer) -> bool:
        """Determines if validation should execute based on condition fields."""
        if not self.condition_field:
            return True  # No condition → always validate

        condition_val = get_from_attr_or_serializer(
            self.condition_field, attrs, serializer, allow_null=True
        )

        if self.condition_value is None:
            return True  # No expected value → validate unconditionally

        if isinstance(self.condition_value, list):
            return condition_val in self.condition_value
        return condition_val == self.condition_value

    def _get_field_value(
        self, field: str, attrs: Dict, serializer: BaseSerializer
    ) -> Any:
        """Retrieves a field's value from attrs or serializer instance."""
        return get_from_attr_or_serializer(
            field, attrs, serializer, allow_null=self.allow_null
        )

    @lru_cache(maxsize=128)
    def _compare_cached(self, a: Any, b: Any) -> bool:
        """Caches comparison results to optimize repeated calls with same inputs."""
        return self.operator_func(a, b)

    def _raise_validation_error(self, serializer: BaseSerializer) -> None:
        """Constructs and raises a ValidationError with contextual message."""
        field_name = get_original_field_name(self.field1, serializer)
        message = self.message or _(
            f"{field_name} must be {self.operator.label} {self.field2}."
        )
        raise serializers.ValidationError({field_name: message})
Editor is loading...
Leave a Comment