import os import torch import gradio as gr from PIL import Image import numpy as np import sys import time # ========== 1. Import project modules and Model Configuration ========== try: from stoma_clip import pmc_clip from stoma_clip.pmc_clip.factory import _rescan_model_configs from stoma_clip.training.fusion_method import convert_model_to_cls from stoma_clip.training.dataset.utils import encode_mlm print("Stoma-CLIP modules imported successfully.") sys.stdout.flush() # 强制刷新输出 except ImportError as e: print(f"FATAL: Error importing Stoma-CLIP modules: {e}") sys.stdout.flush() sys.exit(1) # ========== 2. Model Configuration and Loading ========== LABEL_MAP = { "Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2, "Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6, "Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10 } REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()} NUM_CLASSES = len(LABEL_MAP) class Args: def __init__(self): self.model = "RN50_fusion4" self.pretrained = "stoma_clip.pt" self.num_classes = NUM_CLASSES self.mlm = True self.crop_scale = 0.9 self.context_length = 77 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") sys.stdout.flush() args = Args() MODEL = None PREPROCESS = None TOKENIZER = None def load_model(): """Load model once in the main thread during application initialization.""" global MODEL, PREPROCESS, TOKENIZER start_time = time.time() if MODEL is not None: return MODEL, PREPROCESS, TOKENIZER print(f"--- Starting Model Load Process at {time.strftime('%H:%M:%S')} ---") sys.stdout.flush() # 诊断点 1 try: # Step 1: Create model and transforms print("1. Rescanning model configs and creating architecture...") sys.stdout.flush() # 诊断点 2 _rescan_model_configs() model, _, preprocess = pmc_clip.create_model_and_transforms(args) model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention') print("2. Model architecture created. Moving to device...") sys.stdout.flush() # 诊断点 3 # Move model architecture to GPU/CPU model.to(args.device).eval() # Step 2: Load weights - 必须确保 stoma_clip.pt 文件大小合理或复制完整 print(f"3. Loading weights from {args.pretrained} to {args.device}...") sys.stdout.flush() # 诊断点 4 - 关键点:在执行耗时 I/O 前确保日志已输出 # 强制使用 Float32 加载,然后转换为半精度,如果模型支持的话,有助于加速传输 state_dict = torch.load(args.pretrained, map_location=args.device) print("4. Weights file loaded. Cleaning state dict...") sys.stdout.flush() # 诊断点 5 state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()} # Step 3: Apply weights print("5. Loading state dict into model architecture...") sys.stdout.flush() # 诊断点 6 model.load_state_dict(state_dict_clean) # Step 4: Final setup tokenizer = model.tokenizer MODEL = model PREPROCESS = preprocess TOKENIZER = tokenizer end_time = time.time() print(f"✨ Stoma-CLIP Model loaded successfully! Total time: {end_time - start_time:.2f} seconds.") sys.stdout.flush() # 诊断点 7 return MODEL, PREPROCESS, TOKENIZER except Exception as e: print(f"🔥 Error during model loading: {e}") sys.stdout.flush() raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}") # ========== 3. Inference Function ========== def predict_stoma_clip(image: Image.Image, caption: str): # 确保在推理时调用加载模型(仅作为后备/懒加载) try: # 如果启动时加载失败,这里会再次尝试,但依赖于全局 MODEL 变量 if MODEL is None: model, preprocess, tokenizer = load_model() else: model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER except RuntimeError: return "Model Loading Failed (See Logs)", {} # ... 原来的推理逻辑保持不变 ... image = image.convert("RGB") device = args.device # 将输入数据移动到 GPU image_tensor = preprocess(image).unsqueeze(0).to(device) mask_token, pad_token = '[MASK]', '[PAD]' vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens] bert_input, bert_label = encode_mlm( caption=caption, vocab=vocab, mask_token=mask_token, pad_token=pad_token, ratio=0.0, tokenizer=tokenizer, args=args, ) with torch.no_grad(): inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label} outputs = model(inputs) # 将结果移回 CPU 进行 numpy 转换 probs = torch.softmax(outputs, dim=1).cpu().numpy()[0] predicted_class_idx = torch.argmax(outputs, dim=1).item() predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown") probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)} return predicted_class_name, probability_distribution # ========== 4. Gradio Interface Setup ========== image_input = gr.Image(type="pil", label="上传造口图片") caption_input = gr.Textbox(label="输入造口描述文本 (例如: Exudate, epidermal breakdown, ...)") predicted_label_output = gr.Textbox(label="预测类别") prob_output = gr.Label(label="类别概率分布") # Find example path for Gradio demo (No Change) try: example_path_1 = "demo/Irritant_dermatitis.jpg" example_path_2 = "demo/Folliculitis.jpg" examples_list = [] if os.path.exists(example_path_1): examples_list.append( [example_path_1, "Exudate, epidermal breakdown, irregular erythema, pain, confined to contact areas"]) elif os.path.exists(example_path_2): examples_list.append([example_path_2, "Erythema, papules, pustules confined to hair follicles"]) except Exception: examples_list = [] # 忽略 Gradio 的 allow_flagging 警告 iface = gr.Interface( fn=predict_stoma_clip, inputs=[image_input, caption_input], outputs=[predicted_label_output, prob_output], title="🧪 Stoma-CLIP 分类 API 原型 (Gradio)", description="请上传造口图片并输入临床描述,模型将预测最可能的皮肤并发症类别。", examples=examples_list, allow_flagging="never" ) if __name__ == "__main__": # --- 关键修复:强制在 Gradio launch 之前加载模型,将 I/O 阻塞移到启动阶段 --- print("Pre-loading model before Gradio launch to prevent runtime timeout...") sys.stdout.flush() load_model() print("Model loaded. Launching Gradio interface...") sys.stdout.flush() # 启动 Gradio iface.launch(server_name="0.0.0.0", server_port=7860)