Spaces:
Runtime error
Runtime error
Fix loading without cuda
Browse files
marble.py
CHANGED
|
@@ -59,7 +59,10 @@ def setup_control_mlp(
|
|
| 59 |
|
| 60 |
net = control_mlp(features)
|
| 61 |
net.load_state_dict(
|
| 62 |
-
torch.load(
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
net.to(device, dtype=dtype)
|
| 65 |
net.eval()
|
|
|
|
| 59 |
|
| 60 |
net = control_mlp(features)
|
| 61 |
net.load_state_dict(
|
| 62 |
+
torch.load(
|
| 63 |
+
os.path.join(file_dir, f"model_weights/{material_parameter}.pt"),
|
| 64 |
+
map_location=device
|
| 65 |
+
)
|
| 66 |
)
|
| 67 |
net.to(device, dtype=dtype)
|
| 68 |
net.eval()
|