Update app.py
Browse files
app.py
CHANGED
|
@@ -144,17 +144,31 @@ def run_query(input_text, country, model_sel):
|
|
| 144 |
docs = get_docs(input_text, country=country,vulnerability_cat=vulnerabilities_cat)
|
| 145 |
# st.write('Selected country: ', country) # Debugging country
|
| 146 |
if model_sel == "chatGPT":
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
references = get_refs(docs, output)
|
| 151 |
# else:
|
| 152 |
# res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
|
| 153 |
# output = res
|
| 154 |
# references = get_refs(docs, res)
|
| 155 |
-
st.write('Response')
|
| 156 |
-
st.success(output)
|
| 157 |
-
st.
|
|
|
|
| 158 |
st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
|
| 159 |
st.markdown(references, unsafe_allow_html=True)
|
| 160 |
|
|
@@ -225,7 +239,9 @@ else:
|
|
| 225 |
text = st.text_area('Enter your question in the text box below using natural language or select an example from above:', value=selected_example)
|
| 226 |
|
| 227 |
if st.button('Submit'):
|
|
|
|
|
|
|
|
|
|
| 228 |
run_query(text, country=country, model_sel=model_sel)
|
| 229 |
|
| 230 |
|
| 231 |
-
|
|
|
|
| 144 |
docs = get_docs(input_text, country=country,vulnerability_cat=vulnerabilities_cat)
|
| 145 |
# st.write('Selected country: ', country) # Debugging country
|
| 146 |
if model_sel == "chatGPT":
|
| 147 |
+
|
| 148 |
+
response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": get_prompt(docs, input_text)}], stream=True)
|
| 149 |
+
# iterate through the stream of events
|
| 150 |
+
|
| 151 |
+
report = []
|
| 152 |
+
for chunk in response:
|
| 153 |
+
chunk_message = chunk['choices'][0]['delta']
|
| 154 |
+
# collected_chunks.append(chunk) # save the event response
|
| 155 |
+
if 'content' in chunk_message:
|
| 156 |
+
report.append(chunk_message.content) # extract the message
|
| 157 |
+
result = "".join(report).strip()
|
| 158 |
+
result = result.replace("\n", "")
|
| 159 |
+
# res_box.markdown(f'{result}')
|
| 160 |
+
res_box.success(result)
|
| 161 |
+
|
| 162 |
+
output = result
|
| 163 |
references = get_refs(docs, output)
|
| 164 |
# else:
|
| 165 |
# res = client.text_generation(get_prompt_llama2(docs, query=input_query), max_new_tokens=4000, temperature=0.01, model=model)
|
| 166 |
# output = res
|
| 167 |
# references = get_refs(docs, res)
|
| 168 |
+
# st.write('Response')
|
| 169 |
+
# st.success(output)
|
| 170 |
+
st.markdown("----")
|
| 171 |
+
st.markdown('**REFERENCES:**')
|
| 172 |
st.markdown('References are based on text automatically extracted from climate policy documents. These extracts may contain non-legible characters or disjointed text as an artifact of the extraction procedure')
|
| 173 |
st.markdown(references, unsafe_allow_html=True)
|
| 174 |
|
|
|
|
| 239 |
text = st.text_area('Enter your question in the text box below using natural language or select an example from above:', value=selected_example)
|
| 240 |
|
| 241 |
if st.button('Submit'):
|
| 242 |
+
st.markdown("----")
|
| 243 |
+
st.markdown('**RESPONSE:**')
|
| 244 |
+
res_box = st.empty()
|
| 245 |
run_query(text, country=country, model_sel=model_sel)
|
| 246 |
|
| 247 |
|
|
|