Spaces:
Sleeping
Sleeping
| import os | |
| # ✅ Hugging Face 建議路徑(防止 cache 錯誤) | |
| os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers" | |
| os.environ["HF_HOME"] = "/tmp/huggingface" | |
| os.environ["TORCH_HOME"] = "/tmp/torch" | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from datetime import datetime | |
| from typing import Optional, List | |
| from bert_explainer import analyze_text as bert_analyze_text | |
| from firebase_admin import credentials, firestore | |
| import firebase_admin | |
| import pytz | |
| import json | |
| import requests | |
| import torch | |
| from PIL import Image | |
| import pytesseract | |
| import io | |
| # ✅ 初始化 FastAPI | |
| app = FastAPI( | |
| title="詐騙訊息辨識 API", | |
| description="使用 BERT 模型分析輸入文字或圖片是否為詐騙內容", | |
| version="2.0.0" | |
| ) | |
| # ✅ 跨域處理 | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ✅ 掛載靜態檔案:支援 script.js / style.css | |
| app.mount("/static", StaticFiles(directory="."), name="static") | |
| # ✅ 回傳首頁 index.html | |
| async def serve_index(): | |
| return FileResponse("index.html") | |
| # ✅ Firebase 初始化 | |
| try: | |
| cred_data = os.getenv("FIREBASE_CREDENTIALS") | |
| if not cred_data: | |
| raise ValueError("FIREBASE_CREDENTIALS 環境變數未設置") | |
| cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)}) | |
| firebase_admin.initialize_app(cred) | |
| db = firestore.client() | |
| except Exception as e: | |
| print(f"Firebase 初始化錯誤: {e}") | |
| # ✅ 從 Hugging Face Hub 載入模型(改為 /tmp) | |
| def load_model_from_hub(): | |
| model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth" | |
| model_path = "/tmp/model.pth" | |
| response = requests.get(model_url) | |
| if response.status_code == 200: | |
| with open(model_path, "wb") as f: | |
| f.write(response.content) | |
| return model_path | |
| raise FileNotFoundError("❌ 無法從 Hugging Face 載入 model.pth") | |
| model_path = "/tmp/model.pth" | |
| if not os.path.exists(model_path): | |
| model_path = load_model_from_hub() | |
| from AI_Model_architecture import BertLSTM_CNN_Classifier | |
| model = BertLSTM_CNN_Classifier() | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| # ✅ 資料格式 | |
| class TextAnalysisRequest(BaseModel): | |
| text: str | |
| user_id: Optional[str] = None | |
| class TextAnalysisResponse(BaseModel): | |
| status: str | |
| confidence: float | |
| suspicious_keywords: List[str] | |
| analysis_timestamp: datetime | |
| text_id: str | |
| # ✅ /predict API(文字分析) | |
| async def analyze_text_api(request: TextAnalysisRequest): | |
| try: | |
| tz = pytz.timezone("Asia/Taipei") | |
| now = datetime.now(tz) | |
| doc_id = now.strftime("%Y%m%dT%H%M%S") | |
| date_str = now.strftime("%Y-%m-%d %H:%M:%S") | |
| collection = now.strftime("%Y%m%d") | |
| result = bert_analyze_text(request.text) | |
| record = { | |
| "text_id": doc_id, | |
| "text": request.text, | |
| "user_id": request.user_id, | |
| "analysis_result": result, | |
| "timestamp": date_str, | |
| "type": "text_analysis" | |
| } | |
| db.collection(collection).document(doc_id).set(record) | |
| return TextAnalysisResponse( | |
| status=result["status"], | |
| confidence=result["confidence"], | |
| suspicious_keywords=result["suspicious_keywords"], | |
| analysis_timestamp=now, | |
| text_id=doc_id | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ✅ /feedback API | |
| async def save_user_feedback(feedback: dict): | |
| try: | |
| tz = pytz.timezone("Asia/Taipei") | |
| timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") | |
| feedback["timestamp"] = timestamp_str | |
| feedback["used_in_training"] = False | |
| db.collection("user_feedback").add(feedback) | |
| return {"message": "✅ 已記錄使用者回饋"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # ✅ /analyze-image API(圖片上傳 → OCR → 判斷是否為詐騙) | |
| async def analyze_uploaded_image(file: UploadFile = File(...)): | |
| try: | |
| image_bytes = await file.read() | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # ✅ OCR 擷取圖片中文字(支援繁中+英文) | |
| extracted_text = pytesseract.image_to_string(image, lang="chi_tra+eng") | |
| # ✅ 丟入模型判斷是否為詐騙 | |
| result = bert_analyze_text(extracted_text) | |
| return { | |
| "extracted_text": extracted_text.strip(), | |
| "analysis_result": result | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |