app_main.py
unknown
python
a month ago
17 kB
4
Indexable
import os
import logging
import traceback
from pathlib import Path
from datetime import datetime
from flask import Flask, request, jsonify, send_from_directory
from werkzeug.utils import secure_filename
from dotenv import load_dotenv
load_dotenv()
from nl_to_sql_main import (
NL2SQLPipeline,
OUTPUT_DIR,
UPLOAD_DIR,
MAX_RETRIES,
TunnelledMSSQL,
SSH_HOST,
SSH_PKEY_PATH,
MSSQL_DB_USER,
MSSQL_DB_PASS,
MSSQL_DB_NAME,
)
# ─────────────────────────────────────────────────────────────────
# Logging
# ─────────────────────────────────────────────────────────────────
logging.basicConfig(
level = logging.INFO,
format = "%(asctime)s %(levelname)-8s %(name)s — %(message)s",
datefmt = "%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────────────────────────
# App factory
# ─────────────────────────────────────────────────────────────────
ALLOWED_EXTENSIONS = {".csv", ".xlsx", ".xls"}
def create_app() -> Flask:
app = Flask(__name__)
# Ensure folders exist
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
# ── Initialise pipeline + pre-load SQL Server at startup ──────
pipeline = NL2SQLPipeline()
# ── MySQL pre-load ────────────────────────────────────────────
# ── MS SQL Server via SSH tunnel pre-load ─────────────────────
mssql_tunnel_configured = all([MSSQL_DB_USER, MSSQL_DB_PASS, MSSQL_DB_NAME, SSH_HOST, SSH_PKEY_PATH])
if mssql_tunnel_configured:
logger.info("MS SQL credentials found — loading tables from '%s' via SSH tunnel", MSSQL_DB_NAME)
try:
loaded = pipeline.load_mssql_tunnelled_tables(table_names=None)
logger.info("Pre-loaded %d MS SQL table(s): %s", len(loaded), loaded)
except Exception as exc:
logger.error("Failed to pre-load MS SQL Server tables: %s", exc)
else:
logger.info("MS SQL tunnel env vars not fully set — skipping MS SQL pre-load.")
# ── Helpers ───────────────────────────────────────────────────
def _allowed(filename: str) -> bool:
return Path(filename).suffix.lower() in ALLOWED_EXTENSIONS
def _ok(payload: dict, status: int = 200):
return jsonify({"status": "success", **payload}), status
def _err(message: str, status: int = 400, details: str = None):
body = {"status": "error", "message": message}
if details:
body["details"] = details
return jsonify(body), status
# ── POST /upload ──────────────────────────────────────────────
@app.route("/upload", methods=["POST"])
def upload():
"""
Upload one or more CSV / Excel files.
Form-data fields
────────────────
files : one or more files (field name "files")
Response
────────
{
"status": "success",
"uploaded": ["orders", "customers"], // DuckDB table names assigned
"skipped": ["bad.txt"], // unsupported extensions
"total_tables": 5 // total tables now in DuckDB
}
"""
if "files" not in request.files:
return _err("No files provided. Send files under the field name 'files'.")
uploaded = []
skipped = []
for f in request.files.getlist("files"):
fname = secure_filename(f.filename or "")
if not fname:
continue
if not _allowed(fname):
skipped.append(fname)
continue
save_path = UPLOAD_DIR / fname
f.save(save_path)
try:
table_name = pipeline.load_file(str(save_path))
uploaded.append({"filename": fname, "table_name": table_name})
logger.info("Uploaded and loaded '%s' as table '%s'", fname, table_name)
except Exception as exc:
logger.error("Failed to load '%s': %s", fname, exc)
skipped.append(fname)
if not uploaded:
return _err(
"No files were loaded successfully.",
details=f"Skipped: {skipped}" if skipped else None,
)
return _ok({
"uploaded": uploaded,
"skipped": skipped,
"total_tables": len(pipeline.loader.table_names()),
})
# ── GET /tables ───────────────────────────────────────────────
@app.route("/tables", methods=["GET"])
def tables():
"""
List all currently loaded tables with their full schema.
Response
────────
{
"status": "success",
"total": 3,
"tables": [
{
"table_name": "orders",
"source": "mssql:Orders",
"db_name": "SalesDB",
"row_count": 12450,
"columns": [
{ "name": "OrderID", "sql_type": "BIGINT", "nullable": false, "sample_values": [...] },
...
]
},
...
]
}
"""
if not pipeline.has_tables():
return _ok({"total": 0, "tables": []})
return _ok({
"total": len(pipeline.loader.table_names()),
"tables": pipeline.tables_info(),
})
# ── POST /generate ────────────────────────────────────────────
@app.route("/generate", methods=["POST"])
def generate():
"""
Generate SQL from a natural language question.
Does NOT execute the SQL — use /execute or /query for that.
Request body (JSON)
───────────────────
{
"question": "Total revenue by product category", // required
"pinned_tables": ["orders", "products"] // optional
}
Response
────────
{
"status": "success",
"sql": "SELECT ...",
"tables_used": ["orders", "products"],
"lint_warnings": []
}
"""
body = request.get_json(silent=True) or {}
question = (body.get("question") or "").strip()
if not question:
return _err("'question' is required.")
if not pipeline.has_tables():
return _err("No tables loaded. Upload a CSV or ensure SQL Server env vars are set.")
pinned = body.get("pinned_tables") or None
if pinned and not isinstance(pinned, list):
return _err("'pinned_tables' must be a list of table name strings.")
try:
result = pipeline.generate(question=question, pinned_tables=pinned)
except Exception as exc:
logger.error("Generate error: %s", traceback.format_exc())
return _err("SQL generation failed.", details=str(exc), status=500)
if not result.success:
return _err("SQL generation failed.", details=result.error, status=500)
return _ok({
"sql": result.sql,
"tables_used": result.tables_used,
"lint_warnings": result.lint_warnings,
})
# ── POST /execute ─────────────────────────────────────────────
@app.route("/execute", methods=["POST"])
def execute():
"""
Execute a raw SQL string against DuckDB and return results.
No LLM involved — you supply the SQL directly.
Request body (JSON)
───────────────────
{
"sql": "SELECT * FROM orders LIMIT 10", // required
"label": "my_query" // optional — used for CSV filename
}
Response
────────
{
"status": "success",
"sql": "SELECT ...",
"rows": 10,
"columns": ["OrderID", "Amount", ...],
"data": [{ "OrderID": 1, "Amount": 250.0 }, ...],
"output_csv": "output/my_query__20250430_120000.csv"
}
"""
body = request.get_json(silent=True) or {}
sql = (body.get("sql") or "").strip()
if not sql:
return _err("'sql' is required.")
if not pipeline.has_tables():
return _err("No tables loaded.")
label = (body.get("label") or "manual_query").strip()
try:
body = request.get_json(silent=True) or {}
sql = (body.get("sql") or "").strip()
label = (body.get("label") or "manual_query").strip()
tables_used = body.get("tables_used") or None
result = pipeline.execute_sql(sql=sql, label=label, tables_used=tables_used)
except Exception as exc:
logger.error("Execute error: %s", traceback.format_exc())
return _err("SQL execution failed.", details=str(exc), status=500)
if not result.success:
return _err("SQL execution failed.", details=result.error)
return _ok({
"sql": result.sql,
"rows": result.rows,
"columns": result.columns,
"data": result.data,
"output_csv": result.output_csv,
})
# ── POST /query ───────────────────────────────────────────────
@app.route("/query", methods=["POST"])
def query():
"""
Full pipeline: natural language → SQL generation → validation → execution.
Retries on failure up to max_retries times.
Request body (JSON)
───────────────────
{
"question": "Top 5 customers by total order value", // required
"pinned_tables": ["orders", "customers"], // optional
"max_retries": 5 // optional (default 5)
}
Response (success)
──────────────────
{
"status": "success",
"sql": "SELECT ...",
"rows": 5,
"columns": ["customer_name", "total_value"],
"data": [{ "customer_name": "Acme", "total_value": 12400.0 }, ...],
"tables_used": ["orders", "customers"],
"attempts": 2,
"output_csv": "output/top_5_customers__20250430_120000.csv",
"history": [...]
}
Response (failure)
──────────────────
{
"status": "error",
"message": "Pipeline failed after 6 attempts.",
"details": "<last error message>",
"history": [...]
}
"""
body = request.get_json(silent=True) or {}
question = (body.get("question") or "").strip()
if not question:
return _err("'question' is required.")
if not pipeline.has_tables():
return _err("No tables loaded. Upload a CSV or configure SQL Server env vars.")
pinned = body.get("pinned_tables") or None
max_retries = int(body.get("max_retries") or MAX_RETRIES)
if pinned and not isinstance(pinned, list):
return _err("'pinned_tables' must be a list of table name strings.")
try:
result = pipeline.query(
question = question,
pinned_tables = pinned,
max_retries = max_retries,
)
except Exception as exc:
logger.error("Query pipeline error: %s", traceback.format_exc())
return _err("Pipeline encountered an unexpected error.", details=str(exc), status=500)
if not result.success:
return jsonify({
"status": "error",
"message": f"Pipeline failed after {result.attempts} attempt(s).",
"details": result.error,
"history": result.history,
}), 422
viz = result.viz_meta
return _ok({
"sql": result.final_sql,
"rows": result.rows,
"columns": result.columns,
"data": result.data,
"tables_used": result.tables_used,
"attempts": result.attempts,
"output_csv": result.output_csv,
"history": result.history,
# ── NEW: reasoning ──
"reasoning": result.reasoning,
# ── NEW: column roles ──
"column_roles": result.column_roles,
# ── NEW: visualization metadata ──
"visualization": {
"chart_type": viz.chart_type if viz else "table",
"title": viz.title if viz else "",
"x_axis": viz.x_axis if viz else None,
"y_axis": viz.y_axis if viz else None,
"secondary_axis": viz.secondary_axis if viz else None,
"series_label": viz.series_label if viz else None,
"reasoning": viz.reasoning if viz else "",
} if viz else None,
})
# ── GET /download/<filename> ──────────────────────────────────
@app.route("/download/<path:filename>", methods=["GET"])
def download(filename):
"""
Download a previously saved output CSV.
Example
───────
GET /download/top_5_customers__20250430_120000.csv
"""
safe = secure_filename(filename)
if not safe:
return _err("Invalid filename.")
if not (OUTPUT_DIR / safe).exists():
return _err(f"File '{safe}' not found in output directory.", status=404)
return send_from_directory(
directory = str(OUTPUT_DIR.resolve()),
path = safe,
as_attachment = True,
)
# ── GET /health ───────────────────────────────────────────────
@app.route("/health", methods=["GET"])
def health():
"""Quick liveness check."""
return _ok({
"tables_loaded": len(pipeline.loader.table_names()),
"table_names": pipeline.loader.table_names(),
"timestamp": datetime.utcnow().isoformat() + "Z",
})
# ── Global error handlers ─────────────────────────────────────
@app.errorhandler(404)
def not_found(e):
return _err("Endpoint not found.", status=404)
@app.errorhandler(405)
def method_not_allowed(e):
return _err("Method not allowed.", status=405)
@app.errorhandler(500)
def internal_error(e):
return _err("Internal server error.", status=500)
return app
# ─────────────────────────────────────────────────────────────────
# Entry point
# ─────────────────────────────────────────────────────────────────
if __name__ == "__main__":
app = create_app()
host = os.environ.get("FLASK_HOST", "0.0.0.0")
port = int(os.environ.get("FLASK_PORT", 6112))
debug= os.environ.get("FLASK_DEBUG", "false").lower() == "true"
logger.info("Starting NL2SQL Flask API on %s:%d (debug=%s)", host, port, debug)
app.run(host=host, port=port, debug=debug)Editor is loading...