scam-detectorv2 / app.py
jerrynnms's picture
Update app.py
1632676 verified
import os
# ─────── 修復各類 cache 寫入權限問題 ───────
os.environ["XDG_CACHE_HOME"] = "/tmp/.cache"
os.environ["HF_HOME"] = "/tmp/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache"
os.environ["TORCH_HOME"] = "/tmp/.cache"
os.environ["HF_DATASETS_CACHE"] = "/tmp/.cache"
os.makedirs("/tmp/.cache", exist_ok=True)
# ─────── 指定 Tesseract OCR 執行檔路徑 ───────
import pytesseract
pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract"
import io
import json
import requests
import torch
import pytz
import cv2
import numpy as np
from PIL import Image
from datetime import datetime
from typing import Optional, List
from fastapi import FastAPI, HTTPException, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pydantic import BaseModel
from firebase_admin import credentials, firestore
import firebase_admin
from AI_Model_architecture import BertLSTM_CNN_Classifier
from bert_explainer import analyze_text as bert_analyze_text
app = FastAPI(
title="詐騙訊息辨識 API",
description="使用 BERT 模型與 OCR 圖像前處理,辨識文字並做詐騙判斷",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="."), name="static")
@app.get("/", response_class=FileResponse)
async def serve_index():
return FileResponse("index.html")
# ──────────────────────────────────────────────────────────────────────────
# Firebase 初始化
try:
cred_data = os.getenv("FIREBASE_CREDENTIALS")
if not cred_data:
raise ValueError("FIREBASE_CREDENTIALS 未設置")
firebase_cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)})
firebase_admin.initialize_app(firebase_cred)
db = firestore.client()
except Exception as e:
print(f"Firebase 初始化錯誤: {e}")
# ──────────────────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────────────────
# 載入 BERT+LSTM+CNN 模型
model_path = "/tmp/model.pth"
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
if not os.path.exists(model_path):
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
print("✅ 模型下載完成")
else:
raise FileNotFoundError("❌ 無法從 Hugging Face 下載 model.pth")
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# ──────────────────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────────────────
# Pydantic 定義
class TextAnalysisRequest(BaseModel):
text: str
user_id: Optional[str] = None
class TextAnalysisResponse(BaseModel):
status: str
confidence: float
suspicious_keywords: List[str]
analysis_timestamp: datetime
text_id: str
# ──────────────────────────────────────────────────────────────────────────
@app.post("/predict", response_model=TextAnalysisResponse)
async def analyze_text_api(request: TextAnalysisRequest):
"""
純文字輸入分析:使用 BERT 模型判斷詐騙與否,並取得可疑關鍵詞
"""
try:
tz = pytz.timezone("Asia/Taipei")
now = datetime.now(tz)
doc_id = now.strftime("%Y%m%dT%H%M%S")
date_str = now.strftime("%Y-%m-%d %H:%M:%S")
collection = now.strftime("%Y%m%d")
result = bert_analyze_text(request.text)
record = {
"text_id": doc_id,
"text": request.text,
"user_id": request.user_id,
"analysis_result": result,
"timestamp": date_str,
"type": "text_analysis"
}
try:
db.collection(collection).document(doc_id).set(record)
except Exception:
# 如果 Firestore 無法寫入,也不影響回傳結果
pass
return TextAnalysisResponse(
status=result["status"],
confidence=result["confidence"],
suspicious_keywords=result["suspicious_keywords"],
analysis_timestamp=now,
text_id=doc_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/feedback")
async def save_user_feedback(feedback: dict):
"""
使用者回饋:將回饋資料寫入 Firestore
"""
try:
tz = pytz.timezone("Asia/Taipei")
timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
feedback["timestamp"] = timestamp_str
feedback["used_in_training"] = False
try:
db.collection("user_feedback").add(feedback)
except Exception:
pass
return {"message": "✅ 已記錄使用者回饋"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ──────────────────────────────────────────────────────────────────────────
# 強化 OCR 前處理 + 附帶 Debug 圖輸出
def preprocess_image_for_ocr(pil_image: Image.Image) -> Image.Image:
"""
前處理流程:
1. PIL Image → NumPy BGR
2. 灰階 + CLAHE(對比度增強)
3. 橘色背景遮罩 → 將背景橘色轉為白色
4. 固定閾值反向二值化
5. 放大 & GaussianBlur 平滑
中間各步驟會將影像存到 /tmp/debug_*.png,方便除錯
"""
# 1. PIL → NumPy (RGB->BGR)
img = np.array(pil_image.convert("RGB"))[:, :, ::-1]
# 2. 灰階 + CLAHE
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
enhanced = clahe.apply(gray)
# Debug: CLAHE 增強後的灰階
Image.fromarray(enhanced).save("/tmp/debug_clahe.png")
# 3. HSV 色彩分離 (過濾橘色背景)
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
lower_orange = np.array([5, 100, 100])
upper_orange = np.array([20, 255, 255])
mask_orange = cv2.inRange(hsv, lower_orange, upper_orange)
# Debug: 橘色 mask
Image.fromarray(mask_orange).save("/tmp/debug_mask_orange.png")
# 將 mask 範圍內的像素設為白色(255),其餘保留灰階
filtered = enhanced.copy()
filtered[mask_orange > 0] = 255
# Debug: 過濾橘色後的灰階
Image.fromarray(filtered).save("/tmp/debug_filtered.png")
# 4. 固定閾值反向二值化 (threshold 200)
_, thresh = cv2.threshold(filtered, 200, 255, cv2.THRESH_BINARY_INV)
# Debug: 二值化後
Image.fromarray(thresh).save("/tmp/debug_thresh.png")
# 5. 放大 3 倍 & GaussianBlur 平滑
scaled = cv2.resize(thresh, None, fx=3.0, fy=3.0, interpolation=cv2.INTER_CUBIC)
smoothed = cv2.GaussianBlur(scaled, (5, 5), 0)
# Debug: 最終前處理結果
Image.fromarray(smoothed).save("/tmp/debug_processed.png")
return Image.fromarray(smoothed)
# ──────────────────────────────────────────────────────────────────────────
@app.post("/analyze-image")
async def analyze_uploaded_image(file: UploadFile = File(...)):
"""
圖片上傳並進行 OCR 辨識:
1. 讀取 UploadFile → PIL Image
2. 呼叫 preprocess_image_for_ocr 進行前處理 (並輸出 debug)
3. 用 pytesseract 擷取文字
4. 若擷取到文字,送給 BERT 做詐騙判斷
5. 回傳 JSON 包含 extracted_text 與 analysis_result
"""
# 1) 確認收到檔案
print("🔍 [DEBUG] 收到 analyze-image,檔名 =", file.filename)
try:
# 2) 讀取圖片 bytes,再轉成 PIL Image
image_bytes = await file.read()
print("🔍 [DEBUG] 圖片 bytes 長度 =", len(image_bytes))
pil_img = Image.open(io.BytesIO(image_bytes))
print("🔍 [DEBUG] PIL 成功開啟圖片,格式 =", pil_img.format, "大小 =", pil_img.size)
# 3) 強化前處理 (並產出 debug 影像)
processed_image = preprocess_image_for_ocr(pil_img)
# 4) Tesseract OCR
custom_config = r"-l chi_tra+eng --oem 3 --psm 6"
extracted_text = pytesseract.image_to_string(
processed_image,
config=custom_config
).strip()
print("🔍 [DEBUG] Tesseract 擷取文字 =", repr(extracted_text))
# 5) 如果沒有擷取到任何文字
if not extracted_text:
return JSONResponse({
"extracted_text": "",
"analysis_result": {
"status": "無法辨識",
"confidence": 0.0,
"suspicious_keywords": ["無法擷取分析結果"]
}
})
# 6) 擷取到文字後,呼叫 BERT 模型做詐騙判斷
result = bert_analyze_text(extracted_text)
return JSONResponse({
"extracted_text": extracted_text,
"analysis_result": result
})
except Exception as e:
# 印出詳細錯誤堆疊
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"圖片辨識失敗:{str(e)}")
# ──────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)