scam-detectorv2 / bert_explainer.py
jerrynnms's picture
Update bert_explainer.py
c423743 verified
raw
history blame
3.63 kB
import torch
from transformers import BertTokenizer, BertModel
from AI_Model_architecture import BertLSTM_CNN_Classifier
import re
import os
import requests
# ✅ 使用 CPU 模式(部署環境通用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ✅ 模型權重與儲存位置
model_path = "/tmp/model.pth"
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
# ✅ 初次下載模型(如果不存在)
if not os.path.exists(model_path):
print("📦 下載 model.pth 中...")
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
print("✅ 模型下載完成")
else:
raise FileNotFoundError("❌ 無法下載 model.pth,請檢查網址是否正確")
# ✅ 初始化 tokenizer
tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
# ✅ 初始化自訂分類模型
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# ✅ 初始化 ckiplab BERT 模型,用於抽取 attention 可疑詞(與分類模型無關)
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
bert_model.to(device)
bert_model.eval()
# ✅ 單句推論(輸出預測結果與信心值)
def predict_single_sentence(text: str, max_len=256):
text = re.sub(r"\s+", "", text)
text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
token_type_ids = encoded["token_type_ids"].to(device)
with torch.no_grad():
output = model(input_ids, attention_mask, token_type_ids)
prob = output.item()
label = int(prob > 0.5)
return label, prob
# ✅ 擷取 BERT attention 權重最高的詞(作為可疑詞)
def extract_attention_keywords(text, top_k=5):
cleaned = re.sub(r"\s+", "", text)
inputs = tokenizer(cleaned, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
attentions = outputs.attentions # [layers][batch, heads, seq, seq]
# 取最後一層,平均所有 head 與所有 attention 給 token 的權重
attn = attentions[-1][0].mean(dim=0).mean(dim=0) # [seq]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
top_indices = attn.topk(top_k).indices.tolist()
top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
return top_tokens
# ✅ 主函式:整合推論與關鍵詞標註
def analyze_text(text: str):
label, prob = predict_single_sentence(text)
prob_percent = round(prob * 100, 2)
status = "詐騙" if label == 1 else "正常"
# 風險說明(僅作為備用顯示)
if prob > 0.9:
risk = "🔴 高風險(極可能是詐騙)"
elif prob > 0.5:
risk = "🟡 中風險(可疑)"
else:
risk = "🟢 低風險(正常)"
# 自動抽取可疑詞
attention_keywords = extract_attention_keywords(text)
return {
"status": status,
"confidence": prob_percent,
"suspicious_keywords": attention_keywords or [risk]
}