d:\chatbot\redis_cache.py
unknown
python
a month ago
5.9 kB
5
Indexable
"""
redis_cache.py
──────────────
Redis-backed result cache for the NL2SQL pipeline.
Two responsibilities:
1. Result cache — skip the full 3-Gemini-call pipeline for repeat/similar queries
2. Job state store — track async job status (if you add Celery later)
Minimum viable version: synchronous result cache only.
Celery task queue wiring is scaffolded but commented — enable when ready to go async.
"""
import os
import json
import hashlib
import logging
import re
from typing import Optional, Any
import redis
logger = logging.getLogger(__name__)
# ── Config ────────────────────────────────────────────────────────
REDIS_URL = os.environ.get("REDIS_URL", "redis://localhost:6379/0")
CACHE_TTL_HOT = int(os.environ.get("CACHE_TTL_HOT", "3600")) # 1 hour — repeated queries
CACHE_TTL_WARM = int(os.environ.get("CACHE_TTL_WARM", "86400")) # 24 hours — less frequent
# ── Client singleton ──────────────────────────────────────────────
_redis_client: Optional[redis.Redis] = None
def get_redis() -> redis.Redis:
global _redis_client
if _redis_client is None:
_redis_client = redis.from_url(REDIS_URL, decode_responses=True)
_redis_client.ping() # fail fast at startup if Redis is unreachable
logger.info("Redis connected at %s", REDIS_URL)
return _redis_client
# ── Key generation ────────────────────────────────────────────────
def _normalise(nl_query: str) -> str:
"""Lowercase, strip punctuation, collapse whitespace."""
q = nl_query.lower().strip()
q = re.sub(r"[^\w\s]", "", q)
q = re.sub(r"\s+", " ", q)
return q
def cache_key(nl_query: str) -> str:
"""Stable cache key from normalised query text."""
normalised = _normalise(nl_query)
digest = hashlib.sha256(normalised.encode()).hexdigest()[:16]
return f"nl2sql:result:{digest}"
def job_key(job_id: str) -> str:
return f"nl2sql:job:{job_id}"
# ── Result cache (synchronous pipeline) ───────────────────────────
def get_cached_result(nl_query: str) -> Optional[dict]:
"""
Return cached QueryResult dict if it exists, else None.
Also bumps TTL on hit (keeps hot queries alive longer).
"""
try:
r = get_redis()
key = cache_key(nl_query)
raw = r.get(key)
if raw:
r.expire(key, CACHE_TTL_HOT) # refresh TTL on cache hit
logger.info("Cache HIT for query: %.60s...", nl_query)
result = json.loads(raw)
result["_cache_hit"] = True
return result
except Exception as exc:
logger.warning("Redis get failed (bypassing cache): %s", exc)
return None
def set_cached_result(nl_query: str, result_dict: dict, ttl: int = CACHE_TTL_WARM):
"""
Store a QueryResult dict in Redis.
Only caches successful results — never caches errors.
"""
try:
if not result_dict.get("success", False):
return # never cache failures
r = get_redis()
key = cache_key(nl_query)
# Remove non-serialisable fields before storing
storable = {
k: v for k, v in result_dict.items()
if k not in ("viz_meta",) # VizMeta dataclass — store as dict instead
}
r.setex(key, ttl, json.dumps(storable, default=str))
logger.info("Cached result for query: %.60s... (TTL=%ds)", nl_query, ttl)
except Exception as exc:
logger.warning("Redis set failed (result not cached): %s", exc)
def invalidate(nl_query: str):
"""Manually invalidate a cached result (e.g. after schema change)."""
try:
get_redis().delete(cache_key(nl_query))
except Exception as exc:
logger.warning("Redis delete failed: %s", exc)
def invalidate_all():
"""Flush all nl2sql result cache entries. Call after schema refresh."""
try:
r = get_redis()
keys = r.keys("nl2sql:result:*")
if keys:
r.delete(*keys)
logger.info("Invalidated %d cached results", len(keys))
except Exception as exc:
logger.warning("Redis flush failed: %s", exc)
# ── Job state store (scaffold for future Celery async) ────────────
class JobStatus:
PENDING = "pending"
RUNNING = "running"
DONE = "done"
FAILED = "failed"
def set_job_status(job_id: str, status: str, payload: dict = None, ttl: int = 3600):
"""Write job state to Redis. Used by async workers (Celery) when you add them."""
try:
data = {"status": status, **(payload or {})}
get_redis().setex(job_key(job_id), ttl, json.dumps(data, default=str))
except Exception as exc:
logger.warning("Job state write failed: %s", exc)
def get_job_status(job_id: str) -> Optional[dict]:
"""Read job state. Returns None if job_id not found."""
try:
raw = get_redis().get(job_key(job_id))
return json.loads(raw) if raw else None
except Exception as exc:
logger.warning("Job state read failed: %s", exc)
return None
# ── Health check ──────────────────────────────────────────────────
def redis_healthy() -> bool:
try:
return get_redis().ping()
except Exception:
return FalseEditor is loading...