Spaces:
Sleeping
Sleeping
| import os | |
| # ─────── 修復各類 cache 寫入權限問題 ─────── | |
| os.environ["XDG_CACHE_HOME"] = "/tmp/.cache" | |
| os.environ["HF_HOME"] = "/tmp/.cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/.cache" | |
| os.environ["TORCH_HOME"] = "/tmp/.cache" | |
| os.environ["HF_DATASETS_CACHE"] = "/tmp/.cache" | |
| os.makedirs("/tmp/.cache", exist_ok=True) | |
| # ─────── 指定 Tesseract OCR 執行檔路徑 ─────── | |
| import pytesseract | |
| pytesseract.pytesseract.tesseract_cmd = "/usr/bin/tesseract" | |
| import io | |
| import json | |
| import requests | |
| import torch | |
| import pytz | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from datetime import datetime | |
| from typing import Optional, List | |
| from fastapi import FastAPI, HTTPException, File, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from pydantic import BaseModel | |
| from firebase_admin import credentials, firestore | |
| import firebase_admin | |
| from AI_Model_architecture import BertLSTM_CNN_Classifier | |
| from bert_explainer import analyze_text as bert_analyze_text | |
| app = FastAPI( | |
| title="詐騙訊息辨識 API", | |
| description="使用 BERT 模型與 OCR 圖像前處理,辨識文字並做詐騙判斷", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| app.mount("/static", StaticFiles(directory="."), name="static") | |
| async def serve_index(): | |
| return FileResponse("index.html") | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # Firebase 初始化 | |
| try: | |
| cred_data = os.getenv("FIREBASE_CREDENTIALS") | |
| if not cred_data: | |
| raise ValueError("FIREBASE_CREDENTIALS 未設置") | |
| firebase_cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)}) | |
| firebase_admin.initialize_app(firebase_cred) | |
| db = firestore.client() | |
| except Exception as e: | |
| print(f"Firebase 初始化錯誤: {e}") | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # 載入 BERT+LSTM+CNN 模型 | |
| model_path = "/tmp/model.pth" | |
| model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth" | |
| if not os.path.exists(model_path): | |
| response = requests.get(model_url) | |
| if response.status_code == 200: | |
| with open(model_path, "wb") as f: | |
| f.write(response.content) | |
| print("✅ 模型下載完成") | |
| else: | |
| raise FileNotFoundError("❌ 無法從 Hugging Face 下載 model.pth") | |
| model = BertLSTM_CNN_Classifier() | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # Pydantic 定義 | |
| class TextAnalysisRequest(BaseModel): | |
| text: str | |
| user_id: Optional[str] = None | |
| class TextAnalysisResponse(BaseModel): | |
| status: str | |
| confidence: float | |
| suspicious_keywords: List[str] | |
| analysis_timestamp: datetime | |
| text_id: str | |
| # ────────────────────────────────────────────────────────────────────────── | |
| async def analyze_text_api(request: TextAnalysisRequest): | |
| """ | |
| 純文字輸入分析:使用 BERT 模型判斷詐騙與否,並取得可疑關鍵詞 | |
| """ | |
| try: | |
| tz = pytz.timezone("Asia/Taipei") | |
| now = datetime.now(tz) | |
| doc_id = now.strftime("%Y%m%dT%H%M%S") | |
| date_str = now.strftime("%Y-%m-%d %H:%M:%S") | |
| collection = now.strftime("%Y%m%d") | |
| result = bert_analyze_text(request.text) | |
| record = { | |
| "text_id": doc_id, | |
| "text": request.text, | |
| "user_id": request.user_id, | |
| "analysis_result": result, | |
| "timestamp": date_str, | |
| "type": "text_analysis" | |
| } | |
| try: | |
| db.collection(collection).document(doc_id).set(record) | |
| except Exception: | |
| # 如果 Firestore 無法寫入,也不影響回傳結果 | |
| pass | |
| return TextAnalysisResponse( | |
| status=result["status"], | |
| confidence=result["confidence"], | |
| suspicious_keywords=result["suspicious_keywords"], | |
| analysis_timestamp=now, | |
| text_id=doc_id | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def save_user_feedback(feedback: dict): | |
| """ | |
| 使用者回饋:將回饋資料寫入 Firestore | |
| """ | |
| try: | |
| tz = pytz.timezone("Asia/Taipei") | |
| timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") | |
| feedback["timestamp"] = timestamp_str | |
| feedback["used_in_training"] = False | |
| try: | |
| db.collection("user_feedback").add(feedback) | |
| except Exception: | |
| pass | |
| return {"message": "✅ 已記錄使用者回饋"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # 強化 OCR 前處理 + 附帶 Debug 圖輸出 | |
| def preprocess_image_for_ocr(pil_image: Image.Image) -> Image.Image: | |
| """ | |
| 前處理流程: | |
| 1. PIL Image → NumPy BGR | |
| 2. 灰階 + CLAHE(對比度增強) | |
| 3. 橘色背景遮罩 → 將背景橘色轉為白色 | |
| 4. 固定閾值反向二值化 | |
| 5. 放大 & GaussianBlur 平滑 | |
| 中間各步驟會將影像存到 /tmp/debug_*.png,方便除錯 | |
| """ | |
| # 1. PIL → NumPy (RGB->BGR) | |
| img = np.array(pil_image.convert("RGB"))[:, :, ::-1] | |
| # 2. 灰階 + CLAHE | |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) | |
| enhanced = clahe.apply(gray) | |
| # Debug: CLAHE 增強後的灰階 | |
| Image.fromarray(enhanced).save("/tmp/debug_clahe.png") | |
| # 3. HSV 色彩分離 (過濾橘色背景) | |
| hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
| lower_orange = np.array([5, 100, 100]) | |
| upper_orange = np.array([20, 255, 255]) | |
| mask_orange = cv2.inRange(hsv, lower_orange, upper_orange) | |
| # Debug: 橘色 mask | |
| Image.fromarray(mask_orange).save("/tmp/debug_mask_orange.png") | |
| # 將 mask 範圍內的像素設為白色(255),其餘保留灰階 | |
| filtered = enhanced.copy() | |
| filtered[mask_orange > 0] = 255 | |
| # Debug: 過濾橘色後的灰階 | |
| Image.fromarray(filtered).save("/tmp/debug_filtered.png") | |
| # 4. 固定閾值反向二值化 (threshold 200) | |
| _, thresh = cv2.threshold(filtered, 200, 255, cv2.THRESH_BINARY_INV) | |
| # Debug: 二值化後 | |
| Image.fromarray(thresh).save("/tmp/debug_thresh.png") | |
| # 5. 放大 3 倍 & GaussianBlur 平滑 | |
| scaled = cv2.resize(thresh, None, fx=3.0, fy=3.0, interpolation=cv2.INTER_CUBIC) | |
| smoothed = cv2.GaussianBlur(scaled, (5, 5), 0) | |
| # Debug: 最終前處理結果 | |
| Image.fromarray(smoothed).save("/tmp/debug_processed.png") | |
| return Image.fromarray(smoothed) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| async def analyze_uploaded_image(file: UploadFile = File(...)): | |
| """ | |
| 圖片上傳並進行 OCR 辨識: | |
| 1. 讀取 UploadFile → PIL Image | |
| 2. 呼叫 preprocess_image_for_ocr 進行前處理 (並輸出 debug) | |
| 3. 用 pytesseract 擷取文字 | |
| 4. 若擷取到文字,送給 BERT 做詐騙判斷 | |
| 5. 回傳 JSON 包含 extracted_text 與 analysis_result | |
| """ | |
| # 1) 確認收到檔案 | |
| print("🔍 [DEBUG] 收到 analyze-image,檔名 =", file.filename) | |
| try: | |
| # 2) 讀取圖片 bytes,再轉成 PIL Image | |
| image_bytes = await file.read() | |
| print("🔍 [DEBUG] 圖片 bytes 長度 =", len(image_bytes)) | |
| pil_img = Image.open(io.BytesIO(image_bytes)) | |
| print("🔍 [DEBUG] PIL 成功開啟圖片,格式 =", pil_img.format, "大小 =", pil_img.size) | |
| # 3) 強化前處理 (並產出 debug 影像) | |
| processed_image = preprocess_image_for_ocr(pil_img) | |
| # 4) Tesseract OCR | |
| custom_config = r"-l chi_tra+eng --oem 3 --psm 6" | |
| extracted_text = pytesseract.image_to_string( | |
| processed_image, | |
| config=custom_config | |
| ).strip() | |
| print("🔍 [DEBUG] Tesseract 擷取文字 =", repr(extracted_text)) | |
| # 5) 如果沒有擷取到任何文字 | |
| if not extracted_text: | |
| return JSONResponse({ | |
| "extracted_text": "", | |
| "analysis_result": { | |
| "status": "無法辨識", | |
| "confidence": 0.0, | |
| "suspicious_keywords": ["無法擷取分析結果"] | |
| } | |
| }) | |
| # 6) 擷取到文字後,呼叫 BERT 模型做詐騙判斷 | |
| result = bert_analyze_text(extracted_text) | |
| return JSONResponse({ | |
| "extracted_text": extracted_text, | |
| "analysis_result": result | |
| }) | |
| except Exception as e: | |
| # 印出詳細錯誤堆疊 | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"圖片辨識失敗:{str(e)}") | |
| # ────────────────────────────────────────────────────────────────────────── | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |