jerrynnms commited on
Commit
c423743
·
verified ·
1 Parent(s): 995e13a

Update bert_explainer.py

Browse files
Files changed (1) hide show
  1. bert_explainer.py +41 -14
bert_explainer.py CHANGED
@@ -1,18 +1,18 @@
1
  import torch
 
2
  from AI_Model_architecture import BertLSTM_CNN_Classifier
3
- from transformers import BertTokenizer
4
  import re
5
  import os
6
  import requests
7
 
8
- # ✅ 使用 CPU 模式(如果你只部署在 Hugging Face)
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
- # ✅ Hugging Face 建議路徑(防止 cache 錯誤)
12
  model_path = "/tmp/model.pth"
13
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
14
 
15
- # ✅ 快取模型檔(僅首次下載)
16
  if not os.path.exists(model_path):
17
  print("📦 下載 model.pth 中...")
18
  response = requests.get(model_url)
@@ -21,20 +21,26 @@ if not os.path.exists(model_path):
21
  f.write(response.content)
22
  print("✅ 模型下載完成")
23
  else:
24
- raise FileNotFoundError("❌ 無法下載 model.pth,請檢查網址")
25
 
26
- # ✅ 全域快取模型與 tokenizer
 
 
 
27
  model = BertLSTM_CNN_Classifier()
28
  model.load_state_dict(torch.load(model_path, map_location=device))
29
  model.to(device)
30
  model.eval()
31
 
32
- tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
 
 
 
33
 
34
- # ✅ 預測單句文字
35
  def predict_single_sentence(text: str, max_len=256):
36
- text = re.sub(r"\s+", "", text) # 移除空白
37
- text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text) # 清洗非標點與文字
38
 
39
  encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
40
  input_ids = encoded["input_ids"].to(device)
@@ -48,11 +54,32 @@ def predict_single_sentence(text: str, max_len=256):
48
 
49
  return label, prob
50
 
51
- # ✅ 封裝為 API 可用格式
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def analyze_text(text: str):
53
  label, prob = predict_single_sentence(text)
54
  prob_percent = round(prob * 100, 2)
 
55
 
 
56
  if prob > 0.9:
57
  risk = "🔴 高風險(極可能是詐騙)"
58
  elif prob > 0.5:
@@ -60,11 +87,11 @@ def analyze_text(text: str):
60
  else:
61
  risk = "🟢 低風險(正常)"
62
 
63
- status = "詐騙" if label == 1 else "正常"
 
64
 
65
  return {
66
  "status": status,
67
  "confidence": prob_percent,
68
- "suspicious_keywords": [risk] # 這裡之後可進一步做關鍵字標註
69
  }
70
-
 
1
  import torch
2
+ from transformers import BertTokenizer, BertModel
3
  from AI_Model_architecture import BertLSTM_CNN_Classifier
 
4
  import re
5
  import os
6
  import requests
7
 
8
+ # ✅ 使用 CPU 模式(部署環境通用)
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
 
11
+ # ✅ 模型權重與儲存位置
12
  model_path = "/tmp/model.pth"
13
  model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
14
 
15
+ # ✅ 初次下載模型(如果不存在)
16
  if not os.path.exists(model_path):
17
  print("📦 下載 model.pth 中...")
18
  response = requests.get(model_url)
 
21
  f.write(response.content)
22
  print("✅ 模型下載完成")
23
  else:
24
+ raise FileNotFoundError("❌ 無法下載 model.pth,請檢查網址是否正確")
25
 
26
+ # ✅ 初始化 tokenizer
27
+ tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
28
+
29
+ # ✅ 初始化自訂分類模型
30
  model = BertLSTM_CNN_Classifier()
31
  model.load_state_dict(torch.load(model_path, map_location=device))
32
  model.to(device)
33
  model.eval()
34
 
35
+ # 初始化 ckiplab BERT 模型,用於抽取 attention 可疑詞(與分類模型無關)
36
+ bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
37
+ bert_model.to(device)
38
+ bert_model.eval()
39
 
40
+ # ✅ 單句推論(輸出預測結果與信心值)
41
  def predict_single_sentence(text: str, max_len=256):
42
+ text = re.sub(r"\s+", "", text)
43
+ text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
44
 
45
  encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
46
  input_ids = encoded["input_ids"].to(device)
 
54
 
55
  return label, prob
56
 
57
+ # ✅ 擷取 BERT attention 權重最高的詞(作為可疑詞)
58
+ def extract_attention_keywords(text, top_k=5):
59
+ cleaned = re.sub(r"\s+", "", text)
60
+ inputs = tokenizer(cleaned, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
61
+ input_ids = inputs["input_ids"].to(device)
62
+ attention_mask = inputs["attention_mask"].to(device)
63
+
64
+ with torch.no_grad():
65
+ outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
66
+ attentions = outputs.attentions # [layers][batch, heads, seq, seq]
67
+
68
+ # 取最後一層,平均所有 head 與所有 attention 給 token 的權重
69
+ attn = attentions[-1][0].mean(dim=0).mean(dim=0) # [seq]
70
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
71
+ top_indices = attn.topk(top_k).indices.tolist()
72
+ top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
73
+
74
+ return top_tokens
75
+
76
+ # ✅ 主函式:整合推論與關鍵詞標註
77
  def analyze_text(text: str):
78
  label, prob = predict_single_sentence(text)
79
  prob_percent = round(prob * 100, 2)
80
+ status = "詐騙" if label == 1 else "正常"
81
 
82
+ # 風險說明(僅作為備用顯示)
83
  if prob > 0.9:
84
  risk = "🔴 高風險(極可能是詐騙)"
85
  elif prob > 0.5:
 
87
  else:
88
  risk = "🟢 低風險(正常)"
89
 
90
+ # 自動抽取可疑詞
91
+ attention_keywords = extract_attention_keywords(text)
92
 
93
  return {
94
  "status": status,
95
  "confidence": prob_percent,
96
+ "suspicious_keywords": attention_keywords or [risk]
97
  }