scam-detectorv2 / bert_explainer.py
jerrynnms's picture
Update bert_explainer.py
24fd407 verified
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import io
import re
import requests
import torch
import jieba
import numpy as np
import cv2
import pytesseract
from PIL import Image
from transformers import BertTokenizer, BertModel
from AI_Model_architecture import BertLSTM_CNN_Classifier
# ─────────────────────────────────────────────────────────────────────────────
# 1. Device 設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 2. 下載並載入自訂分類模型(BertLSTM_CNN_Classifier)
model_path = "/tmp/model.pth"
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
if not os.path.exists(model_path):
print("📦 下載 model.pth 中...")
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
print("✅ 模型下載完成")
else:
raise FileNotFoundError("❌ 無法下載 model.pth")
# 3. 初始化 tokenizer 與自訂分類模型
tokenizer = BertTokenizer.from_pretrained("ckiplab/bert-base-chinese")
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# 4. 初始化原始 BERT 模型(供 attention 使用)
bert_model = BertModel.from_pretrained("ckiplab/bert-base-chinese", output_attentions=True)
bert_model.to(device)
bert_model.eval()
# ─────────────────────────────────────────────────────────────────────────────
# ─────────────────────────────────────────────────────────────────────────────
# 5. 預測單句文字函式
def predict_single_sentence(text: str, max_len=256):
# 5.1. 簡單清洗:移除空白、保留中英文和部分標點
text = re.sub(r"\s+", "", text)
text = re.sub(r"[^\u4e00-\u9fffA-Za-z0-9。,!?:/.\-]", "", text)
# 5.2. Tokenize 並轉成 Tensor
encoded = tokenizer(text, return_tensors="pt", truncation=True,
padding="max_length", max_length=max_len)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
token_type_ids = encoded["token_type_ids"].to(device)
# 5.3. 模型推論
with torch.no_grad():
output = model(input_ids, attention_mask, token_type_ids)
prob = output.item()
label = int(prob > 0.5)
return label, prob
# 6. 抽取高 attention token 並轉換為自然語意詞句
def extract_attention_keywords(text, top_k=5):
# 6.1. 清洗文字(去除空白)
cleaned = re.sub(r"\s+", "", text)
# 6.2. Tokenize 但只需要 attention,不需要分類模型
encoded = tokenizer(cleaned, return_tensors="pt", truncation=True,
padding="max_length", max_length=128)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
# 6.3. 將文字丟給原始 BERT 取最後一層 attention
with torch.no_grad():
outputs = bert_model(input_ids=input_ids, attention_mask=attention_mask)
attentions = outputs.attentions # tuple: 每層 transformer block 的 attention
# 6.4. 取最末層 attention,對所有 head、所有 token 均值 → 一維向量 (seq_len)
attn = attentions[-1][0].mean(dim=0).mean(dim=0) # shape: (seq_len,)
# 6.5. 取得該句所有 token,排除特殊 token
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
top_indices = attn.topk(top_k).indices.tolist()
top_tokens = [tokens[i] for i in top_indices if tokens[i] not in ["[CLS]", "[SEP]", "[PAD]"]]
# 6.6. 用 jieba 切詞,將高 attention 的 token 映射回中文詞組
words = list(jieba.cut(text))
suspicious = []
for word in words:
if len(word.strip()) < 2:
continue
for token in top_tokens:
if token in word and word not in suspicious:
suspicious.append(word)
break
# 6.7. 回傳 top_k 個「可疑詞」;若都沒有映射出詞,就直接回 top_tokens
return suspicious[:top_k] if suspicious else top_tokens[:top_k]
# ─────────────────────────────────────────────────────────────────────────────
# ─────────────────────────────────────────────────────────────────────────────
# 7. 文字分析主函式:回傳完整結構
def analyze_text(text: str):
"""
輸入一段文字(純文字),回傳:
{
"status": "詐騙" / "正常",
"confidence": float(百分比),
"suspicious_keywords": [已擷取詞列表]
}
"""
label, prob = predict_single_sentence(text)
prob_percent = round(prob * 100, 2)
status = "詐騙" if label == 1 else "正常"
suspicious = extract_attention_keywords(text)
return {
"status": status,
"confidence": prob_percent,
"suspicious_keywords": suspicious or ["(模型未聚焦可疑詞)"]
}
# ─────────────────────────────────────────────────────────────────────────────
# ─────────────────────────────────────────────────────────────────────────────
# 以下新增:OCR 前處理+圖片分析相關函式
# 8. 前處理:將圖片做灰階→CLAHE→HSV過濾→二值化→放大→模糊,回傳可供 pytesseract 的 PIL.Image
def preprocess_for_pytesseract(pil_image: Image.Image) -> Image.Image:
"""
將 PIL Image 做以下前處理,回傳「黑底白字」的 PIL Image,供 pytesseract 使用:
1. PIL→NumPy (RGB→BGR)
2. 轉灰階 + CLAHE(對比度增強)
3. HSV 色彩過濾 (示範過濾「橘色」海報底色)
4. 固定阈值反向二值化 (深色文字→白,其他→黑)
5. 放大2倍 + GaussianBlur 模糊
最後再把 NumPy 陣列轉回 PIL Image 回傳。
"""
# 8.1. PIL→NumPy (RGB to BGR)
img_bgr = np.array(pil_image.convert("RGB"))[:, :, ::-1]
# 8.2. 轉灰階
gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
# 8.3. CLAHE (對比度限制自適應直方圖均衡)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# 8.4. HSV 色彩過濾 (此範例針對橘色底色)
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
lower_orange = np.array([5, 100, 100])
upper_orange = np.array([20, 255, 255])
mask_orange = cv2.inRange(hsv, lower_orange, upper_orange)
filtered = enhanced.copy()
filtered[mask_orange > 0] = 255 # 將橘色背景設為白
# 8.5. 固定阈值反向二值化 (深色文字→白,背景→黑)
_, thresh = cv2.threshold(filtered, 200, 255, cv2.THRESH_BINARY_INV)
# 8.6. 放大2倍 & GaussianBlur 平滑
scaled = cv2.resize(thresh, None, fx=2.0, fy=2.0, interpolation=cv2.INTER_CUBIC)
smoothed = cv2.GaussianBlur(scaled, (3, 3), 0)
# 8.7. 將 NumPy (黑底白字) 轉回 PIL Image
return Image.fromarray(smoothed)
# 9. 圖片分析:OCR 擷取文字 → BERT 分析
def analyze_image(file_bytes, explain_mode="cnn"):
"""
輸入圖片 bytes,回傳:
{
"status": "詐騙"/"正常"/"無法辨識文字",
"confidence": float,
"suspicious_keywords": [詞列表]
}
流程:
1. bytes → PIL Image
2. 影像前處理 → 得到黑底白字 PIL Image
3. pytesseract 讀取前處理後影像 → 擷取文字
4. 若讀不到文字 → 回傳「無法辨識」
否則 → 呼叫 analyze_text 做 BERT 分析
"""
# 9.1. bytes → PIL Image
image = Image.open(io.BytesIO(file_bytes))
# 9.2. 前處理:取得 PIL (黑底白字)
processed_img = preprocess_for_pytesseract(image)
# 【可選 Debug】儲存前處理後的影像供檢查
# processed_img.save("/tmp/debug_processed.png")
# 9.3. pytesseract OCR 讀取前處理後影像
# 設定 Tesseract 執行檔路徑(在 Space 上通常已經是 /usr/bin/tesseract)
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
custom_config = r"-l chi_tra+eng --oem 3 --psm 6"
extracted_text = pytesseract.image_to_string(processed_img, config=custom_config).strip()
# 9.4. 如果沒擷取到任何文字,回傳「無法辨識」
if not extracted_text:
return {
"status": "無法辨識文字",
"confidence": 0.0,
"suspicious_keywords": ["圖片中無可辨識的中文英文"]
}
# 9.5. 如果擷取到文字,就直接呼叫 analyze_text 做 BERT 分析
return analyze_text(extracted_text, explain_mode=explain_mode)
# ─────────────────────────────────────────────────────────────────────────────