map location=cpu
Browse files- modules.py +1 -1
modules.py
CHANGED
|
@@ -381,7 +381,7 @@ class CLAP(nn.Module):
|
|
| 381 |
|
| 382 |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 383 |
|
| 384 |
-
state_dict = torch.load(ckpt_path)["model"]
|
| 385 |
self.load_state_dict(self.clean_state_dict(state_dict))
|
| 386 |
print("Loaded pretrained CLAP checkpoint.")
|
| 387 |
|
|
|
|
| 381 |
|
| 382 |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 383 |
|
| 384 |
+
state_dict = torch.load(ckpt_path, map_location="cpu")["model"]
|
| 385 |
self.load_state_dict(self.clean_state_dict(state_dict))
|
| 386 |
print("Loaded pretrained CLAP checkpoint.")
|
| 387 |
|