scam-detectorv2 / app.py
jerrynnms's picture
Update app.py
bdc5ad6 verified
raw
history blame
5.06 kB
import os
# ✅ Hugging Face 建議路徑(防止 cache 錯誤)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TORCH_HOME"] = "/tmp/torch"
from fastapi import FastAPI, HTTPException, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
from datetime import datetime
from typing import Optional, List
from bert_explainer import analyze_text as bert_analyze_text
from firebase_admin import credentials, firestore
import firebase_admin
import pytz
import json
import requests
import torch
from PIL import Image
import pytesseract
import io
# ✅ 初始化 FastAPI
app = FastAPI(
title="詐騙訊息辨識 API",
description="使用 BERT 模型分析輸入文字或圖片是否為詐騙內容",
version="2.0.0"
)
# ✅ 跨域處理
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ✅ 掛載靜態檔案:支援 script.js / style.css
app.mount("/static", StaticFiles(directory="."), name="static")
# ✅ 回傳首頁 index.html
@app.get("/", response_class=FileResponse)
async def serve_index():
return FileResponse("index.html")
# ✅ Firebase 初始化
try:
cred_data = os.getenv("FIREBASE_CREDENTIALS")
if not cred_data:
raise ValueError("FIREBASE_CREDENTIALS 環境變數未設置")
cred = credentials.Certificate({"type": "service_account", **json.loads(cred_data)})
firebase_admin.initialize_app(cred)
db = firestore.client()
except Exception as e:
print(f"Firebase 初始化錯誤: {e}")
# ✅ 從 Hugging Face Hub 載入模型(改為 /tmp)
def load_model_from_hub():
model_url = "https://huggingface.co/jerrynnms/scam-model/resolve/main/model.pth"
model_path = "/tmp/model.pth"
response = requests.get(model_url)
if response.status_code == 200:
with open(model_path, "wb") as f:
f.write(response.content)
return model_path
raise FileNotFoundError("❌ 無法從 Hugging Face 載入 model.pth")
model_path = "/tmp/model.pth"
if not os.path.exists(model_path):
model_path = load_model_from_hub()
from AI_Model_architecture import BertLSTM_CNN_Classifier
model = BertLSTM_CNN_Classifier()
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
# ✅ 資料格式
class TextAnalysisRequest(BaseModel):
text: str
user_id: Optional[str] = None
class TextAnalysisResponse(BaseModel):
status: str
confidence: float
suspicious_keywords: List[str]
analysis_timestamp: datetime
text_id: str
# ✅ /predict API(文字分析)
@app.post("/predict", response_model=TextAnalysisResponse)
async def analyze_text_api(request: TextAnalysisRequest):
try:
tz = pytz.timezone("Asia/Taipei")
now = datetime.now(tz)
doc_id = now.strftime("%Y%m%dT%H%M%S")
date_str = now.strftime("%Y-%m-%d %H:%M:%S")
collection = now.strftime("%Y%m%d")
result = bert_analyze_text(request.text)
record = {
"text_id": doc_id,
"text": request.text,
"user_id": request.user_id,
"analysis_result": result,
"timestamp": date_str,
"type": "text_analysis"
}
db.collection(collection).document(doc_id).set(record)
return TextAnalysisResponse(
status=result["status"],
confidence=result["confidence"],
suspicious_keywords=result["suspicious_keywords"],
analysis_timestamp=now,
text_id=doc_id
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ✅ /feedback API
@app.post("/feedback")
async def save_user_feedback(feedback: dict):
try:
tz = pytz.timezone("Asia/Taipei")
timestamp_str = datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
feedback["timestamp"] = timestamp_str
feedback["used_in_training"] = False
db.collection("user_feedback").add(feedback)
return {"message": "✅ 已記錄使用者回饋"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# ✅ /analyze-image API(圖片上傳 → OCR → 判斷是否為詐騙)
@app.post("/analyze-image")
async def analyze_uploaded_image(file: UploadFile = File(...)):
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
# ✅ OCR 擷取圖片中文字(支援繁中+英文)
extracted_text = pytesseract.image_to_string(image, lang="chi_tra+eng")
# ✅ 丟入模型判斷是否為詐騙
result = bert_analyze_text(extracted_text)
return {
"extracted_text": extracted_text.strip(),
"analysis_result": result
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))