File size: 4,847 Bytes
7d705c2
 
bf33743
 
abb51e8
0bd42e9
 
abb51e8
 
 
 
 
bf33743
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abb51e8
b2901f6
abb51e8
c1b60cd
abb51e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d705c2
abb51e8
 
 
 
 
 
 
7d705c2
abb51e8
 
 
 
 
 
 
 
 
 
 
 
7d705c2
abb51e8
 
 
 
7d705c2
abb51e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d705c2
 
abb51e8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import gradio as gr
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
names = ['负向', '正向']

# 分词器
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

class Model(nn.Module):
    def __init__(self, bert_model):
        super().__init__()
        self.bert = bert_model
        # 全连接,模型输入为768,分类为2
        self.fc = nn.Linear(768, 2)

    #
    def forward(self, input_ids, attention_mask, token_type_ids):
        # 使用预训练模型提取特征, 上游任务不参与训练,锁定权重
        with torch.no_grad():
            # Correctly call the BertModel instance stored in self.bert
            output = self.bert(input_ids, attention_mask, token_type_ids)
        # 下游参与训练,二分类任务,获取最新后的状态
        output = self.fc(output.last_hidden_state[:, 0])
        # softmax激活函数,NV结构,获取特征值dim维度为1
        output = output.softmax(dim=1)
        return output

# 加载预训练模型
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese").to(device)
model = Model(bert_model).to(device)
model.load_state_dict(torch.load("params/1bert.pt", map_location=torch.device(device)))
# 切换到eval模式
model.eval()

def collate_fn(data):
    sentes = []
    sentes.append(data)

    #编码
    data = tokenizer.batch_encode_plus(
        batch_text_or_text_pairs=sentes,
        truncation=True,
        padding="max_length",
        max_length=350,
        return_tensors="pt",
        return_length=True
    )
    input_ids = data["input_ids"]
    attention_mask = data["attention_mask"]
    token_type_ids = data["token_type_ids"]
    return input_ids, attention_mask, token_type_ids


def analyze_sentiment(text):
    input_ids, attention_mask, token_type_ids = collate_fn(text)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    # 上游不参与训练
    with torch.no_grad():
        out = model(input_ids, attention_mask, token_type_ids)
    # 找到每个样本在指定维度上的最大值的索引
    out = out.argmax(dim=1)
    return f"{names[out]}评价", names[out]

def create_interface():
    """
    创建Gradio界面
    """
    with gr.Blocks(title="情感分析应用", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# 🎭 情感分析应用")
        gr.Markdown("输入文本,AI将分析其情感倾向")
        
        with gr.Row():
            with gr.Column(scale=2):
                # 输入区域
                text_input = gr.Textbox(
                    label="输入要分析的文本",
                    placeholder="请输入您想要分析情感的文本...",
                    lines=4,
                    max_lines=10
                )
                
                # 按钮
                with gr.Row():
                    analyze_btn = gr.Button("🔍 分析情感", variant="primary")
                    clear_btn = gr.Button("🗑️ 清空", variant="secondary")
            
            with gr.Column(scale=2):
                # 输出区域
                result_summary = gr.Textbox(
                    label="分析结果",
                    lines=3,
                    interactive=False
                )
                
                # 情感标签显示
                sentiment_label = gr.Label(
                    label="情感分类",
                )
        
        # 示例文本
        gr.Markdown("### 📝 示例文本")
        examples = gr.Examples(
            examples=[
                ["这个产品真的很棒,我非常满意!"],
                ["服务态度太差了,完全不推荐"],
                ["还可以吧,没什么特别的感觉"],
                ["质量很好,物流也很快,五星好评!"],
                ["价格太贵了,性价比不高"]
            ],
            inputs=text_input,
            outputs=[result_summary, sentiment_label],
            fn=analyze_sentiment,
            cache_examples=False
        )
        
        # 绑定事件
        analyze_btn.click(
            fn=analyze_sentiment,
            inputs=text_input,
            outputs=[result_summary, sentiment_label]
        )
        
        clear_btn.click(
            fn=lambda: ("", "", ""),
            outputs=[text_input, result_summary, sentiment_label]
        )
        
        # 回车键触发分析
        text_input.submit(
            fn=analyze_sentiment,
            inputs=text_input,
            outputs=[result_summary, sentiment_label]
        )
    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch(debug=True)