app_main.py

 avatar
unknown
python
a month ago
17 kB
4
Indexable


import os
import logging
import traceback
from pathlib import Path
from datetime import datetime

from flask import Flask, request, jsonify, send_from_directory
from werkzeug.utils import secure_filename

from dotenv import load_dotenv
load_dotenv()


from nl_to_sql_main import (
    NL2SQLPipeline,
    OUTPUT_DIR,
    UPLOAD_DIR,
    MAX_RETRIES,
    TunnelledMSSQL,
    SSH_HOST,
    SSH_PKEY_PATH,
    MSSQL_DB_USER,
    MSSQL_DB_PASS,
    MSSQL_DB_NAME,
)


# ─────────────────────────────────────────────────────────────────
# Logging
# ─────────────────────────────────────────────────────────────────

logging.basicConfig(
    level   = logging.INFO,
    format  = "%(asctime)s  %(levelname)-8s  %(name)s — %(message)s",
    datefmt = "%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)


# ─────────────────────────────────────────────────────────────────
# App factory
# ─────────────────────────────────────────────────────────────────

ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls"}


def create_app() -> Flask:
    app = Flask(__name__)

    # Ensure folders exist
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    UPLOAD_DIR.mkdir(parents=True, exist_ok=True)

    # ── Initialise pipeline + pre-load SQL Server at startup ──────
    pipeline = NL2SQLPipeline()

    # ── MySQL pre-load ────────────────────────────────────────────
    # ── MS SQL Server via SSH tunnel pre-load ─────────────────────
    mssql_tunnel_configured = all([MSSQL_DB_USER, MSSQL_DB_PASS, MSSQL_DB_NAME, SSH_HOST, SSH_PKEY_PATH])
    if mssql_tunnel_configured:
        logger.info("MS SQL credentials found — loading tables from '%s' via SSH tunnel", MSSQL_DB_NAME)
        try:
            loaded = pipeline.load_mssql_tunnelled_tables(table_names=None)
            logger.info("Pre-loaded %d MS SQL table(s): %s", len(loaded), loaded)
        except Exception as exc:
            logger.error("Failed to pre-load MS SQL Server tables: %s", exc)
    else:
        logger.info("MS SQL tunnel env vars not fully set — skipping MS SQL pre-load.")

        
    # ── Helpers ───────────────────────────────────────────────────

    def _allowed(filename: str) -> bool:
        return Path(filename).suffix.lower() in ALLOWED_EXTENSIONS

    def _ok(payload: dict, status: int = 200):
        return jsonify({"status": "success", **payload}), status

    def _err(message: str, status: int = 400, details: str = None):
        body = {"status": "error", "message": message}
        if details:
            body["details"] = details
        return jsonify(body), status

    # ── POST /upload ──────────────────────────────────────────────

    @app.route("/upload", methods=["POST"])
    def upload():
        """
        Upload one or more CSV / Excel files.

        Form-data fields
        ────────────────
        files   : one or more files (field name "files")

        Response
        ────────
        {
          "status":         "success",
          "uploaded":       ["orders", "customers"],   // DuckDB table names assigned
          "skipped":        ["bad.txt"],                // unsupported extensions
          "total_tables":   5                           // total tables now in DuckDB
        }
        """
        if "files" not in request.files:
            return _err("No files provided. Send files under the field name 'files'.")

        uploaded = []
        skipped  = []

        for f in request.files.getlist("files"):
            fname = secure_filename(f.filename or "")
            if not fname:
                continue
            if not _allowed(fname):
                skipped.append(fname)
                continue

            save_path = UPLOAD_DIR / fname
            f.save(save_path)
            try:
                table_name = pipeline.load_file(str(save_path))
                uploaded.append({"filename": fname, "table_name": table_name})
                logger.info("Uploaded and loaded '%s' as table '%s'", fname, table_name)
            except Exception as exc:
                logger.error("Failed to load '%s': %s", fname, exc)
                skipped.append(fname)

        if not uploaded:
            return _err(
                "No files were loaded successfully.",
                details=f"Skipped: {skipped}" if skipped else None,
            )

        return _ok({
            "uploaded":     uploaded,
            "skipped":      skipped,
            "total_tables": len(pipeline.loader.table_names()),
        })

    # ── GET /tables ───────────────────────────────────────────────

    @app.route("/tables", methods=["GET"])
    def tables():
        """
        List all currently loaded tables with their full schema.

        Response
        ────────
        {
          "status": "success",
          "total":  3,
          "tables": [
            {
              "table_name": "orders",
              "source":     "mssql:Orders",
              "db_name":    "SalesDB",
              "row_count":  12450,
              "columns": [
                { "name": "OrderID", "sql_type": "BIGINT", "nullable": false, "sample_values": [...] },
                ...
              ]
            },
            ...
          ]
        }
        """
        if not pipeline.has_tables():
            return _ok({"total": 0, "tables": []})

        return _ok({
            "total":  len(pipeline.loader.table_names()),
            "tables": pipeline.tables_info(),
        })

    # ── POST /generate ────────────────────────────────────────────

    @app.route("/generate", methods=["POST"])
    def generate():
        """
        Generate SQL from a natural language question.
        Does NOT execute the SQL — use /execute or /query for that.

        Request body (JSON)
        ───────────────────
        {
          "question":      "Total revenue by product category",   // required
          "pinned_tables": ["orders", "products"]                  // optional
        }

        Response
        ────────
        {
          "status":         "success",
          "sql":            "SELECT ...",
          "tables_used":    ["orders", "products"],
          "lint_warnings":  []
        }
        """
        body = request.get_json(silent=True) or {}

        question = (body.get("question") or "").strip()
        if not question:
            return _err("'question' is required.")

        if not pipeline.has_tables():
            return _err("No tables loaded. Upload a CSV or ensure SQL Server env vars are set.")

        pinned = body.get("pinned_tables") or None
        if pinned and not isinstance(pinned, list):
            return _err("'pinned_tables' must be a list of table name strings.")

        try:
            result = pipeline.generate(question=question, pinned_tables=pinned)
        except Exception as exc:
            logger.error("Generate error: %s", traceback.format_exc())
            return _err("SQL generation failed.", details=str(exc), status=500)

        if not result.success:
            return _err("SQL generation failed.", details=result.error, status=500)

        return _ok({
            "sql":           result.sql,
            "tables_used":   result.tables_used,
            "lint_warnings": result.lint_warnings,
        })

    # ── POST /execute ─────────────────────────────────────────────

    @app.route("/execute", methods=["POST"])
    def execute():
        """
        Execute a raw SQL string against DuckDB and return results.
        No LLM involved — you supply the SQL directly.

        Request body (JSON)
        ───────────────────
        {
          "sql":   "SELECT * FROM orders LIMIT 10",   // required
          "label": "my_query"                          // optional — used for CSV filename
        }

        Response
        ────────
        {
          "status":     "success",
          "sql":        "SELECT ...",
          "rows":       10,
          "columns":    ["OrderID", "Amount", ...],
          "data":       [{ "OrderID": 1, "Amount": 250.0 }, ...],
          "output_csv": "output/my_query__20250430_120000.csv"
        }
        """
        body = request.get_json(silent=True) or {}

        sql = (body.get("sql") or "").strip()
        if not sql:
            return _err("'sql' is required.")

        if not pipeline.has_tables():
            return _err("No tables loaded.")

        label = (body.get("label") or "manual_query").strip()

        try:
            body = request.get_json(silent=True) or {}
            sql = (body.get("sql") or "").strip()
            label = (body.get("label") or "manual_query").strip()
            tables_used = body.get("tables_used") or None
            result = pipeline.execute_sql(sql=sql, label=label, tables_used=tables_used)
        except Exception as exc:
            logger.error("Execute error: %s", traceback.format_exc())
            return _err("SQL execution failed.", details=str(exc), status=500)

        if not result.success:
            return _err("SQL execution failed.", details=result.error)

        return _ok({
            "sql":        result.sql,
            "rows":       result.rows,
            "columns":    result.columns,
            "data":       result.data,
            "output_csv": result.output_csv,
        })

    # ── POST /query ───────────────────────────────────────────────

    @app.route("/query", methods=["POST"])
    def query():
        """
        Full pipeline: natural language → SQL generation → validation → execution.
        Retries on failure up to max_retries times.

        Request body (JSON)
        ───────────────────
        {
          "question":      "Top 5 customers by total order value",   // required
          "pinned_tables": ["orders", "customers"],                   // optional
          "max_retries":   5                                          // optional (default 5)
        }

        Response (success)
        ──────────────────
        {
          "status":       "success",
          "sql":          "SELECT ...",
          "rows":         5,
          "columns":      ["customer_name", "total_value"],
          "data":         [{ "customer_name": "Acme", "total_value": 12400.0 }, ...],
          "tables_used":  ["orders", "customers"],
          "attempts":     2,
          "output_csv":   "output/top_5_customers__20250430_120000.csv",
          "history":      [...]
        }

        Response (failure)
        ──────────────────
        {
          "status":    "error",
          "message":   "Pipeline failed after 6 attempts.",
          "details":   "<last error message>",
          "history":   [...]
        }
        """
        body = request.get_json(silent=True) or {}

        question = (body.get("question") or "").strip()
        if not question:
            return _err("'question' is required.")

        if not pipeline.has_tables():
            return _err("No tables loaded. Upload a CSV or configure SQL Server env vars.")

        pinned      = body.get("pinned_tables") or None
        max_retries = int(body.get("max_retries") or MAX_RETRIES)

        if pinned and not isinstance(pinned, list):
            return _err("'pinned_tables' must be a list of table name strings.")

        try:
            result = pipeline.query(
                question      = question,
                pinned_tables = pinned,
                max_retries   = max_retries,
            )
        except Exception as exc:
            logger.error("Query pipeline error: %s", traceback.format_exc())
            return _err("Pipeline encountered an unexpected error.", details=str(exc), status=500)

        if not result.success:
            return jsonify({
                "status":  "error",
                "message": f"Pipeline failed after {result.attempts} attempt(s).",
                "details": result.error,
                "history": result.history,
            }), 422

        viz  = result.viz_meta
        return _ok({
            "sql":          result.final_sql,
            "rows":         result.rows,
            "columns":      result.columns,
            "data":         result.data,
            "tables_used":  result.tables_used,
            "attempts":     result.attempts,
            "output_csv":   result.output_csv,
            "history":      result.history,

            # ── NEW: reasoning ──
            "reasoning": result.reasoning,

            # ── NEW: column roles ──
            "column_roles": result.column_roles,

            # ── NEW: visualization metadata ──
            "visualization": {
                "chart_type":      viz.chart_type      if viz else "table",
                "title":           viz.title           if viz else "",
                "x_axis":          viz.x_axis          if viz else None,
                "y_axis":          viz.y_axis          if viz else None,
                "secondary_axis":  viz.secondary_axis  if viz else None,
                "series_label":    viz.series_label    if viz else None,
                "reasoning":       viz.reasoning       if viz else "",
            } if viz else None,
        })

    # ── GET /download/<filename> ──────────────────────────────────

    @app.route("/download/<path:filename>", methods=["GET"])
    def download(filename):
        """
        Download a previously saved output CSV.

        Example
        ───────
        GET /download/top_5_customers__20250430_120000.csv
        """
        safe = secure_filename(filename)
        if not safe:
            return _err("Invalid filename.")
        if not (OUTPUT_DIR / safe).exists():
            return _err(f"File '{safe}' not found in output directory.", status=404)
        return send_from_directory(
            directory     = str(OUTPUT_DIR.resolve()),
            path          = safe,
            as_attachment = True,
        )

    # ── GET /health ───────────────────────────────────────────────

    @app.route("/health", methods=["GET"])
    def health():
        """Quick liveness check."""
        return _ok({
            "tables_loaded": len(pipeline.loader.table_names()),
            "table_names":   pipeline.loader.table_names(),
            "timestamp":     datetime.utcnow().isoformat() + "Z",
        })

    # ── Global error handlers ─────────────────────────────────────

    @app.errorhandler(404)
    def not_found(e):
        return _err("Endpoint not found.", status=404)

    @app.errorhandler(405)
    def method_not_allowed(e):
        return _err("Method not allowed.", status=405)

    @app.errorhandler(500)
    def internal_error(e):
        return _err("Internal server error.", status=500)

    return app


# ─────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────


if __name__ == "__main__":
    app  = create_app()
    host = os.environ.get("FLASK_HOST", "0.0.0.0")
    port = int(os.environ.get("FLASK_PORT", 6112))
    debug= os.environ.get("FLASK_DEBUG", "false").lower() == "true"
    logger.info("Starting NL2SQL Flask API on %s:%d  (debug=%s)", host, port, debug)
    app.run(host=host, port=port, debug=debug)
Editor is loading...