Spaces:
Runtime error
Runtime error
File size: 4,847 Bytes
7d705c2 bf33743 abb51e8 0bd42e9 abb51e8 bf33743 abb51e8 b2901f6 abb51e8 c1b60cd abb51e8 7d705c2 abb51e8 7d705c2 abb51e8 7d705c2 abb51e8 7d705c2 abb51e8 7d705c2 abb51e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import gradio as gr
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
names = ['负向', '正向']
# 分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
class Model(nn.Module):
def __init__(self, bert_model):
super().__init__()
self.bert = bert_model
# 全连接,模型输入为768,分类为2
self.fc = nn.Linear(768, 2)
#
def forward(self, input_ids, attention_mask, token_type_ids):
# 使用预训练模型提取特征, 上游任务不参与训练,锁定权重
with torch.no_grad():
# Correctly call the BertModel instance stored in self.bert
output = self.bert(input_ids, attention_mask, token_type_ids)
# 下游参与训练,二分类任务,获取最新后的状态
output = self.fc(output.last_hidden_state[:, 0])
# softmax激活函数,NV结构,获取特征值dim维度为1
output = output.softmax(dim=1)
return output
# 加载预训练模型
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese").to(device)
model = Model(bert_model).to(device)
model.load_state_dict(torch.load("params/1bert.pt", map_location=torch.device(device)))
# 切换到eval模式
model.eval()
def collate_fn(data):
sentes = []
sentes.append(data)
#编码
data = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentes,
truncation=True,
padding="max_length",
max_length=350,
return_tensors="pt",
return_length=True
)
input_ids = data["input_ids"]
attention_mask = data["attention_mask"]
token_type_ids = data["token_type_ids"]
return input_ids, attention_mask, token_type_ids
def analyze_sentiment(text):
input_ids, attention_mask, token_type_ids = collate_fn(text)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
# 上游不参与训练
with torch.no_grad():
out = model(input_ids, attention_mask, token_type_ids)
# 找到每个样本在指定维度上的最大值的索引
out = out.argmax(dim=1)
return f"{names[out]}评价", names[out]
def create_interface():
"""
创建Gradio界面
"""
with gr.Blocks(title="情感分析应用", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🎭 情感分析应用")
gr.Markdown("输入文本,AI将分析其情感倾向")
with gr.Row():
with gr.Column(scale=2):
# 输入区域
text_input = gr.Textbox(
label="输入要分析的文本",
placeholder="请输入您想要分析情感的文本...",
lines=4,
max_lines=10
)
# 按钮
with gr.Row():
analyze_btn = gr.Button("🔍 分析情感", variant="primary")
clear_btn = gr.Button("🗑️ 清空", variant="secondary")
with gr.Column(scale=2):
# 输出区域
result_summary = gr.Textbox(
label="分析结果",
lines=3,
interactive=False
)
# 情感标签显示
sentiment_label = gr.Label(
label="情感分类",
)
# 示例文本
gr.Markdown("### 📝 示例文本")
examples = gr.Examples(
examples=[
["这个产品真的很棒,我非常满意!"],
["服务态度太差了,完全不推荐"],
["还可以吧,没什么特别的感觉"],
["质量很好,物流也很快,五星好评!"],
["价格太贵了,性价比不高"]
],
inputs=text_input,
outputs=[result_summary, sentiment_label],
fn=analyze_sentiment,
cache_examples=False
)
# 绑定事件
analyze_btn.click(
fn=analyze_sentiment,
inputs=text_input,
outputs=[result_summary, sentiment_label]
)
clear_btn.click(
fn=lambda: ("", "", ""),
outputs=[text_input, result_summary, sentiment_label]
)
# 回车键触发分析
text_input.submit(
fn=analyze_sentiment,
inputs=text_input,
outputs=[result_summary, sentiment_label]
)
return demo
if __name__ == "__main__":
demo = create_interface()
demo.launch(debug=True) |