File size: 3,851 Bytes
efadcd3
b08d402
e24b982
40b9204
 
a84bee6
e24b982
40b9204
 
57f33d4
 
40b9204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57f33d4
 
 
 
 
 
 
 
40b9204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0262299
 
ba8d073
0262299
21569c9
 
 
 
 
a84bee6
 
21569c9
1a30aac
21569c9
0262299
 
21569c9
 
 
0262299
 
 
 
21569c9
40b9204
efadcd3
 
7511b3c
efadcd3
2100448
 
 
7511b3c
efadcd3
7511b3c
 
 
 
 
 
 
 
 
efadcd3
7511b3c
 
 
 
 
 
 
 
efadcd3
40b9204
fd3565c
40b9204
4a87731
40b9204
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
from bert_explainer import analyze_text, analyze_image
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
import uvicorn



# ✅ 初始化 FastAPI
api = FastAPI()

# ✅ 開放 CORS(避免跨網域錯誤)
api.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ✅ API 路由:健康檢查
@api.get("/health")
def health_check():
    return {"status": "ok"}

# ✅ API 路由:測試文字分析 GET 方法
@api.get("/run/predict_text")
def test_predict_text():
    result = analyze_text("這是測試訊息", explain_mode="cnn")
    return {
        "status": result["status"],
        "confidence": f'{result["confidence"]}%',
        "suspicious_keywords": result["suspicious_keywords"]
    }

# ✅ API 路由:正式 POST 方法
@api.post("/run/predict_text")
def predict_text_api(payload: dict):
    try:
        text, mode = payload["data"]
        result = analyze_text(text=text, explain_mode=mode)
        return {
            "data": [
                result["status"],
                f'{result["confidence"]}%',
                ", ".join(result["suspicious_keywords"])
            ]
        }
    except Exception as e:
        return {"error": str(e)}

# ✅ API 路由:圖片分析(POST)
@api.post("/run/predict_image")
async def predict_image_api(file: UploadFile = File(...), explain_mode: str = Form(...)):
    try:
        if not file:
            raise ValueError("未上傳圖片")
        if not explain_mode:
            raise ValueError("未指定分析模式")
        
        img_bytes = await file.read()
        if not img_bytes:
            raise ValueError("圖片內容為空")
        print(f"收到圖片: {file.filename}, 模式: {explain_mode}")
        result = analyze_image(img_bytes, explain_mode=explain_mode)

        return {
            "status": result["status"],
            "confidence": f'{result["confidence"]}%',
            "suspicious_keywords": result["suspicious_keywords"]
        }

    except Exception as e:
        return {"error": str(e)}

# ✅ Gradio UI 功能
def predict_text(text, mode):
    result = analyze_text(text=text, explain_mode=mode)
    return result["status"], f"{result['confidence']}%", ", ".join(result["suspicious_keywords"])

def predict_image(file_path, mode):
    with open(file_path, "rb") as f:
        result = analyze_image(f.read(), explain_mode=mode)
    return result["status"], f"{result['confidence']}%", ", ".join(result["suspicious_keywords"])

with gr.Blocks() as demo:
    with gr.Tab("文字模式"):
        text_input = gr.Textbox(lines=3, label="輸入文字")
        text_mode = gr.Radio(["cnn", "bert", "both"], value="cnn", label="分析模式")
        text_btn = gr.Button("提交")
        text_output1 = gr.Textbox(label="判斷結果")
        text_output2 = gr.Textbox(label="置信度")
        text_output3 = gr.Textbox(label="可疑詞彙")
        text_btn.click(fn=predict_text, inputs=[text_input, text_mode], outputs=[text_output1, text_output2, text_output3])

    with gr.Tab("圖片模式"):
        image_input = gr.Image(type="filepath", label="上傳圖片")
        image_mode = gr.Radio(["cnn", "bert", "both"], value="cnn", label="分析模式")
        image_btn = gr.Button("提交")
        image_output1 = gr.Textbox(label="判斷結果")
        image_output2 = gr.Textbox(label="置信度")
        image_output3 = gr.Textbox(label="可疑詞彙")
        image_btn.click(fn=predict_image, inputs=[image_input, image_mode], outputs=[image_output1, image_output2, image_output3])

# ✅ 啟用 Gradio + FastAPI 整合
app = gr.mount_gradio_app(api, demo, path="/")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)