AIEcosystem commited on
Commit
1267c7f
·
verified ·
1 Parent(s): 9675e78

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +130 -120
src/streamlit_app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
 
 
4
  import time
5
  import streamlit as st
6
  import pandas as pd
@@ -13,6 +14,7 @@ from streamlit_extras.stylable_container import stylable_container
13
  from typing import Optional
14
  from gliner import GLiNER
15
  from comet_ml import Experiment
 
16
 
17
  st.markdown(
18
  """
@@ -187,126 +189,9 @@ if st.button("Results"):
187
  fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25), paper_bgcolor='#F5FFFA', plot_bgcolor='#F5FFFA')
188
  st.plotly_chart(fig_treemap)
189
 
190
- # --- Model Loading and Caching ---
191
- @st.cache_resource
192
- def load_gliner_model():
193
- """
194
- Initializes and caches the GLiNER model.
195
- This ensures the model is only loaded once, improving performance.
196
- """
197
- try:
198
- return GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0", device="cpu")
199
- except Exception as e:
200
- st.error(f"Error loading the GLiNER model: {e}")
201
- st.stop()
202
-
203
- # Load the model
204
- model = load_gliner_model()
205
- st.subheader("Question-Answering", divider="violet")
206
- # Replaced two columns with a single text input
207
- question_input = st.text_input("Ask wh-questions. **Wh-questions begin with what, when, where, who, whom, which, whose, why and how. We use them to ask for specific information.**")
208
-
209
- if 'user_labels' not in st.session_state:
210
- st.session_state.user_labels = []
211
-
212
- if st.button("Add Question"):
213
- if question_input:
214
- if question_input not in st.session_state.user_labels:
215
- st.session_state.user_labels.append(question_input)
216
- st.success(f"Added question: {question_input}")
217
- else:
218
- st.warning("This question has already been added.")
219
- else:
220
- st.warning("Please enter a question.")
221
- st.markdown("---")
222
- st.subheader("Record of Questions", divider="violet")
223
-
224
- if st.session_state.user_labels:
225
- # Use enumerate to create a unique key for each item
226
- for i, label in enumerate(st.session_state.user_labels):
227
- col_list, col_delete = st.columns([0.9, 0.1])
228
- with col_list:
229
- st.write(f"- {label}", key=f"label_{i}")
230
- with col_delete:
231
- # Create a unique key for each button using the index
232
- if st.button("Delete", key=f"delete_{i}"):
233
- # Remove the label at the specific index
234
- st.session_state.user_labels.pop(i)
235
- # Rerun to update the UI
236
- st.rerun()
237
- else:
238
- st.info("No questions defined yet. Use the input above to add one.")
239
-
240
- st.divider()
241
-
242
- # --- Main Processing Logic ---
243
- if st.button("Extract Answers"):
244
- if not text.strip():
245
- st.warning("Please enter some text to analyze.")
246
- elif not st.session_state.user_labels:
247
- st.warning("Please define at least one question.")
248
- else:
249
- if comet_initialized:
250
- experiment = Experiment(
251
- api_key=COMET_API_KEY,
252
- workspace=COMET_WORKSPACE,
253
- project_name=COMET_PROJECT_NAME
254
- )
255
- experiment.log_parameter("input_text_length", len(text))
256
- experiment.log_parameter("defined_labels", st.session_state.user_labels)
257
- start_time = time.time()
258
- with st.spinner("Analyzing text...", show_time=True):
259
- try:
260
- entities = model.predict_entities(text, st.session_state.user_labels)
261
- end_time = time.time()
262
- elapsed_time = end_time - start_time
263
- st.info(f"Processing took **{elapsed_time:.2f} seconds**.")
264
-
265
- if entities:
266
- df1 = pd.DataFrame(entities)
267
- df2 = df1[['label', 'text', 'score']]
268
- df = df2.rename(columns={'label': 'question', 'text': 'answer'})
269
-
270
- st.subheader("Extracted Answers", divider="violet")
271
- st.dataframe(df, use_container_width=True)
272
- st.divider()
273
-
274
- dfa = pd.DataFrame(
275
- data={
276
- 'Column Name': ['text', 'label', 'score', 'start', 'end', 'category'],
277
- 'Description': [
278
- 'entity extracted from your text data',
279
- 'label (tag) assigned to a given extracted entity',
280
- 'accuracy score; how accurately a tag has been assigned to a given entity',
281
- 'index of the start of the corresponding entity',
282
- 'index of the end of the corresponding entity',
283
- 'the broader category the entity belongs to',
284
- ]
285
- }
286
- )
287
- buf = io.BytesIO()
288
- with zipfile.ZipFile(buf, "w") as myzip:
289
- myzip.writestr("Summary of the results.csv", df.to_csv(index=False))
290
- myzip.writestr("Glossary of tags.csv", dfa.to_csv(index=False))
291
-
292
- with stylable_container(
293
- key="download_button",
294
- css_styles="""button { background-color: red; border: 1px solid black; padding: 5px; color: white; }""",
295
- ):
296
- st.download_button(
297
- label="Download results and glossary (zip)",
298
- data=buf.getvalue(),
299
- file_name="nlpblogs_results.zip",
300
- mime="application/zip",
301
- )
302
-
303
- if comet_initialized:
304
- experiment.log_figure(figure=fig_treemap, figure_name="entity_treemap_categories")
305
- experiment.end()
306
- else: # If df is empty
307
- st.warning("No entities were found in the provided text.")
308
- except Exception as e:
309
- st.error(f"An error occurred during entity extraction: {e}")
310
 
311
  else: # If df is empty from the first extraction
312
  st.warning("No entities were found in the provided text.")
@@ -317,3 +202,128 @@ if st.button("Results"):
317
  st.text("")
318
  st.info(f"Results processed in **{elapsed_time:.2f} seconds**.")
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
 
4
+
5
  import time
6
  import streamlit as st
7
  import pandas as pd
 
14
  from typing import Optional
15
  from gliner import GLiNER
16
  from comet_ml import Experiment
17
+ import hashlib
18
 
19
  st.markdown(
20
  """
 
189
  fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25), paper_bgcolor='#F5FFFA', plot_bgcolor='#F5FFFA')
190
  st.plotly_chart(fig_treemap)
191
 
192
+ if comet_initialized:
193
+ experiment.log_figure(figure=fig_treemap, figure_name="entity_treemap_categories")
194
+ experiment.end()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  else: # If df is empty from the first extraction
197
  st.warning("No entities were found in the provided text.")
 
202
  st.text("")
203
  st.info(f"Results processed in **{elapsed_time:.2f} seconds**.")
204
 
205
+ # --- Question Answering Section (Moved outside the "Results" button) ---
206
+ # --- Model Loading and Caching ---
207
+ @st.cache_resource
208
+ def load_gliner_model():
209
+ """
210
+ Initializes and caches the GLiNER model.
211
+ This ensures the model is only loaded once, improving performance.
212
+ """
213
+ try:
214
+ return GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0", device="cpu")
215
+ except Exception as e:
216
+ st.error(f"Error loading the GLiNER model: {e}")
217
+ st.stop()
218
+
219
+ # Load the model
220
+ model = load_gliner_model()
221
+ st.subheader("Question-Answering", divider="violet")
222
+ # Replaced two columns with a single text input
223
+ question_input = st.text_input("Ask wh-questions. **Wh-questions begin with what, when, where, who, whom, which, whose, why and how. We use them to ask for specific information.**")
224
+
225
+ if 'user_labels' not in st.session_state:
226
+ st.session_state.user_labels = []
227
+
228
+ if st.button("Add Question"):
229
+ if question_input:
230
+ if question_input not in st.session_state.user_labels:
231
+ st.session_state.user_labels.append(question_input)
232
+ st.success(f"Added question: {question_input}")
233
+ else:
234
+ st.warning("This question has already been added.")
235
+ else:
236
+ st.warning("Please enter a question.")
237
+
238
+ st.markdown("---")
239
+ st.subheader("Record of Questions", divider="violet")
240
+ if st.session_state.user_labels:
241
+ # Use enumerate to create a unique key for each item
242
+ for i, label in enumerate(st.session_state.user_labels):
243
+ col_list, col_delete = st.columns([0.9, 0.1])
244
+ with col_list:
245
+ st.write(f"- {label}", key=f"label_{i}")
246
+ with col_delete:
247
+ # Create a unique key for each button using the index
248
+ if st.button("Delete", key=f"delete_{i}"):
249
+ # Remove the label at the specific index
250
+ st.session_state.user_labels.pop(i)
251
+ # Rerun to update the UI
252
+ st.rerun()
253
+ else:
254
+ st.info("No questions defined yet. Use the input above to add one.")
255
+
256
+ st.divider()
257
+ # --- Main Processing Logic ---
258
+ if st.button("Extract Answers"):
259
+ if not text.strip():
260
+ st.warning("Please enter some text to analyze.")
261
+ elif not st.session_state.user_labels:
262
+ st.warning("Please define at least one question.")
263
+ else:
264
+ if comet_initialized:
265
+ experiment = Experiment(
266
+ api_key=COMET_API_KEY,
267
+ workspace=COMET_WORKSPACE,
268
+ project_name=COMET_PROJECT_NAME
269
+ )
270
+ experiment.log_parameter("input_text_length", len(text))
271
+ experiment.log_parameter("defined_labels", st.session_state.user_labels)
272
+ start_time = time.time()
273
+ with st.spinner("Analyzing text...", show_time=True):
274
+ try:
275
+ entities = model.predict_entities(text, st.session_state.user_labels)
276
+ end_time = time.time()
277
+ elapsed_time = end_time - start_time
278
+ st.info(f"Processing took **{elapsed_time:.2f} seconds**.")
279
+
280
+ if entities:
281
+ df1 = pd.DataFrame(entities)
282
+ df2 = df1[['label', 'text', 'score']]
283
+ df = df2.rename(columns={'label': 'question', 'text': 'answer'})
284
+
285
+ st.subheader("Extracted Answers", divider="violet")
286
+ st.dataframe(df, use_container_width=True)
287
+ st.divider()
288
+
289
+ dfa = pd.DataFrame(
290
+ data={
291
+ 'Column Name': ['text', 'label', 'score', 'start', 'end', 'category'],
292
+ 'Description': [
293
+ 'entity extracted from your text data',
294
+ 'label (tag) assigned to a given extracted entity',
295
+ 'accuracy score; how accurately a tag has been assigned to a given entity',
296
+ 'index of the start of the corresponding entity',
297
+ 'index of the end of the corresponding entity',
298
+ 'the broader category the entity belongs to',
299
+ ]
300
+ }
301
+ )
302
+ buf = io.BytesIO()
303
+ with zipfile.ZipFile(buf, "w") as myzip:
304
+ myzip.writestr("Summary of the results.csv", df.to_csv(index=False))
305
+ myzip.writestr("Glossary of tags.csv", dfa.to_csv(index=False))
306
+
307
+ with stylable_container(
308
+ key="download_button",
309
+ css_styles="""button { background-color: red; border: 1px solid black; padding: 5px; color: white; }""",
310
+ ):
311
+ st.download_button(
312
+ label="Download results and glossary (zip)",
313
+ data=buf.getvalue(),
314
+ file_name="nlpblogs_results.zip",
315
+ mime="application/zip",
316
+ )
317
+
318
+ if comet_initialized:
319
+ # Assuming fig_treemap is still defined from the main NER run
320
+ # If not, you might need to re-generate it or handle the case where it's not available.
321
+ try:
322
+ experiment.log_figure(figure=fig_treemap, figure_name="entity_treemap_categories")
323
+ except NameError:
324
+ pass # Or handle this gracefully
325
+ experiment.end()
326
+ else: # If df is empty
327
+ st.warning("No answers were found for the provided questions.")
328
+ except Exception as e:
329
+ st.error(f"An error occurred during answer extraction: {e}")