playmak3r commited on
Commit
490c02f
·
verified ·
1 Parent(s): c612a94

fix: load model state with device mapping

Browse files
Files changed (1) hide show
  1. src/chatterbox/mtl_tts.py +1 -1
src/chatterbox/mtl_tts.py CHANGED
@@ -176,7 +176,7 @@ class ChatterboxMultilingualTTS:
176
 
177
  s3gen = S3Gen()
178
  s3gen.load_state_dict(
179
- torch.load(ckpt_dir / "s3gen.pt", weights_only=True)
180
  )
181
  s3gen.to(device).eval()
182
 
 
176
 
177
  s3gen = S3Gen()
178
  s3gen.load_state_dict(
179
+ torch.load(ckpt_dir / "s3gen.pt", weights_only=True, map_location=torch.device(device))
180
  )
181
  s3gen.to(device).eval()
182