d:\chatbot\nl2sql03.py

 avatar
unknown
python
2 months ago
54 kB
6
Indexable
import os
import re
import logging
import duckdb
import pandas as pd
import sqlfluff
from pathlib import Path
from datetime import datetime
from typing import Optional
from dataclasses import dataclass, field

import pyodbc
import paramiko
from sshtunnel import SSHTunnelForwarder

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate

logger = logging.getLogger(__name__)

from dotenv import load_dotenv
load_dotenv()

from qdrant_index import (
    ensure_collections,
    index_table,
    index_nl_sql_pair,
    build_compact_schema_from_qdrant,
    retrieve_few_shot_examples,
)
from redis_cache import get_cached_result, set_cached_result, invalidate_all
 


# ─────────────────────────────────────────────────────────────────
# Config — all values read from environment variables
# ─────────────────────────────────────────────────────────────────

GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY", "AIzaSyAi-al-Cs4oL-8M_YQqSU8YVqN4lv7itqo")
GEMINI_MODEL   = os.environ.get("GEMINI_MODEL", "gemini-2.5-flash")
MAX_RETRIES    = int(os.environ.get("MAX_RETRIES", 5))
OUTPUT_DIR     = Path(os.environ.get("OUTPUT_DIR", "output"))
UPLOAD_DIR     = Path(os.environ.get("UPLOAD_DIR", "uploads"))

# ── MS SQL Server via SSH Tunnel ──
SSH_HOST      = os.environ.get("SSH_HOST",      "3.108.55.154")
SSH_USER      = os.environ.get("SSH_USER",      "ubuntu")
SSH_PKEY_PATH = os.environ.get("SSH_PKEY_PATH", "/home/shared/pcsap/pcsap-DB.pem")
MSSQL_DB_HOST = os.environ.get("MSSQL_DB_HOST", "16.0.2.77")
MSSQL_DB_USER = os.environ.get("MSSQL_DB_USER", "pcsap@readonly")
MSSQL_DB_PASS = os.environ.get("MSSQL_DB_PASS", "pcsap@123")
MSSQL_DB_NAME = os.environ.get("MSSQL_DB_NAME", "FCI")



class TunnelledMSSQL:
    """
    Manages a persistent SSH tunnel to the MS SQL Server.
    Call get_connection() each time you need a pyodbc connection.
    """
    _instance = None   # singleton

    def __init__(self):
        self._tunnel = None

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
            cls._instance = cls()
        return cls._instance

    def _ensure_tunnel(self):
        if self._tunnel and self._tunnel.is_active:
            return
        pkey = paramiko.RSAKey.from_private_key_file(SSH_PKEY_PATH)
        self._tunnel = SSHTunnelForwarder(
            (SSH_HOST, 22),
            ssh_username  = SSH_USER,
            ssh_pkey      = pkey,
            remote_bind_address = (MSSQL_DB_HOST, 1433),
            set_keepalive = 60,
        )
        self._tunnel.start()
        logger.info("SSH tunnel established on localhost:%d", self._tunnel.local_bind_port)

    def get_connection(self):
        self._ensure_tunnel()
        conn_str = (
            "DRIVER={ODBC Driver 18 for SQL Server};"
            f"SERVER=127.0.0.1,{self._tunnel.local_bind_port};"
            f"DATABASE={MSSQL_DB_NAME};"
            f"UID={MSSQL_DB_USER};"
            f"PWD={MSSQL_DB_PASS};"
            "Encrypt=yes;"
            "TrustServerCertificate=yes;"
            "Connection Timeout=30;"
        )
        return pyodbc.connect(conn_str)

    def stop(self):
        if self._tunnel:
            self._tunnel.stop()
            logger.info("SSH tunnel closed.")
# ─────────────────────────────────────────────────────────────────
# Data-classes
# ─────────────────────────────────────────────────────────────────

@dataclass
class ColumnMeta:
    name:          str
    dtype:         str       # pandas dtype string  e.g. "int64"
    sql_type:      str       # DuckDB SQL type      e.g. "BIGINT"
    nullable:      bool
    sample_values: list


@dataclass
class TableMeta:
    table_name:  str
    source:      str         # "csv" | "excel" | "mssql:<original>"
    db_name:     str         # filename or SQL Server DB name
    row_count:   int
    columns:     list        # list[ColumnMeta]
    schema_info: str         # rich text block sent to LLM


@dataclass
class GenerateResult:
    """Result of the /generate step (SQL generation only, no execution)."""
    success:       bool
    sql:           str
    tables_used:   list      = field(default_factory=list)
    lint_warnings: list      = field(default_factory=list)
    error:         Optional[str] = None


@dataclass
class ExecuteResult:
    """Result of the /execute step (raw SQL run, no LLM involved)."""
    success:    bool
    sql:        str
    rows:       int          = 0
    columns:    list         = field(default_factory=list)
    data:       list         = field(default_factory=list)   # list of dicts
    output_csv: Optional[str] = None
    error:      Optional[str] = None


@dataclass
class QueryResult:
    """Result of the full /query pipeline (generate + validate + execute)."""
    success:       bool
    final_sql:     str
    rows:          int           = 0
    columns:       list          = field(default_factory=list)
    data:          list          = field(default_factory=list)
    tables_used:   list          = field(default_factory=list)
    attempts:      int           = 0
    output_csv:    Optional[str] = None
    history:       list          = field(default_factory=list)
    error:         Optional[str] = None
    reasoning:        str            = ""   # Agent 0+1 reasoning narrative
    viz_meta:         Optional[any]  = None # VizMeta instance
    column_roles:     dict           = field(default_factory=dict)


@dataclass
class VizMeta:
    chart_type:      str          # "bar" | "line" | "pie" | "scatter" | "table" | ...
    title:           str          # e.g. "Monthly Revenue by Category"
    x_axis:          Optional[str] = None   # primary axis label
    y_axis:          Optional[str] = None   # primary value axis label
    secondary_axis:  Optional[str] = None   # only if dual-axis chart
    series_label:    Optional[str] = None   # legend / series name
    reasoning:       str          = ""      # why this chart type was chosen



# ─────────────────────────────────────────────────────────────────
# Internal helpers
# ─────────────────────────────────────────────────────────────────

_DTYPE_TO_SQL: dict = {
    "int64":           "BIGINT",
    "int32":           "INT",
    "int16":           "SMALLINT",
    "int8":            "TINYINT",
    "float64":         "FLOAT",
    "float32":         "REAL",
    "bool":            "BOOLEAN",
    "object":          "VARCHAR",
    "string":          "VARCHAR",
    "datetime64[ns]":  "TIMESTAMP",
    "date":            "DATE",
    "timedelta64[ns]": "INTERVAL",
}


def _map_sql_type(dtype) -> str:
    key = str(dtype)
    if key in _DTYPE_TO_SQL:
        return _DTYPE_TO_SQL[key]
    if key.startswith("datetime"):
        return "TIMESTAMP"
    if key.startswith("int"):
        return "BIGINT"
    if key.startswith("float"):
        return "FLOAT"
    return "VARCHAR"


def _build_schema_info(table_name: str, source: str, db_name: str, df: pd.DataFrame, keys: dict = None, include_samples: bool = True):
    columns:   list = []
    col_lines: list = []

    primary_keys = keys.get("primary_keys", []) if keys else []
    foreign_keys = keys.get("foreign_keys", []) if keys else []
    unique_keys  = keys.get("unique_keys",  []) if keys else []

    # Build FK lookup: column → (ref_table, ref_column)
    fk_map = {
        fk["column"]: (fk["references_table"], fk["references_column"])
        for fk in foreign_keys
    }

    for col in df.columns:
        series      = df[col]
        sql_type    = _map_sql_type(series.dtype)
        nullable    = bool(series.isna().any())
        sample_vals = series.dropna().head(2).tolist() if include_samples else []

        # Build key tags
        tags = []
        if col in primary_keys: tags.append("PK")
        if col in fk_map:       tags.append(f"FK → {fk_map[col][0]}.{fk_map[col][1]}")
        if col in unique_keys:  tags.append("UNIQUE")
        tag_str = f"  [{', '.join(tags)}]" if tags else ""

        columns.append(ColumnMeta(
            name          = col,
            dtype         = str(series.dtype),
            sql_type      = sql_type,
            nullable      = nullable,
            sample_values = sample_vals,
        ))
        null_tag = "NULLABLE" if nullable else "NOT NULL"
        if include_samples and sample_vals:
            col_lines.append(
                f'    "{col}"  {sql_type}  [{null_tag}]{tag_str}  -- e.g. {sample_vals}'
            )
        else:
            col_lines.append(
                f'    "{col}"  {sql_type}  [{null_tag}]{tag_str}'
            )

    # Build FK summary block
    fk_lines = []
    for fk in foreign_keys:
        fk_lines.append(
            f"│    {fk['column']} → {fk['references_table']}.{fk['references_column']}"
        )
    fk_block = (
        "│  Foreign Keys:\n" + "\n".join(fk_lines) + "\n"
        if fk_lines else ""
    )

    pk_block = (
        f"│  Primary Key  : {', '.join(primary_keys)}\n"
        if primary_keys else ""
    )

    schema_info = (
        f"┌─ Table        : {table_name}\n"
        f"│  Source       : {source}\n"
        f"│  Database/File: {db_name}\n"
        f"│  Row count    : {len(df):,}\n"
        + pk_block
        + fk_block
        + f"│  Columns ({len(columns)}):\n"
        + "\n".join(col_lines)
        + "\n└─"
    )
    return columns, schema_info


def _safe_name(raw: str) -> str:
    return re.sub(r"\W+", "_", raw).lower().strip("_")


def _unique_name(name: str, existing: set) -> str:
    candidate = name
    counter   = 1
    while candidate in existing:
        candidate = f"{name}_{counter}"
        counter  += 1
    return candidate


def _ensure_dir(path: Path) -> Path:
    path.mkdir(parents=True, exist_ok=True)
    return path


def _output_csv_path(label: str) -> Path:
    slug      = re.sub(r"\W+", "_", label[:40]).strip("_").lower()
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    return _ensure_dir(OUTPUT_DIR) / f"{slug}__{timestamp}.csv"


def _df_to_response(df: pd.DataFrame) -> tuple:
    """Convert a DataFrame to (rows, columns, data) ready for JSON serialisation."""
    # Replace NaN/NaT with None so json.dumps doesn't choke
    clean = df.where(pd.notnull(df), other=None)
    return len(clean), list(clean.columns), clean.to_dict(orient="records")

def _sanitize_mssql_sql(sql: str) -> str:
    """
    Convert DuckDB-style double-quoted identifiers to MS SQL bracket style.
    e.g.  T1."QuestionName"  →  T1.[QuestionName]
          "Id"               →  [Id]
    Also strips trailing semicolons.
    """
    # Replace "identifier" with [identifier] — but NOT string literals
    # This regex matches double-quoted words that are used as identifiers
    # (i.e., not preceded by = or LIKE or IN — those are string values)
    sanitized = re.sub(
        r'(?<![=<>!,\(\s])(\."?)([A-Za-z_][A-Za-z0-9_ ]*)"',
        lambda m: f'.[{m.group(2).lstrip(chr(34))}]',
        sql
    )
    # Also handle standalone "ColumnName" not preceded by a dot
    sanitized = re.sub(
        r'(?<!\.)(?<![=<>!(,\s])"([A-Za-z_][A-Za-z0-9_ ]*)"',
        r'[\1]',
        sanitized
    )
    return sanitized.rstrip(";").strip()

# ─────────────────────────────────────────────────────────────────
# DataLoader
# ─────────────────────────────────────────────────────────────────

class DataLoader:
    """
    Single in-memory DuckDB instance that holds all loaded tables.
    Sources: CSV / Excel files   +   SQL Server tables.
    Cross-table JOINs work freely because everything shares one connection.
    """

    def __init__(self):
        self.conn   = duckdb.connect(database=":memory:")
        self.tables: dict[str, TableMeta] = {}
        self._mssql_table_map: dict[str, str] = {}

    # ── registration ──────────────────────────

    def _register_df(self, df: pd.DataFrame, table_name: str, source: str, db_name: str):
        self.conn.register(table_name, df)
        columns, schema_info = _build_schema_info(table_name, source, db_name, df)
        self.tables[table_name] = TableMeta(
            table_name  = table_name,
            source      = source,
            db_name     = db_name,
            row_count   = len(df),
            columns     = columns,
            schema_info = schema_info,
        )

        try:
            index_table(
                internal_name = table_name,
                display_name  = table_name,
                schema_info   = schema_info,
                row_count     = len(df),
            )
        except Exception as exc:
            logger.warning("Qdrant index failed for '%s': %s — continuing without it", table_name, exc)
 
 
        logger.info("Loaded table '%s' — %d rows × %d cols [%s]",
                    table_name, len(df), len(df.columns), source)

    # ── CSV / Excel ───────────────────────────

    def load_file(self, file_path: str) -> str:
        """Load one CSV or Excel file. Returns the DuckDB table name assigned."""
        path = Path(file_path)
        ext  = path.suffix.lower()
        if ext == ".csv":
            df, source = pd.read_csv(path), "csv"
        elif ext in (".xlsx", ".xls"):
            df, source = pd.read_excel(path), "excel"
        else:
            raise ValueError(f"Unsupported file type: {ext}. Use .csv / .xlsx / .xls")

        table_name = _unique_name(_safe_name(path.stem), set(self.tables))
        self._register_df(df, table_name, source, db_name=path.name)
        return table_name

    def load_files(self, file_paths: list) -> list:
        return [self.load_file(fp) for fp in file_paths]

    # ── SQL Server ────────────────────────────





    # ── MS SQL Server via SSH Tunnel ──────────────

    def _list_mssql_tunnelled_tables(self, conn) -> list:
        cursor = conn.cursor()
        cursor.execute(
            "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES "
            "WHERE TABLE_TYPE = 'BASE TABLE' "
            "AND TABLE_CATALOG = 'FCI' "
            "AND TABLE_SCHEMA  = 'dbo' "
            "ORDER BY TABLE_NAME"
        )
        return [row[0] for row in cursor.fetchall()]
    
    def _build_mssql_schema_info(self, conn, original_name: str) -> tuple:
        """
        Fetch only schema + sample rows from MS SQL — NO full data load.
        Registers a dummy empty DataFrame in DuckDB just so the pipeline
        knows the table exists and its columns.
        """
        # Get column info
        cursor = conn.cursor()
        cursor.execute(f"""
            SELECT COLUMN_NAME, DATA_TYPE, IS_NULLABLE
            FROM INFORMATION_SCHEMA.COLUMNS
            WHERE TABLE_CATALOG = 'FCI'
            AND TABLE_SCHEMA  = 'dbo'
            AND TABLE_NAME    = ?
            ORDER BY ORDINAL_POSITION
        """, original_name)
        col_rows = cursor.fetchall()

        # Get row count
        cursor.execute(f"SELECT COUNT(*) FROM [dbo].[{original_name}]")
        row_count = cursor.fetchone()[0]

        # Get 3 sample rows for schema context
        sample_df = pd.read_sql(
            f"SELECT TOP 3 * FROM [dbo].[{original_name}]", conn
        )

        return col_rows, row_count, sample_df

    def load_mssql_tunnelled_table(self, original_name: str, conn) -> str:
        col_rows, row_count, sample_df = self._build_mssql_schema_info(conn, original_name)

        # ── Fetch key relationships ──
        try:
            keys = self._fetch_mssql_keys(conn, original_name)
        except Exception as exc:
            logger.warning("Could not fetch keys for '%s': %s", original_name, exc)
            keys = {"primary_keys": [], "foreign_keys": [], "unique_keys": []}

        name = _unique_name(_safe_name(original_name), set(self.tables))
        self.conn.register(name, sample_df.iloc[0:0])

        # Pass keys into schema builder
        columns, schema_info = _build_schema_info(
            f"[dbo].[{original_name}]",   # ← LLM sees this as the table name
            f"mssql_tunnel:{original_name}",
            MSSQL_DB_NAME,
            sample_df,
            keys=keys,
            include_samples=False
        )

        schema_info = re.sub(
            r"│  Row count    : [\d,]+",
            f"│  Row count    : {row_count:,}",
            schema_info
        )

        self.tables[name] = TableMeta(
            table_name  = name,
            source      = f"mssql_tunnel:{original_name}",
            db_name     = MSSQL_DB_NAME,
            row_count   = row_count,
            columns     = columns,
            schema_info = schema_info,
        )
        self._mssql_table_map[name] = original_name

        logger.info("Registered MS SQL schema '%s' — %d rows, %d FKs",
                    name, row_count, len(keys["foreign_keys"]))
        try:
            display = f"[dbo].[{original_name}]"
            index_table(
                internal_name = name,
                display_name  = display,
                schema_info   = schema_info,
                row_count     = row_count,
            )
        except Exception as exc:
            logger.warning("Qdrant index failed for '%s': %s — continuing without it", name, exc)
 
        return name  

        

    def load_mssql_tunnelled_tables(self, table_names: Optional[list] = None) -> list:
        """
        Load MS SQL Server tables over SSH tunnel.
        Pass None to auto-discover and load ALL tables.
        """
        tunnelled_db = TunnelledMSSQL.get_instance()
        conn         = tunnelled_db.get_connection()

        if not table_names:
            logger.info("Discovering all tables in MS SQL DB '%s'", MSSQL_DB_NAME)
            table_names = self._list_mssql_tunnelled_tables(conn)
            logger.info("Found %d table(s): %s", len(table_names), table_names)

        results = []
        for t in table_names:
            try:
                name = self.load_mssql_tunnelled_table(t, conn)
                results.append(name)
            except Exception as exc:
                logger.warning("Skipping table '%s': %s", t, exc)

        conn.close()
        return results
    
    def execute_on_mssql(self, sql: str) -> pd.DataFrame:
        """
        Run SQL directly on MS SQL Server via tunnel.
        No translation needed — LLM generates [dbo].[TableName] directly for mssql dialect.
        """
        tunnelled_db = TunnelledMSSQL.get_instance()
        conn = tunnelled_db.get_connection()
        try:
            return pd.read_sql(sql, conn)
        finally:
            conn.close()

    def is_mssql_table(self, name: str) -> bool:
        return name in self._mssql_table_map

    def uses_mssql_tables(self, table_names: list) -> bool:
        return any(self.is_mssql_table(n) for n in table_names)

    # ── query / introspection ─────────────────

    def execute(self, sql: str) -> pd.DataFrame:
        return self.conn.execute(sql).df()

    def schema_for_tables(self, names: list) -> str:
        parts = [self.tables[n].schema_info for n in names if n in self.tables]
        return "\n\n".join(parts) or "(no matching tables)"

    def all_schemas_summary(self) -> str:
        return self.schema_for_tables(list(self.tables.keys()))

    def table_names(self) -> list:
        return list(self.tables.keys())

    def validate_table_names(self, names: list) -> tuple:
        valid   = [n for n in names if n in self.tables]
        unknown = [n for n in names if n not in self.tables]
        return valid, unknown

    def table_info_dict(self, table_name: str) -> dict:
        """Return a JSON-serialisable dict describing one table."""
        meta = self.tables[table_name]
        return {
            "table_name": meta.table_name,
            "source":     meta.source,
            "db_name":    meta.db_name,
            "row_count":  meta.row_count,
            "columns": [
                {
                    "name":          c.name,
                    "sql_type":      c.sql_type,
                    "dtype":         c.dtype,
                    "nullable":      c.nullable,
                    "sample_values": [str(v) for v in c.sample_values],
                }
                for c in meta.columns
            ],
        }
    
    def all_tables_summary_compact(self, max_tables: int = 200) -> str:
        lines = []
        for name, meta in list(self.tables.items())[:max_tables]:
            col_names = ", ".join(c.name for c in meta.columns)

            display_name = (
                f"[dbo].[{self._mssql_table_map[name]}]"
                if name in self._mssql_table_map
                else name
            )

            key_info = ""
            for line in meta.schema_info.splitlines():
                if "Primary Key" in line:
                    key_info += "  " + line.strip().replace("│  ", "")
                if "FK →" in line:
                    key_info += "  " + line.strip().replace("│    ", "") + ";"

            lines.append(
                f"- {name} | display: {display_name} | ({meta.row_count:,} rows){key_info}\n"
                f"  columns: {col_names}"
            )
        return "\n".join(lines)
    
    def _fetch_mssql_keys(self, conn, original_name: str) -> dict:
        """
        Fetch PK, FK, and unique constraints for a table from MS SQL.
        Returns a dict with keys: primary_keys, foreign_keys, unique_keys
        """
        cursor = conn.cursor()

        # ── Primary Keys ──
        cursor.execute("""
            SELECT c.COLUMN_NAME
            FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
            JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c
            ON tc.CONSTRAINT_NAME = c.CONSTRAINT_NAME
            WHERE tc.TABLE_CATALOG = 'FCI'
            AND tc.TABLE_SCHEMA  = 'dbo'
            AND tc.TABLE_NAME    = ?
            AND tc.CONSTRAINT_TYPE = 'PRIMARY KEY'
        """, original_name)
        primary_keys = [row[0] for row in cursor.fetchall()]

        # ── Foreign Keys ──
        cursor.execute("""
            SELECT 
                kcu.COLUMN_NAME,
                ccu.TABLE_NAME  AS referenced_table,
                ccu.COLUMN_NAME AS referenced_column
            FROM INFORMATION_SCHEMA.REFERENTIAL_CONSTRAINTS rc
            JOIN INFORMATION_SCHEMA.KEY_COLUMN_USAGE kcu
            ON rc.CONSTRAINT_NAME = kcu.CONSTRAINT_NAME
            JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE ccu
            ON rc.UNIQUE_CONSTRAINT_NAME = ccu.CONSTRAINT_NAME
            WHERE kcu.TABLE_CATALOG = 'FCI'
            AND kcu.TABLE_SCHEMA  = 'dbo'
            AND kcu.TABLE_NAME    = ?
        """, original_name)
        foreign_keys = [
            {
                "column":             row[0],
                "references_table":   row[1],
                "references_column":  row[2],
            }
            for row in cursor.fetchall()
        ]

        # ── Unique Constraints ──
        cursor.execute("""
            SELECT c.COLUMN_NAME
            FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS tc
            JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c
            ON tc.CONSTRAINT_NAME = c.CONSTRAINT_NAME
            WHERE tc.TABLE_CATALOG = 'FCI'
            AND tc.TABLE_SCHEMA  = 'dbo'
            AND tc.TABLE_NAME    = ?
            AND tc.CONSTRAINT_TYPE = 'UNIQUE'
        """, original_name)
        unique_keys = [row[0] for row in cursor.fetchall()]

        return {
            "primary_keys": primary_keys,
            "foreign_keys": foreign_keys,
            "unique_keys":  unique_keys,
        }


# ─────────────────────────────────────────────────────────────────
# Agent 0 — Table Selector
# ─────────────────────────────────────────────────────────────────

_TABLE_SELECTOR_SYSTEM = """
You are a data analyst assistant. Decide which database tables are needed to
answer the user's question.

You will be given a list of tables. Each entry shows:
- INTERNAL_NAME | display: [dbo].[ActualName] | (N rows) | columns: col1, col2, ...

IMPORTANT: In your response, use ONLY the INTERNAL_NAME exactly as shown 
before the | symbol. Do NOT use the display name.

Return your response in this EXACT format — nothing else:

TABLES: internal_name1,internal_name2
REASONING: <2-3 sentences explaining which tables you chose and why>
"""

_TABLE_SELECTOR_HUMAN = """
## All Available Table Schemas
{all_schemas}

## User Question
{question}

Return TABLES and REASONING now:
"""

class TableSelectorAgent:
    def __init__(self):
        self._llm    = ChatGoogleGenerativeAI(
            model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0
        )
        self._chain  = ChatPromptTemplate.from_messages([
            ("system", _TABLE_SELECTOR_SYSTEM),
            ("human",  _TABLE_SELECTOR_HUMAN),
        ]) | self._llm

    def select(self, all_schemas, question, available_tables, pinned_tables=None):
        if pinned_tables:
            return pinned_tables, "Tables manually pinned by user."

        if len(available_tables) == 1:
            return available_tables, f"Only one table available: '{available_tables[0]}'."

        try:
            response  = self._chain.invoke({"all_schemas": all_schemas, "question": question})
            text      = response.content.strip()

            t_match   = re.search(r"TABLES:\s*(.+)", text, re.IGNORECASE)
            r_match   = re.search(r"REASONING:\s*(.+)", text, re.IGNORECASE | re.DOTALL)

            raw_tables = t_match.group(1) if t_match else ""
            reasoning  = r_match.group(1).strip() if r_match else ""

            selected = [t.strip() for t in raw_tables.split(",") if t.strip()]
            valid    = [t for t in selected if t in available_tables]

            if valid:
                return valid, reasoning

            # LLM responded but returned no valid table names
            raise ValueError(
                f"Table selector returned no valid table names from response: {text!r}"
            )

        except ValueError:
            raise   # re-raise — let query() / generate() handle it

        except Exception as exc:
            raise RuntimeError(f"Table selection failed: {exc}") from exc


# ─────────────────────────────────────────────────────────────────
# Agent 1 — SQL Generator
# ─────────────────────────────────────────────────────────────────

_GENERATOR_SYSTEM = """
You are an expert SQL engineer. Convert the user's natural language question into
a valid DuckDB-compatible SQL query using the provided table schemas.

Rules:
1. Return ONLY the raw SQL — no markdown, no backticks, no explanation.
2. Use the EXACT table names and column names from the schema (case-sensitive).
3. Always qualify column names with their table name when multiple tables are involved:
   table_name."Column Name"
4. Use double-quotes around column names that contain spaces or special characters.
5. You MAY use JOINs across tables when the question requires it.
6. If aggregation is needed, include GROUP BY with all non-aggregate SELECT columns.
7. Match column data types — do not compare VARCHAR columns with numeric literals.
8. Do NOT end with a semicolon.
9. ALWAYS alias every aggregated or computed column with a clean, readable name using AS.
   Examples:
     COUNT(*)                        AS total_count
     SUM(orders.revenue)             AS total_revenue
     AVG(CAST(dmart.Price AS DOUBLE)) AS avg_price
     MAX(sales.amount)               AS max_amount
   Never leave an aggregation or CAST expression without an alias.
10. ALWAYS alias every CAST expression, even outside aggregations.
"""

_GENERATOR_HUMAN = """
## Relevant Table Schemas
{schema}

## User Question
{question}

{error_context}

Generate the SQL query now:
"""

_ERROR_CONTEXT_TMPL = """
## Previous Attempt — FAILED (please fix)
SQL tried:
{prev_sql}

Error received:
{error}

Lint warnings:
{lint_warnings}

Use the exact column names, SQL types, and table names from the schema above to fix this.
CRITICAL: Use square brackets [ColumnName] NOT double quotes "ColumnName" for all identifiers.
"""


class SQLGeneratorAgent:
    def __init__(self):
        self._llm   = ChatGoogleGenerativeAI(
            model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0.1
        )
        self._chain = ChatPromptTemplate.from_messages([
            ("system", _GENERATOR_SYSTEM),
            ("human",  _GENERATOR_HUMAN),
        ]) | self._llm

    def generate(
        self,
        question:      str,
        pinned_tables: Optional[list] = None,
    ) -> GenerateResult:
        """
        Runs Agent 0 (table selection) + Agent 1 (SQL generation).
        Returns the generated SQL with lint warnings — does NOT execute it.
        """
        pinned_tables = self._resolve_pinned(pinned_tables)

        # Unpack tuple — discard reasoning since /generate doesn't need it
        relevant, _ = self._selector.select(
            all_schemas      = self.loader.all_schemas_summary(),
            question         = question,
            available_tables = self.loader.table_names(),
            pinned_tables    = pinned_tables,
        )
        schema = self.loader.schema_for_tables(relevant)

        try:
            sql = self._gen_chain.generate(schema=schema, question=question)
        except Exception as exc:
            return GenerateResult(success=False, sql="", error=str(exc))

        # Lint only — no execution here
        lint_warns = self._validator._lint(sql)

        return GenerateResult(
            success       = True,
            sql           = sql,
            tables_used   = relevant,
            lint_warnings = lint_warns,
        )


class _SQLGeneratorChain:
    def __init__(self):
        self._llm = ChatGoogleGenerativeAI(
            model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0.1
        )
        self._chain_duckdb = ChatPromptTemplate.from_messages([
            ("system", _GENERATOR_SYSTEM),
            ("human",  _GENERATOR_HUMAN),
        ]) | self._llm

        # Pre-build MS SQL system prompt once
        _mssql_system = _GENERATOR_SYSTEM.replace(
            "valid DuckDB-compatible SQL query",
            "valid MS SQL Server (T-SQL) query"
        ).replace(
            "Use double-quotes around column names that contain spaces or special characters.",
            "Use square brackets around column names and table names: [ColumnName], NOT \"ColumnName\"."
        ).replace(
            "Always qualify column names with their table name when multiple tables are involved:\n   table_name.\"Column Name\"",
            "Always qualify column names with their table name when multiple tables are involved:\n   alias.[ColumnName]"
        ).replace(
            "Do NOT end with a semicolon.",
            "Do NOT end with a semicolon.\n"
            "8b. NEVER use double quotes (\") around column or table names — use square brackets [Name] instead.\n"
            "Use TOP N instead of LIMIT N.\n"
            "Use GETDATE() instead of NOW() or CURRENT_DATE.\n"
            "Use CONVERT or CAST for type conversions.\n"
            "Always use [dbo].[TableName] format for table references.\n"
            "Always use alias.[ColumnName] format, never alias.\"ColumnName\"."
        )
        self._chain_mssql = ChatPromptTemplate.from_messages([
            ("system", _mssql_system),
            ("human",  _GENERATOR_HUMAN),
        ]) | self._llm

    def generate(
        self,
        schema:        str,
        question:      str,
        prev_sql:      Optional[str] = None,
        error:         Optional[str] = None,
        lint_warnings: list          = [],
        dialect:       str           = "duckdb",
    ) -> str:
        error_context = ""
        if prev_sql and error:
            error_context = _ERROR_CONTEXT_TMPL.format(
                prev_sql      = prev_sql,
                error         = error,
                lint_warnings = "\n".join(lint_warnings) if lint_warnings else "none",
            )

        chain = self._chain_mssql if dialect == "mssql" else self._chain_duckdb

        response = chain.invoke({
            "schema":        schema,
            "question":      question,
            "error_context": error_context,
        })

        sql = response.content.strip()

        # Safety net: fix double-quote identifiers for MS SQL
        if dialect == "mssql":
            sql = _sanitize_mssql_sql(sql)

        return sql
    

# ─────────────────────────────────────────────────────────────────
# Agent 2 — SQL Validator
# ─────────────────────────────────────────────────────────────────

_VALIDATOR_SYSTEM = """
You are a SQL review assistant. Decide whether the SQL query logically answers
the user's question given the table schemas.

Respond in this exact format only:
VERDICT: PASS | FAIL
REASON: <one sentence>

Do not rewrite the SQL.
"""

_VALIDATOR_HUMAN = """
## Table Schemas
{schema}

## User Question
{question}

## SQL to Review
{sql}
"""


class SQLValidatorAgent:
    """
    Three-layer validation:
      Layer 1 — sqlfluff lint   : syntax / style warnings (non-blocking)
      Layer 2 — DuckDB execute  : runtime errors (wrong column, type mismatch …)
      Layer 3 — LLM semantic    : does the query logically answer the question?
    """

    def __init__(self, loader: DataLoader):
        self._loader = loader
        self._llm    = ChatGoogleGenerativeAI(
            model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0
        )
        self._chain  = ChatPromptTemplate.from_messages([
            ("system", _VALIDATOR_SYSTEM),
            ("human",  _VALIDATOR_HUMAN),
        ]) | self._llm

    def _lint(self, sql: str) -> list:
        try:
            result = sqlfluff.lint(sql, dialect="ansi")
            return [f"[{v['code']}] L{v['line_no']}: {v['description']}" for v in result]
        except Exception:
            return []

    def _execute(self, sql: str, tables_used: list = None):
        try:
            # If any table is from MS SQL tunnel, run there directly
            if tables_used and self._loader.uses_mssql_tables(tables_used):
                df = self._loader.execute_on_mssql(sql)
            else:
                df = self._loader.execute(sql)
            return True, df, None
        except Exception as exc:
            return False, None, str(exc)

    def _semantic(self, schema: str, question: str, sql: str) -> tuple:
        resp  = self._chain.invoke({"schema": schema, "question": question, "sql": sql})
        text  = resp.content.strip()
        vm    = re.search(r"VERDICT:\s*(PASS|FAIL)", text, re.IGNORECASE)
        rm    = re.search(r"REASON:\s*(.+)",         text, re.IGNORECASE)
        ok    = (vm.group(1).upper() if vm else "PASS") == "PASS"
        reason= rm.group(1).strip() if rm else ""
        return ok, reason

    def validate(self, sql: str, schema: str, question: str, tables_used: list = None):
        """Returns (is_valid, result_df_or_None, error_or_None, lint_warnings)."""
        lint_warns              = self._lint(sql)
        ok, result_df, exec_err = self._execute(sql, tables_used=tables_used)

        if not ok:
            return False, None, exec_err, lint_warns

        sem_ok, sem_reason = self._semantic(schema, question, sql)
        if not sem_ok:
            return False, None, f"Semantic mismatch: {sem_reason}", lint_warns

        return True, result_df, None, lint_warns




# ─────────────────────────────────────────────────────────────────
# Agent 3 — Visualization Metadata Generator
# ─────────────────────────────────────────────────────────────────

_VIZ_SYSTEM = """
You are a data visualization expert. Given a user's question, the SQL query
that answered it, and the resulting column names and sample data, decide the
best chart type and axis configuration.

Return ONLY this exact format — no markdown, no explanation outside the fields:

CHART_TYPE: <one of: bar, line, pie, scatter, area, heatmap, table>
TITLE: <concise chart title, max 10 words>
X_AXIS: <exact column name or alias from Result Columns to use as primary/category axis, or NONE>
Y_AXIS: <exact column name or alias from Result Columns to use as primary value axis, or NONE>
SECONDARY_AXIS: <exact column name or alias from Result Columns for secondary value axis, or NONE>
SERIES_LABEL: <exact column name or alias from Result Columns to distinguish series/legend, or NONE>
REASONING: <2-3 sentences: why this chart type, why these axes>

Chart type selection guide:
- bar        : comparisons across categories, rankings, counts
- line       : trends over time, continuous progression
- pie        : part-to-whole when ≤ 7 slices
- scatter    : correlation between two numeric variables
- area       : cumulative totals or stacked proportions over time
- heatmap    : two categorical dimensions with a numeric intensity
- table      : raw detail, many columns, no clear chart mapping
"""

_VIZ_HUMAN = """
## User Question
{question}

## SQL Used
{sql}

## Result Columns
{columns}

## Sample Data (first 5 rows)
{sample_data}

Return the visualization metadata now:
"""


class VizMetaAgent:
    def __init__(self):
        self._llm   = ChatGoogleGenerativeAI(
            model=GEMINI_MODEL, google_api_key=GEMINI_API_KEY, temperature=0
        )
        self._chain = ChatPromptTemplate.from_messages([
            ("system", _VIZ_SYSTEM),
            ("human",  _VIZ_HUMAN),
        ]) | self._llm

    def generate(
        self,
        question:    str,
        sql:         str,
        columns:     list,
        sample_data: list,
    ) -> VizMeta:
        try:
            response = self._chain.invoke({
                "question":    question,
                "sql":         sql,
                "columns":     ", ".join(columns),
                "sample_data": str(sample_data[:5]),
            })
            text = response.content.strip()

            def _parse(key):
                m = re.search(rf"{key}:\s*(.+)", text, re.IGNORECASE)
                v = m.group(1).strip() if m else "NONE"
                return None if v.upper() == "NONE" else v

            return VizMeta(
                chart_type     = (_parse("CHART_TYPE") or "table").lower(),
                title          = _parse("TITLE") or question[:60],
                x_axis         = _parse("X_AXIS"),
                y_axis         = _parse("Y_AXIS"),
                secondary_axis = _parse("SECONDARY_AXIS"),
                series_label   = _parse("SERIES_LABEL"),
                reasoning      = _parse("REASONING") or "",
            )
        except Exception as exc:
            logger.warning("VizMetaAgent failed (%s) — defaulting to table", exc)
            return VizMeta(chart_type="table", title=question[:60])

# ─────────────────────────────────────────────────────────────────
# NL2SQLPipeline  — the public interface imported by app.py
# ─────────────────────────────────────────────────────────────────

class NL2SQLPipeline:
    """
    Facade that wires DataLoader + all three agents together.
    app.py creates one instance at startup and calls its methods per request.

    Public methods
    ──────────────
    load_file(path)              → load one CSV/Excel at runtime
    load_mssql_tables(names)     → load SQL Server tables (None = all)
    tables_info()                → list of table metadata dicts
    generate(question, ...)      → GenerateResult  (SQL only, no execution)
    execute_sql(sql, label)      → ExecuteResult   (run raw SQL, save CSV)
    query(question, ...)         → QueryResult     (full pipeline)
    """

    def __init__(self):
        self.loader    = DataLoader()
        self._selector = TableSelectorAgent()
        self._validator= SQLValidatorAgent(self.loader)
        self._viz_agent = VizMetaAgent()
        self._gen_chain = _SQLGeneratorChain()

    # ── data loading ──────────────────────────

    def load_file(self, file_path: str) -> str:
        return self.loader.load_file(file_path)

    
    def load_mssql_tunnelled_tables(self, table_names: Optional[list] = None) -> list:
        return self.loader.load_mssql_tunnelled_tables(table_names)

    # ── introspection ─────────────────────────

    def tables_info(self) -> list:
        return [self.loader.table_info_dict(n) for n in self.loader.table_names()]

    def has_tables(self) -> bool:
        return bool(self.loader.tables)

    # ── generate (SQL only, no execution) ─────

    def generate(
        self,
        question:      str,
        pinned_tables: Optional[list] = None,
    ) -> GenerateResult:
        """
        Runs Agent 0 (table selection) + Agent 1 (SQL generation).
        Returns the generated SQL with lint warnings — does NOT execute it.
        """
        pinned_tables = self._resolve_pinned(pinned_tables)
        try:
            qdrant_summary, qdrant_candidates = build_compact_schema_from_qdrant(question)
            if not qdrant_candidates:
                qdrant_summary    = self.loader.all_tables_summary_compact()
                qdrant_candidates = self.loader.table_names()
        except Exception:
            qdrant_summary    = self.loader.all_tables_summary_compact()
            qdrant_candidates = self.loader.table_names()
 
        try:
            relevant, _ = self._selector.select(
                all_schemas      = qdrant_summary,
                question         = question,
                available_tables = qdrant_candidates,
                pinned_tables    = pinned_tables,
            )
        except Exception as exc:
            return GenerateResult(success=False, sql="", error=f"Table selection failed: {exc}")

        if len(relevant) > 5:
            relevant = relevant[:5]
            
        schema = self.loader.schema_for_tables(relevant)

        dialect = "mssql" if self.loader.uses_mssql_tables(relevant) else "duckdb"
        try:
            sql = self._gen_chain.generate(schema=schema, question=question, dialect=dialect)

       
        except Exception as exc:
            return GenerateResult(success=False, sql="", error=str(exc))

        # Lint only — no execution here
        lint_warns = self._validator._lint(sql)

        return GenerateResult(
            success       = True,
            sql           = sql,
            tables_used   = relevant,
            lint_warnings = lint_warns,
        )

    # ── execute (raw SQL, no LLM) ─────────────

    def execute_sql(self, sql: str, label: str = "manual_query", tables_used: list = None) -> ExecuteResult:
        sql = sql.rstrip(";").strip()
        try:
            if tables_used and self.loader.uses_mssql_tables(tables_used):
                df = self.loader.execute_on_mssql(sql)
            else:
                df = self.loader.execute(sql)
        except Exception as exc:
            return ExecuteResult(success=False, sql=sql, error=str(exc))

        rows, cols, data = _df_to_response(df)
        csv_path         = None
        if not df.empty:
            csv_path = str(_output_csv_path(label))
            df.to_csv(csv_path, index=False)

        return ExecuteResult(
            success    = True,
            sql        = sql,
            rows       = rows,
            columns    = cols,
            data       = data,
            output_csv = csv_path,
        )

    # ── full pipeline (generate + validate + execute) ──

    def query(self, question, pinned_tables=None, max_retries=MAX_RETRIES) -> QueryResult:
        pinned_tables = self._resolve_pinned(pinned_tables)

        if not pinned_tables:   # don't cache pinned queries — they may differ by intent
            cached = get_cached_result(question)
            if cached:
                return QueryResult(
                    success      = cached.get("success", True),
                    final_sql    = cached.get("sql", ""),
                    rows         = cached.get("rows", 0),
                    columns      = cached.get("columns", []),
                    data         = cached.get("data", []),
                    tables_used  = cached.get("tables_used", []),
                    attempts     = cached.get("attempts", 0),
                    output_csv   = cached.get("output_csv"),
                    reasoning    = cached.get("reasoning", ""),
                )
            
        try:
            if pinned_tables:
                # If user pinned tables, use those directly — skip Qdrant
                qdrant_summary = self.loader.all_tables_summary_compact()
                qdrant_candidates = self.loader.table_names()
            else:
                qdrant_summary, qdrant_candidates = build_compact_schema_from_qdrant(question)
                # Fallback: if Qdrant returns nothing, use full list
                if not qdrant_candidates:
                    logger.warning("Qdrant returned 0 candidates — falling back to full table list")
                    qdrant_summary    = self.loader.all_tables_summary_compact()
                    qdrant_candidates = self.loader.table_names()
        except Exception as exc:
            logger.warning("Qdrant retrieval failed (%s) — falling back to full table list", exc)
            qdrant_summary    = self.loader.all_tables_summary_compact()
            qdrant_candidates = self.loader.table_names()



        # Agent 0 — now returns (tables, reasoning)
        try:
            relevant, selector_reasoning = self._selector.select(
                all_schemas      = qdrant_summary,           
                question         = question,
                available_tables = qdrant_candidates,       
                pinned_tables    = pinned_tables,
            )
        except Exception as exc:
            return QueryResult(
                success=False, final_sql="",
                error=f"Table selection failed: {exc}",
                reasoning=str(exc),
            )


        if len(relevant) > 5:
            logger.warning("Selector returned %d tables — capping to 5", len(relevant))
            relevant = relevant[:5]

        # schema = self.loader.schema_for_tables(relevant)
        schema = self.loader.schema_for_tables(relevant)

        prev_sql   = None
        error      = None
        lint_warns = []
        history    = []

        dialect = "mssql" if self.loader.uses_mssql_tables(relevant) else "duckdb"

        for attempt in range(1, max_retries + 2):
            try:
                sql = self._gen_chain.generate(
                        schema=schema, question=question,
                        prev_sql=prev_sql, error=error, lint_warnings=lint_warns,dialect=dialect,
                    )
            except Exception as exc:
                return QueryResult(
                    success=False, final_sql="",
                    tables_used=relevant, attempts=attempt,
                    error=f"Generation failed: {exc}",
                    reasoning=selector_reasoning,
                )

            is_valid, result_df, val_error, lint_warns = self._validator.validate(
                sql=sql, schema=schema, question=question, tables_used=relevant
            )

            history.append({
                "attempt": attempt, "tables_in_scope": relevant,
                "sql": sql, "valid": is_valid,
                "error": val_error, "lint_warnings": lint_warns,
            })

            if is_valid:
                rows, cols, data = _df_to_response(result_df)
                csv_path = None
                if not result_df.empty:
                    csv_path = str(_output_csv_path(question))
                    result_df.to_csv(csv_path, index=False)

                # ── Agent 3 — Viz metadata ──
                viz = self._viz_agent.generate(
                    question    = question,
                    sql         = sql,
                    columns     = cols,
                    sample_data = data,
                )

                # Build column_roles map from viz metadata
                column_roles = {}
                if viz.x_axis:         column_roles[viz.x_axis]        = "x_axis"
                if viz.y_axis:         column_roles[viz.y_axis]        = "y_axis"
                if viz.secondary_axis: column_roles[viz.secondary_axis] = "secondary_axis"
                if viz.series_label:   column_roles[viz.series_label]  = "series_label"

                if not pinned_tables:
                    set_cached_result(question, {
                        "success":    True,
                        "sql":        sql,
                        "rows":       rows,
                        "columns":    cols,
                        "data":       data,
                        "tables_used": relevant,
                        "attempts":   attempt,
                        "output_csv": csv_path,
                        "reasoning":  selector_reasoning,
                    })

                return QueryResult(
                    success      = True,
                    final_sql    = sql,
                    rows         = rows,
                    columns      = cols,
                    data         = data,
                    tables_used  = relevant,
                    attempts     = attempt,
                    output_csv   = csv_path,
                    history      = history,
                    reasoning    = selector_reasoning,
                    viz_meta     = viz,
                    column_roles = column_roles,
                )

            prev_sql = sql
            error    = val_error
            if attempt > max_retries:
                break

        return QueryResult(
            success=False, final_sql=prev_sql or "",
            tables_used=relevant, attempts=max_retries + 1,
            history=history, error=error,
            reasoning=selector_reasoning,
        )
        
    # ── private helpers ───────────────────────

    def _resolve_pinned(self, pinned_tables: Optional[list]) -> Optional[list]:
        """Validate pinned table names against what's actually loaded."""
        if not pinned_tables:
            return None
        valid, unknown = self.loader.validate_table_names(pinned_tables)
        if unknown:
            logger.warning("Unknown pinned table(s) ignored: %s", unknown)
        return valid or None
    
Editor is loading...