Adding app.py for CPU inference
Browse files
app.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import wavio
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from huggingface_hub import snapshot_download
|
| 7 |
+
from models import AudioDiffusion, DDPMScheduler
|
| 8 |
+
from audioldm.audio.stft import TacotronSTFT
|
| 9 |
+
from audioldm.variational_autoencoder import AutoencoderKL
|
| 10 |
+
from gradio import Markdown
|
| 11 |
+
|
| 12 |
+
class Tango:
|
| 13 |
+
def __init__(self, name="declare-lab/tango", device="cpu"):
|
| 14 |
+
|
| 15 |
+
path = snapshot_download(repo_id=name)
|
| 16 |
+
|
| 17 |
+
vae_config = json.load(open("{}/vae_config.json".format(path)))
|
| 18 |
+
stft_config = json.load(open("{}/stft_config.json".format(path)))
|
| 19 |
+
main_config = json.load(open("{}/main_config.json".format(path)))
|
| 20 |
+
|
| 21 |
+
self.vae = AutoencoderKL(**vae_config).to(device)
|
| 22 |
+
self.stft = TacotronSTFT(**stft_config).to(device)
|
| 23 |
+
self.model = AudioDiffusion(**main_config).to(device)
|
| 24 |
+
|
| 25 |
+
vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
|
| 26 |
+
stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
|
| 27 |
+
main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
|
| 28 |
+
|
| 29 |
+
self.vae.load_state_dict(vae_weights)
|
| 30 |
+
self.stft.load_state_dict(stft_weights)
|
| 31 |
+
self.model.load_state_dict(main_weights)
|
| 32 |
+
|
| 33 |
+
print ("Successfully loaded checkpoint from:", name)
|
| 34 |
+
|
| 35 |
+
self.vae.eval()
|
| 36 |
+
self.stft.eval()
|
| 37 |
+
self.model.eval()
|
| 38 |
+
|
| 39 |
+
self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
|
| 40 |
+
|
| 41 |
+
def chunks(self, lst, n):
|
| 42 |
+
""" Yield successive n-sized chunks from a list. """
|
| 43 |
+
for i in range(0, len(lst), n):
|
| 44 |
+
yield lst[i:i + n]
|
| 45 |
+
|
| 46 |
+
def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
|
| 47 |
+
""" Genrate audio for a single prompt string. """
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
|
| 50 |
+
mel = self.vae.decode_first_stage(latents)
|
| 51 |
+
wave = self.vae.decode_to_waveform(mel)
|
| 52 |
+
return wave[0]
|
| 53 |
+
|
| 54 |
+
def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
|
| 55 |
+
""" Genrate audio for a list of prompt strings. """
|
| 56 |
+
outputs = []
|
| 57 |
+
for k in tqdm(range(0, len(prompts), batch_size)):
|
| 58 |
+
batch = prompts[k: k+batch_size]
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
|
| 61 |
+
mel = self.vae.decode_first_stage(latents)
|
| 62 |
+
wave = self.vae.decode_to_waveform(mel)
|
| 63 |
+
outputs += [item for item in wave]
|
| 64 |
+
if samples == 1:
|
| 65 |
+
return outputs
|
| 66 |
+
else:
|
| 67 |
+
return list(self.chunks(outputs, samples))
|
| 68 |
+
|
| 69 |
+
# Initialize Tango model
|
| 70 |
+
tango = Tango()
|
| 71 |
+
|
| 72 |
+
def gradio_generate(prompt):
|
| 73 |
+
|
| 74 |
+
output_wave = tango.generate(prompt)
|
| 75 |
+
|
| 76 |
+
# Save the output_wave as a temporary WAV file
|
| 77 |
+
output_filename = "temp_output.wav"
|
| 78 |
+
wavio.write(output_filename, output_wave, rate=22050, sampwidth=2)
|
| 79 |
+
|
| 80 |
+
return output_filename
|
| 81 |
+
|
| 82 |
+
# Add the description text box
|
| 83 |
+
description_text = '''
|
| 84 |
+
TANGO is a latent diffusion model (LDM) for text-to-audio (TTA) generation. TANGO can generate realistic audios including human sounds, animal sounds, natural and artificial sounds and sound effects from textual prompts. We use the frozen instruction-tuned LLM Flan-T5 as the text encoder and train a UNet based diffusion model for audio generation. We perform comparably to current state-of-the-art models for TTA across both objective and subjective metrics, despite training the LDM on a 63 times smaller dataset. We release our model, training, inference code, and pre-trained checkpoints for the research community.
|
| 85 |
+
'''
|
| 86 |
+
|
| 87 |
+
# Define Gradio input and output components
|
| 88 |
+
input_text = gr.inputs.Textbox(lines=2, label="Prompt")
|
| 89 |
+
output_audio = gr.outputs.Audio(label="Generated Audio", type="filepath")
|
| 90 |
+
|
| 91 |
+
# Create Gradio interface
|
| 92 |
+
gr_interface = gr.Interface(
|
| 93 |
+
fn=gradio_generate,
|
| 94 |
+
inputs=input_text,
|
| 95 |
+
outputs=[output_audio],
|
| 96 |
+
title="Tango Audio Generator",
|
| 97 |
+
description="Generate audio using Tango model by providing a text prompt.",
|
| 98 |
+
allow_flagging=False,
|
| 99 |
+
examples=[
|
| 100 |
+
["A Dog Barking"],
|
| 101 |
+
["A loud thunderstorm"],
|
| 102 |
+
],
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Launch Gradio app
|
| 106 |
+
gr_interface.launch()
|