AIEcosystem commited on
Commit
0b88ebc
·
verified ·
1 Parent(s): 1267c7f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +35 -52
src/streamlit_app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
 
4
-
5
  import time
6
  import streamlit as st
7
  import pandas as pd
@@ -16,6 +16,9 @@ from gliner import GLiNER
16
  from comet_ml import Experiment
17
  import hashlib
18
 
 
 
 
19
  st.markdown(
20
  """
21
  <style>
@@ -87,7 +90,6 @@ with st.sidebar:
87
  st.link_button("AI Web App Builder", " https://nlpblogs.com/custom-web-app-development/", type="primary")
88
 
89
  # --- Comet ML Setup ---
90
- os.environ['HF_HOME'] = '/tmp'
91
  COMET_API_KEY = os.environ.get("COMET_API_KEY")
92
  COMET_WORKSPACE = os.environ.get("COMET_WORKSPACE")
93
  COMET_PROJECT_NAME = os.environ.get("COMET_PROJECT_NAME")
@@ -104,7 +106,7 @@ category_mapping = {
104
  "Contact Information": ["Email", "Phone_number", "Street_address", "City", "Country"],
105
  "Personal Details": ["Date_of_birth", "Marital_status", "Person"],
106
  "Employment Status": ["Full_time", "Part_time", "Contract", "Terminated", "Retired"],
107
- "Employment Information" : ["Job_title", "Date", "Organization", "Role"],
108
  "Performance": ["Performance_score"],
109
  "Attendance": ["Leave_of_absence"],
110
  "Benefits": ["Retirement_plan", "Bonus", "Stock_options", "Health_insurance"],
@@ -112,7 +114,7 @@ category_mapping = {
112
  "Deductions": ["Tax", "Deductions"],
113
  "Recruitment & Sourcing": ["Interview_type", "Applicant", "Referral", "Job_board", "Recruiter"],
114
  "Legal & Compliance": ["Offer_letter", "Agreement"],
115
- "Professional_Development": [ "Certification", "Skill"]
116
  }
117
 
118
  # --- Model Loading ---
@@ -120,7 +122,7 @@ category_mapping = {
120
  def load_ner_model():
121
  """Loads the GLiNER model and caches it."""
122
  try:
123
- return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints= labels)
124
  except Exception as e:
125
  st.error(f"Failed to load NER model. Please check your internet connection or model availability: {e}")
126
  st.stop()
@@ -134,8 +136,13 @@ reverse_category_mapping = {label: category for category, label_list in category
134
  text = st.text_area("Type or paste your text below, and then press Ctrl + Enter", height=250, key='my_text_area')
135
 
136
  def clear_text():
137
- """Clears the text area."""
138
  st.session_state['my_text_area'] = ""
 
 
 
 
 
139
 
140
  st.button("Clear text", on_click=clear_text)
141
 
@@ -150,6 +157,7 @@ if st.button("Results"):
150
  df = pd.DataFrame(entities)
151
  if not df.empty:
152
  df['category'] = df['label'].map(reverse_category_mapping)
 
153
  if comet_initialized:
154
  experiment = Experiment(
155
  api_key=COMET_API_KEY,
@@ -158,12 +166,10 @@ if st.button("Results"):
158
  )
159
  experiment.log_parameter("input_text", text)
160
  experiment.log_table("predicted_entities", df)
161
-
162
  st.subheader("Grouped Entities by Category", divider="green")
163
- # Create tabs for each category
164
  category_names = sorted(list(category_mapping.keys()))
165
  category_tabs = st.tabs(category_names)
166
-
167
  for i, category_name in enumerate(category_names):
168
  with category_tabs[i]:
169
  df_category_filtered = df[df['category'] == category_name]
@@ -181,50 +187,39 @@ if st.button("Results"):
181
  - **start**: ['index of the start of the corresponding entity']
182
  - **end**: ['index of the end of the corresponding entity']
183
  ''')
184
- st.divider()
185
-
186
- # Tree map
187
- st.subheader("Tree map", divider="green")
188
- fig_treemap = px.treemap(df, path=[px.Constant("all"), 'category', 'label', 'text'], values='score', color='category')
189
- fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25), paper_bgcolor='#F5FFFA', plot_bgcolor='#F5FFFA')
190
- st.plotly_chart(fig_treemap)
191
-
192
- if comet_initialized:
193
- experiment.log_figure(figure=fig_treemap, figure_name="entity_treemap_categories")
194
- experiment.end()
195
 
196
- else: # If df is empty from the first extraction
197
  st.warning("No entities were found in the provided text.")
 
 
 
198
 
199
- end_time = time.time()
200
- elapsed_time = end_time - start_time
201
- st.text("")
202
- st.text("")
203
- st.info(f"Results processed in **{elapsed_time:.2f} seconds**.")
 
 
204
 
205
- # --- Question Answering Section (Moved outside the "Results" button) ---
206
- # --- Model Loading and Caching ---
207
  @st.cache_resource
208
  def load_gliner_model():
209
- """
210
- Initializes and caches the GLiNER model.
211
- This ensures the model is only loaded once, improving performance.
212
- """
213
  try:
214
  return GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0", device="cpu")
215
  except Exception as e:
216
  st.error(f"Error loading the GLiNER model: {e}")
217
  st.stop()
218
 
219
- # Load the model
220
- model = load_gliner_model()
221
  st.subheader("Question-Answering", divider="violet")
222
- # Replaced two columns with a single text input
223
- question_input = st.text_input("Ask wh-questions. **Wh-questions begin with what, when, where, who, whom, which, whose, why and how. We use them to ask for specific information.**")
224
 
225
  if 'user_labels' not in st.session_state:
226
  st.session_state.user_labels = []
227
 
 
 
228
  if st.button("Add Question"):
229
  if question_input:
230
  if question_input not in st.session_state.user_labels:
@@ -238,23 +233,19 @@ if st.button("Add Question"):
238
  st.markdown("---")
239
  st.subheader("Record of Questions", divider="violet")
240
  if st.session_state.user_labels:
241
- # Use enumerate to create a unique key for each item
242
  for i, label in enumerate(st.session_state.user_labels):
243
  col_list, col_delete = st.columns([0.9, 0.1])
244
  with col_list:
245
  st.write(f"- {label}", key=f"label_{i}")
246
  with col_delete:
247
- # Create a unique key for each button using the index
248
  if st.button("Delete", key=f"delete_{i}"):
249
- # Remove the label at the specific index
250
  st.session_state.user_labels.pop(i)
251
- # Rerun to update the UI
252
  st.rerun()
253
  else:
254
  st.info("No questions defined yet. Use the input above to add one.")
255
 
256
  st.divider()
257
- # --- Main Processing Logic ---
258
  if st.button("Extract Answers"):
259
  if not text.strip():
260
  st.warning("Please enter some text to analyze.")
@@ -269,10 +260,11 @@ if st.button("Extract Answers"):
269
  )
270
  experiment.log_parameter("input_text_length", len(text))
271
  experiment.log_parameter("defined_labels", st.session_state.user_labels)
 
272
  start_time = time.time()
273
  with st.spinner("Analyzing text...", show_time=True):
274
  try:
275
- entities = model.predict_entities(text, st.session_state.user_labels)
276
  end_time = time.time()
277
  elapsed_time = end_time - start_time
278
  st.info(f"Processing took **{elapsed_time:.2f} seconds**.")
@@ -314,16 +306,7 @@ if st.button("Extract Answers"):
314
  file_name="nlpblogs_results.zip",
315
  mime="application/zip",
316
  )
317
-
318
- if comet_initialized:
319
- # Assuming fig_treemap is still defined from the main NER run
320
- # If not, you might need to re-generate it or handle the case where it's not available.
321
- try:
322
- experiment.log_figure(figure=fig_treemap, figure_name="entity_treemap_categories")
323
- except NameError:
324
- pass # Or handle this gracefully
325
- experiment.end()
326
- else: # If df is empty
327
  st.warning("No answers were found for the provided questions.")
328
  except Exception as e:
329
  st.error(f"An error occurred during answer extraction: {e}")
 
1
  import os
2
  os.environ['HF_HOME'] = '/tmp'
3
 
4
+ import os
5
  import time
6
  import streamlit as st
7
  import pandas as pd
 
16
  from comet_ml import Experiment
17
  import hashlib
18
 
19
+ # Set up environment variables
20
+ os.environ['HF_HOME'] = '/tmp'
21
+
22
  st.markdown(
23
  """
24
  <style>
 
90
  st.link_button("AI Web App Builder", " https://nlpblogs.com/custom-web-app-development/", type="primary")
91
 
92
  # --- Comet ML Setup ---
 
93
  COMET_API_KEY = os.environ.get("COMET_API_KEY")
94
  COMET_WORKSPACE = os.environ.get("COMET_WORKSPACE")
95
  COMET_PROJECT_NAME = os.environ.get("COMET_PROJECT_NAME")
 
106
  "Contact Information": ["Email", "Phone_number", "Street_address", "City", "Country"],
107
  "Personal Details": ["Date_of_birth", "Marital_status", "Person"],
108
  "Employment Status": ["Full_time", "Part_time", "Contract", "Terminated", "Retired"],
109
+ "Employment Information": ["Job_title", "Date", "Organization", "Role"],
110
  "Performance": ["Performance_score"],
111
  "Attendance": ["Leave_of_absence"],
112
  "Benefits": ["Retirement_plan", "Bonus", "Stock_options", "Health_insurance"],
 
114
  "Deductions": ["Tax", "Deductions"],
115
  "Recruitment & Sourcing": ["Interview_type", "Applicant", "Referral", "Job_board", "Recruiter"],
116
  "Legal & Compliance": ["Offer_letter", "Agreement"],
117
+ "Professional_Development": ["Certification", "Skill"]
118
  }
119
 
120
  # --- Model Loading ---
 
122
  def load_ner_model():
123
  """Loads the GLiNER model and caches it."""
124
  try:
125
+ return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints=labels)
126
  except Exception as e:
127
  st.error(f"Failed to load NER model. Please check your internet connection or model availability: {e}")
128
  st.stop()
 
136
  text = st.text_area("Type or paste your text below, and then press Ctrl + Enter", height=250, key='my_text_area')
137
 
138
  def clear_text():
139
+ """Clears the text area and session state."""
140
  st.session_state['my_text_area'] = ""
141
+ # Clear stored results
142
+ if 'df' in st.session_state:
143
+ del st.session_state.df
144
+ if 'fig_treemap' in st.session_state:
145
+ del st.session_state.fig_treemap
146
 
147
  st.button("Clear text", on_click=clear_text)
148
 
 
157
  df = pd.DataFrame(entities)
158
  if not df.empty:
159
  df['category'] = df['label'].map(reverse_category_mapping)
160
+ st.session_state.df = df # Store df in session state
161
  if comet_initialized:
162
  experiment = Experiment(
163
  api_key=COMET_API_KEY,
 
166
  )
167
  experiment.log_parameter("input_text", text)
168
  experiment.log_table("predicted_entities", df)
169
+
170
  st.subheader("Grouped Entities by Category", divider="green")
 
171
  category_names = sorted(list(category_mapping.keys()))
172
  category_tabs = st.tabs(category_names)
 
173
  for i, category_name in enumerate(category_names):
174
  with category_tabs[i]:
175
  df_category_filtered = df[df['category'] == category_name]
 
187
  - **start**: ['index of the start of the corresponding entity']
188
  - **end**: ['index of the end of the corresponding entity']
189
  ''')
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ else:
192
  st.warning("No entities were found in the provided text.")
193
+ # Clear session state if no results found
194
+ if 'df' in st.session_state:
195
+ del st.session_state.df
196
 
197
+ # --- Treemap Display Section ---
198
+ if 'df' in st.session_state and not st.session_state.df.empty:
199
+ st.divider()
200
+ st.subheader("Tree map", divider="green")
201
+ fig_treemap = px.treemap(st.session_state.df, path=[px.Constant("all"), 'category', 'label', 'text'], values='score', color='category')
202
+ fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25), paper_bgcolor='#F5FFFA', plot_bgcolor='#F5FFFA')
203
+ st.plotly_chart(fig_treemap)
204
 
205
+ # --- Question Answering Section ---
 
206
  @st.cache_resource
207
  def load_gliner_model():
208
+ """Initializes and caches the GLiNER model for QA."""
 
 
 
209
  try:
210
  return GLiNER.from_pretrained("knowledgator/gliner-multitask-v1.0", device="cpu")
211
  except Exception as e:
212
  st.error(f"Error loading the GLiNER model: {e}")
213
  st.stop()
214
 
215
+ qa_model = load_gliner_model()
 
216
  st.subheader("Question-Answering", divider="violet")
 
 
217
 
218
  if 'user_labels' not in st.session_state:
219
  st.session_state.user_labels = []
220
 
221
+ question_input = st.text_input("Ask wh-questions. **Wh-questions begin with what, when, where, who, whom, which, whose, why and how. We use them to ask for specific information.**")
222
+
223
  if st.button("Add Question"):
224
  if question_input:
225
  if question_input not in st.session_state.user_labels:
 
233
  st.markdown("---")
234
  st.subheader("Record of Questions", divider="violet")
235
  if st.session_state.user_labels:
 
236
  for i, label in enumerate(st.session_state.user_labels):
237
  col_list, col_delete = st.columns([0.9, 0.1])
238
  with col_list:
239
  st.write(f"- {label}", key=f"label_{i}")
240
  with col_delete:
 
241
  if st.button("Delete", key=f"delete_{i}"):
 
242
  st.session_state.user_labels.pop(i)
 
243
  st.rerun()
244
  else:
245
  st.info("No questions defined yet. Use the input above to add one.")
246
 
247
  st.divider()
248
+
249
  if st.button("Extract Answers"):
250
  if not text.strip():
251
  st.warning("Please enter some text to analyze.")
 
260
  )
261
  experiment.log_parameter("input_text_length", len(text))
262
  experiment.log_parameter("defined_labels", st.session_state.user_labels)
263
+
264
  start_time = time.time()
265
  with st.spinner("Analyzing text...", show_time=True):
266
  try:
267
+ entities = qa_model.predict_entities(text, st.session_state.user_labels)
268
  end_time = time.time()
269
  elapsed_time = end_time - start_time
270
  st.info(f"Processing took **{elapsed_time:.2f} seconds**.")
 
306
  file_name="nlpblogs_results.zip",
307
  mime="application/zip",
308
  )
309
+ else:
 
 
 
 
 
 
 
 
 
310
  st.warning("No answers were found for the provided questions.")
311
  except Exception as e:
312
  st.error(f"An error occurred during answer extraction: {e}")