fix: load model state with device mapping
Browse files
src/chatterbox/mtl_tts.py
CHANGED
|
@@ -176,7 +176,7 @@ class ChatterboxMultilingualTTS:
|
|
| 176 |
|
| 177 |
s3gen = S3Gen()
|
| 178 |
s3gen.load_state_dict(
|
| 179 |
-
torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
|
| 180 |
)
|
| 181 |
s3gen.to(device).eval()
|
| 182 |
|
|
|
|
| 176 |
|
| 177 |
s3gen = S3Gen()
|
| 178 |
s3gen.load_state_dict(
|
| 179 |
+
torch.load(ckpt_dir / "s3gen.pt", weights_only=True, map_location=torch.device(device))
|
| 180 |
)
|
| 181 |
s3gen.to(device).eval()
|
| 182 |
|