|
|
from typing import Dict, Any |
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
|
|
|
|
|
|
def predict(text: str) -> Dict[str, Any]: |
|
|
"""Classify text for PII detection.""" |
|
|
if not text or text.strip() == "": |
|
|
return {"No input provided": 0.0} |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
padding="max_length", |
|
|
max_length=512, |
|
|
truncation=True |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.sigmoid(logits) |
|
|
probs = probabilities.squeeze().tolist() |
|
|
|
|
|
|
|
|
results = { |
|
|
"Asking for PII": float(probs[0]), |
|
|
"Giving PII": float(probs[1]) |
|
|
} |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
return {"Error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
examples = [ |
|
|
["Do you have the blue app?"], |
|
|
["I live at 901 Roosevelt St, Redwood City"], |
|
|
] |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model_id = "Roblox/PII-OSS-Private-Not-Public" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Loading model: {model_id}") |
|
|
try: |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_id) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model.eval() |
|
|
print("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Failed to load model: {e}") |
|
|
print("If running locally, you may need to login with: huggingface-cli login") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Textbox( |
|
|
lines=3, |
|
|
placeholder="Enter text to analyze for PII content...", |
|
|
label="Input Text" |
|
|
), |
|
|
outputs=gr.Label( |
|
|
num_top_classes=2, |
|
|
label="Classification Results" |
|
|
), |
|
|
title="PII Detection Demo", |
|
|
description="This model detects whether text is asking for or giving personal information (PII).", |
|
|
examples=examples, |
|
|
flagging_mode="never", |
|
|
) |
|
|
|
|
|
demo.launch() |