monsimas's picture
Update app.py
c4d21e2 verified
import gradio as gr
import pandas as pd
from openai import OpenAI
import json
from typing import Dict, List, Tuple
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import numpy as np
import re
import time
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
# Constants
AVAILABLE_MODELS = [
"llama-3.3-70b-instruct",
"llama-3.1-70b-instruct",
"llama-3.1-8b-instruct",
"mistral-nemo-instruct-2407",
"qwen2.5-coder-32b-instruct",
"deepseek-r1",
"deepseek-r1-distill-llama-70b"
]
# File and column names
CSV_PATH = "evaluation.csv"
TEXT_COLUMN = "Contribution"
LABEL_COLUMN = "Etat"
DEFAULT_PROMPT = """You are a content moderation assistant for a participatory democracy platform. Your task is to identify and classify potential spam content, with particular attention to black hat SEO backlink attempts.
Key indicators of spam include:
- Hidden or excessive links that seem aimed at manipulating search rankings
- Generic or unrelated content with embedded commercial links
- Repetitive posting of similar content with backlinks
- Content that appears to be aimed at search engines rather than human readers
- Links to questionable or unrelated commercial websites
Classify the following contribution as either SPAM or NOT_SPAM.
Respond with just one word - either SPAM or NOT_SPAM.
Text to classify: {text}"""
def create_client(api_key: str) -> OpenAI:
"""Create OpenAI client with Scaleway configuration."""
return OpenAI(
base_url="https://api.scaleway.ai/fc37904f-3555-4f58-b79c-041615b7001a/v1",
api_key=api_key
)
def parse_model_output(output: str) -> str:
"""Parse and normalize model output to match expected labels with improved pattern matching."""
# Store original output for transparency
cleaned = output.strip().lower()
# Enhanced pattern matching with regex
if re.search(r'\bspam\b', cleaned) and not re.search(r'\bnot\s+spam\b|\bpas\s+spam\b', cleaned):
return "Spam"
elif re.search(r'\bnot[\s_-]*spam\b|\bpas[\s_-]*spam\b|\blegitimate\b|\bham\b|\bclean\b', cleaned):
return "Non spam"
# Additional backup checks for specific formats
if cleaned == "spam":
return "Spam"
elif cleaned in ["not_spam", "not spam", "pas spam"]:
return "Pas spam"
# Log unexpected responses and default to not spam
print(f"Warning: Unexpected model output: {output}")
return "Pas spam" # Default to not spam for unrecognized responses
def process_single_text(
text: str,
prompt_template: str,
model: str,
temperature: float,
max_tokens: int,
top_p: float,
api_key: str
) -> Tuple[str, str, float]:
"""Process a single text input through the model and measure response time."""
client = create_client(api_key)
# Format the prompt
formatted_prompt = prompt_template.format(text=text)
start_time = time.time()
try:
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": formatted_prompt},
],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
presence_penalty=0,
stream=False
)
raw_output = response.choices[0].message.content.strip()
parsed_output = parse_model_output(raw_output)
# Calculate response time
response_time = time.time() - start_time
return raw_output, parsed_output, response_time
except Exception as e:
response_time = time.time() - start_time
return f"Error: {str(e)}", "Pas spam", response_time
def evaluate_performance(
df: pd.DataFrame,
predictions: List[str]
) -> Dict[str, float]:
"""Calculate performance metrics."""
y_true = df[LABEL_COLUMN].tolist()
# Convert string labels to binary
y_true_binary = [1 if label == "Spam" else 0 for label in y_true]
y_pred_binary = [1 if pred == "Spam" else 0 for pred in predictions]
metrics = {
"accuracy": accuracy_score(y_true_binary, y_pred_binary),
"precision": precision_score(y_true_binary, y_pred_binary, zero_division=0),
"recall": recall_score(y_true_binary, y_pred_binary, zero_division=0),
"f1": f1_score(y_true_binary, y_pred_binary, zero_division=0)
}
# Convert any numpy values to Python floats
return {k: float(v) if isinstance(v, (np.floating, np.integer)) else v for k, v in metrics.items()}
def create_metrics_plot(metrics: Dict[str, float]) -> plt.Figure:
"""Create a bar chart visualization of metrics."""
fig, ax = plt.subplots(figsize=(10, 6))
# Extract metrics excluding avg_response_time for performance bar chart
perf_metrics = {k: v for k, v in metrics.items() if k != 'avg_response_time'}
metrics_names = list(perf_metrics.keys())
metrics_values = list(perf_metrics.values())
bars = ax.bar(metrics_names, metrics_values, color='skyblue')
# Add value labels on top of bars
for bar in bars:
height = bar.get_height()
ax.annotate(f'{height:.3f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3), # 3 points vertical offset
textcoords="offset points",
ha='center', va='bottom')
ax.set_ylim(0, 1.0)
ax.set_title('Model Performance Metrics')
ax.set_ylabel('Score')
plt.tight_layout()
return fig
def create_confusion_matrix_plot(df: pd.DataFrame) -> plt.Figure:
"""Create a confusion matrix visualization."""
from sklearn.metrics import confusion_matrix
import seaborn as sns
# Get true and predicted labels
y_true = [1 if label == "Spam" else 0 for label in df[LABEL_COLUMN]]
y_pred = [1 if pred == "Spam" else 0 for pred in df['model_prediction']]
# Create confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plot
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
xticklabels=['Not Spam', 'Spam'],
yticklabels=['Not Spam', 'Spam'])
ax.set_title('Confusion Matrix')
ax.set_ylabel('True Label')
ax.set_xlabel('Predicted Label')
plt.tight_layout()
return fig
def process_benchmark(
prompt_template: str,
model: str,
temperature: float,
max_tokens: int,
top_p: float,
api_key: str,
progress=None
) -> Tuple[pd.DataFrame, Dict[str, float], plt.Figure, plt.Figure]:
"""Process benchmark dataset and return results with metrics and visualizations."""
# Read CSV file
df = pd.read_csv(CSV_PATH)
# Process each text
raw_predictions = []
parsed_predictions = []
response_times = []
total = len(df)
for i, text in enumerate(df[TEXT_COLUMN]):
if progress is not None:
progress(i / total, f"Processing {i+1}/{total}")
raw_output, parsed_output, response_time = process_single_text(
text,
prompt_template,
model,
temperature,
max_tokens,
top_p,
api_key
)
raw_predictions.append(raw_output)
parsed_predictions.append(parsed_output)
response_times.append(response_time)
# Add predictions to DataFrame
df['model_raw_output'] = raw_predictions
df['model_prediction'] = parsed_predictions
df['response_time'] = response_times
# Calculate metrics
metrics = evaluate_performance(df, parsed_predictions)
# Add average response time metric
metrics['avg_response_time'] = sum(response_times) / len(response_times)
# Create visualizations
metrics_plot = create_metrics_plot(metrics)
confusion_matrix_plot = create_confusion_matrix_plot(df)
return df, metrics, metrics_plot, confusion_matrix_plot
def create_interface():
"""Create Gradio interface with enhanced UI and visualizations."""
with gr.Blocks(theme=gr.themes.Soft()) as interface:
gr.Markdown("# Moderation Model Testing Interface")
with gr.Tabs():
with gr.TabItem("Model Configuration"):
with gr.Row():
with gr.Column():
api_key = gr.Textbox(
label="Scaleway API Key",
placeholder="Enter your API key",
type="password"
)
model = gr.Dropdown(
choices=AVAILABLE_MODELS,
label="Model",
value=AVAILABLE_MODELS[0]
)
prompt = gr.Textbox(
label="Prompt Template",
value=DEFAULT_PROMPT,
lines=5
)
with gr.Column():
temperature = gr.Slider(
minimum=0,
maximum=1,
value=0.3,
label="Temperature"
)
max_tokens = gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max Tokens"
)
top_p = gr.Slider(
minimum=0,
maximum=1,
value=1,
label="Top P"
)
run_button = gr.Button("Run Benchmark", variant="primary")
with gr.TabItem("Results"):
with gr.Row():
with gr.Column(scale=2):
results_df = gr.Dataframe(
label="Results Table",
headers=[TEXT_COLUMN, LABEL_COLUMN, "Raw Model Output", "Model Prediction", "Response Time (s)"]
)
with gr.Column(scale=1):
metrics_json = gr.JSON(label="Performance Metrics")
with gr.Row():
metrics_plot = gr.Plot(label="Performance Metrics Visualization")
confusion_matrix_vis = gr.Plot(label="Confusion Matrix")
def run_benchmark_fn(
prompt,
model,
temperature,
max_tokens,
top_p,
api_key,
progress=gr.Progress()
):
df, metrics, metrics_vis, confusion_vis = process_benchmark(
prompt,
model,
temperature,
max_tokens,
top_p,
api_key,
progress
)
# Format dataframe for display
display_df = df[[TEXT_COLUMN, LABEL_COLUMN, 'model_raw_output', 'model_prediction', 'response_time']].copy()
# Format response time to 3 decimal places
display_df['response_time'] = display_df['response_time'].apply(lambda x: f"{x:.3f}")
return display_df, metrics, metrics_vis, confusion_vis
run_button.click(
run_benchmark_fn,
inputs=[
prompt,
model,
temperature,
max_tokens,
top_p,
api_key
],
outputs=[results_df, metrics_json, metrics_plot, confusion_matrix_vis]
)
return interface
if __name__ == "__main__":
interface = create_interface()
interface.launch()