d:\chatbot\redis_cache.py

 avatar
unknown
python
a month ago
5.9 kB
5
Indexable
"""
redis_cache.py
──────────────
Redis-backed result cache for the NL2SQL pipeline.

Two responsibilities:
  1. Result cache  — skip the full 3-Gemini-call pipeline for repeat/similar queries
  2. Job state store — track async job status (if you add Celery later)

Minimum viable version: synchronous result cache only.
Celery task queue wiring is scaffolded but commented — enable when ready to go async.
"""

import os
import json
import hashlib
import logging
import re
from typing import Optional, Any

import redis

logger = logging.getLogger(__name__)

# ── Config ────────────────────────────────────────────────────────

REDIS_URL      = os.environ.get("REDIS_URL",       "redis://localhost:6379/0")
CACHE_TTL_HOT  = int(os.environ.get("CACHE_TTL_HOT",  "3600"))    # 1 hour  — repeated queries
CACHE_TTL_WARM = int(os.environ.get("CACHE_TTL_WARM", "86400"))   # 24 hours — less frequent

# ── Client singleton ──────────────────────────────────────────────

_redis_client: Optional[redis.Redis] = None


def get_redis() -> redis.Redis:
    global _redis_client
    if _redis_client is None:
        _redis_client = redis.from_url(REDIS_URL, decode_responses=True)
        _redis_client.ping()   # fail fast at startup if Redis is unreachable
        logger.info("Redis connected at %s", REDIS_URL)
    return _redis_client


# ── Key generation ────────────────────────────────────────────────

def _normalise(nl_query: str) -> str:
    """Lowercase, strip punctuation, collapse whitespace."""
    q = nl_query.lower().strip()
    q = re.sub(r"[^\w\s]", "", q)
    q = re.sub(r"\s+", " ", q)
    return q


def cache_key(nl_query: str) -> str:
    """Stable cache key from normalised query text."""
    normalised = _normalise(nl_query)
    digest     = hashlib.sha256(normalised.encode()).hexdigest()[:16]
    return f"nl2sql:result:{digest}"


def job_key(job_id: str) -> str:
    return f"nl2sql:job:{job_id}"


# ── Result cache (synchronous pipeline) ───────────────────────────

def get_cached_result(nl_query: str) -> Optional[dict]:
    """
    Return cached QueryResult dict if it exists, else None.
    Also bumps TTL on hit (keeps hot queries alive longer).
    """
    try:
        r   = get_redis()
        key = cache_key(nl_query)
        raw = r.get(key)
        if raw:
            r.expire(key, CACHE_TTL_HOT)   # refresh TTL on cache hit
            logger.info("Cache HIT for query: %.60s...", nl_query)
            result = json.loads(raw)
            result["_cache_hit"] = True
            return result
    except Exception as exc:
        logger.warning("Redis get failed (bypassing cache): %s", exc)
    return None


def set_cached_result(nl_query: str, result_dict: dict, ttl: int = CACHE_TTL_WARM):
    """
    Store a QueryResult dict in Redis.
    Only caches successful results — never caches errors.
    """
    try:
        if not result_dict.get("success", False):
            return   # never cache failures

        r   = get_redis()
        key = cache_key(nl_query)

        # Remove non-serialisable fields before storing
        storable = {
            k: v for k, v in result_dict.items()
            if k not in ("viz_meta",)   # VizMeta dataclass — store as dict instead
        }

        r.setex(key, ttl, json.dumps(storable, default=str))
        logger.info("Cached result for query: %.60s... (TTL=%ds)", nl_query, ttl)
    except Exception as exc:
        logger.warning("Redis set failed (result not cached): %s", exc)


def invalidate(nl_query: str):
    """Manually invalidate a cached result (e.g. after schema change)."""
    try:
        get_redis().delete(cache_key(nl_query))
    except Exception as exc:
        logger.warning("Redis delete failed: %s", exc)


def invalidate_all():
    """Flush all nl2sql result cache entries. Call after schema refresh."""
    try:
        r    = get_redis()
        keys = r.keys("nl2sql:result:*")
        if keys:
            r.delete(*keys)
            logger.info("Invalidated %d cached results", len(keys))
    except Exception as exc:
        logger.warning("Redis flush failed: %s", exc)


# ── Job state store (scaffold for future Celery async) ────────────

class JobStatus:
    PENDING  = "pending"
    RUNNING  = "running"
    DONE     = "done"
    FAILED   = "failed"


def set_job_status(job_id: str, status: str, payload: dict = None, ttl: int = 3600):
    """Write job state to Redis. Used by async workers (Celery) when you add them."""
    try:
        data = {"status": status, **(payload or {})}
        get_redis().setex(job_key(job_id), ttl, json.dumps(data, default=str))
    except Exception as exc:
        logger.warning("Job state write failed: %s", exc)


def get_job_status(job_id: str) -> Optional[dict]:
    """Read job state. Returns None if job_id not found."""
    try:
        raw = get_redis().get(job_key(job_id))
        return json.loads(raw) if raw else None
    except Exception as exc:
        logger.warning("Job state read failed: %s", exc)
        return None


# ── Health check ──────────────────────────────────────────────────

def redis_healthy() -> bool:
    try:
        return get_redis().ping()
    except Exception:
        return False
Editor is loading...