Spaces:
Runtime error
Runtime error
[fix] model load dtype
Browse files- pipeline_ace_step.py +2 -2
pipeline_ace_step.py
CHANGED
|
@@ -143,7 +143,7 @@ class ACEStepPipeline:
|
|
| 143 |
self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
|
| 144 |
self.music_dcae.to(device).eval().to(self.dtype)
|
| 145 |
|
| 146 |
-
self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path)
|
| 147 |
self.ace_step_transformer.to(device).eval().to(self.dtype)
|
| 148 |
|
| 149 |
lang_segment = LangSegment()
|
|
@@ -158,7 +158,7 @@ class ACEStepPipeline:
|
|
| 158 |
])
|
| 159 |
self.lang_segment = lang_segment
|
| 160 |
self.lyric_tokenizer = VoiceBpeTokenizer()
|
| 161 |
-
text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path).eval()
|
| 162 |
text_encoder_model = text_encoder_model.to(device).to(self.dtype)
|
| 163 |
text_encoder_model.requires_grad_(False)
|
| 164 |
self.text_encoder_model = text_encoder_model
|
|
|
|
| 143 |
self.music_dcae = MusicDCAE(dcae_checkpoint_path=dcae_checkpoint_path, vocoder_checkpoint_path=vocoder_checkpoint_path)
|
| 144 |
self.music_dcae.to(device).eval().to(self.dtype)
|
| 145 |
|
| 146 |
+
self.ace_step_transformer = ACEStepTransformer2DModel.from_pretrained(ace_step_checkpoint_path, torch_dtype=self.dtype)
|
| 147 |
self.ace_step_transformer.to(device).eval().to(self.dtype)
|
| 148 |
|
| 149 |
lang_segment = LangSegment()
|
|
|
|
| 158 |
])
|
| 159 |
self.lang_segment = lang_segment
|
| 160 |
self.lyric_tokenizer = VoiceBpeTokenizer()
|
| 161 |
+
text_encoder_model = UMT5EncoderModel.from_pretrained(text_encoder_checkpoint_path, torch_dtype=self.dtype).eval()
|
| 162 |
text_encoder_model = text_encoder_model.to(device).to(self.dtype)
|
| 163 |
text_encoder_model.requires_grad_(False)
|
| 164 |
self.text_encoder_model = text_encoder_model
|