from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import gradio as gr # Load your saved model and tokenizer model_dir = "saved_model" tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir) # Define all 6 labels (Jigsaw-style multi-label toxic comment classification) labels = [ "toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate" ] # Inference function def classify(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.sigmoid(outputs.logits)[0] # Sigmoid for multi-label result = {label: float(probs[i]) for i, label in enumerate(labels)} return result # Gradio interface gr.Interface( fn=classify, inputs=gr.Textbox(placeholder="Enter your comment..."), outputs=gr.Label(num_top_classes=6), title="Toxic Comment Classifier" ).launch()