DataHarvest / src /streamlit_app.py
AIEcosystem's picture
Update src/streamlit_app.py
a015a8e verified
import os
os.environ['HF_HOME'] = '/tmp'
import time
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import io
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import re
import string
import json
# --- PPTX Imports ---
from io import BytesIO
from pptx import Presentation
from pptx.util import Inches, Pt
from pptx.enum.text import MSO_ANCHOR, MSO_AUTO_SIZE
import plotly.io as pio # Required for image export
# ---------------------------
# --- Stable Scikit-learn LDA Imports ---
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
# ------------------------------
from gliner import GLiNER
from streamlit_extras.stylable_container import stylable_container
# Using a try/except for comet_ml import
try:
from comet_ml import Experiment
except ImportError:
class Experiment:
def __init__(self, **kwargs): pass
def log_parameter(self, *args): pass
def log_table(self, *args): pass
def end(self): pass
# --- Model Home Directory (Fix for deployment environments) ---
# Set HF_HOME environment variable to a writable path
os.environ['HF_HOME'] = '/tmp'
# --- Color Map for Highlighting and Network Graph Nodes ---
entity_color_map = {
"person": "#10b981",
"country": "#3b82f6",
"city": "#4ade80",
"organization": "#f59e0b",
"date": "#8b5cf6",
"time": "#ec4899",
"cardinal": "#06b6d4",
"money": "#f43f5e",
"position": "#a855f7",
}
# --- Label Definitions and Category Mapping (Used by the App and PPTX) ---
labels = list(entity_color_map.keys())
category_mapping = {
"People": ["person", "organization", "position"],
"Locations": ["country", "city"],
"Time": ["date", "time"],
"Numbers": ["money", "cardinal"]}
reverse_category_mapping = {label: category for category, label_list in category_mapping.items() for label in label_list}
# --- Utility Functions for Analysis and Plotly ---
def extract_label(node_name):
"""Extracts the label from a node string like 'Text (Label)'."""
match = re.search(r'\(([^)]+)\)$', node_name)
return match.group(1) if match else "Unknown"
def remove_trailing_punctuation(text_string):
"""Removes trailing punctuation from a string."""
return text_string.rstrip(string.punctuation)
def highlight_entities(text, df_entities):
"""Generates HTML to display text with entities highlighted and colored."""
if df_entities.empty:
return text
# Sort entities by start index descending to insert highlights without affecting subsequent indices
entities = df_entities.sort_values(by='start', ascending=False).to_dict('records')
highlighted_text = text
for entity in entities:
start = entity['start']
end = entity['end']
label = entity['label']
entity_text = entity['text']
color = entity_color_map.get(label, '#000000')
# Create a span with background color and tooltip
highlight_html = f'<span style="background-color: {color}; color: white; padding: 2px 4px; border-radius: 3px; cursor: help;" title="{label}">{entity_text}</span>'
# Replace the original text segment with the highlighted HTML
highlighted_text = highlighted_text[:start] + highlight_html + highlighted_text[end:]
# Use a div to mimic the Streamlit input box style for the report
return f'<div style="border: 1px solid #888888; padding: 15px; border-radius: 5px; background-color: #ffffff; font-family: monospace; white-space: pre-wrap; margin-bottom: 20px;">{highlighted_text}</div>'
def perform_topic_modeling(df_entities, num_topics=2, num_top_words=10):
"""
Performs basic Topic Modeling using LDA on the extracted entities,
allowing for n-grams to capture multi-word entities like 'Dr. Emily Carter'.
"""
# 1. Prepare Documents: Use unique entities (they are short, clean documents)
documents = df_entities['text'].unique().tolist()
if len(documents) < 2:
return None
N = min(num_top_words, len(documents))
try:
# 2. Vectorizer: Use TfidfVectorizer, but allow unigrams, bigrams, and trigrams (ngram_range)
# to capture multi-word entities. We keep stop_words='english' for the *components* of the entity.
tfidf_vectorizer = TfidfVectorizer(
max_df=0.95,
min_df=2, # Only consider words/phrases that appear at least twice to find topics
stop_words='english',
ngram_range=(1, 3) # This is the KEY to capturing "Dr. Emily Carter" as a single token (if it appears enough times)
)
tfidf = tfidf_vectorizer.fit_transform(documents)
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
# Check if the vocabulary is too small after tokenization/ngram generation
if len(tfidf_feature_names) < num_topics:
# Re-run with min_df=1 if vocab is too small
tfidf_vectorizer = TfidfVectorizer(
max_df=1.0, min_df=1, stop_words='english', ngram_range=(1, 3)
)
tfidf = tfidf_vectorizer.fit_transform(documents)
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
if len(tfidf_feature_names) < num_topics:
return None
# 3. LDA Model Fit
lda = LatentDirichletAllocation(
n_components=num_topics, max_iter=5, learning_method='online',
random_state=42, n_jobs=-1
)
lda.fit(tfidf)
# 4. Extract Topic Data
topic_data_list = []
for topic_idx, topic in enumerate(lda.components_):
top_words_indices = topic.argsort()[:-N - 1:-1]
# These top_words will now include phrases like 'emily carter' or 'european space agency'
top_words = [tfidf_feature_names[i] for i in top_words_indices]
word_weights = [topic[i] for i in top_words_indices]
for word, weight in zip(top_words, word_weights):
topic_data_list.append({
'Topic_ID': f'Topic #{topic_idx + 1}',
'Word': word,
'Weight': weight,
})
return pd.DataFrame(topic_data_list)
except Exception as e:
# A broader catch for robustness
# st.error(f"Topic modeling failed: {e}") # Keep commented out for cleaner app
return None
def create_topic_word_bubbles(df_topic_data):
"""Generates a Plotly Bubble Chart for top words across
all topics, displaying the word directly on the bubble."""
# Renaming columns to match the output of perform_topic_modeling
df_topic_data = df_topic_data.rename(columns={'Topic_ID': 'topic',
'Word': 'word', 'Weight': 'weight'})
df_topic_data['x_pos'] = df_topic_data.index # Use index for x-position
if df_topic_data.empty:
return None
fig = px.scatter(
df_topic_data,
x='x_pos',
y='weight',
size='weight',
color='topic',
# Set text to the word
text='word',
hover_name='word',
size_max=40,
title='Topic Word Weights (Bubble Chart)',
color_discrete_sequence=px.colors.qualitative.Bold,
labels={
'x_pos': 'Entity/Word Index',
'weight': 'Word Weight',
'topic': 'Topic ID'
},
custom_data=['word', 'weight', 'topic']
)
fig.update_layout(
xaxis_title="Entity/Word",
yaxis_title="Word Weight",
# Hide x-axis labels since words are now labels
xaxis={'tickangle': -45, 'showgrid': False, 'showticklabels': False, 'zeroline': False, 'showline': False},
yaxis={'showgrid': True},
showlegend=True,
plot_bgcolor='#f9f9f9',
paper_bgcolor='#f9f9f9',
height=600,
margin=dict(t=50, b=100, l=50, r=10),
)
# Update traces to show the word text, set the text position, and set text color
fig.update_traces(
# Position the text on top of the bubble
textposition='middle center',
# --- THE KEY FIX IS HERE ---
# Set the text color to white for visibility against dark bubble colors
textfont=dict(color='white', size=10),
# ---------------------------
hovertemplate='<b>%{customdata[0]}</b><br>Weight: %{customdata[1]:.3f}<extra></extra>',
marker=dict(line=dict(width=1, color='DarkSlateGrey'))
)
return fig
def generate_network_graph(df, raw_text):
"""
Generates a network graph visualization (Node Plot) with edges
based on entity co-occurrence in sentences. (Content omitted for brevity but assumed to be here).
"""
# Using the existing generate_network_graph logic from previous context...
entity_counts = df['text'].value_counts().reset_index()
entity_counts.columns = ['text', 'frequency']
unique_entities = df.drop_duplicates(subset=['text', 'label']).merge(entity_counts, on='text')
if unique_entities.shape[0] < 2:
return go.Figure().update_layout(title="Not enough unique entities for a meaningful graph.")
num_nodes = len(unique_entities)
thetas = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False)
radius = 10
unique_entities['x'] = radius * np.cos(thetas) + np.random.normal(0, 0.5, num_nodes)
unique_entities['y'] = radius * np.sin(thetas) + np.random.normal(0, 0.5, num_nodes)
pos_map = unique_entities.set_index('text')[['x', 'y']].to_dict('index')
edges = set()
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', raw_text)
for sentence in sentences:
entities_in_sentence = []
for entity_text in unique_entities['text'].unique():
if entity_text.lower() in sentence.lower():
entities_in_sentence.append(entity_text)
unique_entities_in_sentence = list(set(entities_in_sentence))
for i in range(len(unique_entities_in_sentence)):
for j in range(i + 1, len(unique_entities_in_sentence)):
node1 = unique_entities_in_sentence[i]
node2 = unique_entities_in_sentence[j]
edge_tuple = tuple(sorted((node1, node2)))
edges.add(edge_tuple)
edge_x = []
edge_y = []
for edge in edges:
n1, n2 = edge
if n1 in pos_map and n2 in pos_map:
edge_x.extend([pos_map[n1]['x'], pos_map[n2]['x'], None])
edge_y.extend([pos_map[n1]['y'], pos_map[n2]['y'], None])
fig = go.Figure()
edge_trace = go.Scatter(
x=edge_x, y=edge_y,
line=dict(width=0.5, color='#888'),
hoverinfo='none',
mode='lines',
name='Co-occurrence Edges',
showlegend=False
)
fig.add_trace(edge_trace)
fig.add_trace(go.Scatter(
x=unique_entities['x'],
y=unique_entities['y'],
mode='markers+text',
name='Entities',
text=unique_entities['text'],
textposition="top center",
showlegend=False,
marker=dict(
size=unique_entities['frequency'] * 5 + 10,
color=[entity_color_map.get(label, '#cccccc') for label in unique_entities['label']],
line_width=1,
line_color='black',
opacity=0.9
),
textfont=dict(size=10),
customdata=unique_entities[['label', 'score', 'frequency']],
hovertemplate=(
"<b>%{text}</b><br>" +
"Label: %{customdata[0]}<br>" +
"Score: %{customdata[1]:.2f}<br>" +
"Frequency: %{customdata[2]}<extra></extra>"
)
))
legend_traces = []
seen_labels = set()
for index, row in unique_entities.iterrows():
label = row['label']
if label not in seen_labels:
seen_labels.add(label)
color = entity_color_map.get(label, '#cccccc')
legend_traces.append(go.Scatter(
x=[None], y=[None], mode='markers', marker=dict(size=10, color=color), name=f"{label.capitalize()}", showlegend=True
))
for trace in legend_traces:
fig.add_trace(trace)
fig.update_layout(
title='Entity Co-occurrence Network (Edges = Same Sentence)',
showlegend=True,
hovermode='closest',
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-15, 15]),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-15, 15]),
plot_bgcolor='#f9f9f9',
paper_bgcolor='#f9f9f9',
margin=dict(t=50, b=10, l=10, r=10),
height=600
)
return fig
# --- NEW CSV GENERATION FUNCTION ---
def generate_entity_csv(df):
"""
Generates a CSV file of the extracted entities in an in-memory buffer,
including text, label, category, score, start, and end indices.
"""
csv_buffer = BytesIO()
# Select desired columns and write to buffer
df_export = df[['text', 'label', 'category', 'score', 'start', 'end']]
csv_buffer.write(df_export.to_csv(index=False).encode('utf-8'))
csv_buffer.seek(0)
return csv_buffer
# -----------------------------------
# --- Existing App Functionality (HTML) ---
def generate_html_report(df, text_input, elapsed_time, df_topic_data):
"""
Generates a full HTML report containing all analysis results and visualizations.
(Content omitted for brevity but assumed to be here).
"""
# 1. Generate Visualizations (Plotly HTML)
# 1a. Treemap
fig_treemap = px.treemap(
df,
path=[px.Constant("All Entities"), 'category', 'label', 'text'],
values='score',
color='category',
title="Entity Distribution by Category and Label",
color_discrete_sequence=px.colors.qualitative.Dark24
)
fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25))
treemap_html = fig_treemap.to_html(full_html=False, include_plotlyjs='cdn')
# 1b. Pie Chart
grouped_counts = df['category'].value_counts().reset_index()
grouped_counts.columns = ['Category', 'Count']
# Changed color_discrete_sequence from sequential.RdBu (which has reds) to sequential.Cividis
fig_pie = px.pie(grouped_counts, values='Count', names='Category',title='Distribution of Entities by Category',color_discrete_sequence=px.colors.sequential.Cividis)
fig_pie.update_layout(margin=dict(t=50, b=10))
pie_html = fig_pie.to_html(full_html=False, include_plotlyjs='cdn')
# 1c. Bar Chart (Category Count)
fig_bar_category = px.bar(grouped_counts, x='Category', y='Count',color='Category', title='Total Entities per Category',color_discrete_sequence=px.colors.qualitative.Pastel)
fig_bar_category.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=50, b=100))
bar_category_html = fig_bar_category.to_html(full_html=False,include_plotlyjs='cdn')
# 1d. Bar Chart (Most Frequent Entities)
word_counts = df['text'].value_counts().reset_index()
word_counts.columns = ['Entity', 'Count']
repeating_entities = word_counts[word_counts['Count'] > 1].head(10)
bar_freq_html = '<p>No entities appear more than once in the text for visualization.</p>'
if not repeating_entities.empty:
# Changed color_discrete_sequence from sequential.Plasma (which has pink/magenta) to sequential.Viridis
fig_bar_freq = px.bar(repeating_entities, x='Entity', y='Count',color='Entity', title='Top 10 Most Frequent Entities',color_discrete_sequence=px.colors.sequential.Viridis)
fig_bar_freq.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=50, b=100))
bar_freq_html = fig_bar_freq.to_html(full_html=False, include_plotlyjs='cdn')
# 1e. Network Graph HTML
network_fig = generate_network_graph(df, text_input)
network_html = network_fig.to_html(full_html=False, include_plotlyjs='cdn')
# 1f. Topic Charts HTML
topic_charts_html = '<h3>Topic Word Weights (Bubble Chart)</h3>'
if df_topic_data is not None and not df_topic_data.empty:
bubble_figure = create_topic_word_bubbles(df_topic_data)
if bubble_figure:
topic_charts_html += f'<div class="chart-box">{bubble_figure.to_html(full_html=False, include_plotlyjs="cdn", config={"responsive": True})}</div>'
else:
topic_charts_html += '<p style="color: red;">Error: Topic modeling data was available but visualization failed.</p>'
else:
topic_charts_html += '<div class="chart-box" style="text-align: center; padding: 50px; background-color: #fff; border: 1px dashed #888888;">' # Changed border color
topic_charts_html += '<p><strong>Topic Modeling requires more unique input.</strong></p>'
topic_charts_html += '<p>Please enter text containing at least two unique entities to generate the Topic Bubble Chart.</p>'
topic_charts_html += '</div>'
# 2. Get Highlighted Text
highlighted_text_html = highlight_entities(text_input, df).replace("div style", "div class='highlighted-text' style")
# 3. Entity Tables (Pandas to HTML)
entity_table_html = df[['text', 'label', 'score', 'start', 'end', 'category']].to_html(
classes='table table-striped',
index=False
)
# 4. Construct the Final HTML
html_content = f"""<!DOCTYPE html><html lang="en"><head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Entity and Topic Analysis Report</title>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<style>
body {{ font-family: 'Inter', sans-serif; margin: 0; padding: 20px; background-color: #f4f4f9; color: #333; }}
.container {{ max-width: 1200px; margin: 0 auto; background-color: #ffffff; padding: 30px; border-radius: 12px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); }}
h1 {{ color: #007bff; border-bottom: 3px solid #007bff; padding-bottom: 10px; margin-top: 0; }}
h2 {{ color: #007bff; margin-top: 30px; border-bottom: 1px solid #ddd; padding-bottom: 5px; }}
h3 {{ color: #555; margin-top: 20px; }}
.metadata {{ background-color: #e6f0ff; padding: 15px; border-radius: 8px; margin-bottom: 20px; font-size: 0.9em; }}
.chart-box {{ background-color: #f9f9f9; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05); min-width: 0; margin-bottom: 20px; }}
table {{ width: 100%; border-collapse: collapse; margin-top: 15px; }}
table th, table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
table th {{ background-color: #f0f0f0; }}
.highlighted-text {{ border: 1px solid #888888; padding: 15px; border-radius: 5px; background-color: #ffffff; font-family: monospace; white-space: pre-wrap; margin-bottom: 20px; }}
</style></head><body>
<div class="container">
<h1>Entity and Topic Analysis Report</h1>
<div class="metadata">
<p><strong>Generated on:</strong> {time.strftime('%Y-%m-%d')}</p>
<p><strong>Processing Time:</strong> {elapsed_time:.2f} seconds</p>
</div>
<h2>1. Analyzed Text & Extracted Entities</h2>
<h3>Original Text with Highlighted Entities</h3>
<div class="highlighted-text-container">
{highlighted_text_html}
</div>
<h2>2. Full Extracted Entities Table</h2>
{entity_table_html}
<h2>3. Data Visualizations</h2>
<h3>3.1 Entity Distribution Treemap</h3>
<div class="chart-box">{treemap_html}</div>
<h3>3.2 Comparative Charts (Pie, Category Count, Frequency) - *Stacked Vertically*</h3>
<div class="chart-box">{pie_html}</div>
<div class="chart-box">{bar_category_html}</div>
<div class="chart-box">{bar_freq_html}</div>
<h3>3.3 Entity Relationship Map (Edges = Same Sentence)</h3>
<div class="chart-box">{network_html}</div>
<h2>4. Topic Modelling</h2>
{topic_charts_html}
</div></body></html>
"""
return html_content
# --- Page Configuration and Styling (No Sidebar) ---
st.set_page_config(layout="wide", page_title="NER & Topic Report App")
# --- Conditional Mobile Warning ---
st.markdown(
"""
<style>
/* CSS Media Query: Only show the content inside this selector when the screen width is 600px or less (typical mobile size) */
@media (max-width: 600px) {
#mobile-warning-container {
display: block; /* Show the warning container */
background-color: #ffcccc; /* Light red/pink background */
color: #cc0000; /* Dark red text */
padding: 10px;
border-radius: 5px;
text-align: center;
margin-bottom: 20px;
font-weight: bold;
border: 1px solid #cc0000;
}
}
/* Hide the content by default (for larger screens) */
@media (min-width: 601px) {
#mobile-warning-container {
display: none; /* Hide the warning container on desktop */
}
}
</style>
<div id="mobile-warning-container">
โš ๏ธ **Tip for Mobile Users:** For the best viewing experience of the charts and tables, please switch your browser to **"Desktop Site"** view.
</div>
""",
unsafe_allow_html=True
)
# ----------------------------------
st.markdown(
"""
<style>
/* ... (Keep your existing styles for main, stApp, stTextArea, stButton) ... */
/* --- FIX: Tab Label Colors for Visibility --- */
/* Target the container for the tab labels (the buttons) */
[data-testid="stConfigurableTabs"] button {
color: #333333 !important; /* Dark gray for inactive tabs */
background-color: #f0f0f0; /* Light gray background for inactive tabs */
border: 1px solid #cccccc;
}
/* Target the ACTIVE tab label */
[data-testid="stConfigurableTabs"] button[aria-selected="true"] {
color: #FFFFFF !important; /* White text for active tab */
background-color: #007bff; /* Blue background for active tab */
border-bottom: 2px solid #007bff; /* Optional: adds an accent line */
}
/* Expander header color fix (since you overwrote it to white) */
.streamlit-expanderHeader {
color: #007bff; /* Blue text for Expander header */
}
</style>
""",
unsafe_allow_html=True
)
st.subheader("Entity and Topic Analysis Report Generator", divider="blue") # Changed divider from "rainbow" (often includes red/pink) to "blue"
st.link_button("by nlpblogs", "https://nlpblogs.com", type="tertiary")
tab1, tab2 = st.tabs(["Embed", "Important Notes"]) # Assuming you have defined the tabs
with tab1:
with st.expander("Embed"):
st.write("Use the following code to embed the DataHarvest web app on your website. Feel free to adjust the width and height values to fit your page.")
code = '''
<iframe
src="https://aiecosystem-dataharvest.hf.space"
frameborder="0"
width="850"
height="450"
></iframe>
'''
st.code(code, language="html") # Keeps the copy icon, as intended for tab1
with tab2:
expander = st.expander("**Important Notes**")
# Use st.markdown() with a code block (```) to display the notes
# without the copy-to-clipboard icon, and retaining the styling.
expander.markdown("""
**Named Entities:** This DataHarvest web app predicts nine (9) labels: "person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"
**Results:** Results are compiled into a single, comprehensive **HTML report** and a **CSV file** for easy download and sharing.
**How to Use:** Type or paste your text into the text area below, press Ctrl + Enter, and then click the 'Results' button.
**Technical issues:** If your connection times out, please refresh the page or reopen the app's URL.
""")
st.markdown("For any errors or inquiries, please contact us at [info@nlpblogs.com](mailto:info@nlpblogs.com)")
# --- Comet ML Setup (Placeholder/Conditional) ---
COMET_API_KEY = os.environ.get("COMET_API_KEY")
COMET_WORKSPACE = os.environ.get("COMET_WORKSPACE")
COMET_PROJECT_NAME = os.environ.get("COMET_PROJECT_NAME")
comet_initialized = bool(COMET_API_KEY and COMET_WORKSPACE and COMET_PROJECT_NAME)
# --- Model Loading ---
@st.cache_resource
def load_ner_model():
"""Loads the GLiNER model and caches it."""
try:
return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints=labels)
except Exception as e:
st.error(f"Failed to load NER model. Please check your internet connection or model availability: {e}")
st.stop()
model = load_ner_model()
# --- LONG DEFAULT TEXT (178 Words) ---
DEFAULT_TEXT = (
"In June 2024, the founder, Dr. Emily Carter, officially announced a new, expansive partnership between "
"TechSolutions Inc. and the European Space Agency (ESA). This strategic alliance represents a significant "
"leap forward for commercial space technology across the entire **European Union**. The agreement, finalized "
"on Monday in Paris, France, focuses specifically on jointly developing the next generation of the 'Astra' "
"software platform. This version of the **Astra** platform is critical for processing and managing the vast amounts of data being sent "
"back from the recent Mars rover mission. This project underscores the ESA's commitment to advancing "
"space capabilities within the **European Union**. The core team, including lead engineer Marcus Davies, will hold "
"their first collaborative workshop in Berlin, Germany, on August 15th. The community response on social "
"media platform X (under the username @TechCEO) was overwhelmingly positive, with many major tech "
"publications, including Wired Magazine, predicting a major impact on the space technology industry by the "
"end of the year, further strengthening the technological standing of the **European Union**. The platform is designed to be compatible with both Windows and Linux operating systems. "
"The initial funding, secured via a Series B round, totaled $50 million. Financial analysts from Morgan Stanley "
"are closely monitoring the impact on TechSolutions Inc.'s Q3 financial reports, expected to be released to the "
"general public by October 1st. The goal is to deploy the **Astra** v2 platform before the next solar eclipse event in 2026."
)
# -----------------------------------
# --- Session State Initialization (CRITICAL FIX) ---
if 'show_results' not in st.session_state:
st.session_state.show_results = False
if 'last_text' not in st.session_state:
st.session_state.last_text = ""
if 'results_df' not in st.session_state:
st.session_state.results_df = pd.DataFrame()
if 'elapsed_time' not in st.session_state:
st.session_state.elapsed_time = 0.0
if 'topic_results' not in st.session_state:
st.session_state.topic_results = None
if 'my_text_area' not in st.session_state:
st.session_state.my_text_area = DEFAULT_TEXT
# --- Clear Button Function (MODIFIED) ---
def clear_text():
"""Clears the text area (sets it to an empty string) and hides results."""
st.session_state['my_text_area'] = ""
st.session_state.show_results = False
st.session_state.last_text = ""
st.session_state.results_df = pd.DataFrame()
st.session_state.elapsed_time = 0.0
st.session_state.topic_results = None
# --- Text Input and Clear Button ---
word_limit = 1000
text = st.text_area(
f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter",
height=250,
key='my_text_area',
)
word_count = len(text.split())
st.markdown(f"**Word count:** {word_count}/{word_limit}")
st.button("Clear text", on_click=clear_text)
# --- Results Trigger and Processing (Updated Logic) ---
if st.button("Results"):
if not text.strip():
st.warning("Please enter some text to extract entities.")
st.session_state.show_results = False
elif word_count > word_limit:
st.warning(f"Your text exceeds the {word_limit} word limit. Please shorten it to continue.")
st.session_state.show_results = False
else:
with st.spinner("Extracting entities and generating report data...", show_time=True):
if text != st.session_state.last_text:
st.session_state.last_text = text
start_time = time.time()
# --- Model Prediction & Dataframe Creation ---
entities = model.predict_entities(text, labels)
df = pd.DataFrame(entities)
if not df.empty:
df['text'] = df['text'].apply(remove_trailing_punctuation)
df['category'] = df['label'].map(reverse_category_mapping)
st.session_state.results_df = df
unique_entity_count = len(df['text'].unique())
N_TOP_WORDS_TO_USE = min(10, unique_entity_count)
st.session_state.topic_results = perform_topic_modeling(
df,
num_topics=2,
num_top_words=N_TOP_WORDS_TO_USE
)
if comet_initialized:
experiment = Experiment(api_key=COMET_API_KEY, workspace=COMET_WORKSPACE, project_name=COMET_PROJECT_NAME)
experiment.log_parameter("input_text", text)
experiment.log_table("predicted_entities", df)
experiment.end()
else:
st.session_state.results_df = pd.DataFrame()
st.session_state.topic_results = None
end_time = time.time()
st.session_state.elapsed_time = end_time - start_time
st.info(f"Report data generated in **{st.session_state.elapsed_time:.2f} seconds**.")
st.session_state.show_results = True
# --- Display Download Link and Results ---
if st.session_state.show_results:
df = st.session_state.results_df
df_topic_data = st.session_state.topic_results
if df.empty:
st.warning("No entities were found in the provided text.")
else:
st.subheader("Analysis Results", divider="blue")
# 1. Highlighted Text
st.markdown("### 1. Analyzed Text with Highlighted Entities")
st.markdown(highlight_entities(st.session_state.last_text, df), unsafe_allow_html=True)
# 2. Detailed Entity Analysis Tabs
st.markdown("### 2. Detailed Entity Analysis")
tab_category_details, tab_treemap_viz = st.tabs(["๐Ÿ“‘ Entities Grouped by Category", "๐Ÿ—บ๏ธ Treemap Distribution"])
with tab_category_details:
st.markdown("#### Detailed Entities Table (Grouped by Category)")
unique_categories = list(category_mapping.keys())
tabs_category = st.tabs(unique_categories)
for category, tab in zip(unique_categories, tabs_category):
df_category = df[df['category'] == category][['text', 'label', 'score', 'start', 'end']].sort_values(by='score', ascending=False)
with tab:
st.markdown(f"##### {category} Entities ({len(df_category)} total)")
if not df_category.empty:
st.dataframe(
df_category,
use_container_width=True,
column_config={'score': st.column_config.NumberColumn(format="%.4f")}
)
else:
st.info(f"No entities of category **{category}** were found in the text.")
with st.expander("See Glossary of tags"):
st.write('''
- **text**: ['entity extracted from your text data']
- **label**: ['label (tag) assigned to a given extracted entity']
- **score**: ['accuracy score; how accurately a tag has been assigned to a given entity']
- **start**: ['index of the start of the corresponding entity']
- **end**: ['index of the end of the corresponding entity']
''')
with tab_treemap_viz:
st.markdown("#### Treemap: Entity Distribution")
fig_treemap = px.treemap(
df,
path=[px.Constant("All Entities"), 'category', 'label', 'text'],
values='score',
color='category',
color_discrete_sequence=px.colors.qualitative.Dark24
)
fig_treemap.update_layout(margin=dict(t=10, l=10, r=10, b=10))
st.plotly_chart(fig_treemap, use_container_width=True)
# 3. Comparative Charts
st.markdown("---")
st.markdown("### 3. Comparative Charts")
col1, col2, col3 = st.columns(3)
grouped_counts = df['category'].value_counts().reset_index()
grouped_counts.columns = ['Category', 'Count']
with col1: # Pie Chart
# Changed color_discrete_sequence
fig_pie = px.pie(grouped_counts, values='Count', names='Category',title='Distribution of Entities by Category',color_discrete_sequence=px.colors.sequential.Cividis)
fig_pie.update_layout(margin=dict(t=30, b=10, l=10, r=10), height=350)
st.plotly_chart(fig_pie, use_container_width=True)
with col2: # Bar Chart (Category Count)
fig_bar_category = px.bar(grouped_counts, x='Category', y='Count',color='Category', title='Total Entities per Category',color_discrete_sequence=px.colors.qualitative.Pastel)
fig_bar_category.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=30, b=10, l=10, r=10), height=350)
st.plotly_chart(fig_bar_category, use_container_width=True)
with col3: # Bar Chart (Most Frequent Entities)
word_counts = df['text'].value_counts().reset_index()
word_counts.columns = ['Entity', 'Count']
repeating_entities = word_counts[word_counts['Count'] > 1].head(10)
if not repeating_entities.empty:
# Changed color_discrete_sequence
fig_bar_freq = px.bar(repeating_entities, x='Entity', y='Count',color='Entity', title='Top 10 Most Frequent Entities',color_discrete_sequence=px.colors.sequential.Viridis)
fig_bar_freq.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=30, b=10, l=10, r=10), height=350)
st.plotly_chart(fig_bar_freq, use_container_width=True)
else:
st.info("No entities repeat for frequency chart.")
st.markdown("---")
st.markdown("### 4. Entity Relationship Map")
network_fig = generate_network_graph(df, st.session_state.last_text)
st.plotly_chart(network_fig, use_container_width=True)
st.markdown("---")
st.markdown("### 5. Topic Modelling Analysis")
if df_topic_data is not None and not df_topic_data.empty:
bubble_figure = create_topic_word_bubbles(df_topic_data)
if bubble_figure:
st.plotly_chart(bubble_figure, use_container_width=True)
else:
st.error("Error generating Topic Word Bubble Chart.")
else:
st.info("Topic modeling requires more unique input (at least two unique entities).")
# --- Report Download ---
st.markdown("---")
st.markdown("### Download Full Report Artifacts")
# 1. HTML Report Download (Retained)
html_report = generate_html_report(df, st.session_state.last_text, st.session_state.elapsed_time, df_topic_data)
st.download_button(
label="Download Comprehensive HTML Report",
data=html_report,
file_name="ner_topic_report.html",
mime="text/html",
type="primary"
)
# 2. CSV Data Download (NEW)
csv_buffer = generate_entity_csv(df)
st.download_button(
label="Download Extracted Entities (CSV)",
data=csv_buffer,
file_name="extracted_entities.csv",
mime="text/csv",
type="secondary"
)