scam-detector / bert_explainer.py
jerrynnms's picture
Update bert_explainer.py
995e13a verified
import torch
from AI_Model_architecture import BertLSTM_CNN_Classifier
from transformers import BertTokenizer
import re
import os
import requests
# ✅ 使用 CPU 模式(如果你只部署在 Hugging Face)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ✅ Hugging Face 建議路徑(防止 cache 錯誤)
model_path = "/tmp/model.pth"
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
# ✅ 快取模型檔(僅首次下載)
if not os.path.exists(model_path):
print("📦 下載 model.pth 中...")
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
print("✅ 模型下載完成")
else:
raise FileNotFoundError("❌ 無法下載 model.pth,請檢查網址")
# ✅ 全域快取模型與 tokenizer
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
# ✅ 預測單句文字
def predict_single_sentence(text: str, max_len=256):
text = re.sub(r"\s+", "", text) # 移除空白
text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text) # 清洗非標點與文字
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
token_type_ids = encoded["token_type_ids"].to(device)
with torch.no_grad():
output = model(input_ids, attention_mask, token_type_ids)
prob = output.item()
label = int(prob > 0.5)
return label, prob
# ✅ 封裝為 API 可用格式
def analyze_text(text: str):
label, prob = predict_single_sentence(text)
prob_percent = round(prob * 100, 2)
if prob > 0.9:
risk = "🔴 高風險(極可能是詐騙)"
elif prob > 0.5:
risk = "🟡 中風險(可疑)"
else:
risk = "🟢 低風險(正常)"
status = "詐騙" if label == 1 else "正常"
return {
"status": status,
"confidence": prob_percent,
"suspicious_keywords": [risk] # 這裡之後可進一步做關鍵字標註
}