zhouxiaoxi's picture
Update app.py
c1b60cd verified
import gradio as gr
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
names = ['负向', '正向']
# 分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
class Model(nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert = bert_model
# 全连接,模型输入为768,分类为2
self.fc = nn.Linear(768, 2)
#
def forward(self, input_ids, attention_mask, token_type_ids):
# 使用预训练模型提取特征, 上游任务不参与训练,锁定权重
with torch.no_grad():
# Correctly call the BertModel instance stored in self.bert
output = self.bert(input_ids, attention_mask, token_type_ids)
# 下游参与训练,二分类任务,获取最新后的状态
output = self.fc(output.last_hidden_state[:, 0])
# softmax激活函数,NV结构,获取特征值dim维度为1
output = output.softmax(dim=1)
return output
# 加载预训练模型
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese").to(device)
model = Model(bert_model).to(device)
model.load_state_dict(torch.load("params/1bert.pt", map_location=torch.device(device)))
# 切换到eval模式
model.eval()
def collate_fn(data):
sentes = []
sentes.append(data)
#编码
data = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentes,
truncation=True,
padding="max_length",
max_length=350,
return_tensors="pt",
return_length=True
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
return input_ids, attention_mask, token_type_ids
def analyze_sentiment(text):
input_ids, attention_mask, token_type_ids = collate_fn(text)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
# 上游不参与训练
with torch.no_grad():
out = model(input_ids, attention_mask, token_type_ids)
# 找到每个样本在指定维度上的最大值的索引
out = out.argmax(dim=1)
return f"{names[out]}评价", names[out]
def create_interface():
"""
创建Gradio界面
"""
with gr.Blocks(title="情感分析应用", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎭 情感分析应用")
gr.Markdown("输入文本,AI将分析其情感倾向")
with gr.Row():
with gr.Column(scale=2):
# 输入区域
text_input = gr.Textbox(
label="输入要分析的文本",
placeholder="请输入您想要分析情感的文本...",
lines=4,
max_lines=10
)
# 按钮
with gr.Row():
analyze_btn = gr.Button("🔍 分析情感", variant="primary")
clear_btn = gr.Button("🗑️ 清空", variant="secondary")
with gr.Column(scale=2):
# 输出区域
result_summary = gr.Textbox(
label="分析结果",
lines=3,
interactive=False
)
# 情感标签显示
sentiment_label = gr.Label(
label="情感分类",
)
# 示例文本
gr.Markdown("### 📝 示例文本")
examples = gr.Examples(
examples=[
["这个产品真的很棒,我非常满意!"],
["服务态度太差了,完全不推荐"],
["还可以吧,没什么特别的感觉"],
["质量很好,物流也很快,五星好评!"],
["价格太贵了,性价比不高"]
],
inputs=text_input,
outputs=[result_summary, sentiment_label],
fn=analyze_sentiment,
cache_examples=False
)
# 绑定事件
analyze_btn.click(
fn=analyze_sentiment,
inputs=text_input,
outputs=[result_summary, sentiment_label]
)
clear_btn.click(
fn=lambda: ("", "", ""),
outputs=[text_input, result_summary, sentiment_label]
)
# 回车键触发分析
text_input.submit(
fn=analyze_sentiment,
inputs=text_input,
outputs=[result_summary, sentiment_label]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(debug=True)