jeevitha-app commited on
Commit
0567c9f
·
verified ·
1 Parent(s): 255e4d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -55
app.py CHANGED
@@ -1,84 +1,114 @@
 
 
 
 
 
1
  import gradio as gr
2
  import joblib
3
- import shap
4
  import numpy as np
 
5
  import matplotlib.pyplot as plt
6
- import tempfile
7
  import os
8
 
9
- # ---------------------------------------------------------
10
- # Load both models and vectorizers
11
- # ---------------------------------------------------------
12
- english_model = joblib.load("best_model.pkl")
13
- english_vec = joblib.load("tfidf_vectorizer.pkl")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- persian_model = joblib.load("logistic_regression.pkl")
16
- persian_vec = joblib.load("tfidf_vectorizer_persian.pkl")
 
 
 
 
 
 
 
17
 
18
- class_names = ["Negative", "Neutral", "Positive"]
19
 
20
- # ---------------------------------------------------------
21
- # Prediction + Interpretability Function
22
- # ---------------------------------------------------------
23
  def predict_sentiment(text, language):
24
  if not text.strip():
25
- return "Please enter text!", None
26
 
27
  if language == "English":
28
- model, vec = english_model, english_vec
29
  else:
30
- model, vec = persian_model, persian_vec
31
 
32
- X = vec.transform([text])
33
- probs = model.predict_proba(X)[0]
34
  pred_idx = np.argmax(probs)
35
- label = class_names[pred_idx]
36
-
37
- # --- SHAP interpretability ---
38
- explainer = shap.LinearExplainer(model, vec.transform([text]))
39
- shap_vals = explainer(X)
40
- shap_values = shap_vals.values[0][:, pred_idx]
41
- feature_names = vec.get_feature_names_out()
42
-
43
- top_idx = np.argsort(-abs(shap_values))[:10]
44
- tokens = [feature_names[i] for i in top_idx]
45
- impacts = [shap_values[i] for i in top_idx]
46
-
47
- # Save temporary bar chart
48
- fig, ax = plt.subplots(figsize=(6, 3))
49
- colors = ["crimson" if v > 0 else "steelblue" for v in impacts]
50
- ax.barh(tokens, impacts, color=colors)
51
- ax.invert_yaxis()
52
- ax.set_title(f"Top Words driving {label} prediction")
53
- tmp_path = tempfile.mktemp(suffix=".png")
54
- plt.tight_layout()
55
- plt.savefig(tmp_path)
56
- plt.close(fig)
57
 
 
 
 
 
58
  explanation = f"""
59
- **Predicted Sentiment:** {label}\n
60
- **Confidence:** {probs[pred_idx]:.2f}\n
61
- **Top Influential Words:**\n
62
- {', '.join(tokens)}
63
  """
64
- return explanation, tmp_path
65
 
66
- # ---------------------------------------------------------
67
- # Gradio UI
68
- # ---------------------------------------------------------
69
- iface = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  fn=predict_sentiment,
71
  inputs=[
72
  gr.Textbox(lines=3, label="Enter comment"),
73
- gr.Radio(["English", "Persian"], label="Choose Dataset/Language")
74
  ],
75
  outputs=[
76
- gr.Markdown(label="Prediction + Interpretation"),
77
  gr.Image(label="Top Word Contributions")
78
  ],
79
- title="🌍 Multi-Lingual Sentiment Analysis (English + Persian)",
80
- description="Select a language, type a comment, and see both the prediction and SHAP interpretability."
 
81
  )
82
 
83
- iface.launch()
84
-
 
1
+ # ============================================================
2
+ # 🌍 Multi-Lingual Sentiment Analysis (English + Persian)
3
+ # With SHAP Interpretability
4
+ # ============================================================
5
+
6
  import gradio as gr
7
  import joblib
 
8
  import numpy as np
9
+ import shap
10
  import matplotlib.pyplot as plt
 
11
  import os
12
 
13
+ # ------------------------------------------------------------
14
+ # 1️⃣ Load Pretrained Models and Vectorizers
15
+ # ------------------------------------------------------------
16
+ english_model = joblib.load("english_model.pkl")
17
+ english_vectorizer = joblib.load("english_vectorizer.pkl")
18
+
19
+ persian_model = joblib.load("persian_model.pkl")
20
+ persian_vectorizer = joblib.load("persian_vectorizer.pkl")
21
+
22
+ # Define class labels
23
+ english_labels = ["Negative", "Neutral", "Positive"]
24
+ persian_labels = ["منفی", "خنثی", "مثبت"]
25
+
26
+ # ------------------------------------------------------------
27
+ # 2️⃣ SHAP Visualization Function
28
+ # ------------------------------------------------------------
29
+ def get_shap_plot(model, vectorizer, text, class_index, class_name):
30
+ X_input = vectorizer.transform([text])
31
+ explainer = shap.Explainer(model, vectorizer.transform([" ".join(text.split()[:50])]))
32
+ shap_values = explainer(X_input)
33
+ shap_for_class = shap_values.values[0][:, class_index]
34
+ feature_names = np.array(vectorizer.get_feature_names_out())
35
+
36
+ top_idx = np.argsort(-np.abs(shap_for_class))[:10]
37
+ top_words = feature_names[top_idx]
38
+ top_impacts = shap_for_class[top_idx]
39
 
40
+ plt.figure(figsize=(6, 3))
41
+ colors = ["crimson" if v > 0 else "steelblue" for v in top_impacts]
42
+ plt.barh(top_words, top_impacts, color=colors)
43
+ plt.title(f"Top Words driving {class_name} prediction")
44
+ plt.xlabel("SHAP Value (Impact)")
45
+ plt.gca().invert_yaxis()
46
+ plt.tight_layout()
47
+ plt.savefig("shap_plot.png", bbox_inches='tight')
48
+ plt.close()
49
 
50
+ return top_words.tolist(), "shap_plot.png"
51
 
52
+ # ------------------------------------------------------------
53
+ # 3️⃣ Prediction + Interpretability Function
54
+ # ------------------------------------------------------------
55
  def predict_sentiment(text, language):
56
  if not text.strip():
57
+ return "Please enter a comment.", None
58
 
59
  if language == "English":
60
+ model, vectorizer, labels = english_model, english_vectorizer, english_labels
61
  else:
62
+ model, vectorizer, labels = persian_model, persian_vectorizer, persian_labels
63
 
64
+ X_input = vectorizer.transform([text])
65
+ probs = model.predict_proba(X_input)[0]
66
  pred_idx = np.argmax(probs)
67
+ pred_class = labels[pred_idx]
68
+ conf = probs[pred_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # SHAP interpretation
71
+ top_words, shap_plot = get_shap_plot(model, vectorizer, text, pred_idx, pred_class)
72
+
73
+ # Final output
74
  explanation = f"""
75
+ **Predicted Sentiment:** {pred_class}
76
+ **Confidence:** {conf:.2f}
77
+ **Top Influential Words:** {', '.join(top_words)}
 
78
  """
79
+ return explanation, shap_plot
80
 
81
+ # ------------------------------------------------------------
82
+ # 4️⃣ Gradio Interface
83
+ # ------------------------------------------------------------
84
+ title = "🌐 Multi-Lingual Sentiment Analysis (English + Persian)"
85
+ description = """
86
+ Select a language, type a comment, and see both the sentiment prediction and SHAP interpretability.
87
+ """
88
+
89
+ examples = [
90
+ ["I love this product! Highly recommend.", "English"],
91
+ ["Worst experience ever, totally disappointed.", "English"],
92
+ ["The service was okay, nothing special.", "English"],
93
+ ["این محصول فوق‌العاده است", "Persian"],
94
+ ["تجربه‌ی بدی بود، ناراضی‌ام", "Persian"],
95
+ ["کیفیتش متوسط بود", "Persian"]
96
+ ]
97
+
98
+ demo = gr.Interface(
99
  fn=predict_sentiment,
100
  inputs=[
101
  gr.Textbox(lines=3, label="Enter comment"),
102
+ gr.Radio(["English", "Persian"], label="Choose Dataset/Language", value="English")
103
  ],
104
  outputs=[
105
+ gr.Markdown(label="Prediction & Explanation"),
106
  gr.Image(label="Top Word Contributions")
107
  ],
108
+ title=title,
109
+ description=description,
110
+ examples=examples,
111
  )
112
 
113
+ if __name__ == "__main__":
114
+ demo.launch()