d:\chatbot\nl2sql03.py
unknown
python
a month ago
54 kB
5
Indexable
import os
import re
import logging
import duckdb
import pandas as pd
import sqlfluff
from pathlib import Path
from datetime import datetime
from typing import Optional
from dataclasses import dataclass, field
import pyodbc
import paramiko
from sshtunnel import SSHTunnelForwarder
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
logger = logging.getLogger(__name__)
from dotenv import load_dotenv
load_dotenv()
from qdrant_index import (
ensure_collections,
index_table,
index_nl_sql_pair,
build_compact_schema_from_qdrant,
retrieve_few_shot_examples,
)
from redis_cache import get_cached_result, set_cached_result, invalidate_all
# ─────────────────────────────────────────────────────────────────
# Config — all values read from environment variables
# ─────────────────────────────────────────────────────────────────
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "AIzaSyAi-al-Cs4oL-8M_YQqSU8YVqN4lv7itqo")
GEMINI_MODEL = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
MAX_RETRIES = int(os.environ.get("MAX_RETRIES", 5))
OUTPUT_DIR = Path(os.environ.get("OUTPUT_DIR", "output"))
UPLOAD_DIR = Path(os.environ.get("UPLOAD_DIR", "uploads"))
# ── MS SQL Server via SSH Tunnel ──
SSH_HOST = os.environ.get("SSH_HOST", "3.108.55.154")
SSH_USER = os.environ.get("SSH_USER", "ubuntu")
SSH_PKEY_PATH = os.environ.get("SSH_PKEY_PATH", "/home/shared/pcsap/pcsap-DB.pem")
MSSQL_DB_HOST = os.environ.get("MSSQL_DB_HOST", "16.0.2.77")
MSSQL_DB_USER = os.environ.get("MSSQL_DB_USER", "pcsap@readonly")
MSSQL_DB_PASS = os.environ.get("MSSQL_DB_PASS", "pcsap@123")
MSSQL_DB_NAME = os.environ.get("MSSQL_DB_NAME", "FCI")
class TunnelledMSSQL:
"""
Manages a persistent SSH tunnel to the MS SQL Server.
Call get_connection() each time you need a pyodbc connection.
"""
_instance = None # singleton
def __init__(self):
self._tunnel = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = cls()
return cls._instance
def _ensure_tunnel(self):
if self._tunnel and self._tunnel.is_active:
return
pkey = paramiko.RSAKey.from_private_key_file(SSH_PKEY_PATH)
self._tunnel = SSHTunnelForwarder(
(SSH_HOST, 22),
ssh_username = SSH_USER,
ssh_pkey = pkey,
remote_bind_address = (MSSQL_DB_HOST, 1433),
set_keepalive = 60,
)
self._tunnel.start()
logger.info("SSH tunnel established on localhost:%d", self._tunnel.local_bind_port)
def get_connection(self):
self._ensure_tunnel()
conn_str = (
"DRIVER={ODBC Driver 18 for SQL Server};"
f"SERVER=127.0.0.1,{self._tunnel.local_bind_port};"
f"DATABASE={MSSQL_DB_NAME};"
f"UID={MSSQL_DB_USER};"
f"PWD={MSSQL_DB_PASS};"
"Encrypt=yes;"
"TrustServerCertificate=yes;"
"Connection Timeout=30;"
)
return pyodbc.connect(conn_str)
def stop(self):
if self._tunnel:
self._tunnel.stop()
logger.info("SSH tunnel closed.")
# ─────────────────────────────────────────────────────────────────
# Data-classes
# ─────────────────────────────────────────────────────────────────
@dataclass
class ColumnMeta:
name: str
dtype: str # pandas dtype string e.g. "int64"
sql_type: str # DuckDB SQL type e.g. "BIGINT"
nullable: bool
sample_values: list
@dataclass
class TableMeta:
table_name: str
source: str # "csv" | "excel" | "mssql:<original>"
db_name: str # filename or SQL Server DB name
row_count: int
columns: list # list[ColumnMeta]
schema_info: str # rich text block sent to LLM
@dataclass
class GenerateResult:
"""Result of the /generate step (SQL generation only, no execution)."""
success: bool
sql: str
tables_used: list = field(default_factory=list)
lint_warnings: list = field(default_factory=list)
error: Optional[str] = None
@dataclass
class ExecuteResult:
"""Result of the /execute step (raw SQL run, no LLM involved)."""
success: bool
sql: str
rows: int = 0
columns: list = field(default_factory=list)
data: list = field(default_factory=list) # list of dicts
output_csv: Optional[str] = None
error: Optional[str] = None
@dataclass
class QueryResult:
"""Result of the full /query pipeline (generate + validate + execute)."""
success: bool
final_sql: str
rows: int = 0
columns: list = field(default_factory=list)
data: list = field(default_factory=list)
tables_used: list = field(default_factory=list)
attempts: int = 0
output_csv: Optional[str] = None
history: list = field(default_factory=list)
error: Optional[str] = None
reasoning: str = "" # Agent 0+1 reasoning narrative
viz_meta: Optional[any] = None # VizMeta instance
column_roles: dict = field(default_factory=dict)
@dataclass
class VizMeta:
chart_type: str # "bar" | "line" | "pie" | "scatter" | "table" | ...
title: str # e.g. "Monthly Revenue by Category"
x_axis: Optional[str] = None # primary axis label
y_axis: Optional[str] = None # primary value axis label
secondary_axis: Optional[str] = None # only if dual-axis chart
series_label: Optional[str] = None # legend / series name
reasoning: str = "" # why this chart type was chosen
# ─────────────────────────────────────────────────────────────────
# Internal helpers
# ─────────────────────────────────────────────────────────────────
_DTYPE_TO_SQL: dict = {
"int64": "BIGINT",
"int32": "INT",
"int16": "SMALLINT",
"int8": "TINYINT",
"float64": "FLOAT",
"float32": "REAL",
"bool": "BOOLEAN",
"object": "VARCHAR",
"string": "VARCHAR",
"datetime64[ns]": "TIMESTAMP",
"date": "DATE",
"timedelta64[ns]": "INTERVAL",
}
def _map_sql_type(dtype) -> str:
key = str(dtype)
if key in _DTYPE_TO_SQL:
return _DTYPE_TO_SQL[key]
if key.startswith("datetime"):
return "TIMESTAMP"
if key.startswith("int"):
return "BIGINT"
if key.startswith("float"):
return "FLOAT"
return "VARCHAR"
def _build_schema_info(table_name: str, source: str, db_name: str, df: pd.DataFrame, keys: dict = None, include_samples: bool = True):
columns: list = []
col_lines: list = []
primary_keys = keys.get("primary_keys", []) if keys else []
foreign_keys = keys.get("foreign_keys", []) if keys else []
unique_keys = keys.get("unique_keys", []) if keys else []
# Build FK lookup: column → (ref_table, ref_column)
fk_map = {
fk["column"]: (fk["references_table"], fk["references_column"])
for fk in foreign_keys
}
for col in df.columns:
series = df[col]
sql_type = _map_sql_type(series.dtype)
nullable = bool(series.isna().any())
sample_vals = series.dropna().head(2).tolist() if include_samples else []
# Build key tags
tags = []
if col in primary_keys: tags.append("PK")
if col in fk_map: tags.append(f"FK → {fk_map[col][0]}.{fk_map[col][1]}")
if col in unique_keys: tags.append("UNIQUE")
tag_str = f" [{', '.join(tags)}]" if tags else ""
columns.append(ColumnMeta(
name = col,
dtype = str(series.dtype),
sql_type = sql_type,
nullable = nullable,
sample_values = sample_vals,
))
null_tag = "NULLABLE" if nullable else "NOT NULL"
if include_samples and sample_vals:
col_lines.append(
f' "{col}" {sql_type} [{null_tag}]{tag_str} -- e.g. {sample_vals}'
)
else:
col_lines.append(
f' "{col}" {sql_type} [{null_tag}]{tag_str}'
)
# Build FK summary block
fk_lines = []
for fk in foreign_keys:
fk_lines.append(
f"│ {fk['column']} → {fk['references_table']}.{fk['references_column']}"
)
fk_block = (
"│ Foreign Keys:\n" + "\n".join(fk_lines) + "\n"
if fk_lines else ""
)
pk_block = (
f"│ Primary Key : {', '.join(primary_keys)}\n"
if primary_keys else ""
)
schema_info = (
f"┌─ Table : {table_name}\n"
f"│ Source : {source}\n"
f"│ Database/File: {db_name}\n"
f"│ Row count : {len(df):,}\n"
+ pk_block
+ fk_block
+ f"│ Columns ({len(columns)}):\n"
+ "\n".join(col_lines)
+ "\n└─"
)
return columns, schema_info
def _safe_name(raw: str) -> str:
return re.sub(r"\W+", "_", raw).lower().strip("_")
def _unique_name(name: str, existing: set) -> str:
candidate = name
counter = 1
while candidate in existing:
candidate = f"{name}_{counter}"
counter += 1
return candidate
def _ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def _output_csv_path(label: str) -> Path:
slug = re.sub(r"\W+", "_", label[:40]).strip("_").lower()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return _ensure_dir(OUTPUT_DIR) / f"{slug}__{timestamp}.csv"
def _df_to_response(df: pd.DataFrame) -> tuple:
"""Convert a DataFrame to (rows, columns, data) ready for JSON serialisation."""
# Replace NaN/NaT with None so json.dumps doesn't choke
clean = df.where(pd.notnull(df), other=None)
return len(clean), list(clean.columns), clean.to_dict(orient="records")
def _sanitize_mssql_sql(sql: str) -> str:
"""
Convert DuckDB-style double-quoted identifiers to MS SQL bracket style.
e.g. T1."QuestionName" → T1.[QuestionName]
"Id" → [Id]
Also strips trailing semicolons.
"""
# Replace "identifier" with [identifier] — but NOT string literals
# This regex matches double-quoted words that are used as identifiers
# (i.e., not preceded by = or LIKE or IN — those are string values)
sanitized = re.sub(
r'(?<![=<>!,\(\s])(\."?)([A-Za-z_][A-Za-z0-9_ ]*)"',
lambda m: f'.[{m.group(2).lstrip(chr(34))}]',
sql
)
# Also handle standalone "ColumnName" not preceded by a dot
sanitized = re.sub(
r'(?<!\.)(?<![=<>!(,\s])"([A-Za-z_][A-Za-z0-9_ ]*)"',
r'[\1]',
sanitized
)
return sanitized.rstrip(";").strip()
# ─────────────────────────────────────────────────────────────────
# DataLoader
# ─────────────────────────────────────────────────────────────────
class DataLoader:
"""
Single in-memory DuckDB instance that holds all loaded tables.
Sources: CSV / Excel files + SQL Server tables.
Cross-table JOINs work freely because everything shares one connection.
"""
def __init__(self):
self.conn = duckdb.connect(database=":memory:")
self.tables: dict[str, TableMeta] = {}
self._mssql_table_map: dict[str, str] = {}
# ── registration ──────────────────────────
def _register_df(self, df: pd.DataFrame, table_name: str, source: str, db_name: str):
self.conn.register(table_name, df)
columns, schema_info = _build_schema_info(table_name, source, db_name, df)
self.tables[table_name] = TableMeta(
table_name = table_name,
source = source,
db_name = db_name,
row_count = len(df),
columns = columns,
schema_info = schema_info,
)
try:
index_table(
internal_name = table_name,
display_name = table_name,
schema_info = schema_info,
row_count = len(df),
)
except Exception as exc:
logger.warning("Qdrant index failed for '%s': %s — continuing without it", table_name, exc)
logger.info("Loaded table '%s' — %d rows × %d cols [%s]",
table_name, len(df), len(df.columns), source)
# ── CSV / Excel ───────────────────────────
def load_file(self, file_path: str) -> str:
"""Load one CSV or Excel file. Returns the DuckDB table name assigned."""
path = Path(file_path)
ext = path.suffix.lower()
if ext == ".csv":
df, source = pd.read_csv(path), "csv"
elif ext in (".xlsx", ".xls"):
df, source = pd.read_excel(path), "excel"
else:
raise ValueError(f"Unsupported file type: {ext}. Use .csv / .xlsx / .xls")
table_name = _unique_name(_safe_name(path.stem), set(self.tables))
self._register_df(df, table_name, source, db_name=path.name)
return table_name
def load_files(self, file_paths: list) -> list:
return [self.load_file(fp) for fp in file_paths]
# ── SQL Server ────────────────────────────
# ── MS SQL Server via SSH Tunnel ──────────────
def _list_mssql_tunnelled_tables(self, conn) -> list:
cursor = conn.cursor()
cursor.execute(
"SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
"WHERE TABLE_TYPE = 'BASE TABLE' "
"AND TABLE_CATALOG = 'FCI' "
"AND TABLE_SCHEMA = 'dbo' "
"ORDER BY TABLE_NAME"
)
return [row[0] for row in cursor.fetchall()]
def _build_mssql_schema_info(self, conn, original_name: str) -> tuple:
"""
Fetch only schema + sample rows from MS SQL — NO full data load.
Registers a dummy empty DataFrame in DuckDB just so the pipeline
knows the table exists and its columns.
"""
# Get column info
cursor = conn.cursor()
cursor.execute(f"""
SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_CATALOG = 'FCI'
AND TABLE_SCHEMA = 'dbo'
AND TABLE_NAME = ?
ORDER BY ORDINAL_POSITION
""", original_name)
col_rows = cursor.fetchall()
# Get row count
cursor.execute(f"SELECT COUNT(*) FROM [dbo].[{original_name}]")
row_count = cursor.fetchone()[0]
# Get 3 sample rows for schema context
sample_df = pd.read_sql(
f"SELECT TOP 3 * FROM [dbo].[{original_name}]", conn
)
return col_rows, row_count, sample_df
def load_mssql_tunnelled_table(self, original_name: str, conn) -> str:
col_rows, row_count, sample_df = self._build_mssql_schema_info(conn, original_name)
# ── Fetch key relationships ──
try:
keys = self._fetch_mssql_keys(conn, original_name)
except Exception as exc:
logger.warning("Could not fetch keys for '%s': %s", original_name, exc)
keys = {"primary_keys": [], "foreign_keys": [], "unique_keys": []}
name = _unique_name(_safe_name(original_name), set(self.tables))
self.conn.register(name, sample_df.iloc[0:0])
# Pass keys into schema builder
columns, schema_info = _build_schema_info(
f"[dbo].[{original_name}]", # ← LLM sees this as the table name
f"mssql_tunnel:{original_name}",
MSSQL_DB_NAME,
sample_df,
keys=keys,
include_samples=False
)
schema_info = re.sub(
r"│ Row count : [\d,]+",
f"│ Row count : {row_count:,}",
schema_info
)
self.tables[name] = TableMeta(
table_name = name,
source = f"mssql_tunnel:{original_name}",
db_name = MSSQL_DB_NAME,
row_count = row_count,
columns = columns,
schema_info = schema_info,
)
self._mssql_table_map[name] = original_name
logger.info("Registered MS SQL schema '%s' — %d rows, %d FKs",
name, row_count, len(keys["foreign_keys"]))
try:
display = f"[dbo].[{original_name}]"
index_table(
internal_name = name,
display_name = display,
schema_info = schema_info,
row_count = row_count,
)
except Exception as exc:
logger.warning("Qdrant index failed for '%s': %s — continuing without it", name, exc)
return name
def load_mssql_tunnelled_tables(self, table_names: Optional[list] = None) -> list:
"""
Load MS SQL Server tables over SSH tunnel.
Pass None to auto-discover and load ALL tables.
"""
tunnelled_db = TunnelledMSSQL.get_instance()
conn = tunnelled_db.get_connection()
if not table_names:
logger.info("Discovering all tables in MS SQL DB '%s'", MSSQL_DB_NAME)
table_names = self._list_mssql_tunnelled_tables(conn)
logger.info("Found %d table(s): %s", len(table_names), table_names)
results = []
for t in table_names:
try:
name = self.load_mssql_tunnelled_table(t, conn)
results.append(name)
except Exception as exc:
logger.warning("Skipping table '%s': %s", t, exc)
conn.close()
return results
def execute_on_mssql(self, sql: str) -> pd.DataFrame:
"""
Run SQL directly on MS SQL Server via tunnel.
No translation needed — LLM generates [dbo].[TableName] directly for mssql dialect.
"""
tunnelled_db = TunnelledMSSQL.get_instance()
conn = tunnelled_db.get_connection()
try:
return pd.read_sql(sql, conn)
finally:
conn.close()
def is_mssql_table(self, name: str) -> bool:
return name in self._mssql_table_map
def uses_mssql_tables(self, table_names: list) -> bool:
return any(self.is_mssql_table(n) for n in table_names)
# ── query / introspection ─────────────────
def execute(self, sql: str) -> pd.DataFrame:
return self.conn.execute(sql).df()
def schema_for_tables(self, names: list) -> str:
parts = [self.tables[n].schema_info for n in names if n in self.tables]
return "\n\n".join(parts) or "(no matching tables)"
def all_schemas_summary(self) -> str:
return self.schema_for_tables(list(self.tables.keys()))
def table_names(self) -> list:
return list(self.tables.keys())
def validate_table_names(self, names: list) -> tuple:
valid = [n for n in names if n in self.tables]
unknown = [n for n in names if n not in self.tables]
return valid, unknown
def table_info_dict(self, table_name: str) -> dict:
"""Return a JSON-serialisable dict describing one table."""
meta = self.tables[table_name]
return {
"table_name": meta.table_name,
"source": meta.source,
"db_name": meta.db_name,
"row_count": meta.row_count,
"columns": [
{
"name": c.name,
"sql_type": c.sql_type,
"dtype": c.dtype,
"nullable": c.nullable,
"sample_values": [str(v) for v in c.sample_values],
}
for c in meta.columns
],
}
def all_tables_summary_compact(self, max_tables: int = 200) -> str:
lines = []
for name, meta in list(self.tables.items())[:max_tables]:
col_names = ", ".join(c.name for c in meta.columns)
display_name = (
f"[dbo].[{self._mssql_table_map[name]}]"
if name in self._mssql_table_map
else name
)
key_info = ""
for line in meta.schema_info.splitlines():
if "Primary Key" in line:
key_info += " " + line.strip().replace("│ ", "")
if "FK →" in line:
key_info += " " + line.strip().replace("│ ", "") + ";"
lines.append(
f"- {name} | display: {display_name} | ({meta.row_count:,} rows){key_info}\n"
f" columns: {col_names}"
)
return "\n".join(lines)
def _fetch_mssql_keys(self, conn, original_name: str) -> dict:
"""
Fetch PK, FK, and unique constraints for a table from MS SQL.
Returns a dict with keys: primary_keys, foreign_keys, unique_keys
"""
cursor = conn.cursor()
# ── Primary Keys ──
cursor.execute("""
SELECT c.COLUMN_NAME
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c
ON tc.CONSTRAINT_NAME = c.CONSTRAINT_NAME
WHERE tc.TABLE_CATALOG = 'FCI'
AND tc.TABLE_SCHEMA = 'dbo'
AND tc.TABLE_NAME = ?
AND tc.CONSTRAINT_TYPE = 'PRIMARY KEY'
""", original_name)
primary_keys = [row[0] for row in cursor.fetchall()]
# ── Foreign Keys ──
cursor.execute("""
SELECT
kcu.COLUMN_NAME,
ccu.TABLE_NAME AS referenced_table,
ccu.COLUMN_NAME AS referenced_column
FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc
JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
ON rc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME
JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE ccu
ON rc.UNIQUE_CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
WHERE kcu.TABLE_CATALOG = 'FCI'
AND kcu.TABLE_SCHEMA = 'dbo'
AND kcu.TABLE_NAME = ?
""", original_name)
foreign_keys = [
{
"column": row[0],
"references_table": row[1],
"references_column": row[2],
}
for row in cursor.fetchall()
]
# ── Unique Constraints ──
cursor.execute("""
SELECT c.COLUMN_NAME
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c
ON tc.CONSTRAINT_NAME = c.CONSTRAINT_NAME
WHERE tc.TABLE_CATALOG = 'FCI'
AND tc.TABLE_SCHEMA = 'dbo'
AND tc.TABLE_NAME = ?
AND tc.CONSTRAINT_TYPE = 'UNIQUE'
""", original_name)
unique_keys = [row[0] for row in cursor.fetchall()]
return {
"primary_keys": primary_keys,
"foreign_keys": foreign_keys,
"unique_keys": unique_keys,
}
# ─────────────────────────────────────────────────────────────────
# Agent 0 — Table Selector
# ─────────────────────────────────────────────────────────────────
_TABLE_SELECTOR_SYSTEM = """
You are a data analyst assistant. Decide which database tables are needed to
answer the user's question.
You will be given a list of tables. Each entry shows:
- INTERNAL_NAME | display: [dbo].[ActualName] | (N rows) | columns: col1, col2, ...
IMPORTANT: In your response, use ONLY the INTERNAL_NAME exactly as shown
before the | symbol. Do NOT use the display name.
Return your response in this EXACT format — nothing else:
TABLES: internal_name1,internal_name2
REASONING: <2-3 sentences explaining which tables you chose and why>
"""
_TABLE_SELECTOR_HUMAN = """
## All Available Table Schemas
{all_schemas}
## User Question
{question}
Return TABLES and REASONING now:
"""
class TableSelectorAgent:
def __init__(self):
self._llm = ChatGoogleGenerativeAI(
model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0
)
self._chain = ChatPromptTemplate.from_messages([
("system", _TABLE_SELECTOR_SYSTEM),
("human", _TABLE_SELECTOR_HUMAN),
]) | self._llm
def select(self, all_schemas, question, available_tables, pinned_tables=None):
if pinned_tables:
return pinned_tables, "Tables manually pinned by user."
if len(available_tables) == 1:
return available_tables, f"Only one table available: '{available_tables[0]}'."
try:
response = self._chain.invoke({"all_schemas": all_schemas, "question": question})
text = response.content.strip()
t_match = re.search(r"TABLES:\s*(.+)", text, re.IGNORECASE)
r_match = re.search(r"REASONING:\s*(.+)", text, re.IGNORECASE | re.DOTALL)
raw_tables = t_match.group(1) if t_match else ""
reasoning = r_match.group(1).strip() if r_match else ""
selected = [t.strip() for t in raw_tables.split(",") if t.strip()]
valid = [t for t in selected if t in available_tables]
if valid:
return valid, reasoning
# LLM responded but returned no valid table names
raise ValueError(
f"Table selector returned no valid table names from response: {text!r}"
)
except ValueError:
raise # re-raise — let query() / generate() handle it
except Exception as exc:
raise RuntimeError(f"Table selection failed: {exc}") from exc
# ─────────────────────────────────────────────────────────────────
# Agent 1 — SQL Generator
# ─────────────────────────────────────────────────────────────────
_GENERATOR_SYSTEM = """
You are an expert SQL engineer. Convert the user's natural language question into
a valid DuckDB-compatible SQL query using the provided table schemas.
Rules:
1. Return ONLY the raw SQL — no markdown, no backticks, no explanation.
2. Use the EXACT table names and column names from the schema (case-sensitive).
3. Always qualify column names with their table name when multiple tables are involved:
table_name."Column Name"
4. Use double-quotes around column names that contain spaces or special characters.
5. You MAY use JOINs across tables when the question requires it.
6. If aggregation is needed, include GROUP BY with all non-aggregate SELECT columns.
7. Match column data types — do not compare VARCHAR columns with numeric literals.
8. Do NOT end with a semicolon.
9. ALWAYS alias every aggregated or computed column with a clean, readable name using AS.
Examples:
COUNT(*) AS total_count
SUM(orders.revenue) AS total_revenue
AVG(CAST(dmart.Price AS DOUBLE)) AS avg_price
MAX(sales.amount) AS max_amount
Never leave an aggregation or CAST expression without an alias.
10. ALWAYS alias every CAST expression, even outside aggregations.
"""
_GENERATOR_HUMAN = """
## Relevant Table Schemas
{schema}
## User Question
{question}
{error_context}
Generate the SQL query now:
"""
_ERROR_CONTEXT_TMPL = """
## Previous Attempt — FAILED (please fix)
SQL tried:
{prev_sql}
Error received:
{error}
Lint warnings:
{lint_warnings}
Use the exact column names, SQL types, and table names from the schema above to fix this.
CRITICAL: Use square brackets [ColumnName] NOT double quotes "ColumnName" for all identifiers.
"""
class SQLGeneratorAgent:
def __init__(self):
self._llm = ChatGoogleGenerativeAI(
model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0.1
)
self._chain = ChatPromptTemplate.from_messages([
("system", _GENERATOR_SYSTEM),
("human", _GENERATOR_HUMAN),
]) | self._llm
def generate(
self,
question: str,
pinned_tables: Optional[list] = None,
) -> GenerateResult:
"""
Runs Agent 0 (table selection) + Agent 1 (SQL generation).
Returns the generated SQL with lint warnings — does NOT execute it.
"""
pinned_tables = self._resolve_pinned(pinned_tables)
# Unpack tuple — discard reasoning since /generate doesn't need it
relevant, _ = self._selector.select(
all_schemas = self.loader.all_schemas_summary(),
question = question,
available_tables = self.loader.table_names(),
pinned_tables = pinned_tables,
)
schema = self.loader.schema_for_tables(relevant)
try:
sql = self._gen_chain.generate(schema=schema, question=question)
except Exception as exc:
return GenerateResult(success=False, sql="", error=str(exc))
# Lint only — no execution here
lint_warns = self._validator._lint(sql)
return GenerateResult(
success = True,
sql = sql,
tables_used = relevant,
lint_warnings = lint_warns,
)
class _SQLGeneratorChain:
def __init__(self):
self._llm = ChatGoogleGenerativeAI(
model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0.1
)
self._chain_duckdb = ChatPromptTemplate.from_messages([
("system", _GENERATOR_SYSTEM),
("human", _GENERATOR_HUMAN),
]) | self._llm
# Pre-build MS SQL system prompt once
_mssql_system = _GENERATOR_SYSTEM.replace(
"valid DuckDB-compatible SQL query",
"valid MS SQL Server (T-SQL) query"
).replace(
"Use double-quotes around column names that contain spaces or special characters.",
"Use square brackets around column names and table names: [ColumnName], NOT \"ColumnName\"."
).replace(
"Always qualify column names with their table name when multiple tables are involved:\n table_name.\"Column Name\"",
"Always qualify column names with their table name when multiple tables are involved:\n alias.[ColumnName]"
).replace(
"Do NOT end with a semicolon.",
"Do NOT end with a semicolon.\n"
"8b. NEVER use double quotes (\") around column or table names — use square brackets [Name] instead.\n"
"Use TOP N instead of LIMIT N.\n"
"Use GETDATE() instead of NOW() or CURRENT_DATE.\n"
"Use CONVERT or CAST for type conversions.\n"
"Always use [dbo].[TableName] format for table references.\n"
"Always use alias.[ColumnName] format, never alias.\"ColumnName\"."
)
self._chain_mssql = ChatPromptTemplate.from_messages([
("system", _mssql_system),
("human", _GENERATOR_HUMAN),
]) | self._llm
def generate(
self,
schema: str,
question: str,
prev_sql: Optional[str] = None,
error: Optional[str] = None,
lint_warnings: list = [],
dialect: str = "duckdb",
) -> str:
error_context = ""
if prev_sql and error:
error_context = _ERROR_CONTEXT_TMPL.format(
prev_sql = prev_sql,
error = error,
lint_warnings = "\n".join(lint_warnings) if lint_warnings else "none",
)
chain = self._chain_mssql if dialect == "mssql" else self._chain_duckdb
response = chain.invoke({
"schema": schema,
"question": question,
"error_context": error_context,
})
sql = response.content.strip()
# Safety net: fix double-quote identifiers for MS SQL
if dialect == "mssql":
sql = _sanitize_mssql_sql(sql)
return sql
# ─────────────────────────────────────────────────────────────────
# Agent 2 — SQL Validator
# ─────────────────────────────────────────────────────────────────
_VALIDATOR_SYSTEM = """
You are a SQL review assistant. Decide whether the SQL query logically answers
the user's question given the table schemas.
Respond in this exact format only:
VERDICT: PASS | FAIL
REASON: <one sentence>
Do not rewrite the SQL.
"""
_VALIDATOR_HUMAN = """
## Table Schemas
{schema}
## User Question
{question}
## SQL to Review
{sql}
"""
class SQLValidatorAgent:
"""
Three-layer validation:
Layer 1 — sqlfluff lint : syntax / style warnings (non-blocking)
Layer 2 — DuckDB execute : runtime errors (wrong column, type mismatch …)
Layer 3 — LLM semantic : does the query logically answer the question?
"""
def __init__(self, loader: DataLoader):
self._loader = loader
self._llm = ChatGoogleGenerativeAI(
model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0
)
self._chain = ChatPromptTemplate.from_messages([
("system", _VALIDATOR_SYSTEM),
("human", _VALIDATOR_HUMAN),
]) | self._llm
def _lint(self, sql: str) -> list:
try:
result = sqlfluff.lint(sql, dialect="ansi")
return [f"[{v['code']}] L{v['line_no']}: {v['description']}" for v in result]
except Exception:
return []
def _execute(self, sql: str, tables_used: list = None):
try:
# If any table is from MS SQL tunnel, run there directly
if tables_used and self._loader.uses_mssql_tables(tables_used):
df = self._loader.execute_on_mssql(sql)
else:
df = self._loader.execute(sql)
return True, df, None
except Exception as exc:
return False, None, str(exc)
def _semantic(self, schema: str, question: str, sql: str) -> tuple:
resp = self._chain.invoke({"schema": schema, "question": question, "sql": sql})
text = resp.content.strip()
vm = re.search(r"VERDICT:\s*(PASS|FAIL)", text, re.IGNORECASE)
rm = re.search(r"REASON:\s*(.+)", text, re.IGNORECASE)
ok = (vm.group(1).upper() if vm else "PASS") == "PASS"
reason= rm.group(1).strip() if rm else ""
return ok, reason
def validate(self, sql: str, schema: str, question: str, tables_used: list = None):
"""Returns (is_valid, result_df_or_None, error_or_None, lint_warnings)."""
lint_warns = self._lint(sql)
ok, result_df, exec_err = self._execute(sql, tables_used=tables_used)
if not ok:
return False, None, exec_err, lint_warns
sem_ok, sem_reason = self._semantic(schema, question, sql)
if not sem_ok:
return False, None, f"Semantic mismatch: {sem_reason}", lint_warns
return True, result_df, None, lint_warns
# ─────────────────────────────────────────────────────────────────
# Agent 3 — Visualization Metadata Generator
# ─────────────────────────────────────────────────────────────────
_VIZ_SYSTEM = """
You are a data visualization expert. Given a user's question, the SQL query
that answered it, and the resulting column names and sample data, decide the
best chart type and axis configuration.
Return ONLY this exact format — no markdown, no explanation outside the fields:
CHART_TYPE: <one of: bar, line, pie, scatter, area, heatmap, table>
TITLE: <concise chart title, max 10 words>
X_AXIS: <exact column name or alias from Result Columns to use as primary/category axis, or NONE>
Y_AXIS: <exact column name or alias from Result Columns to use as primary value axis, or NONE>
SECONDARY_AXIS: <exact column name or alias from Result Columns for secondary value axis, or NONE>
SERIES_LABEL: <exact column name or alias from Result Columns to distinguish series/legend, or NONE>
REASONING: <2-3 sentences: why this chart type, why these axes>
Chart type selection guide:
- bar : comparisons across categories, rankings, counts
- line : trends over time, continuous progression
- pie : part-to-whole when ≤ 7 slices
- scatter : correlation between two numeric variables
- area : cumulative totals or stacked proportions over time
- heatmap : two categorical dimensions with a numeric intensity
- table : raw detail, many columns, no clear chart mapping
"""
_VIZ_HUMAN = """
## User Question
{question}
## SQL Used
{sql}
## Result Columns
{columns}
## Sample Data (first 5 rows)
{sample_data}
Return the visualization metadata now:
"""
class VizMetaAgent:
def __init__(self):
self._llm = ChatGoogleGenerativeAI(
model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0
)
self._chain = ChatPromptTemplate.from_messages([
("system", _VIZ_SYSTEM),
("human", _VIZ_HUMAN),
]) | self._llm
def generate(
self,
question: str,
sql: str,
columns: list,
sample_data: list,
) -> VizMeta:
try:
response = self._chain.invoke({
"question": question,
"sql": sql,
"columns": ", ".join(columns),
"sample_data": str(sample_data[:5]),
})
text = response.content.strip()
def _parse(key):
m = re.search(rf"{key}:\s*(.+)", text, re.IGNORECASE)
v = m.group(1).strip() if m else "NONE"
return None if v.upper() == "NONE" else v
return VizMeta(
chart_type = (_parse("CHART_TYPE") or "table").lower(),
title = _parse("TITLE") or question[:60],
x_axis = _parse("X_AXIS"),
y_axis = _parse("Y_AXIS"),
secondary_axis = _parse("SECONDARY_AXIS"),
series_label = _parse("SERIES_LABEL"),
reasoning = _parse("REASONING") or "",
)
except Exception as exc:
logger.warning("VizMetaAgent failed (%s) — defaulting to table", exc)
return VizMeta(chart_type="table", title=question[:60])
# ─────────────────────────────────────────────────────────────────
# NL2SQLPipeline — the public interface imported by app.py
# ─────────────────────────────────────────────────────────────────
class NL2SQLPipeline:
"""
Facade that wires DataLoader + all three agents together.
app.py creates one instance at startup and calls its methods per request.
Public methods
──────────────
load_file(path) → load one CSV/Excel at runtime
load_mssql_tables(names) → load SQL Server tables (None = all)
tables_info() → list of table metadata dicts
generate(question, ...) → GenerateResult (SQL only, no execution)
execute_sql(sql, label) → ExecuteResult (run raw SQL, save CSV)
query(question, ...) → QueryResult (full pipeline)
"""
def __init__(self):
self.loader = DataLoader()
self._selector = TableSelectorAgent()
self._validator= SQLValidatorAgent(self.loader)
self._viz_agent = VizMetaAgent()
self._gen_chain = _SQLGeneratorChain()
# ── data loading ──────────────────────────
def load_file(self, file_path: str) -> str:
return self.loader.load_file(file_path)
def load_mssql_tunnelled_tables(self, table_names: Optional[list] = None) -> list:
return self.loader.load_mssql_tunnelled_tables(table_names)
# ── introspection ─────────────────────────
def tables_info(self) -> list:
return [self.loader.table_info_dict(n) for n in self.loader.table_names()]
def has_tables(self) -> bool:
return bool(self.loader.tables)
# ── generate (SQL only, no execution) ─────
def generate(
self,
question: str,
pinned_tables: Optional[list] = None,
) -> GenerateResult:
"""
Runs Agent 0 (table selection) + Agent 1 (SQL generation).
Returns the generated SQL with lint warnings — does NOT execute it.
"""
pinned_tables = self._resolve_pinned(pinned_tables)
try:
qdrant_summary, qdrant_candidates = build_compact_schema_from_qdrant(question)
if not qdrant_candidates:
qdrant_summary = self.loader.all_tables_summary_compact()
qdrant_candidates = self.loader.table_names()
except Exception:
qdrant_summary = self.loader.all_tables_summary_compact()
qdrant_candidates = self.loader.table_names()
try:
relevant, _ = self._selector.select(
all_schemas = qdrant_summary,
question = question,
available_tables = qdrant_candidates,
pinned_tables = pinned_tables,
)
except Exception as exc:
return GenerateResult(success=False, sql="", error=f"Table selection failed: {exc}")
if len(relevant) > 5:
relevant = relevant[:5]
schema = self.loader.schema_for_tables(relevant)
dialect = "mssql" if self.loader.uses_mssql_tables(relevant) else "duckdb"
try:
sql = self._gen_chain.generate(schema=schema, question=question, dialect=dialect)
except Exception as exc:
return GenerateResult(success=False, sql="", error=str(exc))
# Lint only — no execution here
lint_warns = self._validator._lint(sql)
return GenerateResult(
success = True,
sql = sql,
tables_used = relevant,
lint_warnings = lint_warns,
)
# ── execute (raw SQL, no LLM) ─────────────
def execute_sql(self, sql: str, label: str = "manual_query", tables_used: list = None) -> ExecuteResult:
sql = sql.rstrip(";").strip()
try:
if tables_used and self.loader.uses_mssql_tables(tables_used):
df = self.loader.execute_on_mssql(sql)
else:
df = self.loader.execute(sql)
except Exception as exc:
return ExecuteResult(success=False, sql=sql, error=str(exc))
rows, cols, data = _df_to_response(df)
csv_path = None
if not df.empty:
csv_path = str(_output_csv_path(label))
df.to_csv(csv_path, index=False)
return ExecuteResult(
success = True,
sql = sql,
rows = rows,
columns = cols,
data = data,
output_csv = csv_path,
)
# ── full pipeline (generate + validate + execute) ──
def query(self, question, pinned_tables=None, max_retries=MAX_RETRIES) -> QueryResult:
pinned_tables = self._resolve_pinned(pinned_tables)
if not pinned_tables: # don't cache pinned queries — they may differ by intent
cached = get_cached_result(question)
if cached:
return QueryResult(
success = cached.get("success", True),
final_sql = cached.get("sql", ""),
rows = cached.get("rows", 0),
columns = cached.get("columns", []),
data = cached.get("data", []),
tables_used = cached.get("tables_used", []),
attempts = cached.get("attempts", 0),
output_csv = cached.get("output_csv"),
reasoning = cached.get("reasoning", ""),
)
try:
if pinned_tables:
# If user pinned tables, use those directly — skip Qdrant
qdrant_summary = self.loader.all_tables_summary_compact()
qdrant_candidates = self.loader.table_names()
else:
qdrant_summary, qdrant_candidates = build_compact_schema_from_qdrant(question)
# Fallback: if Qdrant returns nothing, use full list
if not qdrant_candidates:
logger.warning("Qdrant returned 0 candidates — falling back to full table list")
qdrant_summary = self.loader.all_tables_summary_compact()
qdrant_candidates = self.loader.table_names()
except Exception as exc:
logger.warning("Qdrant retrieval failed (%s) — falling back to full table list", exc)
qdrant_summary = self.loader.all_tables_summary_compact()
qdrant_candidates = self.loader.table_names()
# Agent 0 — now returns (tables, reasoning)
try:
relevant, selector_reasoning = self._selector.select(
all_schemas = qdrant_summary,
question = question,
available_tables = qdrant_candidates,
pinned_tables = pinned_tables,
)
except Exception as exc:
return QueryResult(
success=False, final_sql="",
error=f"Table selection failed: {exc}",
reasoning=str(exc),
)
if len(relevant) > 5:
logger.warning("Selector returned %d tables — capping to 5", len(relevant))
relevant = relevant[:5]
# schema = self.loader.schema_for_tables(relevant)
schema = self.loader.schema_for_tables(relevant)
prev_sql = None
error = None
lint_warns = []
history = []
dialect = "mssql" if self.loader.uses_mssql_tables(relevant) else "duckdb"
for attempt in range(1, max_retries + 2):
try:
sql = self._gen_chain.generate(
schema=schema, question=question,
prev_sql=prev_sql, error=error, lint_warnings=lint_warns,dialect=dialect,
)
except Exception as exc:
return QueryResult(
success=False, final_sql="",
tables_used=relevant, attempts=attempt,
error=f"Generation failed: {exc}",
reasoning=selector_reasoning,
)
is_valid, result_df, val_error, lint_warns = self._validator.validate(
sql=sql, schema=schema, question=question, tables_used=relevant
)
history.append({
"attempt": attempt, "tables_in_scope": relevant,
"sql": sql, "valid": is_valid,
"error": val_error, "lint_warnings": lint_warns,
})
if is_valid:
rows, cols, data = _df_to_response(result_df)
csv_path = None
if not result_df.empty:
csv_path = str(_output_csv_path(question))
result_df.to_csv(csv_path, index=False)
# ── Agent 3 — Viz metadata ──
viz = self._viz_agent.generate(
question = question,
sql = sql,
columns = cols,
sample_data = data,
)
# Build column_roles map from viz metadata
column_roles = {}
if viz.x_axis: column_roles[viz.x_axis] = "x_axis"
if viz.y_axis: column_roles[viz.y_axis] = "y_axis"
if viz.secondary_axis: column_roles[viz.secondary_axis] = "secondary_axis"
if viz.series_label: column_roles[viz.series_label] = "series_label"
if not pinned_tables:
set_cached_result(question, {
"success": True,
"sql": sql,
"rows": rows,
"columns": cols,
"data": data,
"tables_used": relevant,
"attempts": attempt,
"output_csv": csv_path,
"reasoning": selector_reasoning,
})
return QueryResult(
success = True,
final_sql = sql,
rows = rows,
columns = cols,
data = data,
tables_used = relevant,
attempts = attempt,
output_csv = csv_path,
history = history,
reasoning = selector_reasoning,
viz_meta = viz,
column_roles = column_roles,
)
prev_sql = sql
error = val_error
if attempt > max_retries:
break
return QueryResult(
success=False, final_sql=prev_sql or "",
tables_used=relevant, attempts=max_retries + 1,
history=history, error=error,
reasoning=selector_reasoning,
)
# ── private helpers ───────────────────────
def _resolve_pinned(self, pinned_tables: Optional[list]) -> Optional[list]:
"""Validate pinned table names against what's actually loaded."""
if not pinned_tables:
return None
valid, unknown = self.loader.validate_table_names(pinned_tables)
if unknown:
logger.warning("Unknown pinned table(s) ignored: %s", unknown)
return valid or None
Editor is loading...