Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Fri May 26 14:07:22 2023 | |
| @author: vibin | |
| """ | |
| import streamlit as st | |
| from pandasql import sqldf | |
| import pandas as pd | |
| import re | |
| from typing import List | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import re | |
| ### Main | |
| nav = st.sidebar.radio("Navigation",["TAPAS","Text2SQL"]) | |
| if nav == "TAPAS": | |
| col1 , col2, col3 = st.columns(3) | |
| col2.title("TAPAS") | |
| col3 , col4 = st.columns([3,12]) | |
| col4.text("Tabular Data Text Extraction using text") | |
| table = pd.read_csv("data.csv") | |
| table = table.astype(str) | |
| st.text("DataSet - ") | |
| st.dataframe(table,width=3000,height= 400) | |
| st.title("") | |
| lst_q = ["Which country has low medicare","Who are the patients from india","Who are the patients from india","Patients who have Edema","CUI code for diabetes patients","Patients having oxygen less than 94 but 91"] | |
| v2 = st.selectbox("Choose your text",lst_q,index = 0) | |
| st.title("") | |
| sql_txt = st.text_area("TAPAS Input",v2) | |
| if st.button("Predict"): | |
| tqa = pipeline(task="table-question-answering", | |
| model="google/tapas-base-finetuned-wtq") | |
| txt_sql = tqa(table=table, query=sql_txt)["answer"] | |
| st.text("Output - ") | |
| st.success(f"{txt_sql}") | |
| # st.write(all_students) | |
| elif nav == "Text2SQL": | |
| ### Function | |
| def prepare_input(question: str, table: List[str]): | |
| table_prefix = "table:" | |
| question_prefix = "question:" | |
| join_table = ",".join(table) | |
| inputs = f"{question_prefix} {question} {table_prefix} {join_table}" | |
| input_ids = tokenizer(inputs, max_length=512, return_tensors="pt").input_ids | |
| return input_ids | |
| def inference(question: str, table: List[str]) -> str: | |
| input_data = prepare_input(question=question, table=table) | |
| input_data = input_data.to(model.device) | |
| outputs = model.generate(inputs=input_data, num_beams=10, top_k=10, max_length=700) | |
| result = tokenizer.decode(token_ids=outputs[0], skip_special_tokens=True) | |
| return result | |
| col1 , col2, col3 = st.columns(3) | |
| col2.title("Text2SQL") | |
| col3 , col4 = st.columns([1,20]) | |
| col4.text("Text will be converted to SQL Query and can extract the data from DataSet") | |
| # Import Data | |
| df_qna = pd.read_csv("data.csv", encoding= 'unicode_escape') | |
| st.title("") | |
| st.text("DataSet - ") | |
| st.dataframe(df_qna,width=3000,height= 500) | |
| st.title("") | |
| lst_q = ["what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD", "get class code with measure = 72_HR_ABX", "get sum of version for Class_Code is Antibiotic Stewardship", "what interface is measure indicator code = 72_HR_ABX"] | |
| v2 = st.selectbox("Choose your text",lst_q,index = 0) | |
| st.title("") | |
| sql_txt = st.text_area("Text for SQL Conversion",v2) | |
| if st.button("Predict"): | |
| tokenizer = AutoTokenizer.from_pretrained("juierror/flan-t5-text2sql-with-schema") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("juierror/flan-t5-text2sql-with-schema") | |
| # text = "what interface is measure indicator code = 72_HR_ABX and version is 1 and source is TD" | |
| table_name = "df_qna" | |
| table_col = ["Patient_Name","Country","Disease","CUI","Snomed","Oxygen_Rate","Med_Type","Admission_Date"] | |
| txt_sql = inference(question=sql_txt, table=table_col) | |
| ### SQL Modification | |
| txt_sql = txt_sql.replace("table",table_name) | |
| sql_quotes = [] | |
| for match in re.finditer("=",txt_sql): | |
| new_txt = txt_sql[match.span()[1]+1:] | |
| try: | |
| match2 = re.search("AND",new_txt) | |
| sql_quotes.append((new_txt[:match2.span()[0]]).strip()) | |
| except: | |
| sql_quotes.append(new_txt.strip()) | |
| for i in sql_quotes: | |
| qts = "'" + i + "'" | |
| txt_sql = txt_sql.replace(i, qts) | |
| st.success(f"{txt_sql}") | |
| all_students = sqldf(txt_sql) | |
| st.text("Output - ") | |
| st.write(all_students) | |