File size: 7,440 Bytes
3d67bd4
 
 
 
 
75059a8
 
3d67bd4
75059a8
3d67bd4
 
 
 
 
 
75059a8
3d67bd4
75059a8
 
 
3d67bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75059a8
3d67bd4
 
 
 
 
 
 
75059a8
3d67bd4
75059a8
 
3d67bd4
 
 
75059a8
 
 
3d67bd4
 
75059a8
 
 
3d67bd4
 
 
 
75059a8
3d67bd4
 
 
 
75059a8
3d67bd4
75059a8
 
 
3d67bd4
 
 
75059a8
 
3d67bd4
 
 
 
75059a8
 
3d67bd4
 
 
 
 
 
 
 
75059a8
 
 
 
3d67bd4
 
 
 
75059a8
3d67bd4
 
 
 
75059a8
3d67bd4
75059a8
 
 
 
 
 
3d67bd4
 
75059a8
 
3d67bd4
 
75059a8
3d67bd4
 
75059a8
3d67bd4
 
75059a8
3d67bd4
 
 
 
 
 
 
 
 
75059a8
3d67bd4
 
 
 
 
 
75059a8
3d67bd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75059a8
 
 
607a3d1
75059a8
 
 
 
 
 
607a3d1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import os
import torch
import gradio as gr
from PIL import Image
import numpy as np
import sys
import time

# ========== 1. Import project modules and Model Configuration ==========
try:
    from stoma_clip import pmc_clip
    from stoma_clip.pmc_clip.factory import _rescan_model_configs
    from stoma_clip.training.fusion_method import convert_model_to_cls
    from stoma_clip.training.dataset.utils import encode_mlm
    print("Stoma-CLIP modules imported successfully.")
    sys.stdout.flush() # 强制刷新输出
except ImportError as e:
    print(f"FATAL: Error importing Stoma-CLIP modules: {e}")
    sys.stdout.flush()
    sys.exit(1)

# ========== 2. Model Configuration and Loading ==========
LABEL_MAP = {
    "Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2,
    "Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6,
    "Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10
}
REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
NUM_CLASSES = len(LABEL_MAP)

class Args:
    def __init__(self):
        self.model = "RN50_fusion4"
        self.pretrained = "stoma_clip.pt"
        self.num_classes = NUM_CLASSES
        self.mlm = True
        self.crop_scale = 0.9
        self.context_length = 77
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        sys.stdout.flush() 
args = Args()

MODEL = None
PREPROCESS = None
TOKENIZER = None

def load_model():
    """Load model once in the main thread during application initialization."""
    global MODEL, PREPROCESS, TOKENIZER
    
    start_time = time.time()
    if MODEL is not None:
        return MODEL, PREPROCESS, TOKENIZER
        
    print(f"--- Starting Model Load Process at {time.strftime('%H:%M:%S')} ---")
    sys.stdout.flush() # 诊断点 1
    
    try:
        # Step 1: Create model and transforms
        print("1. Rescanning model configs and creating architecture...")
        sys.stdout.flush() # 诊断点 2
        
        _rescan_model_configs()
        model, _, preprocess = pmc_clip.create_model_and_transforms(args)
        model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention')
        print("2. Model architecture created. Moving to device...")
        sys.stdout.flush() # 诊断点 3
        
        # Move model architecture to GPU/CPU
        model.to(args.device).eval()
        
        # Step 2: Load weights - 必须确保 stoma_clip.pt 文件大小合理或复制完整
        print(f"3. Loading weights from {args.pretrained} to {args.device}...")
        sys.stdout.flush() # 诊断点 4 - 关键点:在执行耗时 I/O 前确保日志已输出
        
        # 强制使用 Float32 加载,然后转换为半精度,如果模型支持的话,有助于加速传输
        state_dict = torch.load(args.pretrained, map_location=args.device) 
        
        print("4. Weights file loaded. Cleaning state dict...")
        sys.stdout.flush() # 诊断点 5
        
        state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()}
        
        # Step 3: Apply weights
        print("5. Loading state dict into model architecture...")
        sys.stdout.flush() # 诊断点 6
        
        model.load_state_dict(state_dict_clean)
        
        # Step 4: Final setup
        tokenizer = model.tokenizer
        MODEL = model
        PREPROCESS = preprocess
        TOKENIZER = tokenizer
        
        end_time = time.time()
        print(f"✨ Stoma-CLIP Model loaded successfully! Total time: {end_time - start_time:.2f} seconds.")
        sys.stdout.flush() # 诊断点 7
        
        return MODEL, PREPROCESS, TOKENIZER
        
    except Exception as e:
        print(f"🔥 Error during model loading: {e}")
        sys.stdout.flush()
        raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}")

# ========== 3. Inference Function ==========
def predict_stoma_clip(image: Image.Image, caption: str):
    # 确保在推理时调用加载模型(仅作为后备/懒加载)
    try:
        # 如果启动时加载失败,这里会再次尝试,但依赖于全局 MODEL 变量
        if MODEL is None:
             model, preprocess, tokenizer = load_model()
        else:
             model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER

    except RuntimeError:
        return "Model Loading Failed (See Logs)", {}
    
    # ... 原来的推理逻辑保持不变 ...
    image = image.convert("RGB")
    device = args.device
        
    # 将输入数据移动到 GPU
    image_tensor = preprocess(image).unsqueeze(0).to(device)
        
    mask_token, pad_token = '[MASK]', '[PAD]'
    vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens]
        
    bert_input, bert_label = encode_mlm(
        caption=caption,
        vocab=vocab,
        mask_token=mask_token,
        pad_token=pad_token,
        ratio=0.0,
        tokenizer=tokenizer,
        args=args,
    )
        
    with torch.no_grad():
        inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label}
        outputs = model(inputs)
        # 将结果移回 CPU 进行 numpy 转换
        probs = torch.softmax(outputs, dim=1).cpu().numpy()[0]
        predicted_class_idx = torch.argmax(outputs, dim=1).item()
            
    predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown")
    probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)}
    return predicted_class_name, probability_distribution

# ========== 4. Gradio Interface Setup ==========
image_input = gr.Image(type="pil", label="上传造口图片")
caption_input = gr.Textbox(label="输入造口描述文本 (例如: Exudate, epidermal breakdown, ...)")
predicted_label_output = gr.Textbox(label="预测类别")
prob_output = gr.Label(label="类别概率分布")

# Find example path for Gradio demo (No Change)
try:
    example_path_1 = "demo/Irritant_dermatitis.jpg"
    example_path_2 = "demo/Folliculitis.jpg"
    examples_list = []
    if os.path.exists(example_path_1):
        examples_list.append(
            [example_path_1, "Exudate, epidermal breakdown, irregular erythema, pain, confined to contact areas"])
    elif os.path.exists(example_path_2):
        examples_list.append([example_path_2, "Erythema, papules, pustules confined to hair follicles"])
except Exception:
    examples_list = []

# 忽略 Gradio 的 allow_flagging 警告
iface = gr.Interface(
    fn=predict_stoma_clip,
    inputs=[image_input, caption_input],
    outputs=[predicted_label_output, prob_output],
    title="🧪 Stoma-CLIP 分类 API 原型 (Gradio)",
    description="请上传造口图片并输入临床描述,模型将预测最可能的皮肤并发症类别。",
    examples=examples_list,
    allow_flagging="never"
)

if __name__ == "__main__":
    # --- 关键修复:强制在 Gradio launch 之前加载模型,将 I/O 阻塞移到启动阶段 ---
    print("Pre-loading model before Gradio launch to prevent runtime timeout...")
    sys.stdout.flush()
    
    load_model() 
    
    print("Model loaded. Launching Gradio interface...")
    sys.stdout.flush()

    # 启动 Gradio
    iface.launch(server_name="0.0.0.0", server_port=7860)