Spaces:
Build error
Build error
add new page and refine sql query ui
Browse files- poetry.lock +16 -1
- pyproject.toml +1 -0
- src/pages/1_Chart_Generation.py +3 -2
- src/pages/2_Query_And_Answer.py +159 -0
- src/utils.py +9 -0
poetry.lock
CHANGED
|
@@ -974,6 +974,21 @@ files = [
|
|
| 974 |
{file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"},
|
| 975 |
]
|
| 976 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 977 |
[[package]]
|
| 978 |
name = "streamlit"
|
| 979 |
version = "1.46.1"
|
|
@@ -1139,4 +1154,4 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|
| 1139 |
[metadata]
|
| 1140 |
lock-version = "2.0"
|
| 1141 |
python-versions = ">=3.12,<3.13"
|
| 1142 |
-
content-hash = "
|
|
|
|
| 974 |
{file = "smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5"},
|
| 975 |
]
|
| 976 |
|
| 977 |
+
[[package]]
|
| 978 |
+
name = "sqlparse"
|
| 979 |
+
version = "0.5.3"
|
| 980 |
+
description = "A non-validating SQL parser."
|
| 981 |
+
optional = false
|
| 982 |
+
python-versions = ">=3.8"
|
| 983 |
+
files = [
|
| 984 |
+
{file = "sqlparse-0.5.3-py3-none-any.whl", hash = "sha256:cf2196ed3418f3ba5de6af7e82c694a9fbdbfecccdfc72e281548517081f16ca"},
|
| 985 |
+
{file = "sqlparse-0.5.3.tar.gz", hash = "sha256:09f67787f56a0b16ecdbde1bfc7f5d9c3371ca683cfeaa8e6ff60b4807ec9272"},
|
| 986 |
+
]
|
| 987 |
+
|
| 988 |
+
[package.extras]
|
| 989 |
+
dev = ["build", "hatch"]
|
| 990 |
+
doc = ["sphinx"]
|
| 991 |
+
|
| 992 |
[[package]]
|
| 993 |
name = "streamlit"
|
| 994 |
version = "1.46.1"
|
|
|
|
| 1154 |
[metadata]
|
| 1155 |
lock-version = "2.0"
|
| 1156 |
python-versions = ">=3.12,<3.13"
|
| 1157 |
+
content-hash = "67b448796799eb25e83725fc27c23ff400273d48c767151c6462d8e1545052fe"
|
pyproject.toml
CHANGED
|
@@ -12,6 +12,7 @@ python = ">=3.12,<3.13"
|
|
| 12 |
streamlit = "^1.46.1"
|
| 13 |
requests = "^2.32.4"
|
| 14 |
watchdog = "^6.0.0"
|
|
|
|
| 15 |
|
| 16 |
[tool.poetry.group.dev.dependencies]
|
| 17 |
python-dotenv = "^1.1.1"
|
|
|
|
| 12 |
streamlit = "^1.46.1"
|
| 13 |
requests = "^2.32.4"
|
| 14 |
watchdog = "^6.0.0"
|
| 15 |
+
sqlparse = "^0.5.3"
|
| 16 |
|
| 17 |
[tool.poetry.group.dev.dependencies]
|
| 18 |
python-dotenv = "^1.1.1"
|
src/pages/1_Chart_Generation.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
from apis import generate_sql, generate_chart
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def main():
|
|
@@ -46,7 +47,7 @@ def main():
|
|
| 46 |
st.write(message["content"])
|
| 47 |
if "sql" in message:
|
| 48 |
with st.expander("π Generated SQL Query", expanded=False):
|
| 49 |
-
st.code(message["sql"], language="sql")
|
| 50 |
if "vega_spec" in message:
|
| 51 |
try:
|
| 52 |
with st.expander("π Chart Specification", expanded=False):
|
|
@@ -87,7 +88,7 @@ def main():
|
|
| 87 |
|
| 88 |
# Display SQL query
|
| 89 |
with st.expander("π Generated SQL Query", expanded=False):
|
| 90 |
-
st.code(sql_query, language="sql")
|
| 91 |
|
| 92 |
# Generate chart
|
| 93 |
with st.spinner("Generating chart..."):
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
from apis import generate_sql, generate_chart
|
| 4 |
+
from utils import format_sql
|
| 5 |
|
| 6 |
|
| 7 |
def main():
|
|
|
|
| 47 |
st.write(message["content"])
|
| 48 |
if "sql" in message:
|
| 49 |
with st.expander("π Generated SQL Query", expanded=False):
|
| 50 |
+
st.code(format_sql(message["sql"]), language="sql")
|
| 51 |
if "vega_spec" in message:
|
| 52 |
try:
|
| 53 |
with st.expander("π Chart Specification", expanded=False):
|
|
|
|
| 88 |
|
| 89 |
# Display SQL query
|
| 90 |
with st.expander("π Generated SQL Query", expanded=False):
|
| 91 |
+
st.code(format_sql(sql_query), language="sql")
|
| 92 |
|
| 93 |
# Generate chart
|
| 94 |
with st.spinner("Generating chart..."):
|
src/pages/2_Query_And_Answer.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from apis import ask, run_sql
|
| 5 |
+
from utils import format_sql
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main():
|
| 9 |
+
st.title("π¬ Wren AI Cloud API Demo - Query and Answer")
|
| 10 |
+
|
| 11 |
+
if "api_key" not in st.session_state or "project_id" not in st.session_state:
|
| 12 |
+
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
|
| 13 |
+
return
|
| 14 |
+
if not st.session_state.api_key or not st.session_state.project_id:
|
| 15 |
+
st.error("Please enter your API Key and Project ID in the sidebar of Home page to get started.")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
api_key = st.session_state.api_key
|
| 19 |
+
project_id = st.session_state.project_id
|
| 20 |
+
|
| 21 |
+
st.markdown('Using APIs: [Ask](https://wrenai.readme.io/reference/post_ask-1), [Run SQL](https://wrenai.readme.io/reference/cloud_post_run-sql)')
|
| 22 |
+
|
| 23 |
+
# Sidebar for API configuration
|
| 24 |
+
with st.sidebar:
|
| 25 |
+
st.header("π§ Configuration")
|
| 26 |
+
# Sample size configuration
|
| 27 |
+
sample_size = st.slider(
|
| 28 |
+
"Sample Size",
|
| 29 |
+
min_value=100,
|
| 30 |
+
max_value=10000,
|
| 31 |
+
value=1000,
|
| 32 |
+
step=100,
|
| 33 |
+
help="Number of data points to include in results"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Initialize chat history
|
| 37 |
+
if "qa_messages" not in st.session_state:
|
| 38 |
+
st.session_state.qa_messages = []
|
| 39 |
+
if "qa_thread_id" not in st.session_state:
|
| 40 |
+
st.session_state.qa_thread_id = ""
|
| 41 |
+
|
| 42 |
+
# Display chat history
|
| 43 |
+
for message in st.session_state.qa_messages:
|
| 44 |
+
with st.chat_message(message["role"]):
|
| 45 |
+
if message["role"] == "user":
|
| 46 |
+
st.write(message["content"])
|
| 47 |
+
else:
|
| 48 |
+
st.write(message["content"])
|
| 49 |
+
if "sql" in message:
|
| 50 |
+
with st.expander("π Generated SQL Query", expanded=False):
|
| 51 |
+
st.code(format_sql(message["sql"]), language="sql")
|
| 52 |
+
|
| 53 |
+
# Add button to run SQL
|
| 54 |
+
if st.button("π Run SQL Query", key=f"run_sql_{message.get('message_id', 'unknown')}"):
|
| 55 |
+
with st.spinner("Executing SQL query..."):
|
| 56 |
+
sql_result, error = run_sql(api_key, project_id, message["sql"], st.session_state.qa_thread_id, sample_size)
|
| 57 |
+
|
| 58 |
+
if sql_result:
|
| 59 |
+
data = sql_result.get("records", [])
|
| 60 |
+
if data:
|
| 61 |
+
# Convert to DataFrame for better display
|
| 62 |
+
df = pd.DataFrame(data)
|
| 63 |
+
st.success("SQL query executed successfully!")
|
| 64 |
+
st.dataframe(df, use_container_width=True)
|
| 65 |
+
else:
|
| 66 |
+
st.info("Query executed but returned no data.")
|
| 67 |
+
else:
|
| 68 |
+
st.error(f"Error executing SQL: {error}")
|
| 69 |
+
|
| 70 |
+
if "sql_results" in message:
|
| 71 |
+
st.subheader("π Query Results")
|
| 72 |
+
if message["sql_results"]:
|
| 73 |
+
st.dataframe(message["sql_results"], use_container_width=True)
|
| 74 |
+
else:
|
| 75 |
+
st.info("No results returned from the query.")
|
| 76 |
+
|
| 77 |
+
# Chat input
|
| 78 |
+
if prompt := st.chat_input("Ask a question about your data..."):
|
| 79 |
+
# Add user message to chat history
|
| 80 |
+
st.session_state.qa_messages.append({"role": "user", "content": prompt})
|
| 81 |
+
|
| 82 |
+
# Display user message
|
| 83 |
+
with st.chat_message("user"):
|
| 84 |
+
st.write(prompt)
|
| 85 |
+
|
| 86 |
+
# Generate response using ask API
|
| 87 |
+
with st.chat_message("assistant"):
|
| 88 |
+
with st.spinner("Generating answer..."):
|
| 89 |
+
ask_response, error = ask(api_key, project_id, prompt, st.session_state.qa_thread_id, sample_size=sample_size)
|
| 90 |
+
|
| 91 |
+
if ask_response:
|
| 92 |
+
answer = ask_response.get("summary", "")
|
| 93 |
+
sql_query = ask_response.get("sql", "")
|
| 94 |
+
st.session_state.qa_thread_id = ask_response.get("threadId", "")
|
| 95 |
+
|
| 96 |
+
if answer:
|
| 97 |
+
st.toast("Answer generated successfully!", icon="π")
|
| 98 |
+
|
| 99 |
+
# Create unique message ID
|
| 100 |
+
message_id = len(st.session_state.qa_messages)
|
| 101 |
+
|
| 102 |
+
# Store the response
|
| 103 |
+
assistant_message = {
|
| 104 |
+
"role": "assistant",
|
| 105 |
+
"content": answer,
|
| 106 |
+
"message_id": message_id
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
if sql_query:
|
| 110 |
+
assistant_message["sql"] = sql_query
|
| 111 |
+
|
| 112 |
+
st.session_state.qa_messages.append(assistant_message)
|
| 113 |
+
st.write(answer)
|
| 114 |
+
|
| 115 |
+
# Display SQL query if available
|
| 116 |
+
if sql_query:
|
| 117 |
+
with st.expander("π Generated SQL Query", expanded=False):
|
| 118 |
+
st.code(format_sql(sql_query), language="sql")
|
| 119 |
+
|
| 120 |
+
# Add button to run SQL
|
| 121 |
+
if st.button("π Run SQL Query", key=f"run_sql_{message_id}"):
|
| 122 |
+
with st.spinner("Executing SQL query..."):
|
| 123 |
+
sql_result, error = run_sql(api_key, project_id, sql_query, st.session_state.qa_thread_id, sample_size)
|
| 124 |
+
|
| 125 |
+
if sql_result:
|
| 126 |
+
data = sql_result.get("records", [])
|
| 127 |
+
if data:
|
| 128 |
+
# Convert to DataFrame for better display
|
| 129 |
+
df = pd.DataFrame(data)
|
| 130 |
+
st.success("SQL query executed successfully!")
|
| 131 |
+
st.dataframe(df, use_container_width=True)
|
| 132 |
+
else:
|
| 133 |
+
st.info("Query executed but returned no data.")
|
| 134 |
+
else:
|
| 135 |
+
st.error(f"Error executing SQL: {error}")
|
| 136 |
+
else:
|
| 137 |
+
st.toast("No answer was generated. Please try rephrasing your question.", icon="π€")
|
| 138 |
+
assistant_message = {
|
| 139 |
+
"role": "assistant",
|
| 140 |
+
"content": "I couldn't generate an answer for your question. Please try rephrasing it or make sure it's related to your data."
|
| 141 |
+
}
|
| 142 |
+
st.session_state.qa_messages.append(assistant_message)
|
| 143 |
+
else:
|
| 144 |
+
st.toast(f"Error generating answer: {error}", icon="π€")
|
| 145 |
+
assistant_message = {
|
| 146 |
+
"role": "assistant",
|
| 147 |
+
"content": "Sorry, I couldn't process your request. Please check your API credentials and try again."
|
| 148 |
+
}
|
| 149 |
+
st.session_state.qa_messages.append(assistant_message)
|
| 150 |
+
|
| 151 |
+
# Clear chat button
|
| 152 |
+
if st.sidebar.button("π§Ή Clear Chat History"):
|
| 153 |
+
st.session_state.qa_messages = []
|
| 154 |
+
st.session_state.qa_thread_id = ""
|
| 155 |
+
st.rerun()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
main()
|
src/utils.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def format_sql(sql: str) -> str:
|
| 5 |
+
return sqlparse.format(
|
| 6 |
+
sql,
|
| 7 |
+
reindent=True,
|
| 8 |
+
keyword_case="upper",
|
| 9 |
+
)
|