Xiaomeng1130 commited on
Commit
607a3d1
·
verified ·
1 Parent(s): 5e49d2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
 
7
  # ========== 1. Import project modules ==========
8
  try:
 
9
  from stoma_clip import pmc_clip
10
  from stoma_clip.pmc_clip.factory import _rescan_model_configs
11
  from stoma_clip.training.fusion_method import convert_model_to_cls
@@ -27,6 +28,8 @@ NUM_CLASSES = len(LABEL_MAP)
27
  class Args:
28
  def __init__(self):
29
  self.model = "RN50_fusion4"
 
 
30
  self.pretrained = "stoma_clip.pt"
31
  self.num_classes = NUM_CLASSES
32
  self.mlm = True
@@ -42,7 +45,7 @@ PREPROCESS = None
42
  TOKENIZER = None
43
 
44
  def load_model():
45
- """Load model once when Gradio starts."""
46
  global MODEL, PREPROCESS, TOKENIZER
47
  if MODEL is not None:
48
  print("Model already loaded. Returning cached objects.")
@@ -60,8 +63,9 @@ def load_model():
60
  # Move model architecture to GPU/CPU
61
  model.to(args.device).eval()
62
 
63
- # Step 2: Load weights - 使用 args.device 确保加载到 GPU (CUDA)
64
  print(f"3. Loading weights from {args.pretrained} to {args.device}...")
 
65
  state_dict = torch.load(args.pretrained, map_location=args.device)
66
 
67
  print("4. Weights file loaded. Cleaning state dict...")
@@ -155,4 +159,8 @@ iface = gr.Interface(
155
  )
156
 
157
  if __name__ == "__main__":
158
- iface.launch()
 
 
 
 
 
6
 
7
  # ========== 1. Import project modules ==========
8
  try:
9
+ # 尝试导入 stoma_clip 模块(通过 requirements.txt 中的 -e . 安装)
10
  from stoma_clip import pmc_clip
11
  from stoma_clip.pmc_clip.factory import _rescan_model_configs
12
  from stoma_clip.training.fusion_method import convert_model_to_cls
 
28
  class Args:
29
  def __init__(self):
30
  self.model = "RN50_fusion4"
31
+ # 假设 stoma_clip.pt 文件位于应用的根目录(/app),或被您的内部库识别。
32
+ # 确保这个文件是正确的文件名。
33
  self.pretrained = "stoma_clip.pt"
34
  self.num_classes = NUM_CLASSES
35
  self.mlm = True
 
45
  TOKENIZER = None
46
 
47
  def load_model():
48
+ """Load model once when Gradio starts, implementing the singleton pattern."""
49
  global MODEL, PREPROCESS, TOKENIZER
50
  if MODEL is not None:
51
  print("Model already loaded. Returning cached objects.")
 
63
  # Move model architecture to GPU/CPU
64
  model.to(args.device).eval()
65
 
66
+ # Step 2: Load weights - 使用 map_location 确保加载到正确的设备
67
  print(f"3. Loading weights from {args.pretrained} to {args.device}...")
68
+ # 这里的 torch.load 必须依赖于 Dockerfile 预下载或 COPY 进来的文件
69
  state_dict = torch.load(args.pretrained, map_location=args.device)
70
 
71
  print("4. Weights file loaded. Cleaning state dict...")
 
159
  )
160
 
161
  if __name__ == "__main__":
162
+ # 在应用启动时尝试加载模型,如果失败,launch 会抛出异常
163
+ # load_model() # 在 iface.launch() 内部通常会自动触发模型加载,但显式调用可以捕获启动错误
164
+
165
+ # T4 / Docker 环境下使用 0.0.0.0 和默认端口
166
+ iface.launch(server_name="0.0.0.0", server_port=7860)