Spaces:
Running
Running
| import streamlit as st | |
| import requests | |
| import subprocess | |
| import re | |
| import sys | |
| import urllib.request | |
| import json | |
| import os | |
| import ssl | |
| import time | |
| PROMPT_TEMPLATE = """### Instruction:\n{instruction}\n\n### Input:\n{input}\n### Question:\n{question}\n\n### Response (use duckdb shorthand if possible):\n""" | |
| INSTRUCTION_TEMPLATE = """Your task is to generate valid duckdb SQL to answer the following question{has_schema}""" # noqa: E501 | |
| ERROR_MESSAGE = ":red[ Quack! Much to our regret, SQL generation has gone a tad duck-side-down.\nThe model is currently not able to craft a correct SQL query for this request. \nSorry my duck friend. ]\n\n:red[If the question is about your own database, make sure to set the correct schema. Otherwise, try to rephrase your request. ]\n\n```sql\n{sql_query}\n```\n\n```sql\n{error_msg}\n```" | |
| STOP_TOKENS = ["###", ";", "--", "```"] | |
| def allowSelfSignedHttps(allowed): | |
| # bypass the server certificate verification on client side | |
| if allowed and not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr(ssl, '_create_unverified_context', None): | |
| ssl._create_default_https_context = ssl._create_unverified_context | |
| allowSelfSignedHttps(True) # this line is needed if you use self-signed certificate in your scoring service. | |
| def generate_prompt(question, schema): | |
| input = "" | |
| if schema: | |
| # Lowercase types inside each CREATE TABLE (...) statement | |
| for create_table in re.findall( | |
| r"CREATE TABLE [^(]+\((.*?)\);", schema, flags=re.DOTALL | re.MULTILINE | |
| ): | |
| for create_col in re.findall(r"(\w+) (\w+)", create_table): | |
| schema = schema.replace( | |
| f"{create_col[0]} {create_col[1]}", | |
| f"{create_col[0]} {create_col[1].lower()}", | |
| ) | |
| input = """Here is the database schema that the SQL query will run on:\n{schema}\n""".format( # noqa: E501 | |
| schema=schema | |
| ) | |
| prompt = PROMPT_TEMPLATE.format( | |
| instruction=INSTRUCTION_TEMPLATE.format( | |
| has_schema="." if schema == "" else ", given a duckdb database schema." | |
| ), | |
| input=input, | |
| question=question, | |
| ) | |
| return prompt | |
| def generate_sql_azure(question, schema): | |
| prompt = generate_prompt(question, schema) | |
| start = time.time() | |
| data={ | |
| "input_data": { | |
| "input_string": [prompt], | |
| "parameters":{ | |
| "top_p": 0.9, | |
| "temperature": 0.1, | |
| "max_new_tokens": 200, | |
| "do_sample": True | |
| } | |
| } | |
| } | |
| body = str.encode(json.dumps(data)) | |
| url = 'https://motherduck-eu-west2-xbdfd.westeurope.inference.ml.azure.com/score' | |
| headers = {'Content-Type':'application/json', 'Authorization':('Bearer '+ st.secrets['azure_ai_token']), 'azureml-model-deployment': 'motherduckdb-duckdb-nsql-7b-v-1' } | |
| req = urllib.request.Request(url, body, headers) | |
| raw_resp = urllib.request.urlopen(req) | |
| resp = json.loads(raw_resp.read().decode("utf-8"))[0]["0"] | |
| sql_query = resp[len(prompt):] | |
| print(time.time()-start) | |
| return sql_query | |
| def generate_sql(question, schema): | |
| print(question) | |
| prompt = generate_prompt(question, schema) | |
| start = time.time() | |
| s = requests.Session() | |
| api_base = "https://text-motherduck-sql-fp16-4vycuix6qcp2.octoai.run" | |
| url = f"{api_base}/v1/completions" | |
| body = { | |
| "model": "motherduck-sql-fp16", | |
| "prompt": prompt, | |
| "temperature": 0.1, | |
| "max_tokens": 200, | |
| "stop": "<s>", | |
| "n": 1, | |
| } | |
| headers = {"Authorization": f"Bearer {st.secrets['octoml_token']}"} | |
| with s.post(url, json=body, headers=headers) as resp: | |
| sql_query = resp.json()["choices"][0]["text"] | |
| print(time.time()-start) | |
| return sql_query | |
| def validate_sql(query, schema): | |
| try: | |
| # Define subprocess | |
| process = subprocess.Popen( | |
| [sys.executable, './validate_sql.py', query, schema], | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE | |
| ) | |
| # Get output and potential parser, and binder error message | |
| stdout, stderr = process.communicate(timeout=0.5) | |
| if stderr: | |
| error_message = stderr.decode('utf8').split("\n") | |
| # skip traceback | |
| if len(error_message) > 3: | |
| error_message = "\n".join(error_message[3:]) | |
| return False, error_message | |
| return True, "" | |
| except subprocess.TimeoutExpired: | |
| process.kill() | |
| # timeout reached, so parsing and binding was very likely successful | |
| return True, "" | |
| st.title("DuckDB-NSQL-7B Demo") | |
| expander = st.expander("Customize Schema (Optional)") | |
| expander.markdown( | |
| "If you DuckDB database is `database.duckdb`, execute this query in your terminal to get your current schema:" | |
| ) | |
| expander.markdown( | |
| """```bash\necho ".schema" | duckdb database.duckdb | sed 's/(/(\\n /g' | sed 's/, /,\\n /g' | sed 's/);/\\n);\\n/g'\n```""", | |
| ) | |
| # Input field for text prompt | |
| default_schema = """CREATE TABLE rideshare( | |
| hvfhs_license_num VARCHAR, | |
| dispatching_base_num VARCHAR, | |
| originating_base_num VARCHAR, | |
| request_datetime TIMESTAMP, | |
| on_scene_datetime TIMESTAMP, | |
| pickup_datetime TIMESTAMP, | |
| dropoff_datetime TIMESTAMP, | |
| PULocationID BIGINT, | |
| DOLocationID BIGINT, | |
| trip_miles DOUBLE, | |
| trip_time BIGINT, | |
| base_passenger_fare DOUBLE, | |
| tolls DOUBLE, | |
| bcf DOUBLE, | |
| sales_tax DOUBLE, | |
| congestion_surcharge DOUBLE, | |
| airport_fee DOUBLE, | |
| tips DOUBLE, | |
| driver_pay DOUBLE, | |
| shared_request_flag VARCHAR, | |
| shared_match_flag VARCHAR, | |
| access_a_ride_flag VARCHAR, | |
| wav_request_flag VARCHAR, | |
| wav_match_flag VARCHAR | |
| ); | |
| CREATE TABLE service_requests( | |
| unique_key BIGINT, | |
| created_date TIMESTAMP, | |
| closed_date TIMESTAMP, | |
| agency VARCHAR, | |
| agency_name VARCHAR, | |
| complaint_type VARCHAR, | |
| descriptor VARCHAR, | |
| location_type VARCHAR, | |
| incident_zip VARCHAR, | |
| incident_address VARCHAR, | |
| street_name VARCHAR, | |
| cross_street_1 VARCHAR, | |
| cross_street_2 VARCHAR, | |
| intersection_street_1 VARCHAR, | |
| intersection_street_2 VARCHAR, | |
| address_type VARCHAR, | |
| city VARCHAR, | |
| landmark VARCHAR, | |
| facility_type VARCHAR, | |
| status VARCHAR, | |
| due_date TIMESTAMP, | |
| resolution_description VARCHAR, | |
| resolution_action_updated_date TIMESTAMP, | |
| community_board VARCHAR, | |
| bbl VARCHAR, | |
| borough VARCHAR, | |
| x_coordinate_state_plane VARCHAR, | |
| y_coordinate_state_plane VARCHAR, | |
| open_data_channel_type VARCHAR, | |
| park_facility_name VARCHAR, | |
| park_borough VARCHAR, | |
| vehicle_type VARCHAR, | |
| taxi_company_borough VARCHAR, | |
| taxi_pick_up_location VARCHAR, | |
| bridge_highway_name VARCHAR, | |
| bridge_highway_direction VARCHAR, | |
| road_ramp VARCHAR, | |
| bridge_highway_segment VARCHAR, | |
| latitude DOUBLE, | |
| longitude DOUBLE | |
| ); | |
| CREATE TABLE taxi( | |
| VendorID BIGINT, | |
| tpep_pickup_datetime TIMESTAMP, | |
| tpep_dropoff_datetime TIMESTAMP, | |
| passenger_count DOUBLE, | |
| trip_distance DOUBLE, | |
| RatecodeID DOUBLE, | |
| store_and_fwd_flag VARCHAR, | |
| PULocationID BIGINT, | |
| DOLocationID BIGINT, | |
| payment_type BIGINT, | |
| fare_amount DOUBLE, | |
| extra DOUBLE, | |
| mta_tax DOUBLE, | |
| tip_amount DOUBLE, | |
| tolls_amount DOUBLE, | |
| improvement_surcharge DOUBLE, | |
| total_amount DOUBLE, | |
| congestion_surcharge DOUBLE, | |
| airport_fee DOUBLE, | |
| drivers VARCHAR[], | |
| speeding_tickets STRUCT(date TIMESTAMP, speed VARCHAR)[], | |
| other_violations JSON | |
| );""" | |
| schema = expander.text_area("Current schema:", value=default_schema, height=500) | |
| # Input field for text prompt | |
| text_prompt = st.text_input( | |
| "What DuckDB SQL query can I write for you?", value="Read a CSV file from test.csv" | |
| ) | |
| if text_prompt: | |
| sql_query = generate_sql(text_prompt, schema) | |
| valid, msg = validate_sql(sql_query, schema) | |
| if not valid: | |
| st.markdown(ERROR_MESSAGE.format(sql_query=sql_query, error_msg=msg)) | |
| else: | |
| st.markdown(f"""```sql\n{sql_query}\n```""") | |