import gradio as gr import pandas as pd from openai import OpenAI import json from typing import Dict, List, Tuple from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score import numpy as np import re import time import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # Constants AVAILABLE_MODELS = [ "llama-3.3-70b-instruct", "llama-3.1-70b-instruct", "llama-3.1-8b-instruct", "mistral-nemo-instruct-2407", "qwen2.5-coder-32b-instruct", "deepseek-r1", "deepseek-r1-distill-llama-70b" ] # File and column names CSV_PATH = "evaluation.csv" TEXT_COLUMN = "Contribution" LABEL_COLUMN = "Etat" DEFAULT_PROMPT = """You are a content moderation assistant for a participatory democracy platform. Your task is to identify and classify potential spam content, with particular attention to black hat SEO backlink attempts. Key indicators of spam include: - Hidden or excessive links that seem aimed at manipulating search rankings - Generic or unrelated content with embedded commercial links - Repetitive posting of similar content with backlinks - Content that appears to be aimed at search engines rather than human readers - Links to questionable or unrelated commercial websites Classify the following contribution as either SPAM or NOT_SPAM. Respond with just one word - either SPAM or NOT_SPAM. Text to classify: {text}""" def create_client(api_key: str) -> OpenAI: """Create OpenAI client with Scaleway configuration.""" return OpenAI( base_url="https://api.scaleway.ai/fc37904f-3555-4f58-b79c-041615b7001a/v1", api_key=api_key ) def parse_model_output(output: str) -> str: """Parse and normalize model output to match expected labels with improved pattern matching.""" # Store original output for transparency cleaned = output.strip().lower() # Enhanced pattern matching with regex if re.search(r'\bspam\b', cleaned) and not re.search(r'\bnot\s+spam\b|\bpas\s+spam\b', cleaned): return "Spam" elif re.search(r'\bnot[\s_-]*spam\b|\bpas[\s_-]*spam\b|\blegitimate\b|\bham\b|\bclean\b', cleaned): return "Non spam" # Additional backup checks for specific formats if cleaned == "spam": return "Spam" elif cleaned in ["not_spam", "not spam", "pas spam"]: return "Pas spam" # Log unexpected responses and default to not spam print(f"Warning: Unexpected model output: {output}") return "Pas spam" # Default to not spam for unrecognized responses def process_single_text( text: str, prompt_template: str, model: str, temperature: float, max_tokens: int, top_p: float, api_key: str ) -> Tuple[str, str, float]: """Process a single text input through the model and measure response time.""" client = create_client(api_key) # Format the prompt formatted_prompt = prompt_template.format(text=text) start_time = time.time() try: response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": formatted_prompt}, ], max_tokens=max_tokens, temperature=temperature, top_p=top_p, presence_penalty=0, stream=False ) raw_output = response.choices[0].message.content.strip() parsed_output = parse_model_output(raw_output) # Calculate response time response_time = time.time() - start_time return raw_output, parsed_output, response_time except Exception as e: response_time = time.time() - start_time return f"Error: {str(e)}", "Pas spam", response_time def evaluate_performance( df: pd.DataFrame, predictions: List[str] ) -> Dict[str, float]: """Calculate performance metrics.""" y_true = df[LABEL_COLUMN].tolist() # Convert string labels to binary y_true_binary = [1 if label == "Spam" else 0 for label in y_true] y_pred_binary = [1 if pred == "Spam" else 0 for pred in predictions] metrics = { "accuracy": accuracy_score(y_true_binary, y_pred_binary), "precision": precision_score(y_true_binary, y_pred_binary, zero_division=0), "recall": recall_score(y_true_binary, y_pred_binary, zero_division=0), "f1": f1_score(y_true_binary, y_pred_binary, zero_division=0) } # Convert any numpy values to Python floats return {k: float(v) if isinstance(v, (np.floating, np.integer)) else v for k, v in metrics.items()} def create_metrics_plot(metrics: Dict[str, float]) -> plt.Figure: """Create a bar chart visualization of metrics.""" fig, ax = plt.subplots(figsize=(10, 6)) # Extract metrics excluding avg_response_time for performance bar chart perf_metrics = {k: v for k, v in metrics.items() if k != 'avg_response_time'} metrics_names = list(perf_metrics.keys()) metrics_values = list(perf_metrics.values()) bars = ax.bar(metrics_names, metrics_values, color='skyblue') # Add value labels on top of bars for bar in bars: height = bar.get_height() ax.annotate(f'{height:.3f}', xy=(bar.get_x() + bar.get_width() / 2, height), xytext=(0, 3), # 3 points vertical offset textcoords="offset points", ha='center', va='bottom') ax.set_ylim(0, 1.0) ax.set_title('Model Performance Metrics') ax.set_ylabel('Score') plt.tight_layout() return fig def create_confusion_matrix_plot(df: pd.DataFrame) -> plt.Figure: """Create a confusion matrix visualization.""" from sklearn.metrics import confusion_matrix import seaborn as sns # Get true and predicted labels y_true = [1 if label == "Spam" else 0 for label in df[LABEL_COLUMN]] y_pred = [1 if pred == "Spam" else 0 for pred in df['model_prediction']] # Create confusion matrix cm = confusion_matrix(y_true, y_pred) # Plot fig, ax = plt.subplots(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, xticklabels=['Not Spam', 'Spam'], yticklabels=['Not Spam', 'Spam']) ax.set_title('Confusion Matrix') ax.set_ylabel('True Label') ax.set_xlabel('Predicted Label') plt.tight_layout() return fig def process_benchmark( prompt_template: str, model: str, temperature: float, max_tokens: int, top_p: float, api_key: str, progress=None ) -> Tuple[pd.DataFrame, Dict[str, float], plt.Figure, plt.Figure]: """Process benchmark dataset and return results with metrics and visualizations.""" # Read CSV file df = pd.read_csv(CSV_PATH) # Process each text raw_predictions = [] parsed_predictions = [] response_times = [] total = len(df) for i, text in enumerate(df[TEXT_COLUMN]): if progress is not None: progress(i / total, f"Processing {i+1}/{total}") raw_output, parsed_output, response_time = process_single_text( text, prompt_template, model, temperature, max_tokens, top_p, api_key ) raw_predictions.append(raw_output) parsed_predictions.append(parsed_output) response_times.append(response_time) # Add predictions to DataFrame df['model_raw_output'] = raw_predictions df['model_prediction'] = parsed_predictions df['response_time'] = response_times # Calculate metrics metrics = evaluate_performance(df, parsed_predictions) # Add average response time metric metrics['avg_response_time'] = sum(response_times) / len(response_times) # Create visualizations metrics_plot = create_metrics_plot(metrics) confusion_matrix_plot = create_confusion_matrix_plot(df) return df, metrics, metrics_plot, confusion_matrix_plot def create_interface(): """Create Gradio interface with enhanced UI and visualizations.""" with gr.Blocks(theme=gr.themes.Soft()) as interface: gr.Markdown("# Moderation Model Testing Interface") with gr.Tabs(): with gr.TabItem("Model Configuration"): with gr.Row(): with gr.Column(): api_key = gr.Textbox( label="Scaleway API Key", placeholder="Enter your API key", type="password" ) model = gr.Dropdown( choices=AVAILABLE_MODELS, label="Model", value=AVAILABLE_MODELS[0] ) prompt = gr.Textbox( label="Prompt Template", value=DEFAULT_PROMPT, lines=5 ) with gr.Column(): temperature = gr.Slider( minimum=0, maximum=1, value=0.3, label="Temperature" ) max_tokens = gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max Tokens" ) top_p = gr.Slider( minimum=0, maximum=1, value=1, label="Top P" ) run_button = gr.Button("Run Benchmark", variant="primary") with gr.TabItem("Results"): with gr.Row(): with gr.Column(scale=2): results_df = gr.Dataframe( label="Results Table", headers=[TEXT_COLUMN, LABEL_COLUMN, "Raw Model Output", "Model Prediction", "Response Time (s)"] ) with gr.Column(scale=1): metrics_json = gr.JSON(label="Performance Metrics") with gr.Row(): metrics_plot = gr.Plot(label="Performance Metrics Visualization") confusion_matrix_vis = gr.Plot(label="Confusion Matrix") def run_benchmark_fn( prompt, model, temperature, max_tokens, top_p, api_key, progress=gr.Progress() ): df, metrics, metrics_vis, confusion_vis = process_benchmark( prompt, model, temperature, max_tokens, top_p, api_key, progress ) # Format dataframe for display display_df = df[[TEXT_COLUMN, LABEL_COLUMN, 'model_raw_output', 'model_prediction', 'response_time']].copy() # Format response time to 3 decimal places display_df['response_time'] = display_df['response_time'].apply(lambda x: f"{x:.3f}") return display_df, metrics, metrics_vis, confusion_vis run_button.click( run_benchmark_fn, inputs=[ prompt, model, temperature, max_tokens, top_p, api_key ], outputs=[results_df, metrics_json, metrics_plot, confusion_matrix_vis] ) return interface if __name__ == "__main__": interface = create_interface() interface.launch()