ranranlove's picture
Create app.py
5b67a04 verified
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch
import os
# --- 1. 全局模型加载 (---
MODEL_NAME = "uer/roberta-base-finetuned-jd-binary-chinese"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载分词器和模型
try:
print(f"正在加载模型: {MODEL_NAME} 到设备: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.to(DEVICE)
model.eval() # 将模型设置为评估模式
print("模型加载成功。")
except Exception as e:
print(f"模型加载失败,请检查安装和网络连接: {e}")
# 在实际应用中,您可能需要优雅地退出或提供一个备用方案
# 定义类别映射
LABEL_MAP = {0: "消极 (Negative)", 1: "积极 (Positive)"}
# --- 2. Gradio 核心预测函数 ---
def classify_text(text):
"""
接收用户输入的文本,返回分类结果和置信度。
"""
if not text:
return "请输入需要分类的文本。", None, None, None
# 分词和预处理
inputs = tokenizer(text,
padding=True,
truncation=True,
max_length=128,
return_tensors='pt').to(DEVICE)
# 模型预测
with torch.no_grad():
outputs = model(**inputs)
# 处理输出结果
# logits -> softmax 转换为概率
predictions = torch.softmax(outputs.logits, dim=1)[0] # 取出第一个(也是唯一的)输入结果
# 预测的类别
predicted_class_id = torch.argmax(predictions).item()
predicted_label = LABEL_MAP[predicted_class_id]
# 预测分数
score_negative = predictions[0].item()
score_positive = predictions[1].item()
# 格式化输出文本
result_text = f"预测类别:**{predicted_label}**"
# 格式化置信度字典,用于 Gradio 的 Label 组件
confidence_dict = {
"消极 (Negative)": score_negative,
"积极 (Positive)": score_positive
}
return result_text, confidence_dict
# --- 3. Gradio 接口配置和启动 ---
# 定义演示界面的标题和描述
title = "Hugging Face 中文情感分析演示"
description = "使用 uer/roberta-base-finetuned-jd-binary-chinese 模型对输入的中文文本进行积极/消极情感二分类。"
examples = [
["这家餐厅的菜味道太棒了,服务员也很热情。"],
["我等了两个小时,包裹还没送到,非常生气。"],
["我对这款产品不满意,但也不算太差。"]
]
# 创建 Gradio 接口
iface = gr.Interface(
fn=classify_text, # 绑定上面定义的预测函数
inputs=gr.Textbox(lines=5, label="输入您的中文文本"), # 输入组件
outputs=[
gr.Markdown(label="分类结果"), # 显示最终的预测结果文本
gr.Label(label="置信度分数", num_top_classes=2) # 显示两个类别的概率
],
title=title,
description=description,
examples=examples
)
# 启动 Web 服务
iface.launch()