MonaHamid commited on
Commit
f8271bc
·
verified ·
1 Parent(s): b4206c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -11
app.py CHANGED
@@ -1,15 +1,17 @@
1
- import mlflow.pyfunc
 
2
  import gradio as gr
3
 
4
- # Load the model from MLflow registry
5
- model = mlflow.pyfunc.load_model("models:/bert-toxic-classifier/1")
 
6
 
7
  def classify(text):
8
- return model.predict([text])[0]
9
-
10
- gr.Interface(
11
- fn=classify,
12
- inputs=gr.Textbox(label="Enter your comment"),
13
- outputs=gr.Textbox(label="Prediction"),
14
- title="Toxic Comment Classifier"
15
- ).launch()
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
  import gradio as gr
4
 
5
+ model_dir = "saved_model"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
7
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
8
 
9
  def classify(text):
10
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
11
+ outputs = model(**inputs)
12
+ probs = torch.softmax(outputs.logits, dim=1)
13
+ labels = ["non-toxic", "toxic"] # Adjust if needed
14
+ return {labels[i]: float(probs[0][i]) for i in range(len(labels))}
15
+
16
+ gr.Interface(fn=classify, inputs="text", outputs="label").launch()
17
+