monsimas commited on
Commit
c4d21e2
·
verified ·
1 Parent(s): 9bd9e1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -79
app.py CHANGED
@@ -6,7 +6,12 @@ from typing import Dict, List, Tuple
6
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
7
  import numpy as np
8
  import re
 
 
 
 
9
 
 
10
  AVAILABLE_MODELS = [
11
  "llama-3.3-70b-instruct",
12
  "llama-3.1-70b-instruct",
@@ -17,6 +22,7 @@ AVAILABLE_MODELS = [
17
  "deepseek-r1-distill-llama-70b"
18
  ]
19
 
 
20
  CSV_PATH = "evaluation.csv"
21
  TEXT_COLUMN = "Contribution"
22
  LABEL_COLUMN = "Etat"
@@ -43,24 +49,25 @@ def create_client(api_key: str) -> OpenAI:
43
  )
44
 
45
  def parse_model_output(output: str) -> str:
46
- """Parse and normalize model output to match expected labels."""
47
- cleaned = output.strip()
 
48
 
49
- if cleaned == "SPAM":
 
50
  return "Spam"
51
- elif cleaned == "NOT_SPAM":
52
- return "Pas spam"
53
-
54
- cleaned_lower = cleaned.lower().replace('_', ' ')
55
 
56
- if cleaned_lower in ['spam', 'yes', 'true', 'is spam']:
 
57
  return "Spam"
58
- elif 'not spam' in cleaned_lower or cleaned_lower in ['no', 'false', 'clean', 'ham', 'legitimate']:
59
  return "Pas spam"
60
- else:
61
- # Log unexpected responses for debugging
62
- print(f"Warning: Unexpected model output: {output}")
63
- return "Pas spam" # Default to not spam for unrecognized responses
64
 
65
  def process_single_text(
66
  text: str,
@@ -70,12 +77,14 @@ def process_single_text(
70
  max_tokens: int,
71
  top_p: float,
72
  api_key: str
73
- ) -> Tuple[str, str]:
74
- """Process a single text input through the model."""
75
  client = create_client(api_key)
76
 
 
77
  formatted_prompt = prompt_template.format(text=text)
78
 
 
79
  try:
80
  response = client.chat.completions.create(
81
  model=model,
@@ -91,9 +100,14 @@ def process_single_text(
91
  )
92
  raw_output = response.choices[0].message.content.strip()
93
  parsed_output = parse_model_output(raw_output)
94
- return raw_output, parsed_output
 
 
 
 
95
  except Exception as e:
96
- return f"Error: {str(e)}", "Pas spam"
 
97
 
98
  def evaluate_performance(
99
  df: pd.DataFrame,
@@ -116,23 +130,83 @@ def evaluate_performance(
116
  # Convert any numpy values to Python floats
117
  return {k: float(v) if isinstance(v, (np.floating, np.integer)) else v for k, v in metrics.items()}
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  def process_benchmark(
120
  prompt_template: str,
121
  model: str,
122
  temperature: float,
123
  max_tokens: int,
124
  top_p: float,
125
- api_key: str
126
- ) -> Tuple[pd.DataFrame, Dict[str, float]]:
127
- """Process benchmark dataset and return results with metrics."""
 
128
  # Read CSV file
129
  df = pd.read_csv(CSV_PATH)
130
 
131
  # Process each text
132
  raw_predictions = []
133
  parsed_predictions = []
134
- for text in df[TEXT_COLUMN]:
135
- raw_output, parsed_output = process_single_text(
 
 
 
 
 
 
136
  text,
137
  prompt_template,
138
  model,
@@ -143,70 +217,85 @@ def process_benchmark(
143
  )
144
  raw_predictions.append(raw_output)
145
  parsed_predictions.append(parsed_output)
 
146
 
147
  # Add predictions to DataFrame
148
  df['model_raw_output'] = raw_predictions
149
  df['model_prediction'] = parsed_predictions
 
150
 
151
  # Calculate metrics
152
  metrics = evaluate_performance(df, parsed_predictions)
153
 
154
- return df, metrics
 
 
 
 
 
 
 
155
 
156
  def create_interface():
157
- """Create Gradio interface."""
158
- with gr.Blocks() as interface:
159
  gr.Markdown("# Moderation Model Testing Interface")
160
 
161
- with gr.Row():
162
- with gr.Column():
163
- api_key = gr.Textbox(
164
- label="Scaleway API Key",
165
- placeholder="Enter your API key",
166
- type="password"
167
- )
168
- model = gr.Dropdown(
169
- choices=AVAILABLE_MODELS,
170
- label="Model",
171
- value=AVAILABLE_MODELS[0]
172
- )
173
- prompt = gr.Textbox(
174
- label="Prompt Template",
175
- value=DEFAULT_PROMPT,
176
- lines=5
177
- )
178
-
179
- with gr.Column():
180
- temperature = gr.Slider(
181
- minimum=0,
182
- maximum=1,
183
- value=0.3,
184
- label="Temperature"
185
- )
186
- max_tokens = gr.Slider(
187
- minimum=1,
188
- maximum=2048,
189
- value=512,
190
- step=1,
191
- label="Max Tokens"
192
- )
193
- top_p = gr.Slider(
194
- minimum=0,
195
- maximum=1,
196
- value=1,
197
- label="Top P"
198
- )
199
-
200
- run_button = gr.Button("Run Benchmark")
201
-
202
- with gr.Row():
203
- with gr.Column():
204
- results_df = gr.Dataframe(
205
- label="Results",
206
- headers=[TEXT_COLUMN, LABEL_COLUMN, "Raw Model Output", "Model Prediction"]
207
- )
208
- with gr.Column():
209
- metrics_json = gr.JSON(label="Performance Metrics")
 
 
 
 
 
 
210
 
211
  def run_benchmark_fn(
212
  prompt,
@@ -214,17 +303,24 @@ def create_interface():
214
  temperature,
215
  max_tokens,
216
  top_p,
217
- api_key
 
218
  ):
219
- df, metrics = process_benchmark(
220
  prompt,
221
  model,
222
  temperature,
223
  max_tokens,
224
  top_p,
225
- api_key
 
226
  )
227
- return df, metrics
 
 
 
 
 
228
 
229
  run_button.click(
230
  run_benchmark_fn,
@@ -236,7 +332,7 @@ def create_interface():
236
  top_p,
237
  api_key
238
  ],
239
- outputs=[results_df, metrics_json]
240
  )
241
 
242
  return interface
 
6
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
7
  import numpy as np
8
  import re
9
+ import time
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib
12
+ matplotlib.use('Agg')
13
 
14
+ # Constants
15
  AVAILABLE_MODELS = [
16
  "llama-3.3-70b-instruct",
17
  "llama-3.1-70b-instruct",
 
22
  "deepseek-r1-distill-llama-70b"
23
  ]
24
 
25
+ # File and column names
26
  CSV_PATH = "evaluation.csv"
27
  TEXT_COLUMN = "Contribution"
28
  LABEL_COLUMN = "Etat"
 
49
  )
50
 
51
  def parse_model_output(output: str) -> str:
52
+ """Parse and normalize model output to match expected labels with improved pattern matching."""
53
+ # Store original output for transparency
54
+ cleaned = output.strip().lower()
55
 
56
+ # Enhanced pattern matching with regex
57
+ if re.search(r'\bspam\b', cleaned) and not re.search(r'\bnot\s+spam\b|\bpas\s+spam\b', cleaned):
58
  return "Spam"
59
+ elif re.search(r'\bnot[\s_-]*spam\b|\bpas[\s_-]*spam\b|\blegitimate\b|\bham\b|\bclean\b', cleaned):
60
+ return "Non spam"
 
 
61
 
62
+ # Additional backup checks for specific formats
63
+ if cleaned == "spam":
64
  return "Spam"
65
+ elif cleaned in ["not_spam", "not spam", "pas spam"]:
66
  return "Pas spam"
67
+
68
+ # Log unexpected responses and default to not spam
69
+ print(f"Warning: Unexpected model output: {output}")
70
+ return "Pas spam" # Default to not spam for unrecognized responses
71
 
72
  def process_single_text(
73
  text: str,
 
77
  max_tokens: int,
78
  top_p: float,
79
  api_key: str
80
+ ) -> Tuple[str, str, float]:
81
+ """Process a single text input through the model and measure response time."""
82
  client = create_client(api_key)
83
 
84
+ # Format the prompt
85
  formatted_prompt = prompt_template.format(text=text)
86
 
87
+ start_time = time.time()
88
  try:
89
  response = client.chat.completions.create(
90
  model=model,
 
100
  )
101
  raw_output = response.choices[0].message.content.strip()
102
  parsed_output = parse_model_output(raw_output)
103
+
104
+ # Calculate response time
105
+ response_time = time.time() - start_time
106
+
107
+ return raw_output, parsed_output, response_time
108
  except Exception as e:
109
+ response_time = time.time() - start_time
110
+ return f"Error: {str(e)}", "Pas spam", response_time
111
 
112
  def evaluate_performance(
113
  df: pd.DataFrame,
 
130
  # Convert any numpy values to Python floats
131
  return {k: float(v) if isinstance(v, (np.floating, np.integer)) else v for k, v in metrics.items()}
132
 
133
+ def create_metrics_plot(metrics: Dict[str, float]) -> plt.Figure:
134
+ """Create a bar chart visualization of metrics."""
135
+ fig, ax = plt.subplots(figsize=(10, 6))
136
+
137
+ # Extract metrics excluding avg_response_time for performance bar chart
138
+ perf_metrics = {k: v for k, v in metrics.items() if k != 'avg_response_time'}
139
+
140
+ metrics_names = list(perf_metrics.keys())
141
+ metrics_values = list(perf_metrics.values())
142
+
143
+ bars = ax.bar(metrics_names, metrics_values, color='skyblue')
144
+
145
+ # Add value labels on top of bars
146
+ for bar in bars:
147
+ height = bar.get_height()
148
+ ax.annotate(f'{height:.3f}',
149
+ xy=(bar.get_x() + bar.get_width() / 2, height),
150
+ xytext=(0, 3), # 3 points vertical offset
151
+ textcoords="offset points",
152
+ ha='center', va='bottom')
153
+
154
+ ax.set_ylim(0, 1.0)
155
+ ax.set_title('Model Performance Metrics')
156
+ ax.set_ylabel('Score')
157
+
158
+ plt.tight_layout()
159
+ return fig
160
+
161
+ def create_confusion_matrix_plot(df: pd.DataFrame) -> plt.Figure:
162
+ """Create a confusion matrix visualization."""
163
+ from sklearn.metrics import confusion_matrix
164
+ import seaborn as sns
165
+
166
+ # Get true and predicted labels
167
+ y_true = [1 if label == "Spam" else 0 for label in df[LABEL_COLUMN]]
168
+ y_pred = [1 if pred == "Spam" else 0 for pred in df['model_prediction']]
169
+
170
+ # Create confusion matrix
171
+ cm = confusion_matrix(y_true, y_pred)
172
+
173
+ # Plot
174
+ fig, ax = plt.subplots(figsize=(8, 6))
175
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
176
+ xticklabels=['Not Spam', 'Spam'],
177
+ yticklabels=['Not Spam', 'Spam'])
178
+
179
+ ax.set_title('Confusion Matrix')
180
+ ax.set_ylabel('True Label')
181
+ ax.set_xlabel('Predicted Label')
182
+
183
+ plt.tight_layout()
184
+ return fig
185
+
186
  def process_benchmark(
187
  prompt_template: str,
188
  model: str,
189
  temperature: float,
190
  max_tokens: int,
191
  top_p: float,
192
+ api_key: str,
193
+ progress=None
194
+ ) -> Tuple[pd.DataFrame, Dict[str, float], plt.Figure, plt.Figure]:
195
+ """Process benchmark dataset and return results with metrics and visualizations."""
196
  # Read CSV file
197
  df = pd.read_csv(CSV_PATH)
198
 
199
  # Process each text
200
  raw_predictions = []
201
  parsed_predictions = []
202
+ response_times = []
203
+
204
+ total = len(df)
205
+ for i, text in enumerate(df[TEXT_COLUMN]):
206
+ if progress is not None:
207
+ progress(i / total, f"Processing {i+1}/{total}")
208
+
209
+ raw_output, parsed_output, response_time = process_single_text(
210
  text,
211
  prompt_template,
212
  model,
 
217
  )
218
  raw_predictions.append(raw_output)
219
  parsed_predictions.append(parsed_output)
220
+ response_times.append(response_time)
221
 
222
  # Add predictions to DataFrame
223
  df['model_raw_output'] = raw_predictions
224
  df['model_prediction'] = parsed_predictions
225
+ df['response_time'] = response_times
226
 
227
  # Calculate metrics
228
  metrics = evaluate_performance(df, parsed_predictions)
229
 
230
+ # Add average response time metric
231
+ metrics['avg_response_time'] = sum(response_times) / len(response_times)
232
+
233
+ # Create visualizations
234
+ metrics_plot = create_metrics_plot(metrics)
235
+ confusion_matrix_plot = create_confusion_matrix_plot(df)
236
+
237
+ return df, metrics, metrics_plot, confusion_matrix_plot
238
 
239
  def create_interface():
240
+ """Create Gradio interface with enhanced UI and visualizations."""
241
+ with gr.Blocks(theme=gr.themes.Soft()) as interface:
242
  gr.Markdown("# Moderation Model Testing Interface")
243
 
244
+ with gr.Tabs():
245
+ with gr.TabItem("Model Configuration"):
246
+ with gr.Row():
247
+ with gr.Column():
248
+ api_key = gr.Textbox(
249
+ label="Scaleway API Key",
250
+ placeholder="Enter your API key",
251
+ type="password"
252
+ )
253
+ model = gr.Dropdown(
254
+ choices=AVAILABLE_MODELS,
255
+ label="Model",
256
+ value=AVAILABLE_MODELS[0]
257
+ )
258
+ prompt = gr.Textbox(
259
+ label="Prompt Template",
260
+ value=DEFAULT_PROMPT,
261
+ lines=5
262
+ )
263
+
264
+ with gr.Column():
265
+ temperature = gr.Slider(
266
+ minimum=0,
267
+ maximum=1,
268
+ value=0.3,
269
+ label="Temperature"
270
+ )
271
+ max_tokens = gr.Slider(
272
+ minimum=1,
273
+ maximum=2048,
274
+ value=512,
275
+ step=1,
276
+ label="Max Tokens"
277
+ )
278
+ top_p = gr.Slider(
279
+ minimum=0,
280
+ maximum=1,
281
+ value=1,
282
+ label="Top P"
283
+ )
284
+ run_button = gr.Button("Run Benchmark", variant="primary")
285
+
286
+ with gr.TabItem("Results"):
287
+ with gr.Row():
288
+ with gr.Column(scale=2):
289
+ results_df = gr.Dataframe(
290
+ label="Results Table",
291
+ headers=[TEXT_COLUMN, LABEL_COLUMN, "Raw Model Output", "Model Prediction", "Response Time (s)"]
292
+ )
293
+ with gr.Column(scale=1):
294
+ metrics_json = gr.JSON(label="Performance Metrics")
295
+
296
+ with gr.Row():
297
+ metrics_plot = gr.Plot(label="Performance Metrics Visualization")
298
+ confusion_matrix_vis = gr.Plot(label="Confusion Matrix")
299
 
300
  def run_benchmark_fn(
301
  prompt,
 
303
  temperature,
304
  max_tokens,
305
  top_p,
306
+ api_key,
307
+ progress=gr.Progress()
308
  ):
309
+ df, metrics, metrics_vis, confusion_vis = process_benchmark(
310
  prompt,
311
  model,
312
  temperature,
313
  max_tokens,
314
  top_p,
315
+ api_key,
316
+ progress
317
  )
318
+ # Format dataframe for display
319
+ display_df = df[[TEXT_COLUMN, LABEL_COLUMN, 'model_raw_output', 'model_prediction', 'response_time']].copy()
320
+ # Format response time to 3 decimal places
321
+ display_df['response_time'] = display_df['response_time'].apply(lambda x: f"{x:.3f}")
322
+
323
+ return display_df, metrics, metrics_vis, confusion_vis
324
 
325
  run_button.click(
326
  run_benchmark_fn,
 
332
  top_p,
333
  api_key
334
  ],
335
+ outputs=[results_df, metrics_json, metrics_plot, confusion_matrix_vis]
336
  )
337
 
338
  return interface