Update models/unet.py
Browse files- models/unet.py +1 -1
models/unet.py
CHANGED
|
@@ -25,7 +25,7 @@ class CustomLayerNorm(nn.LayerNorm):
|
|
| 25 |
def replace_layer_norm(model):
|
| 26 |
for name, module in model.named_children():
|
| 27 |
if isinstance(module, nn.LayerNorm):
|
| 28 |
-
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine)
|
| 29 |
else:
|
| 30 |
replace_layer_norm(module) # Recursively apply to all submodules
|
| 31 |
|
|
|
|
| 25 |
def replace_layer_norm(model):
|
| 26 |
for name, module in model.named_children():
|
| 27 |
if isinstance(module, nn.LayerNorm):
|
| 28 |
+
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine).cuda())
|
| 29 |
else:
|
| 30 |
replace_layer_norm(module) # Recursively apply to all submodules
|
| 31 |
|