Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import numpy as np | |
| import sys | |
| import time | |
| # ========== 1. Import project modules and Model Configuration ========== | |
| try: | |
| from stoma_clip import pmc_clip | |
| from stoma_clip.pmc_clip.factory import _rescan_model_configs | |
| from stoma_clip.training.fusion_method import convert_model_to_cls | |
| from stoma_clip.training.dataset.utils import encode_mlm | |
| print("Stoma-CLIP modules imported successfully.") | |
| sys.stdout.flush() # 强制刷新输出 | |
| except ImportError as e: | |
| print(f"FATAL: Error importing Stoma-CLIP modules: {e}") | |
| sys.stdout.flush() | |
| sys.exit(1) | |
| # ========== 2. Model Configuration and Loading ========== | |
| LABEL_MAP = { | |
| "Irritant dermatitis": 0, "Allergic contact dermatitis": 1, "Mechanical injury": 2, | |
| "Folliculitis": 3, "Fungal infection": 4, "Skin hyperplasia": 5, "Parastomal varices": 6, | |
| "Urate crystals": 7, "Cancerous metastasis": 8, "Pyoderma gangrenosum": 9, "Normal": 10 | |
| } | |
| REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()} | |
| NUM_CLASSES = len(LABEL_MAP) | |
| class Args: | |
| def __init__(self): | |
| self.model = "RN50_fusion4" | |
| self.pretrained = "stoma_clip.pt" | |
| self.num_classes = NUM_CLASSES | |
| self.mlm = True | |
| self.crop_scale = 0.9 | |
| self.context_length = 77 | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| sys.stdout.flush() | |
| args = Args() | |
| MODEL = None | |
| PREPROCESS = None | |
| TOKENIZER = None | |
| def load_model(): | |
| """Load model once in the main thread during application initialization.""" | |
| global MODEL, PREPROCESS, TOKENIZER | |
| start_time = time.time() | |
| if MODEL is not None: | |
| return MODEL, PREPROCESS, TOKENIZER | |
| print(f"--- Starting Model Load Process at {time.strftime('%H:%M:%S')} ---") | |
| sys.stdout.flush() # 诊断点 1 | |
| try: | |
| # Step 1: Create model and transforms | |
| print("1. Rescanning model configs and creating architecture...") | |
| sys.stdout.flush() # 诊断点 2 | |
| _rescan_model_configs() | |
| model, _, preprocess = pmc_clip.create_model_and_transforms(args) | |
| model = convert_model_to_cls(model, num_classes=args.num_classes, fusion_method='cross_attention') | |
| print("2. Model architecture created. Moving to device...") | |
| sys.stdout.flush() # 诊断点 3 | |
| # Move model architecture to GPU/CPU | |
| model.to(args.device).eval() | |
| # Step 2: Load weights - 必须确保 stoma_clip.pt 文件大小合理或复制完整 | |
| print(f"3. Loading weights from {args.pretrained} to {args.device}...") | |
| sys.stdout.flush() # 诊断点 4 - 关键点:在执行耗时 I/O 前确保日志已输出 | |
| # 强制使用 Float32 加载,然后转换为半精度,如果模型支持的话,有助于加速传输 | |
| state_dict = torch.load(args.pretrained, map_location=args.device) | |
| print("4. Weights file loaded. Cleaning state dict...") | |
| sys.stdout.flush() # 诊断点 5 | |
| state_dict_clean = {k.replace("module.", "", 1): v for k, v in state_dict['state_dict'].items()} | |
| # Step 3: Apply weights | |
| print("5. Loading state dict into model architecture...") | |
| sys.stdout.flush() # 诊断点 6 | |
| model.load_state_dict(state_dict_clean) | |
| # Step 4: Final setup | |
| tokenizer = model.tokenizer | |
| MODEL = model | |
| PREPROCESS = preprocess | |
| TOKENIZER = tokenizer | |
| end_time = time.time() | |
| print(f"✨ Stoma-CLIP Model loaded successfully! Total time: {end_time - start_time:.2f} seconds.") | |
| sys.stdout.flush() # 诊断点 7 | |
| return MODEL, PREPROCESS, TOKENIZER | |
| except Exception as e: | |
| print(f"🔥 Error during model loading: {e}") | |
| sys.stdout.flush() | |
| raise RuntimeError(f"Failed to load Stoma-CLIP model: {e}") | |
| # ========== 3. Inference Function ========== | |
| def predict_stoma_clip(image: Image.Image, caption: str): | |
| # 确保在推理时调用加载模型(仅作为后备/懒加载) | |
| try: | |
| # 如果启动时加载失败,这里会再次尝试,但依赖于全局 MODEL 变量 | |
| if MODEL is None: | |
| model, preprocess, tokenizer = load_model() | |
| else: | |
| model, preprocess, tokenizer = MODEL, PREPROCESS, TOKENIZER | |
| except RuntimeError: | |
| return "Model Loading Failed (See Logs)", {} | |
| # ... 原来的推理逻辑保持不变 ... | |
| image = image.convert("RGB") | |
| device = args.device | |
| # 将输入数据移动到 GPU | |
| image_tensor = preprocess(image).unsqueeze(0).to(device) | |
| mask_token, pad_token = '[MASK]', '[PAD]' | |
| vocab = [v for v in tokenizer.get_vocab().keys() if v not in tokenizer.all_special_tokens] | |
| bert_input, bert_label = encode_mlm( | |
| caption=caption, | |
| vocab=vocab, | |
| mask_token=mask_token, | |
| pad_token=pad_token, | |
| ratio=0.0, | |
| tokenizer=tokenizer, | |
| args=args, | |
| ) | |
| with torch.no_grad(): | |
| inputs = {"images": image_tensor, "bert_input": bert_input, "bert_label": bert_label} | |
| outputs = model(inputs) | |
| # 将结果移回 CPU 进行 numpy 转换 | |
| probs = torch.softmax(outputs, dim=1).cpu().numpy()[0] | |
| predicted_class_idx = torch.argmax(outputs, dim=1).item() | |
| predicted_class_name = REVERSE_LABEL_MAP.get(predicted_class_idx, "Unknown") | |
| probability_distribution = {REVERSE_LABEL_MAP[i]: float(p) for i, p in enumerate(probs)} | |
| return predicted_class_name, probability_distribution | |
| # ========== 4. Gradio Interface Setup ========== | |
| image_input = gr.Image(type="pil", label="上传造口图片") | |
| caption_input = gr.Textbox(label="输入造口描述文本 (例如: Exudate, epidermal breakdown, ...)") | |
| predicted_label_output = gr.Textbox(label="预测类别") | |
| prob_output = gr.Label(label="类别概率分布") | |
| # Find example path for Gradio demo (No Change) | |
| try: | |
| example_path_1 = "demo/Irritant_dermatitis.jpg" | |
| example_path_2 = "demo/Folliculitis.jpg" | |
| examples_list = [] | |
| if os.path.exists(example_path_1): | |
| examples_list.append( | |
| [example_path_1, "Exudate, epidermal breakdown, irregular erythema, pain, confined to contact areas"]) | |
| elif os.path.exists(example_path_2): | |
| examples_list.append([example_path_2, "Erythema, papules, pustules confined to hair follicles"]) | |
| except Exception: | |
| examples_list = [] | |
| # 忽略 Gradio 的 allow_flagging 警告 | |
| iface = gr.Interface( | |
| fn=predict_stoma_clip, | |
| inputs=[image_input, caption_input], | |
| outputs=[predicted_label_output, prob_output], | |
| title="🧪 Stoma-CLIP 分类 API 原型 (Gradio)", | |
| description="请上传造口图片并输入临床描述,模型将预测最可能的皮肤并发症类别。", | |
| examples=examples_list, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| # --- 关键修复:强制在 Gradio launch 之前加载模型,将 I/O 阻塞移到启动阶段 --- | |
| print("Pre-loading model before Gradio launch to prevent runtime timeout...") | |
| sys.stdout.flush() | |
| load_model() | |
| print("Model loaded. Launching Gradio interface...") | |
| sys.stdout.flush() | |
| # 启动 Gradio | |
| iface.launch(server_name="0.0.0.0", server_port=7860) | |