Spaces:
Sleeping
Sleeping
File size: 3,626 Bytes
310d6ab c423743 995e13a 310d6ab 995e13a 310d6ab c423743 310d6ab c423743 310d6ab 995e13a 310d6ab c423743 995e13a 310d6ab 995e13a c423743 310d6ab c423743 310d6ab c423743 310d6ab c423743 995e13a c423743 310d6ab 995e13a 310d6ab 995e13a 310d6ab 995e13a 310d6ab c423743 995e13a c423743 310d6ab c423743 995e13a c423743 310d6ab 995e13a c423743 310d6ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
from transformers import BertTokenizer, BertModel
from AI_Model_architecture import BertLSTM_CNN_Classifier
import re
import os
import requests
# ✅ 使用 CPU 模式(部署環境通用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ✅ 模型權重與儲存位置
model_path = "/tmp/model.pth"
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
# ✅ 初次下載模型(如果不存在)
if not os.path.exists(model_path):
print("📦 下載 model.pth 中...")
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
print("✅ 模型下載完成")
else:
raise FileNotFoundError("❌ 無法下載 model.pth,請檢查網址是否正確")
# ✅ 初始化 tokenizer
tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
# ✅ 初始化自訂分類模型
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# ✅ 初始化 ckiplab BERT 模型,用於抽取 attention 可疑詞(與分類模型無關)
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
bert_model.to(device)
bert_model.eval()
# ✅ 單句推論(輸出預測結果與信心值)
def predict_single_sentence(text: str, max_len=256):
text = re.sub(r"\s+", "", text)
text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=max_len)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
token_type_ids = encoded["token_type_ids"].to(device)
with torch.no_grad():
output = model(input_ids, attention_mask, token_type_ids)
prob = output.item()
label = int(prob > 0.5)
return label, prob
# ✅ 擷取 BERT attention 權重最高的詞(作為可疑詞)
def extract_attention_keywords(text, top_k=5):
cleaned = re.sub(r"\s+", "", text)
inputs = tokenizer(cleaned, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
attentions = outputs.attentions # [layers][batch, heads, seq, seq]
# 取最後一層,平均所有 head 與所有 attention 給 token 的權重
attn = attentions[-1][0].mean(dim=0).mean(dim=0) # [seq]
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
top_indices = attn.topk(top_k).indices.tolist()
top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
return top_tokens
# ✅ 主函式:整合推論與關鍵詞標註
def analyze_text(text: str):
label, prob = predict_single_sentence(text)
prob_percent = round(prob * 100, 2)
status = "詐騙" if label == 1 else "正常"
# 風險說明(僅作為備用顯示)
if prob > 0.9:
risk = "🔴 高風險(極可能是詐騙)"
elif prob > 0.5:
risk = "🟡 中風險(可疑)"
else:
risk = "🟢 低風險(正常)"
# 自動抽取可疑詞
attention_keywords = extract_attention_keywords(text)
return {
"status": status,
"confidence": prob_percent,
"suspicious_keywords": attention_keywords or [risk]
}
|