Commit
·
b4d4d63
1
Parent(s):
3dc00eb
Update app.py
Browse files
app.py
CHANGED
|
@@ -155,6 +155,14 @@ def generate_audio(text_prompt, audio_length_in_s=10.0, play_steps_in_s=2.0):
|
|
| 155 |
max_new_tokens = int(frame_rate * audio_length_in_s)
|
| 156 |
play_steps = int(frame_rate * play_steps_in_s)
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
|
| 159 |
|
| 160 |
generation_kwargs = dict(
|
|
|
|
| 155 |
max_new_tokens = int(frame_rate * audio_length_in_s)
|
| 156 |
play_steps = int(frame_rate * play_steps_in_s)
|
| 157 |
|
| 158 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 159 |
+
|
| 160 |
+
if device != model.device:
|
| 161 |
+
model.to(device)
|
| 162 |
+
|
| 163 |
+
if device == "cuda:0":
|
| 164 |
+
model.to(device).half();
|
| 165 |
+
|
| 166 |
streamer = MusicgenStreamer(model, device=device, play_steps=play_steps)
|
| 167 |
|
| 168 |
generation_kwargs = dict(
|