import gradio as gr import torch import torch.nn as nn from transformers import BertTokenizer, BertModel # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = "cpu" names = ['负向', '正向'] # 分词器 tokenizer = BertTokenizer.from_pretrained("bert-base-chinese") class Model(nn.Module): def __init__(self, bert_model): super().__init__() self.bert = bert_model # 全连接,模型输入为768,分类为2 self.fc = nn.Linear(768, 2) # def forward(self, input_ids, attention_mask, token_type_ids): # 使用预训练模型提取特征, 上游任务不参与训练,锁定权重 with torch.no_grad(): # Correctly call the BertModel instance stored in self.bert output = self.bert(input_ids, attention_mask, token_type_ids) # 下游参与训练,二分类任务,获取最新后的状态 output = self.fc(output.last_hidden_state[:, 0]) # softmax激活函数,NV结构,获取特征值dim维度为1 output = output.softmax(dim=1) return output # 加载预训练模型 bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese").to(device) model = Model(bert_model).to(device) model.load_state_dict(torch.load("params/1bert.pt", map_location=torch.device(device))) # 切换到eval模式 model.eval() def collate_fn(data): sentes = [] sentes.append(data) #编码 data = tokenizer.batch_encode_plus( batch_text_or_text_pairs=sentes, truncation=True, padding="max_length", max_length=350, return_tensors="pt", return_length=True ) input_ids = data["input_ids"] attention_mask = data["attention_mask"] token_type_ids = data["token_type_ids"] return input_ids, attention_mask, token_type_ids def analyze_sentiment(text): input_ids, attention_mask, token_type_ids = collate_fn(text) input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) token_type_ids = token_type_ids.to(device) # 上游不参与训练 with torch.no_grad(): out = model(input_ids, attention_mask, token_type_ids) # 找到每个样本在指定维度上的最大值的索引 out = out.argmax(dim=1) return f"{names[out]}评价", names[out] def create_interface(): """ 创建Gradio界面 """ with gr.Blocks(title="情感分析应用", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎭 情感分析应用") gr.Markdown("输入文本,AI将分析其情感倾向") with gr.Row(): with gr.Column(scale=2): # 输入区域 text_input = gr.Textbox( label="输入要分析的文本", placeholder="请输入您想要分析情感的文本...", lines=4, max_lines=10 ) # 按钮 with gr.Row(): analyze_btn = gr.Button("🔍 分析情感", variant="primary") clear_btn = gr.Button("🗑️ 清空", variant="secondary") with gr.Column(scale=2): # 输出区域 result_summary = gr.Textbox( label="分析结果", lines=3, interactive=False ) # 情感标签显示 sentiment_label = gr.Label( label="情感分类", ) # 示例文本 gr.Markdown("### 📝 示例文本") examples = gr.Examples( examples=[ ["这个产品真的很棒,我非常满意!"], ["服务态度太差了,完全不推荐"], ["还可以吧,没什么特别的感觉"], ["质量很好,物流也很快,五星好评!"], ["价格太贵了,性价比不高"] ], inputs=text_input, outputs=[result_summary, sentiment_label], fn=analyze_sentiment, cache_examples=False ) # 绑定事件 analyze_btn.click( fn=analyze_sentiment, inputs=text_input, outputs=[result_summary, sentiment_label] ) clear_btn.click( fn=lambda: ("", "", ""), outputs=[text_input, result_summary, sentiment_label] ) # 回车键触发分析 text_input.submit( fn=analyze_sentiment, inputs=text_input, outputs=[result_summary, sentiment_label] ) return demo if __name__ == "__main__": demo = create_interface() demo.launch(debug=True)