Stoma-clip-api / app.py
Xiaomeng1130's picture
Update app.py
75059a8 verified
import os
import torch
import gradio as gr
from PIL import Image
import numpy as np
import sys
import time
# ========== 1. Import project modules and Model Configuration ==========
try:
from stoma_clip import pmc_clip
from stoma_clip.pmc_clip.factory import _rescan_model_configs
from stoma_clip.training.fusion_method import convert_model_to_cls
from stoma_clip.training.dataset.utils import encode_mlm
print("Stoma-CLIP modules imported successfully.")
sys.stdout.flush() # 强制刷新输出
except ImportError as e:
print(f"FATAL: Error importing Stoma-CLIP modules: {e}")
sys.stdout.flush()
sys.exit(1)
# ========== 2. Model Configuration and Loading ==========
LABEL_MAP = {
"Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2,
"Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6,
"Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10
}
REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
NUM_CLASSES = len(LABEL_MAP)
class Args:
def __init__(self):
self.model = "RN50_fusion4"
self.pretrained = "stoma_clip.pt"
self.num_classes = NUM_CLASSES
self.mlm = True
self.crop_scale = 0.9
self.context_length = 77
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
sys.stdout.flush()
args = Args()
MODEL = None
PREPROCESS = None
TOKENIZER = None
def load_model():
"""Load model once in the main thread during application initialization."""
global MODEL, PREPROCESS, TOKENIZER
start_time = time.time()
if MODEL is not None:
return MODEL, PREPROCESS, TOKENIZER
print(f"--- Starting Model Load Process at {time.strftime('%H:%M:%S')} ---")
sys.stdout.flush() # 诊断点 1
try:
# Step 1: Create model and transforms
print("1. Rescanning model configs and creating architecture...")
sys.stdout.flush() # 诊断点 2
_rescan_model_configs()
model, _, preprocess = pmc_clip.create_model_and_transforms(args)
model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention')
print("2. Model architecture created. Moving to device...")
sys.stdout.flush() # 诊断点 3
# Move model architecture to GPU/CPU
model.to(args.device).eval()
# Step 2: Load weights - 必须确保 stoma_clip.pt 文件大小合理或复制完整
print(f"3. Loading weights from {args.pretrained} to {args.device}...")
sys.stdout.flush() # 诊断点 4 - 关键点:在执行耗时 I/O 前确保日志已输出
# 强制使用 Float32 加载,然后转换为半精度,如果模型支持的话,有助于加速传输
state_dict = torch.load(args.pretrained, map_location=args.device)
print("4. Weights file loaded. Cleaning state dict...")
sys.stdout.flush() # 诊断点 5
state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()}
# Step 3: Apply weights
print("5. Loading state dict into model architecture...")
sys.stdout.flush() # 诊断点 6
model.load_state_dict(state_dict_clean)
# Step 4: Final setup
tokenizer = model.tokenizer
MODEL = model
PREPROCESS = preprocess
TOKENIZER = tokenizer
end_time = time.time()
print(f"✨ Stoma-CLIP Model loaded successfully! Total time: {end_time - start_time:.2f} seconds.")
sys.stdout.flush() # 诊断点 7
return MODEL, PREPROCESS, TOKENIZER
except Exception as e:
print(f"🔥 Error during model loading: {e}")
sys.stdout.flush()
raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}")
# ========== 3. Inference Function ==========
def predict_stoma_clip(image: Image.Image, caption: str):
# 确保在推理时调用加载模型(仅作为后备/懒加载)
try:
# 如果启动时加载失败,这里会再次尝试,但依赖于全局 MODEL 变量
if MODEL is None:
model, preprocess, tokenizer = load_model()
else:
model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER
except RuntimeError:
return "Model Loading Failed (See Logs)", {}
# ... 原来的推理逻辑保持不变 ...
image = image.convert("RGB")
device = args.device
# 将输入数据移动到 GPU
image_tensor = preprocess(image).unsqueeze(0).to(device)
mask_token, pad_token = '[MASK]', '[PAD]'
vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens]
bert_input, bert_label = encode_mlm(
caption=caption,
vocab=vocab,
mask_token=mask_token,
pad_token=pad_token,
ratio=0.0,
tokenizer=tokenizer,
args=args,
)
with torch.no_grad():
inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label}
outputs = model(inputs)
# 将结果移回 CPU 进行 numpy 转换
probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
predicted_class_idx = torch.argmax(outputs, dim=1).item()
predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown")
probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
return predicted_class_name, probability_distribution
# ========== 4. Gradio Interface Setup ==========
image_input = gr.Image(type="pil", label="上传造口图片")
caption_input = gr.Textbox(label="输入造口描述文本 (例如: Exudate, epidermal breakdown, ...)")
predicted_label_output = gr.Textbox(label="预测类别")
prob_output = gr.Label(label="类别概率分布")
# Find example path for Gradio demo (No Change)
try:
example_path_1 = "demo/Irritant_dermatitis.jpg"
example_path_2 = "demo/Folliculitis.jpg"
examples_list = []
if os.path.exists(example_path_1):
examples_list.append(
[example_path_1, "Exudate, epidermal breakdown, irregular erythema, pain, confined to contact areas"])
elif os.path.exists(example_path_2):
examples_list.append([example_path_2, "Erythema, papules, pustules confined to hair follicles"])
except Exception:
examples_list = []
# 忽略 Gradio 的 allow_flagging 警告
iface = gr.Interface(
fn=predict_stoma_clip,
inputs=[image_input, caption_input],
outputs=[predicted_label_output, prob_output],
title="🧪 Stoma-CLIP 分类 API 原型 (Gradio)",
description="请上传造口图片并输入临床描述,模型将预测最可能的皮肤并发症类别。",
examples=examples_list,
allow_flagging="never"
)
if __name__ == "__main__":
# --- 关键修复:强制在 Gradio launch 之前加载模型,将 I/O 阻塞移到启动阶段 ---
print("Pre-loading model before Gradio launch to prevent runtime timeout...")
sys.stdout.flush()
load_model()
print("Model loaded. Launching Gradio interface...")
sys.stdout.flush()
# 启动 Gradio
iface.launch(server_name="0.0.0.0", server_port=7860)