Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from openai import OpenAI | |
| import json | |
| from typing import Dict, List, Tuple | |
| from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score | |
| import numpy as np | |
| import re | |
| import time | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| # Constants | |
| AVAILABLE_MODELS = [ | |
| "llama-3.3-70b-instruct", | |
| "llama-3.1-70b-instruct", | |
| "llama-3.1-8b-instruct", | |
| "mistral-nemo-instruct-2407", | |
| "qwen2.5-coder-32b-instruct", | |
| "deepseek-r1", | |
| "deepseek-r1-distill-llama-70b" | |
| ] | |
| # File and column names | |
| CSV_PATH = "evaluation.csv" | |
| TEXT_COLUMN = "Contribution" | |
| LABEL_COLUMN = "Etat" | |
| DEFAULT_PROMPT = """You are a content moderation assistant for a participatory democracy platform. Your task is to identify and classify potential spam content, with particular attention to black hat SEO backlink attempts. | |
| Key indicators of spam include: | |
| - Hidden or excessive links that seem aimed at manipulating search rankings | |
| - Generic or unrelated content with embedded commercial links | |
| - Repetitive posting of similar content with backlinks | |
| - Content that appears to be aimed at search engines rather than human readers | |
| - Links to questionable or unrelated commercial websites | |
| Classify the following contribution as either SPAM or NOT_SPAM. | |
| Respond with just one word - either SPAM or NOT_SPAM. | |
| Text to classify: {text}""" | |
| def create_client(api_key: str) -> OpenAI: | |
| """Create OpenAI client with Scaleway configuration.""" | |
| return OpenAI( | |
| base_url="https://api.scaleway.ai/fc37904f-3555-4f58-b79c-041615b7001a/v1", | |
| api_key=api_key | |
| ) | |
| def parse_model_output(output: str) -> str: | |
| """Parse and normalize model output to match expected labels with improved pattern matching.""" | |
| # Store original output for transparency | |
| cleaned = output.strip().lower() | |
| # Enhanced pattern matching with regex | |
| if re.search(r'\bspam\b', cleaned) and not re.search(r'\bnot\s+spam\b|\bpas\s+spam\b', cleaned): | |
| return "Spam" | |
| elif re.search(r'\bnot[\s_-]*spam\b|\bpas[\s_-]*spam\b|\blegitimate\b|\bham\b|\bclean\b', cleaned): | |
| return "Non spam" | |
| # Additional backup checks for specific formats | |
| if cleaned == "spam": | |
| return "Spam" | |
| elif cleaned in ["not_spam", "not spam", "pas spam"]: | |
| return "Pas spam" | |
| # Log unexpected responses and default to not spam | |
| print(f"Warning: Unexpected model output: {output}") | |
| return "Pas spam" # Default to not spam for unrecognized responses | |
| def process_single_text( | |
| text: str, | |
| prompt_template: str, | |
| model: str, | |
| temperature: float, | |
| max_tokens: int, | |
| top_p: float, | |
| api_key: str | |
| ) -> Tuple[str, str, float]: | |
| """Process a single text input through the model and measure response time.""" | |
| client = create_client(api_key) | |
| # Format the prompt | |
| formatted_prompt = prompt_template.format(text=text) | |
| start_time = time.time() | |
| try: | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant"}, | |
| {"role": "user", "content": formatted_prompt}, | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| presence_penalty=0, | |
| stream=False | |
| ) | |
| raw_output = response.choices[0].message.content.strip() | |
| parsed_output = parse_model_output(raw_output) | |
| # Calculate response time | |
| response_time = time.time() - start_time | |
| return raw_output, parsed_output, response_time | |
| except Exception as e: | |
| response_time = time.time() - start_time | |
| return f"Error: {str(e)}", "Pas spam", response_time | |
| def evaluate_performance( | |
| df: pd.DataFrame, | |
| predictions: List[str] | |
| ) -> Dict[str, float]: | |
| """Calculate performance metrics.""" | |
| y_true = df[LABEL_COLUMN].tolist() | |
| # Convert string labels to binary | |
| y_true_binary = [1 if label == "Spam" else 0 for label in y_true] | |
| y_pred_binary = [1 if pred == "Spam" else 0 for pred in predictions] | |
| metrics = { | |
| "accuracy": accuracy_score(y_true_binary, y_pred_binary), | |
| "precision": precision_score(y_true_binary, y_pred_binary, zero_division=0), | |
| "recall": recall_score(y_true_binary, y_pred_binary, zero_division=0), | |
| "f1": f1_score(y_true_binary, y_pred_binary, zero_division=0) | |
| } | |
| # Convert any numpy values to Python floats | |
| return {k: float(v) if isinstance(v, (np.floating, np.integer)) else v for k, v in metrics.items()} | |
| def create_metrics_plot(metrics: Dict[str, float]) -> plt.Figure: | |
| """Create a bar chart visualization of metrics.""" | |
| fig, ax = plt.subplots(figsize=(10, 6)) | |
| # Extract metrics excluding avg_response_time for performance bar chart | |
| perf_metrics = {k: v for k, v in metrics.items() if k != 'avg_response_time'} | |
| metrics_names = list(perf_metrics.keys()) | |
| metrics_values = list(perf_metrics.values()) | |
| bars = ax.bar(metrics_names, metrics_values, color='skyblue') | |
| # Add value labels on top of bars | |
| for bar in bars: | |
| height = bar.get_height() | |
| ax.annotate(f'{height:.3f}', | |
| xy=(bar.get_x() + bar.get_width() / 2, height), | |
| xytext=(0, 3), # 3 points vertical offset | |
| textcoords="offset points", | |
| ha='center', va='bottom') | |
| ax.set_ylim(0, 1.0) | |
| ax.set_title('Model Performance Metrics') | |
| ax.set_ylabel('Score') | |
| plt.tight_layout() | |
| return fig | |
| def create_confusion_matrix_plot(df: pd.DataFrame) -> plt.Figure: | |
| """Create a confusion matrix visualization.""" | |
| from sklearn.metrics import confusion_matrix | |
| import seaborn as sns | |
| # Get true and predicted labels | |
| y_true = [1 if label == "Spam" else 0 for label in df[LABEL_COLUMN]] | |
| y_pred = [1 if pred == "Spam" else 0 for pred in df['model_prediction']] | |
| # Create confusion matrix | |
| cm = confusion_matrix(y_true, y_pred) | |
| # Plot | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, | |
| xticklabels=['Not Spam', 'Spam'], | |
| yticklabels=['Not Spam', 'Spam']) | |
| ax.set_title('Confusion Matrix') | |
| ax.set_ylabel('True Label') | |
| ax.set_xlabel('Predicted Label') | |
| plt.tight_layout() | |
| return fig | |
| def process_benchmark( | |
| prompt_template: str, | |
| model: str, | |
| temperature: float, | |
| max_tokens: int, | |
| top_p: float, | |
| api_key: str, | |
| progress=None | |
| ) -> Tuple[pd.DataFrame, Dict[str, float], plt.Figure, plt.Figure]: | |
| """Process benchmark dataset and return results with metrics and visualizations.""" | |
| # Read CSV file | |
| df = pd.read_csv(CSV_PATH) | |
| # Process each text | |
| raw_predictions = [] | |
| parsed_predictions = [] | |
| response_times = [] | |
| total = len(df) | |
| for i, text in enumerate(df[TEXT_COLUMN]): | |
| if progress is not None: | |
| progress(i / total, f"Processing {i+1}/{total}") | |
| raw_output, parsed_output, response_time = process_single_text( | |
| text, | |
| prompt_template, | |
| model, | |
| temperature, | |
| max_tokens, | |
| top_p, | |
| api_key | |
| ) | |
| raw_predictions.append(raw_output) | |
| parsed_predictions.append(parsed_output) | |
| response_times.append(response_time) | |
| # Add predictions to DataFrame | |
| df['model_raw_output'] = raw_predictions | |
| df['model_prediction'] = parsed_predictions | |
| df['response_time'] = response_times | |
| # Calculate metrics | |
| metrics = evaluate_performance(df, parsed_predictions) | |
| # Add average response time metric | |
| metrics['avg_response_time'] = sum(response_times) / len(response_times) | |
| # Create visualizations | |
| metrics_plot = create_metrics_plot(metrics) | |
| confusion_matrix_plot = create_confusion_matrix_plot(df) | |
| return df, metrics, metrics_plot, confusion_matrix_plot | |
| def create_interface(): | |
| """Create Gradio interface with enhanced UI and visualizations.""" | |
| with gr.Blocks(theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# Moderation Model Testing Interface") | |
| with gr.Tabs(): | |
| with gr.TabItem("Model Configuration"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| api_key = gr.Textbox( | |
| label="Scaleway API Key", | |
| placeholder="Enter your API key", | |
| type="password" | |
| ) | |
| model = gr.Dropdown( | |
| choices=AVAILABLE_MODELS, | |
| label="Model", | |
| value=AVAILABLE_MODELS[0] | |
| ) | |
| prompt = gr.Textbox( | |
| label="Prompt Template", | |
| value=DEFAULT_PROMPT, | |
| lines=5 | |
| ) | |
| with gr.Column(): | |
| temperature = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.3, | |
| label="Temperature" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=2048, | |
| value=512, | |
| step=1, | |
| label="Max Tokens" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=1, | |
| label="Top P" | |
| ) | |
| run_button = gr.Button("Run Benchmark", variant="primary") | |
| with gr.TabItem("Results"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| results_df = gr.Dataframe( | |
| label="Results Table", | |
| headers=[TEXT_COLUMN, LABEL_COLUMN, "Raw Model Output", "Model Prediction", "Response Time (s)"] | |
| ) | |
| with gr.Column(scale=1): | |
| metrics_json = gr.JSON(label="Performance Metrics") | |
| with gr.Row(): | |
| metrics_plot = gr.Plot(label="Performance Metrics Visualization") | |
| confusion_matrix_vis = gr.Plot(label="Confusion Matrix") | |
| def run_benchmark_fn( | |
| prompt, | |
| model, | |
| temperature, | |
| max_tokens, | |
| top_p, | |
| api_key, | |
| progress=gr.Progress() | |
| ): | |
| df, metrics, metrics_vis, confusion_vis = process_benchmark( | |
| prompt, | |
| model, | |
| temperature, | |
| max_tokens, | |
| top_p, | |
| api_key, | |
| progress | |
| ) | |
| # Format dataframe for display | |
| display_df = df[[TEXT_COLUMN, LABEL_COLUMN, 'model_raw_output', 'model_prediction', 'response_time']].copy() | |
| # Format response time to 3 decimal places | |
| display_df['response_time'] = display_df['response_time'].apply(lambda x: f"{x:.3f}") | |
| return display_df, metrics, metrics_vis, confusion_vis | |
| run_button.click( | |
| run_benchmark_fn, | |
| inputs=[ | |
| prompt, | |
| model, | |
| temperature, | |
| max_tokens, | |
| top_p, | |
| api_key | |
| ], | |
| outputs=[results_df, metrics_json, metrics_plot, confusion_matrix_vis] | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch() |