Spaces:
Runtime error
Runtime error
File size: 7,440 Bytes
3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 3d67bd4 75059a8 607a3d1 75059a8 607a3d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import os
import torch
import gradio as gr
from PIL import Image
import numpy as np
import sys
import time
# ========== 1. Import project modules and Model Configuration ==========
try:
from stoma_clip import pmc_clip
from stoma_clip.pmc_clip.factory import _rescan_model_configs
from stoma_clip.training.fusion_method import convert_model_to_cls
from stoma_clip.training.dataset.utils import encode_mlm
print("Stoma-CLIP modules imported successfully.")
sys.stdout.flush() # 强制刷新输出
except ImportError as e:
print(f"FATAL: Error importing Stoma-CLIP modules: {e}")
sys.stdout.flush()
sys.exit(1)
# ========== 2. Model Configuration and Loading ==========
LABEL_MAP = {
"Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2,
"Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6,
"Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10
}
REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
NUM_CLASSES = len(LABEL_MAP)
class Args:
def __init__(self):
self.model = "RN50_fusion4"
self.pretrained = "stoma_clip.pt"
self.num_classes = NUM_CLASSES
self.mlm = True
self.crop_scale = 0.9
self.context_length = 77
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
sys.stdout.flush()
args = Args()
MODEL = None
PREPROCESS = None
TOKENIZER = None
def load_model():
"""Load model once in the main thread during application initialization."""
global MODEL, PREPROCESS, TOKENIZER
start_time = time.time()
if MODEL is not None:
return MODEL, PREPROCESS, TOKENIZER
print(f"--- Starting Model Load Process at {time.strftime('%H:%M:%S')} ---")
sys.stdout.flush() # 诊断点 1
try:
# Step 1: Create model and transforms
print("1. Rescanning model configs and creating architecture...")
sys.stdout.flush() # 诊断点 2
_rescan_model_configs()
model, _, preprocess = pmc_clip.create_model_and_transforms(args)
model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention')
print("2. Model architecture created. Moving to device...")
sys.stdout.flush() # 诊断点 3
# Move model architecture to GPU/CPU
model.to(args.device).eval()
# Step 2: Load weights - 必须确保 stoma_clip.pt 文件大小合理或复制完整
print(f"3. Loading weights from {args.pretrained} to {args.device}...")
sys.stdout.flush() # 诊断点 4 - 关键点:在执行耗时 I/O 前确保日志已输出
# 强制使用 Float32 加载,然后转换为半精度,如果模型支持的话,有助于加速传输
state_dict = torch.load(args.pretrained, map_location=args.device)
print("4. Weights file loaded. Cleaning state dict...")
sys.stdout.flush() # 诊断点 5
state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()}
# Step 3: Apply weights
print("5. Loading state dict into model architecture...")
sys.stdout.flush() # 诊断点 6
model.load_state_dict(state_dict_clean)
# Step 4: Final setup
tokenizer = model.tokenizer
MODEL = model
PREPROCESS = preprocess
TOKENIZER = tokenizer
end_time = time.time()
print(f"✨ Stoma-CLIP Model loaded successfully! Total time: {end_time - start_time:.2f} seconds.")
sys.stdout.flush() # 诊断点 7
return MODEL, PREPROCESS, TOKENIZER
except Exception as e:
print(f"🔥 Error during model loading: {e}")
sys.stdout.flush()
raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}")
# ========== 3. Inference Function ==========
def predict_stoma_clip(image: Image.Image, caption: str):
# 确保在推理时调用加载模型(仅作为后备/懒加载)
try:
# 如果启动时加载失败,这里会再次尝试,但依赖于全局 MODEL 变量
if MODEL is None:
model, preprocess, tokenizer = load_model()
else:
model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER
except RuntimeError:
return "Model Loading Failed (See Logs)", {}
# ... 原来的推理逻辑保持不变 ...
image = image.convert("RGB")
device = args.device
# 将输入数据移动到 GPU
image_tensor = preprocess(image).unsqueeze(0).to(device)
mask_token, pad_token = '[MASK]', '[PAD]'
vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens]
bert_input, bert_label = encode_mlm(
caption=caption,
vocab=vocab,
mask_token=mask_token,
pad_token=pad_token,
ratio=0.0,
tokenizer=tokenizer,
args=args,
)
with torch.no_grad():
inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label}
outputs = model(inputs)
# 将结果移回 CPU 进行 numpy 转换
probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
predicted_class_idx = torch.argmax(outputs, dim=1).item()
predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown")
probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
return predicted_class_name, probability_distribution
# ========== 4. Gradio Interface Setup ==========
image_input = gr.Image(type="pil", label="上传造口图片")
caption_input = gr.Textbox(label="输入造口描述文本 (例如: Exudate, epidermal breakdown, ...)")
predicted_label_output = gr.Textbox(label="预测类别")
prob_output = gr.Label(label="类别概率分布")
# Find example path for Gradio demo (No Change)
try:
example_path_1 = "demo/Irritant_dermatitis.jpg"
example_path_2 = "demo/Folliculitis.jpg"
examples_list = []
if os.path.exists(example_path_1):
examples_list.append(
[example_path_1, "Exudate, epidermal breakdown, irregular erythema, pain, confined to contact areas"])
elif os.path.exists(example_path_2):
examples_list.append([example_path_2, "Erythema, papules, pustules confined to hair follicles"])
except Exception:
examples_list = []
# 忽略 Gradio 的 allow_flagging 警告
iface = gr.Interface(
fn=predict_stoma_clip,
inputs=[image_input, caption_input],
outputs=[predicted_label_output, prob_output],
title="🧪 Stoma-CLIP 分类 API 原型 (Gradio)",
description="请上传造口图片并输入临床描述,模型将预测最可能的皮肤并发症类别。",
examples=examples_list,
allow_flagging="never"
)
if __name__ == "__main__":
# --- 关键修复:强制在 Gradio launch 之前加载模型,将 I/O 阻塞移到启动阶段 ---
print("Pre-loading model before Gradio launch to prevent runtime timeout...")
sys.stdout.flush()
load_model()
print("Model loaded. Launching Gradio interface...")
sys.stdout.flush()
# 启动 Gradio
iface.launch(server_name="0.0.0.0", server_port=7860)
|