|
|
import gradio as gr |
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
import torch |
|
|
import os |
|
|
|
|
|
|
|
|
MODEL_NAME = "uer/roberta-base-finetuned-jd-binary-chinese" |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
try: |
|
|
print(f"正在加载模型: {MODEL_NAME} 到设备: {DEVICE}") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
print("模型加载成功。") |
|
|
except Exception as e: |
|
|
print(f"模型加载失败,请检查安装和网络连接: {e}") |
|
|
|
|
|
|
|
|
|
|
|
LABEL_MAP = {0: "消极 (Negative)", 1: "积极 (Positive)"} |
|
|
|
|
|
|
|
|
|
|
|
def classify_text(text): |
|
|
""" |
|
|
接收用户输入的文本,返回分类结果和置信度。 |
|
|
""" |
|
|
if not text: |
|
|
return "请输入需要分类的文本。", None, None, None |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, |
|
|
padding=True, |
|
|
truncation=True, |
|
|
max_length=128, |
|
|
return_tensors='pt').to(DEVICE) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
|
|
|
|
|
|
predictions = torch.softmax(outputs.logits, dim=1)[0] |
|
|
|
|
|
|
|
|
predicted_class_id = torch.argmax(predictions).item() |
|
|
predicted_label = LABEL_MAP[predicted_class_id] |
|
|
|
|
|
|
|
|
score_negative = predictions[0].item() |
|
|
score_positive = predictions[1].item() |
|
|
|
|
|
|
|
|
result_text = f"预测类别:**{predicted_label}**" |
|
|
|
|
|
|
|
|
confidence_dict = { |
|
|
"消极 (Negative)": score_negative, |
|
|
"积极 (Positive)": score_positive |
|
|
} |
|
|
|
|
|
return result_text, confidence_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
title = "Hugging Face 中文情感分析演示" |
|
|
description = "使用 uer/roberta-base-finetuned-jd-binary-chinese 模型对输入的中文文本进行积极/消极情感二分类。" |
|
|
examples = [ |
|
|
["这家餐厅的菜味道太棒了,服务员也很热情。"], |
|
|
["我等了两个小时,包裹还没送到,非常生气。"], |
|
|
["我对这款产品不满意,但也不算太差。"] |
|
|
] |
|
|
|
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=classify_text, |
|
|
inputs=gr.Textbox(lines=5, label="输入您的中文文本"), |
|
|
outputs=[ |
|
|
gr.Markdown(label="分类结果"), |
|
|
gr.Label(label="置信度分数", num_top_classes=2) |
|
|
], |
|
|
title=title, |
|
|
description=description, |
|
|
examples=examples |
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |
|
|
|