IFMedTechdemo commited on
Commit
1338ad7
·
verified ·
1 Parent(s): 3cc5be7

Fix UNet architecture: add norm='batch' parameter and use strict=False for state_dict loading

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -29,7 +29,7 @@ def load_model():
29
  filename="model.pt"
30
  )
31
 
32
- # Initialize UNet architecture
33
  model = UNet(
34
  spatial_dims=3,
35
  in_channels=1,
@@ -37,11 +37,12 @@ def load_model():
37
  channels=(16, 32, 64, 128, 256),
38
  strides=(2, 2, 2, 2),
39
  num_res_units=2,
 
40
  )
41
 
42
- # Load weights
43
  checkpoint = torch.load(model_path, map_location=device)
44
- model.load_state_dict(checkpoint)
45
  model.to(device)
46
  model.eval()
47
  return model
@@ -105,7 +106,7 @@ def segment_spleen(input_file):
105
  overlay[seg_slice == 1] = [255, 0, 0] # Red for spleen
106
 
107
  return overlay, output_path, "Segmentation completed successfully!"
108
-
109
  except Exception as e:
110
  return None, None, f"Error: {str(e)}"
111
 
 
29
  filename="model.pt"
30
  )
31
 
32
+ # Initialize UNet architecture with exact parameters from inference.json
33
  model = UNet(
34
  spatial_dims=3,
35
  in_channels=1,
 
37
  channels=(16, 32, 64, 128, 256),
38
  strides=(2, 2, 2, 2),
39
  num_res_units=2,
40
+ norm="batch", # Added: batch normalization as specified in inference.json
41
  )
42
 
43
+ # Load weights with strict=False to handle minor key mismatches
44
  checkpoint = torch.load(model_path, map_location=device)
45
+ model.load_state_dict(checkpoint, strict=False)
46
  model.to(device)
47
  model.eval()
48
  return model
 
106
  overlay[seg_slice == 1] = [255, 0, 0] # Red for spleen
107
 
108
  return overlay, output_path, "Segmentation completed successfully!"
109
+
110
  except Exception as e:
111
  return None, None, f"Error: {str(e)}"
112