AIEcosystem commited on
Commit
efb6584
·
verified ·
1 Parent(s): 793094e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +102 -124
src/streamlit_app.py CHANGED
@@ -14,36 +14,23 @@ from typing import Optional
14
  from gliner import GLiNER
15
  from comet_ml import Experiment
16
 
17
-
18
-
19
  # --- Page Configuration and UI Elements ---
20
  st.set_page_config(layout="wide", page_title="Named Entity Recognition App")
 
 
 
 
 
21
  st.subheader("DataHarvest", divider="violet")
22
  st.link_button("by nlpblogs", "https://nlpblogs.com", type="tertiary")
23
  st.markdown(':rainbow[**Supported Languages: English**]')
24
  expander = st.expander("**Important notes**")
25
- expander.write("""**Named Entities:** This DataHarvest web app predicts nine (9) labels: "person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"
26
-
27
- Results are presented in easy-to-read tables, visualized in an interactive tree map, pie chart and bar chart, and are available for download along with a Glossary of tags.
28
-
29
- **How to Use:** Type or paste your text into the text area below, then press Ctrl + Enter. Click the 'Results' button to extract and tag entities in your text data.
30
-
31
- **Usage Limits:** You can request results unlimited times for one (1) month.
32
-
33
- **Technical issues:** If your connection times out, please refresh the page or reopen the app's URL.
34
-
35
- For any errors or inquiries, please contact us at info@nlpblogs.com""")
36
 
37
  with st.sidebar:
38
  st.write("Use the following code to embed the DataHarvest web app on your website. Feel free to adjust the width and height values to fit your page.")
39
  code = '''
40
- <iframe
41
- src="https://aiecosystem-dataharvest.hf.space"
42
- frameborder="0"
43
- width="850"
44
- height="450"
45
- ></iframe>
46
-
47
  '''
48
  st.code(code, language="html")
49
  st.text("")
@@ -62,8 +49,6 @@ if not comet_initialized:
62
 
63
  # --- Label Definitions ---
64
  labels = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
65
- # Corrected mapping dictionary
66
- # Create a mapping dictionary for labels to categories
67
  category_mapping = {
68
  "People": ["person", "organization", "position"],
69
  "Locations": ["country", "city"],
@@ -73,16 +58,23 @@ category_mapping = {
73
  # --- Model Loading ---
74
  @st.cache_resource
75
  def load_ner_model():
76
- """Loads the GLiNER model and caches it."""
77
  try:
78
  return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints= labels)
79
  except Exception as e:
80
  st.error(f"Failed to load NER model. Please check your internet connection or model availability: {e}")
81
  st.stop()
82
  model = load_ner_model()
83
- # Flatten the mapping to a single dictionary
84
  reverse_category_mapping = {label: category for category, label_list in category_mapping.items() for label in label_list}
85
 
 
 
 
 
 
 
 
 
 
86
  # --- Text Input and Clear Button ---
87
  word_limit = 200
88
  text = st.text_area(f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter", height=250, key='my_text_area')
@@ -90,8 +82,10 @@ word_count = len(text.split())
90
  st.markdown(f"**Word count:** {word_count}/{word_limit}")
91
 
92
  def clear_text():
93
- """Clears the text area."""
94
  st.session_state['my_text_area'] = ""
 
 
95
 
96
  def remove_punctuation(text):
97
  """Removes punctuation from a string."""
@@ -104,118 +98,102 @@ st.button("Clear text", on_click=clear_text)
104
  if st.button("Results"):
105
  if not text.strip():
106
  st.warning("Please enter some text to extract entities.")
 
107
  elif word_count > word_limit:
108
  st.warning(f"Your text exceeds the {word_limit} word limit. Please shorten it to continue.")
 
109
  else:
 
 
110
  start_time = time.time()
111
- # Call the new function to remove punctuation from the input text
112
- cleaned_text = remove_punctuation(text)
113
  with st.spinner("Extracting entities...", show_time=True):
114
- # Use the cleaned text for prediction
115
  entities = model.predict_entities(cleaned_text, labels)
116
  df = pd.DataFrame(entities)
 
117
  if not df.empty:
118
  df['category'] = df['label'].map(reverse_category_mapping)
119
  if comet_initialized:
120
- experiment = Experiment(
121
- api_key=COMET_API_KEY,
122
- workspace=COMET_WORKSPACE,
123
- project_name=COMET_PROJECT_NAME,
124
- )
125
  experiment.log_parameter("input_text", text)
126
  experiment.log_table("predicted_entities", df)
127
- st.subheader("Grouped Entities by Category", divider = "violet")
128
- # Create tabs for each category
129
- category_names = sorted(list(category_mapping.keys()))
130
- category_tabs = st.tabs(category_names)
131
- for i, category_name in enumerate(category_names):
132
- with category_tabs[i]:
133
- df_category_filtered = df[df['category'] == category_name]
134
- if not df_category_filtered.empty:
135
- st.dataframe(df_category_filtered.drop(columns=['category']), use_container_width=True)
136
- else:
137
- st.info(f"No entities found for the '{category_name}' category.")
138
- with st.expander("See Glossary of tags"):
139
- st.write('''
140
- - **text**: ['entity extracted from your text data']
141
- - **score**: ['accuracy score; how accurately a tag has been assigned to a given entity']
142
- - **label**: ['label (tag) assigned to a given extracted entity']
143
- - **start**: ['index of the start of the corresponding entity']
144
- - **end**: ['index of the end of the corresponding entity']
145
- ''')
146
- st.divider()
147
- # Tree map
148
- st.subheader("Tree map", divider = "violet")
149
- fig_treemap = px.treemap(df, path=[px.Constant("all"), 'category', 'label', 'text'], values='score', color='category')
150
- fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25))
151
- st.plotly_chart(fig_treemap)
152
- # Pie and Bar charts
153
- grouped_counts = df['category'].value_counts().reset_index()
154
- grouped_counts.columns = ['category', 'count']
155
- col1, col2 = st.columns(2)
156
- with col1:
157
- st.subheader("Pie chart", divider = "violet")
158
- fig_pie = px.pie(grouped_counts, values='count', names='category', hover_data=['count'], labels={'count': 'count'}, title='Percentage of predicted categories')
159
- fig_pie.update_traces(textposition='inside', textinfo='percent+label')
160
- fig_pie.update_layout(
161
- )
162
- st.plotly_chart(fig_pie)
163
- with col2:
164
- st.subheader("Bar chart", divider = "violet")
165
- fig_bar = px.bar(grouped_counts, x="count", y="category", color="category", text_auto=True, title='Occurrences of predicted categories')
166
- fig_bar.update_layout( # Changed from fig_pie to fig_bar
167
- )
168
- st.plotly_chart(fig_bar)
169
- # Most Frequent Entities
170
- st.subheader("Most Frequent Entities", divider="violet")
171
- word_counts = df['text'].value_counts().reset_index()
172
- word_counts.columns = ['Entity', 'Count']
173
- repeating_entities = word_counts[word_counts['Count'] > 1]
174
- if not repeating_entities.empty:
175
- st.dataframe(repeating_entities, use_container_width=True)
176
- fig_repeating_bar = px.bar(repeating_entities, x='Entity', y='Count', color='Entity')
177
- fig_repeating_bar.update_layout(xaxis={'categoryorder': 'total descending'},
178
- )
179
- st.plotly_chart(fig_repeating_bar)
180
- else:
181
- st.warning("No entities were found that occur more than once.")
182
- # Download Section
183
- st.divider()
184
- dfa = pd.DataFrame(
185
- data={
186
- 'Column Name': ['text', 'label', 'score', 'start', 'end'],
187
- 'Description': [
188
- 'entity extracted from your text data',
189
- 'label (tag) assigned to a given extracted entity',
190
- 'accuracy score; how accurately a tag has been assigned to a given entity',
191
- 'index of the start of the corresponding entity',
192
- 'index of the end of the corresponding entity',
193
- ]
194
- }
195
- )
196
- buf = io.BytesIO()
197
- with zipfile.ZipFile(buf, "w") as myzip:
198
- myzip.writestr("Summary of the results.csv", df.to_csv(index=False))
199
- myzip.writestr("Glossary of tags.csv", dfa.to_csv(index=False))
200
- with stylable_container(
201
- key="download_button",
202
- css_styles="""button { background-color: red; border: 1px solid black; padding: 5px; color: white; }""",
203
- ):
204
- st.download_button(
205
- label="Download results and glossary (zip)",
206
- data=buf.getvalue(),
207
- file_name="nlpblogs_results.zip",
208
- mime="application/zip",
209
- )
210
- if comet_initialized:
211
- experiment.log_figure(figure=fig_treemap, figure_name="entity_treemap_categories")
212
  experiment.end()
213
 
214
- # Corrected placement for time calculation and display
215
  end_time = time.time()
216
  elapsed_time = end_time - start_time
217
- st.text("")
218
- st.text("")
219
- st.info(f"Results processed in **{elapsed_time:.2f} seconds**.")
220
- else: # If df is empty
221
- st.warning("No entities were found in the provided text.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from gliner import GLiNER
15
  from comet_ml import Experiment
16
 
 
 
17
  # --- Page Configuration and UI Elements ---
18
  st.set_page_config(layout="wide", page_title="Named Entity Recognition App")
19
+ st.markdown("""
20
+ <style>
21
+ /* ... (Your CSS Styles) ... */
22
+ </style>
23
+ """, unsafe_allow_html=True)
24
  st.subheader("DataHarvest", divider="violet")
25
  st.link_button("by nlpblogs", "https://nlpblogs.com", type="tertiary")
26
  st.markdown(':rainbow[**Supported Languages: English**]')
27
  expander = st.expander("**Important notes**")
28
+ expander.write("""**Named Entities:** This DataHarvest web app predicts nine (9) labels: "person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"Results are presented in easy-to-read tables, visualized in an interactive tree map, pie chart and bar chart, and are available for download along with a Glossary of tags.**How to Use:** Type or paste your text into the text area below, then press Ctrl + Enter. Click the 'Results' button to extract and tag entities in your text data.**Usage Limits:** You can request results unlimited times for one (1) month.**Technical issues:** If your connection times out, please refresh the page or reopen the app's URL. For any errors or inquiries, please contact us at info@nlpblogs.com""")
 
 
 
 
 
 
 
 
 
 
29
 
30
  with st.sidebar:
31
  st.write("Use the following code to embed the DataHarvest web app on your website. Feel free to adjust the width and height values to fit your page.")
32
  code = '''
33
+ <iframe src="https://aiecosystem-dataharvest.hf.space" frameborder="0" width="850" height="450" ></iframe>
 
 
 
 
 
 
34
  '''
35
  st.code(code, language="html")
36
  st.text("")
 
49
 
50
  # --- Label Definitions ---
51
  labels = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
 
 
52
  category_mapping = {
53
  "People": ["person", "organization", "position"],
54
  "Locations": ["country", "city"],
 
58
  # --- Model Loading ---
59
  @st.cache_resource
60
  def load_ner_model():
 
61
  try:
62
  return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints= labels)
63
  except Exception as e:
64
  st.error(f"Failed to load NER model. Please check your internet connection or model availability: {e}")
65
  st.stop()
66
  model = load_ner_model()
 
67
  reverse_category_mapping = {label: category for category, label_list in category_mapping.items() for label in label_list}
68
 
69
+ # --- Session State Initialization ---
70
+ # This is the key fix. We use session state to control what is displayed.
71
+ if 'show_results' not in st.session_state:
72
+ st.session_state.show_results = False
73
+ if 'last_text' not in st.session_state:
74
+ st.session_state.last_text = ""
75
+ if 'results_df' not in st.session_state:
76
+ st.session_state.results_df = pd.DataFrame()
77
+
78
  # --- Text Input and Clear Button ---
79
  word_limit = 200
80
  text = st.text_area(f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter", height=250, key='my_text_area')
 
82
  st.markdown(f"**Word count:** {word_count}/{word_limit}")
83
 
84
  def clear_text():
85
+ """Clears the text area and hides results."""
86
  st.session_state['my_text_area'] = ""
87
+ st.session_state.show_results = False
88
+ st.session_state.last_text = ""
89
 
90
  def remove_punctuation(text):
91
  """Removes punctuation from a string."""
 
98
  if st.button("Results"):
99
  if not text.strip():
100
  st.warning("Please enter some text to extract entities.")
101
+ st.session_state.show_results = False
102
  elif word_count > word_limit:
103
  st.warning(f"Your text exceeds the {word_limit} word limit. Please shorten it to continue.")
104
+ st.session_state.show_results = False
105
  else:
106
+ st.session_state.show_results = True
107
+ st.session_state.last_text = text
108
  start_time = time.time()
 
 
109
  with st.spinner("Extracting entities...", show_time=True):
110
+ cleaned_text = remove_punctuation(text)
111
  entities = model.predict_entities(cleaned_text, labels)
112
  df = pd.DataFrame(entities)
113
+ st.session_state.results_df = df
114
  if not df.empty:
115
  df['category'] = df['label'].map(reverse_category_mapping)
116
  if comet_initialized:
117
+ experiment = Experiment(api_key=COMET_API_KEY, workspace=COMET_WORKSPACE, project_name=COMET_PROJECT_NAME)
 
 
 
 
118
  experiment.log_parameter("input_text", text)
119
  experiment.log_table("predicted_entities", df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  experiment.end()
121
 
 
122
  end_time = time.time()
123
  elapsed_time = end_time - start_time
124
+ st.session_state.elapsed_time = elapsed_time
125
+
126
+ # Display results if the state variable is True
127
+ if st.session_state.show_results:
128
+ df = st.session_state.results_df
129
+ if not df.empty:
130
+ st.subheader("Grouped Entities by Category", divider = "violet")
131
+ category_names = sorted(list(category_mapping.keys()))
132
+ category_tabs = st.tabs(category_names)
133
+ for i, category_name in enumerate(category_names):
134
+ with category_tabs[i]:
135
+ df_category_filtered = df[df['category'] == category_name]
136
+ if not df_category_filtered.empty:
137
+ st.dataframe(df_category_filtered.drop(columns=['category']), use_container_width=True)
138
+ else:
139
+ st.info(f"No entities found for the '{category_name}' category.")
140
+
141
+ with st.expander("See Glossary of tags"):
142
+ st.write('''
143
+ - **text**: ['entity extracted from your text data']
144
+ - **score**: ['accuracy score; how accurately a tag has been assigned to a given entity']
145
+ - **label**: ['label (tag) assigned to a given extracted entity']
146
+ - **start**: ['index of the start of the corresponding entity']
147
+ - **end**: ['index of the end of the corresponding entity']
148
+ ''')
149
+ st.divider()
150
+
151
+ # Tree map
152
+ st.subheader("Tree map", divider = "violet")
153
+ fig_treemap = px.treemap(df, path=[px.Constant("all"), 'category', 'label', 'text'], values='score', color='category')
154
+ fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25))
155
+ st.plotly_chart(fig_treemap)
156
+
157
+ # Pie and Bar charts
158
+ grouped_counts = df['category'].value_counts().reset_index()
159
+ grouped_counts.columns = ['category', 'count']
160
+ col1, col2 = st.columns(2)
161
+ with col1:
162
+ st.subheader("Pie chart", divider = "violet")
163
+ fig_pie = px.pie(grouped_counts, values='count', names='category', hover_data=['count'], labels={'count': 'count'}, title='Percentage of predicted categories')
164
+ fig_pie.update_traces(textposition='inside', textinfo='percent+label')
165
+ st.plotly_chart(fig_pie)
166
+ with col2:
167
+ st.subheader("Bar chart", divider = "violet")
168
+ fig_bar = px.bar(grouped_counts, x="count", y="category", color="category", text_auto=True, title='Occurrences of predicted categories')
169
+ st.plotly_chart(fig_bar)
170
+
171
+ # Most Frequent Entities
172
+ st.subheader("Most Frequent Entities", divider="violet")
173
+ word_counts = df['text'].value_counts().reset_index()
174
+ word_counts.columns = ['Entity', 'Count']
175
+ repeating_entities = word_counts[word_counts['Count'] > 1]
176
+ if not repeating_entities.empty:
177
+ st.dataframe(repeating_entities, use_container_width=True)
178
+ fig_repeating_bar = px.bar(repeating_entities, x='Entity', y='Count', color='Entity')
179
+ fig_repeating_bar.update_layout(xaxis={'categoryorder': 'total descending'})
180
+ st.plotly_chart(fig_repeating_bar)
181
+ else:
182
+ st.warning("No entities were found that occur more than once.")
183
+
184
+ # Download Section
185
+ st.divider()
186
+ dfa = pd.DataFrame(data={'Column Name': ['text', 'label', 'score', 'start', 'end'],
187
+ 'Description': ['entity extracted from your text data', 'label (tag) assigned to a given extracted entity', 'accuracy score; how accurately a tag has been assigned to a given entity', 'index of the start of the corresponding entity', 'index of the end of the corresponding entity']})
188
+ buf = io.BytesIO()
189
+ with zipfile.ZipFile(buf, "w") as myzip:
190
+ myzip.writestr("Summary of the results.csv", df.to_csv(index=False))
191
+ myzip.writestr("Glossary of tags.csv", dfa.to_csv(index=False))
192
+ with stylable_container(key="download_button", css_styles="""button { background-color: red; border: 1px solid black; padding: 5px; color: white; }""",):
193
+ st.download_button(label="Download results and glossary (zip)", data=buf.getvalue(), file_name="nlpblogs_results.zip", mime="application/zip")
194
+
195
+ st.text("")
196
+ st.text("")
197
+ st.info(f"Results processed in **{st.session_state.elapsed_time:.2f} seconds**.")
198
+ else: # If df is empty after the button click
199
+ st.warning("No entities were found in the provided text.")