Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Form, Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.llms import Ollama | |
| from langchain.chains import LLMChain | |
| import sqlite3 | |
| import os | |
| app = FastAPI() | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Initialize LLM | |
| llm = Ollama(model="mannix/defog-llama3-sqlcoder-8b") | |
| # LangChain Prompt | |
| prompt = PromptTemplate( | |
| input_variables=["user_question", "create_table_statements"], | |
| template="""<|begin_of_text|><|start_header_id|>user<|end_header_id|> | |
| Generate a SQL query to answer this question: {user_question} | |
| Instructions: | |
| - Use valid SQL syntax compatible with SQLite. | |
| - Use the exact table and column names provided in the question. | |
| - For date comparisons, do NOT use `INTERVAL`. Use SQLite-compatible expressions like `DATE('now', '-6 months')` or `DATE('now', '-30 days')`. | |
| - Do not use `NULLS LAST` or `NULLS FIRST` in `ORDER BY` clauses unless explicitly required, as SQLite does not support it. | |
| - Do not use `RANK()` or `OVER()` unless the question specifically requires ranking or window functions. SQLite has limited support for window functions. | |
| - When using joins, ensure that the correct join condition is specified between the related tables. Use `INNER JOIN`, `LEFT JOIN`, etc., and ensure the conditions in `ON` are correct. | |
| - If aggregating data (e.g., `SUM()`, `COUNT()`), ensure that grouping is done correctly using `GROUP BY` to avoid errors. | |
| - Avoid complex SQL expressions that may not be supported in SQLite, such as `INTERVAL` for date calculations or subqueries that are not supported by SQLite. | |
| - Return only the SQL query, no explanation. | |
| DDL statements: | |
| {create_table_statements}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | |
| The following SQL query best answers the question {user_question}: | |
| sql | |
| """ | |
| ) | |
| chain = LLMChain(llm=llm, prompt=prompt) | |
| session_config = { | |
| "ddl": None, | |
| "db_path": None, | |
| "chat_history": [] | |
| } | |
| def index(request: Request): | |
| return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": True, "chat_history": []}) | |
| async def setup(request: Request, ddl: str = Form(...), db_path: str = Form(...)): | |
| session_config["ddl"] = ddl | |
| session_config["db_path"] = db_path | |
| session_config["chat_history"] = [] | |
| return templates.TemplateResponse("chat.html", {"request": request, "ddl_mode": False, "chat_history": session_config["chat_history"]}) | |
| async def ask(request: Request, user_question: str = Form(...)): | |
| generated_sql = chain.invoke({ | |
| "user_question": user_question, | |
| "create_table_statements": session_config["ddl"] | |
| }) | |
| result_rows = [] | |
| error = None | |
| try: | |
| conn = sqlite3.connect(session_config["db_path"]) | |
| cur = conn.cursor() | |
| cur.execute(generated_sql) | |
| result_rows = cur.fetchall() | |
| except sqlite3.Error as e: | |
| error = str(e) | |
| finally: | |
| if 'cur' in locals(): | |
| cur.close() | |
| if 'conn' in locals(): | |
| conn.close() | |
| session_config["chat_history"].append({ | |
| "question": user_question, | |
| "sql": generated_sql, | |
| "result": result_rows, | |
| "error": error | |
| }) | |
| return templates.TemplateResponse("chat.html", { | |
| "request": request, | |
| "ddl_mode": False, | |
| "chat_history": session_config["chat_history"] | |
| }) | |