jeevitha-app commited on
Commit
dbfb384
·
verified ·
1 Parent(s): 1ef9d51

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -210
app.py CHANGED
@@ -1,221 +1,84 @@
1
- # app.py
2
- # Gradio app: English + Persian sentiment with SHAP-based interpretability and word highlighting
3
-
4
  import joblib
5
- import numpy as np
6
- import pandas as pd
7
  import shap
 
8
  import matplotlib.pyplot as plt
9
- import io
10
- import base64
11
- import html
12
- from typing import Tuple, Dict, List
13
- import math
14
-
15
- import gradio as gr
16
-
17
- # --------- Load models (replace filenames if you used different names) ----------
18
- ENG_MODEL_PATH = "best_model.pkl"
19
- ENG_VEC_PATH = "tfidf_vectorizer.pkl"
20
- PER_MODEL_PATH = "logistic_regression.pkl"
21
- PER_VEC_PATH = "tfidf_vectorizer_persian.pkl"
22
 
23
- eng_model = joblib.load(ENG_MODEL_PATH)
24
- eng_vectorizer = joblib.load(ENG_VEC_PATH)
 
 
 
25
 
26
- per_model = joblib.load(PER_MODEL_PATH)
27
- per_vectorizer = joblib.load(PER_VEC_PATH)
28
 
29
- CLASS_NAMES_EN = ["Negative", "Neutral", "Positive"]
30
- CLASS_NAMES_PER = ["منفی", "خنثی", "مثبت"]
31
 
32
- # --------- Utility: create bar data for gradio BarPlot ----------
33
- def probs_to_bar(probs: List[float], lang: str):
34
- names = CLASS_NAMES_EN if lang == "English" else CLASS_NAMES_PER
35
- return {names[i]: float(probs[i]) for i in range(len(probs))}
 
 
36
 
37
- # --------- Utility: create HTML highlight from SHAP values ----------
38
- def make_html_highlight(original_text: str,
39
- feature_names: np.ndarray,
40
- shap_values_feature: np.ndarray,
41
- vectorizer_vocab: dict,
42
- max_display: int = 30) -> str:
43
- """
44
- Simple token-level highlighting:
45
- - Tokenize by whitespace (preserves original punctuation).
46
- - For each token, attempt to map token.lower() to the vectorizer vocab;
47
- if found, get SHAP impact for that feature name.
48
- - Color red for positive contribution, blue for negative.
49
- Returns an HTML-safe string.
50
- """
51
- # Build mapping word -> shap value if present in vocabulary
52
- # vectorizer_vocab maps token -> idx in feature_names
53
- token_to_shap = {}
54
- for idx, fname in enumerate(feature_names):
55
- # Often fname is the token/ngram itself
56
- token_to_shap[fname] = shap_values_feature[idx]
57
-
58
- # Tokenize (simple)
59
- tokens = original_text.split()
60
- # Compute max magnitude for scaling opacity
61
- mags = []
62
- for t in tokens:
63
- key = t.lower()
64
- val = None
65
- # Try several common variants: exact, lower, strip punctuation from ends
66
- if key in vectorizer_vocab:
67
- val = shap_values_feature[vectorizer_vocab[key]]
68
- else:
69
- key2 = ''.join(ch for ch in key if ch.isalnum())
70
- if key2 in vectorizer_vocab:
71
- val = shap_values_feature[vectorizer_vocab[key2]]
72
- mags.append(abs(val) if val is not None else 0.0)
73
- max_mag = max(mags) if mags else 1.0
74
- if max_mag == 0:
75
- max_mag = 1.0
76
-
77
- # Build HTML with span coloring
78
- html_tokens = []
79
- for t in tokens:
80
- display = html.escape(t)
81
- key = t.lower()
82
- val = None
83
- if key in vectorizer_vocab:
84
- val = shap_values_feature[vectorizer_vocab[key]]
85
- else:
86
- key2 = ''.join(ch for ch in key if ch.isalnum())
87
- if key2 in vectorizer_vocab:
88
- val = shap_values_feature[vectorizer_vocab[key2]]
89
- if val is None or abs(val) < 1e-6:
90
- html_tokens.append(f"<span style='padding:2px'>{display}</span>")
91
- else:
92
- sign = "pos" if val > 0 else "neg"
93
- mag = min(1.0, abs(val) / max_mag) # scale 0..1
94
- opacity = 0.15 + 0.85 * mag # avoid fully transparent
95
- color = f"rgba(220,20,60,{opacity})" if sign == "pos" else f"rgba(30,144,255,{opacity})"
96
- border = "1px solid rgba(0,0,0,0.04)"
97
- html_tokens.append(
98
- f"<span style='background:{color};padding:2px;margin:1px;border-radius:4px;display:inline-block;{border}'>"
99
- f"{display}</span>"
100
- )
101
-
102
- highlighted_html = "<div style='line-height:1.6;font-size:16px'>" + " ".join(html_tokens) + "</div>"
103
- return highlighted_html
104
-
105
- # --------- Core function: predict + interpret ----------
106
- def explain_and_predict(text: str, language: str):
107
- text = text or ""
108
  if language == "English":
109
- model = eng_model
110
- vectorizer = eng_vectorizer
111
- class_names = CLASS_NAMES_EN
112
  else:
113
- model = per_model
114
- vectorizer = per_vectorizer
115
- class_names = CLASS_NAMES_PER
116
-
117
- if text.strip() == "":
118
- return "⚠️ Please enter text.", {}, {"Word": [], "SHAP Impact": []}, "<i>No input</i>"
119
-
120
- # vectorize
121
- vec = vectorizer.transform([text])
122
- probs = model.predict_proba(vec)[0]
123
- pred_class = int(np.argmax(probs))
124
- label = class_names[pred_class]
125
- confidence = float(probs[pred_class])
126
-
127
- # Build SHAP explainer on a small background (use small subset via dummy background)
128
- # NOTE: building explainer can be slow; in Spaces you can build once at import
129
- # For robustness we build a simple LinearExplainer on vector space
130
- # Use small dense sample from training if available - here use vectorizer vocabulary size fallback
131
- # Convert to dense for LinearExplainer
132
- try:
133
- # Use a small background of zeros (cheap) — LinearExplainer can accept arrays
134
- background = np.zeros((1, vec.shape[1]))
135
- explainer = shap.LinearExplainer(model, background, feature_names=vectorizer.get_feature_names_out())
136
- # compute shap on the numeric vector
137
- vec_dense = vec.toarray()
138
- shap_vals = explainer(vec_dense) # returns shap.Explanation
139
- except Exception:
140
- # fallback: use PermutationExplainer on numeric input (slower)
141
- explainer = shap.Explainer(model.predict_proba, vec)
142
- shap_vals = explainer(vec)
143
-
144
- # shap_vals.values shape: (n_outputs, n_features) OR Explanation with values (n_features, n_classes)
145
- # Normalize to feature vector for chosen class
146
- # shap_vals may be multi-output: shap_vals.values => (n_samples, n_features, n_classes) or similar
147
- try:
148
- # preferred shape: shap_vals.values -> (1, n_features, n_classes)
149
- values = shap_vals.values # ND array
150
- if values.ndim == 3:
151
- # pick sample 0, class pred_class
152
- shap_per_feature = values[0, :, pred_class]
153
- elif values.ndim == 2:
154
- # shape (n_samples, n_features) for single class models — take sample 0
155
- shap_per_feature = values[0, :]
156
- else:
157
- # try to flatten
158
- shap_per_feature = np.ravel(values)[0:vec.shape[1]]
159
- except Exception:
160
- # Last resort: try shap_vals[0].values
161
- try:
162
- shap_per_feature = shap_vals[0].values[:, pred_class]
163
- except Exception:
164
- shap_per_feature = np.zeros(vec.shape[1])
165
-
166
- # Feature names & vocab
167
- feature_names = np.array(vectorizer.get_feature_names_out())
168
- vocab = {k: v for k, v in (getattr(vectorizer, "vocabulary_", {})).items()}
169
-
170
- # Build top contributing words list (pairs)
171
- # shap_per_feature length must match len(feature_names)
172
- if len(shap_per_feature) != len(feature_names):
173
- # try to align by vectorizer.vocabulary_
174
- full_shap = np.zeros(len(feature_names))
175
- # if shap_per_feature smaller, attempt to use indices from vocab
176
- min_len = min(len(shap_per_feature), len(full_shap))
177
- full_shap[:min_len] = shap_per_feature[:min_len]
178
- shap_per_feature = full_shap
179
-
180
- # Top positive and negative features
181
- n = 10
182
- idx_sorted = np.argsort(-np.abs(shap_per_feature))
183
- top_idx = idx_sorted[:n]
184
- top_words = feature_names[top_idx].tolist()
185
- top_contribs = shap_per_feature[top_idx].tolist()
186
-
187
- # Build word table for display
188
- word_table = {"Word": top_words, "SHAP Impact": top_contribs}
189
-
190
- # Build highlight HTML (token-level approx using unigram mapping)
191
- highlight_html = make_html_highlight(text, feature_names, shap_per_feature, vocab)
192
-
193
- # Return: label string, probabilities dict, table dict, html highlight
194
- return f"🎯 **{label}** (confidence: {confidence:.2f})", probs_to_bar(probs.tolist(), language), word_table, highlight_html
195
-
196
-
197
- # --------- Gradio UI build ----------
198
- with gr.Blocks() as demo:
199
- gr.Markdown("## 🌍 Multilingual Sentiment Analysis (English 🇬🇧 & Persian 🇮🇷) — Interpretable")
200
- with gr.Row():
201
- language = gr.Radio(["English", "Persian"], value="English", label="Choose language")
202
- text_input = gr.Textbox(lines=4, placeholder="Type comment here...", label="Input text")
203
- with gr.Row():
204
- btn = gr.Button("Analyze")
205
- with gr.Row():
206
- pred_out = gr.Markdown()
207
- with gr.Row():
208
- bar = gr.BarPlot(label="Class probabilities")
209
- table = gr.Dataframe(headers=["Word", "SHAP Impact"], label="Top contributing words")
210
- with gr.Row():
211
- html_out = gr.HTML(label="Word-level Highlight (red = pushes toward prediction, blue = pushes away)")
212
-
213
- def run(text, lang):
214
- label, probs, word_table, html_highlight = explain_and_predict(text, lang)
215
- # format outputs for gradio
216
- return label, probs, pd.DataFrame(word_table), html_highlight
217
-
218
- btn.click(fn=run, inputs=[text_input, language], outputs=[pred_out, bar, table, html_out])
219
 
220
- if __name__ == "__main__":
221
- demo.launch(server_name="0.0.0.0", share=True)
 
1
+ import gradio as gr
 
 
2
  import joblib
 
 
3
  import shap
4
+ import numpy as np
5
  import matplotlib.pyplot as plt
6
+ import tempfile
7
+ import os
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # ---------------------------------------------------------
10
+ # Load both models and vectorizers
11
+ # ---------------------------------------------------------
12
+ english_model = joblib.load("models/english_model.pkl")
13
+ english_vec = joblib.load("models/english_vectorizer.pkl")
14
 
15
+ persian_model = joblib.load("models/persian_model.pkl")
16
+ persian_vec = joblib.load("models/persian_vectorizer.pkl")
17
 
18
+ class_names = ["Negative", "Neutral", "Positive"]
 
19
 
20
+ # ---------------------------------------------------------
21
+ # Prediction + Interpretability Function
22
+ # ---------------------------------------------------------
23
+ def predict_sentiment(text, language):
24
+ if not text.strip():
25
+ return "Please enter text!", None
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  if language == "English":
28
+ model, vec = english_model, english_vec
 
 
29
  else:
30
+ model, vec = persian_model, persian_vec
31
+
32
+ X = vec.transform([text])
33
+ probs = model.predict_proba(X)[0]
34
+ pred_idx = np.argmax(probs)
35
+ label = class_names[pred_idx]
36
+
37
+ # --- SHAP interpretability ---
38
+ explainer = shap.LinearExplainer(model, vec.transform([text]))
39
+ shap_vals = explainer(X)
40
+ shap_values = shap_vals.values[0][:, pred_idx]
41
+ feature_names = vec.get_feature_names_out()
42
+
43
+ top_idx = np.argsort(-abs(shap_values))[:10]
44
+ tokens = [feature_names[i] for i in top_idx]
45
+ impacts = [shap_values[i] for i in top_idx]
46
+
47
+ # Save temporary bar chart
48
+ fig, ax = plt.subplots(figsize=(6, 3))
49
+ colors = ["crimson" if v > 0 else "steelblue" for v in impacts]
50
+ ax.barh(tokens, impacts, color=colors)
51
+ ax.invert_yaxis()
52
+ ax.set_title(f"Top Words driving {label} prediction")
53
+ tmp_path = tempfile.mktemp(suffix=".png")
54
+ plt.tight_layout()
55
+ plt.savefig(tmp_path)
56
+ plt.close(fig)
57
+
58
+ explanation = f"""
59
+ **Predicted Sentiment:** {label}\n
60
+ **Confidence:** {probs[pred_idx]:.2f}\n
61
+ **Top Influential Words:**\n
62
+ {', '.join(tokens)}
63
+ """
64
+ return explanation, tmp_path
65
+
66
+ # ---------------------------------------------------------
67
+ # Gradio UI
68
+ # ---------------------------------------------------------
69
+ iface = gr.Interface(
70
+ fn=predict_sentiment,
71
+ inputs=[
72
+ gr.Textbox(lines=3, label="Enter comment"),
73
+ gr.Radio(["English", "Persian"], label="Choose Dataset/Language")
74
+ ],
75
+ outputs=[
76
+ gr.Markdown(label="Prediction + Interpretation"),
77
+ gr.Image(label="Top Word Contributions")
78
+ ],
79
+ title="🌍 Multi-Lingual Sentiment Analysis (English + Persian)",
80
+ description="Select a language, type a comment, and see both the prediction and SHAP interpretability."
81
+ )
82
+
83
+ iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84