Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device = "cpu" | |
| names = ['负向', '正向'] | |
| # 分词器 | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") | |
| class Model(nn.Module): | |
| def __init__(self, bert_model): | |
| super().__init__() | |
| self.bert = bert_model | |
| # 全连接,模型输入为768,分类为2 | |
| self.fc = nn.Linear(768, 2) | |
| # | |
| def forward(self, input_ids, attention_mask, token_type_ids): | |
| # 使用预训练模型提取特征, 上游任务不参与训练,锁定权重 | |
| with torch.no_grad(): | |
| # Correctly call the BertModel instance stored in self.bert | |
| output = self.bert(input_ids, attention_mask, token_type_ids) | |
| # 下游参与训练,二分类任务,获取最新后的状态 | |
| output = self.fc(output.last_hidden_state[:, 0]) | |
| # softmax激活函数,NV结构,获取特征值dim维度为1 | |
| output = output.softmax(dim=1) | |
| return output | |
| # 加载预训练模型 | |
| bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese").to(device) | |
| model = Model(bert_model).to(device) | |
| model.load_state_dict(torch.load("params/1bert.pt", map_location=torch.device(device))) | |
| # 切换到eval模式 | |
| model.eval() | |
| def collate_fn(data): | |
| sentes = [] | |
| sentes.append(data) | |
| #编码 | |
| data = tokenizer.batch_encode_plus( | |
| batch_text_or_text_pairs=sentes, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=350, | |
| return_tensors="pt", | |
| return_length=True | |
| ) | |
| input_ids = data["input_ids"] | |
| attention_mask = data["attention_mask"] | |
| token_type_ids = data["token_type_ids"] | |
| return input_ids, attention_mask, token_type_ids | |
| def analyze_sentiment(text): | |
| input_ids, attention_mask, token_type_ids = collate_fn(text) | |
| input_ids = input_ids.to(device) | |
| attention_mask = attention_mask.to(device) | |
| token_type_ids = token_type_ids.to(device) | |
| # 上游不参与训练 | |
| with torch.no_grad(): | |
| out = model(input_ids, attention_mask, token_type_ids) | |
| # 找到每个样本在指定维度上的最大值的索引 | |
| out = out.argmax(dim=1) | |
| return f"{names[out]}评价", names[out] | |
| def create_interface(): | |
| """ | |
| 创建Gradio界面 | |
| """ | |
| with gr.Blocks(title="情感分析应用", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🎭 情感分析应用") | |
| gr.Markdown("输入文本,AI将分析其情感倾向") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # 输入区域 | |
| text_input = gr.Textbox( | |
| label="输入要分析的文本", | |
| placeholder="请输入您想要分析情感的文本...", | |
| lines=4, | |
| max_lines=10 | |
| ) | |
| # 按钮 | |
| with gr.Row(): | |
| analyze_btn = gr.Button("🔍 分析情感", variant="primary") | |
| clear_btn = gr.Button("🗑️ 清空", variant="secondary") | |
| with gr.Column(scale=2): | |
| # 输出区域 | |
| result_summary = gr.Textbox( | |
| label="分析结果", | |
| lines=3, | |
| interactive=False | |
| ) | |
| # 情感标签显示 | |
| sentiment_label = gr.Label( | |
| label="情感分类", | |
| ) | |
| # 示例文本 | |
| gr.Markdown("### 📝 示例文本") | |
| examples = gr.Examples( | |
| examples=[ | |
| ["这个产品真的很棒,我非常满意!"], | |
| ["服务态度太差了,完全不推荐"], | |
| ["还可以吧,没什么特别的感觉"], | |
| ["质量很好,物流也很快,五星好评!"], | |
| ["价格太贵了,性价比不高"] | |
| ], | |
| inputs=text_input, | |
| outputs=[result_summary, sentiment_label], | |
| fn=analyze_sentiment, | |
| cache_examples=False | |
| ) | |
| # 绑定事件 | |
| analyze_btn.click( | |
| fn=analyze_sentiment, | |
| inputs=text_input, | |
| outputs=[result_summary, sentiment_label] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", "", ""), | |
| outputs=[text_input, result_summary, sentiment_label] | |
| ) | |
| # 回车键触发分析 | |
| text_input.submit( | |
| fn=analyze_sentiment, | |
| inputs=text_input, | |
| outputs=[result_summary, sentiment_label] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch(debug=True) |