Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Update audio_foundation_models.py
Browse files- audio_foundation_models.py +418 -192
    	
        audio_foundation_models.py
    CHANGED
    
    | @@ -1,5 +1,6 @@ | |
| 1 | 
             
            import sys
         | 
| 2 | 
             
            import os
         | 
|  | |
| 3 | 
             
            sys.path.append(os.path.dirname(os.path.realpath(__file__)))
         | 
| 4 | 
             
            sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
         | 
| 5 | 
             
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
         | 
| @@ -53,6 +54,7 @@ from target_sound_detection.src.models import event_labels | |
| 53 | 
             
            from target_sound_detection.src.utils import median_filter, decode_with_timestamps
         | 
| 54 | 
             
            import clip
         | 
| 55 |  | 
|  | |
| 56 | 
             
            def prompts(name, description):
         | 
| 57 | 
             
                def decorator(func):
         | 
| 58 | 
             
                    func.name = name
         | 
| @@ -61,10 +63,11 @@ def prompts(name, description): | |
| 61 |  | 
| 62 | 
             
                return decorator
         | 
| 63 |  | 
|  | |
| 64 | 
             
            def initialize_model(config, ckpt, device):
         | 
| 65 | 
             
                config = OmegaConf.load(config)
         | 
| 66 | 
             
                model = instantiate_from_config(config.model)
         | 
| 67 | 
            -
                model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
         | 
| 68 |  | 
| 69 | 
             
                model = model.to(device)
         | 
| 70 | 
             
                model.cond_stage_model.to(model.device)
         | 
| @@ -72,29 +75,48 @@ def initialize_model(config, ckpt, device): | |
| 72 | 
             
                sampler = DDIMSampler(model)
         | 
| 73 | 
             
                return sampler
         | 
| 74 |  | 
|  | |
| 75 | 
             
            def initialize_model_inpaint(config, ckpt):
         | 
| 76 | 
             
                config = OmegaConf.load(config)
         | 
| 77 | 
             
                model = instantiate_from_config(config.model)
         | 
| 78 | 
            -
                model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
         | 
| 79 | 
             
                device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         | 
| 80 | 
             
                model = model.to(device)
         | 
| 81 | 
            -
                print(model.device,device,model.cond_stage_model.device)
         | 
| 82 | 
             
                sampler = DDIMSampler(model)
         | 
| 83 | 
             
                return sampler
         | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 86 | 
             
                text_embeddings = clap_model.get_text_embeddings([prompt])
         | 
| 87 | 
             
                score_list = []
         | 
| 88 | 
             
                for data in wav_list:
         | 
| 89 | 
            -
                    sr,wav = data
         | 
| 90 | 
            -
                    audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav),sr)], resample=True)
         | 
| 91 | 
            -
                    score = clap_model.compute_similarity(audio_embeddings, text_embeddings, | 
|  | |
| 92 | 
             
                    score_list.append(score)
         | 
| 93 | 
             
                max_index = np.array(score_list).argmax()
         | 
| 94 | 
            -
                print(score_list,max_index)
         | 
| 95 | 
             
                return wav_list[max_index]
         | 
| 96 |  | 
| 97 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 98 | 
             
            class T2I:
         | 
| 99 | 
             
                def __init__(self, device):
         | 
| 100 | 
             
                    print("Initializing T2I to %s" % device)
         | 
| @@ -102,14 +124,14 @@ class T2I: | |
| 102 | 
             
                    self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
         | 
| 103 | 
             
                    self.text_refine_tokenizer = AutoTokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
         | 
| 104 | 
             
                    self.text_refine_model = AutoModelForCausalLM.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
         | 
| 105 | 
            -
                    self.text_refine_gpt2_pipe = pipeline("text-generation", model=self.text_refine_model, | 
|  | |
| 106 | 
             
                    self.pipe.to(device)
         | 
| 107 |  | 
| 108 | 
             
                @prompts(name="Generate Image From User Input Text",
         | 
| 109 | 
             
                         description="useful when you want to generate an image from a user input text and save it to a file. "
         | 
| 110 | 
             
                                     "like: generate an image of an object or something, or generate an image that includes some objects. "
         | 
| 111 | 
             
                                     "The input to this tool should be a string, representing the text used to generate image. ")
         | 
| 112 | 
            -
             | 
| 113 | 
             
                def inference(self, text):
         | 
| 114 | 
             
                    image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
         | 
| 115 | 
             
                    refined_text = self.text_refine_gpt2_pipe(text)[0]["generated_text"]
         | 
| @@ -119,58 +141,60 @@ class T2I: | |
| 119 | 
             
                    print(f"Processed T2I.run, text: {text}, image_filename: {image_filename}")
         | 
| 120 | 
             
                    return image_filename
         | 
| 121 |  | 
|  | |
| 122 | 
             
            class ImageCaptioning:
         | 
| 123 | 
             
                def __init__(self, device):
         | 
| 124 | 
             
                    print("Initializing ImageCaptioning to %s" % device)
         | 
| 125 | 
             
                    self.device = device
         | 
| 126 | 
             
                    self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
         | 
| 127 | 
            -
                    self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to( | 
| 128 | 
            -
             | 
| 129 |  | 
| 130 | 
             
                @prompts(name="Remove Something From The Photo",
         | 
| 131 | 
             
                         description="useful when you want to remove and object or something from the photo "
         | 
| 132 | 
             
                                     "from its description or location. "
         | 
| 133 | 
             
                                     "The input to this tool should be a comma separated string of two, "
         | 
| 134 | 
             
                                     "representing the image_path and the object need to be removed. ")
         | 
| 135 | 
            -
             | 
| 136 | 
             
                def inference(self, image_path):
         | 
| 137 | 
             
                    inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device)
         | 
| 138 | 
             
                    out = self.model.generate(**inputs)
         | 
| 139 | 
             
                    captions = self.processor.decode(out[0], skip_special_tokens=True)
         | 
| 140 | 
             
                    return captions
         | 
| 141 |  | 
|  | |
| 142 | 
             
            class T2A:
         | 
| 143 | 
             
                def __init__(self, device):
         | 
| 144 | 
             
                    print("Initializing Make-An-Audio to %s" % device)
         | 
| 145 | 
             
                    self.device = device
         | 
| 146 | 
            -
                    self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/text-to-audio/txt2audio_args.yaml', | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
|  | |
| 149 |  | 
| 150 | 
            -
                def txt2audio(self, text, seed | 
| 151 | 
             
                    SAMPLE_RATE = 16000
         | 
| 152 | 
             
                    prng = np.random.RandomState(seed)
         | 
| 153 | 
             
                    start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
         | 
| 154 | 
             
                    start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
         | 
| 155 | 
             
                    uc = self.sampler.model.get_learned_conditioning(n_samples * [""])
         | 
| 156 | 
             
                    c = self.sampler.model.get_learned_conditioning(n_samples * [text])
         | 
| 157 | 
            -
                    shape = [self.sampler.model.first_stage_model.embed_dim, H//8, W//8]  # (z_dim, 80//2^x, 848//2^x)
         | 
| 158 | 
            -
                    samples_ddim, _ = self.sampler.sample(S | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
| 161 | 
            -
             | 
| 162 | 
            -
             | 
| 163 | 
            -
             | 
| 164 | 
            -
             | 
| 165 | 
            -
             | 
| 166 |  | 
| 167 | 
             
                    x_samples_ddim = self.sampler.model.decode_first_stage(samples_ddim)
         | 
| 168 | 
            -
                    x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) | 
| 169 |  | 
| 170 | 
             
                    wav_list = []
         | 
| 171 | 
            -
                    for idx,spec in enumerate(x_samples_ddim):
         | 
| 172 | 
             
                        wav = self.vocoder.vocode(spec)
         | 
| 173 | 
            -
                        wav_list.append((SAMPLE_RATE,wav))
         | 
| 174 | 
             
                    best_wav = select_best_audio(text, wav_list)
         | 
| 175 | 
             
                    return best_wav
         | 
| 176 |  | 
| @@ -179,56 +203,57 @@ class T2A: | |
| 179 | 
             
                                     "from a user input text and it saved it to a file."
         | 
| 180 | 
             
                                     "The input to this tool should be a string, "
         | 
| 181 | 
             
                                     "representing the text used to generate audio.")
         | 
| 182 | 
            -
                
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                    melbins,mel_len = 80,624
         | 
| 185 | 
             
                    with torch.no_grad():
         | 
| 186 | 
             
                        result = self.txt2audio(
         | 
| 187 | 
            -
                            text | 
| 188 | 
            -
                            H | 
| 189 | 
            -
                            W | 
| 190 | 
             
                        )
         | 
| 191 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 192 | 
            -
                    soundfile.write(audio_filename, result[1], samplerate | 
| 193 | 
             
                    print(f"Processed T2I.run, text: {text}, audio_filename: {audio_filename}")
         | 
| 194 | 
             
                    return audio_filename
         | 
| 195 |  | 
|  | |
| 196 | 
             
            class I2A:
         | 
| 197 | 
             
                def __init__(self, device):
         | 
| 198 | 
             
                    print("Initializing Make-An-Audio-Image to %s" % device)
         | 
| 199 | 
             
                    self.device = device
         | 
| 200 | 
            -
                    self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/img_to_audio/img2audio_args.yaml', | 
| 201 | 
            -
             | 
|  | |
|  | |
| 202 |  | 
| 203 | 
            -
             | 
| 204 | 
            -
                def img2audio(self, image, seed = 55, scale = 3, ddim_steps = 100, W = 624, H = 80):
         | 
| 205 | 
             
                    SAMPLE_RATE = 16000
         | 
| 206 | 
            -
                    n_samples = 1 | 
| 207 | 
             
                    prng = np.random.RandomState(seed)
         | 
| 208 | 
             
                    start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
         | 
| 209 | 
             
                    start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
         | 
| 210 | 
             
                    uc = self.sampler.model.get_learned_conditioning(n_samples * [""])
         | 
| 211 | 
            -
                    #image = Image.fromarray(image)
         | 
| 212 | 
             
                    image = Image.open(image)
         | 
| 213 | 
             
                    image = self.sampler.model.cond_stage_model.preprocess(image).unsqueeze(0)
         | 
| 214 | 
             
                    image_embedding = self.sampler.model.cond_stage_model.forward_img(image)
         | 
| 215 | 
             
                    c = image_embedding.repeat(n_samples, 1, 1)
         | 
| 216 | 
            -
                    shape = [self.sampler.model.first_stage_model.embed_dim, H//8, W//8]  # (z_dim, 80//2^x, 848//2^x)
         | 
| 217 | 
             
                    samples_ddim, _ = self.sampler.sample(S=ddim_steps,
         | 
| 218 | 
            -
             | 
| 219 | 
            -
             | 
| 220 | 
            -
             | 
| 221 | 
            -
             | 
| 222 | 
            -
             | 
| 223 | 
            -
             | 
| 224 | 
            -
             | 
| 225 |  | 
| 226 | 
             
                    x_samples_ddim = self.sampler.model.decode_first_stage(samples_ddim)
         | 
| 227 | 
            -
                    x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) | 
| 228 | 
             
                    wav_list = []
         | 
| 229 | 
            -
                    for idx,spec in enumerate(x_samples_ddim):
         | 
| 230 | 
             
                        wav = self.vocoder.vocode(spec)
         | 
| 231 | 
            -
                        wav_list.append((SAMPLE_RATE,wav))
         | 
| 232 | 
             
                    best_wav = wav_list[0]
         | 
| 233 | 
             
                    return best_wav
         | 
| 234 |  | 
| @@ -237,44 +262,44 @@ class I2A: | |
| 237 | 
             
                                     "based on an image. "
         | 
| 238 | 
             
                                     "The input to this tool should be a string, "
         | 
| 239 | 
             
                                     "representing the image_path. ")
         | 
| 240 | 
            -
                
         | 
| 241 | 
            -
             | 
| 242 | 
            -
                    melbins,mel_len = 80,624
         | 
| 243 | 
             
                    with torch.no_grad():
         | 
| 244 | 
             
                        result = self.img2audio(
         | 
| 245 | 
             
                            image=image,
         | 
| 246 | 
            -
                            H=melbins, | 
| 247 | 
             
                            W=mel_len
         | 
| 248 | 
             
                        )
         | 
| 249 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 250 | 
            -
                    soundfile.write(audio_filename, result[1], samplerate | 
| 251 | 
             
                    print(f"Processed I2a.run, image_filename: {image}, audio_filename: {audio_filename}")
         | 
| 252 | 
             
                    return audio_filename
         | 
| 253 |  | 
|  | |
| 254 | 
             
            class TTS:
         | 
| 255 | 
             
                def __init__(self, device=None):
         | 
| 256 | 
             
                    self.model = TTSInference(device)
         | 
| 257 | 
            -
             | 
| 258 | 
             
                @prompts(name="Synthesize Speech Given the User Input Text",
         | 
| 259 | 
             
                         description="useful for when you want to convert a user input text into speech audio it saved it to a file."
         | 
| 260 | 
             
                                     "The input to this tool should be a string, "
         | 
| 261 | 
             
                                     "representing the text used to be converted to speech.")
         | 
| 262 | 
            -
             | 
| 263 | 
             
                def inference(self, text):
         | 
| 264 | 
             
                    inp = {"text": text}
         | 
| 265 | 
             
                    out = self.model.infer_once(inp)
         | 
| 266 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 267 | 
            -
                    soundfile.write(audio_filename, out, samplerate | 
| 268 | 
             
                    return audio_filename
         | 
| 269 |  | 
|  | |
| 270 | 
             
            class T2S:
         | 
| 271 | 
            -
                def __init__(self, device= | 
| 272 | 
             
                    if device is None:
         | 
| 273 | 
             
                        device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 274 | 
             
                    print("Initializing DiffSinger to %s" % device)
         | 
| 275 | 
             
                    self.device = device
         | 
| 276 | 
             
                    self.exp_name = 'checkpoints/0831_opencpop_ds1000'
         | 
| 277 | 
            -
                    self.config= 'NeuralSeq/egs/egs_bases/svs/midi/e2e/opencpop/ds1000.yaml'
         | 
| 278 | 
             
                    self.set_model_hparams()
         | 
| 279 | 
             
                    self.pipe = DiffSingerE2EInfer(self.hp, device)
         | 
| 280 | 
             
                    self.default_inp = {
         | 
| @@ -283,7 +308,6 @@ class T2S: | |
| 283 | 
             
                        'notes_duration': '0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590'
         | 
| 284 | 
             
                    }
         | 
| 285 |  | 
| 286 | 
            -
             | 
| 287 | 
             
                def set_model_hparams(self):
         | 
| 288 | 
             
                    set_hparams(config=self.config, exp_name=self.exp_name, print_hparams=False)
         | 
| 289 | 
             
                    self.hp = hp
         | 
| @@ -296,7 +320,6 @@ class T2S: | |
| 296 | 
             
                                     "Or Like: Generate a piece of singing voice. Text is xxx, note is xxx, duration is xxx."
         | 
| 297 | 
             
                                     "The input to this tool should be a comma seperated string of three, "
         | 
| 298 | 
             
                                     "representing text, note and duration sequence since User Input Text, Note and Duration Sequence are all provided. ")
         | 
| 299 | 
            -
                
         | 
| 300 | 
             
                def inference(self, inputs):
         | 
| 301 | 
             
                    self.set_model_hparams()
         | 
| 302 | 
             
                    val = inputs.split(",")
         | 
| @@ -314,6 +337,7 @@ class T2S: | |
| 314 | 
             
                    print(f"Processed T2S.run, audio_filename: {audio_filename}")
         | 
| 315 | 
             
                    return audio_filename
         | 
| 316 |  | 
|  | |
| 317 | 
             
            class TTS_OOD:
         | 
| 318 | 
             
                def __init__(self, device):
         | 
| 319 | 
             
                    if device is None:
         | 
| @@ -340,8 +364,7 @@ class TTS_OOD: | |
| 340 | 
             
                                     "(e.g., timbre, emotion, and prosody) derived from a reference custom voice. "
         | 
| 341 | 
             
                                     "Like: Generate a speech with style transferred from this voice. The text is xxx., or speak using the voice of this audio. The text is xxx."
         | 
| 342 | 
             
                                     "The input to this tool should be a comma seperated string of two, "
         | 
| 343 | 
            -
                                     "representing reference audio path and input text. " | 
| 344 | 
            -
                
         | 
| 345 | 
             
                def inference(self, inputs):
         | 
| 346 | 
             
                    self.set_model_hparams()
         | 
| 347 | 
             
                    key = ['ref_audio', 'text']
         | 
| @@ -354,147 +377,154 @@ class TTS_OOD: | |
| 354 | 
             
                    print(
         | 
| 355 | 
             
                        f"Processed GenerSpeech.run. Input text:{val[1]}. Input reference audio: {val[0]}. Output Audio_filename: {audio_filename}")
         | 
| 356 | 
             
                    return audio_filename
         | 
| 357 | 
            -
             | 
|  | |
| 358 | 
             
            class Inpaint:
         | 
| 359 | 
             
                def __init__(self, device):
         | 
| 360 | 
             
                    print("Initializing Make-An-Audio-inpaint to %s" % device)
         | 
| 361 | 
             
                    self.device = device
         | 
| 362 | 
            -
                    self.sampler = initialize_model_inpaint('text_to_audio/Make_An_Audio/configs/inpaint/txt2audio_args.yaml', | 
| 363 | 
            -
             | 
|  | |
| 364 | 
             
                    self.cmap_transform = matplotlib.cm.viridis
         | 
| 365 |  | 
| 366 | 
             
                def make_batch_sd(self, mel, mask, num_samples=1):
         | 
| 367 |  | 
| 368 | 
            -
                    mel = torch.from_numpy(mel)[None,None | 
| 369 | 
            -
                    mask = torch.from_numpy(mask)[None,None | 
| 370 | 
             
                    masked_mel = (1 - mask) * mel
         | 
| 371 |  | 
| 372 | 
             
                    mel = mel * 2 - 1
         | 
| 373 | 
             
                    mask = mask * 2 - 1
         | 
| 374 | 
            -
                    masked_mel = masked_mel * 2 -1
         | 
| 375 |  | 
| 376 | 
             
                    batch = {
         | 
| 377 | 
            -
             | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
| 380 | 
             
                    }
         | 
| 381 | 
             
                    return batch
         | 
|  | |
| 382 | 
             
                def gen_mel(self, input_audio_path):
         | 
| 383 | 
             
                    SAMPLE_RATE = 16000
         | 
| 384 | 
             
                    sr, ori_wav = wavfile.read(input_audio_path)
         | 
| 385 | 
             
                    print("gen_mel")
         | 
| 386 | 
            -
                    print(sr,ori_wav.shape,ori_wav)
         | 
| 387 | 
             
                    ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0
         | 
| 388 | 
            -
                    if len(ori_wav.shape)==2 | 
| 389 | 
            -
                        ori_wav = librosa.to_mono( | 
| 390 | 
            -
             | 
| 391 | 
            -
                     | 
|  | |
| 392 |  | 
| 393 | 
            -
                    mel_len,hop_size = 848,256
         | 
| 394 | 
             
                    input_len = mel_len * hop_size
         | 
| 395 | 
             
                    if len(ori_wav) < input_len:
         | 
| 396 | 
            -
                        input_wav = np.pad(ori_wav,(0,mel_len*hop_size),constant_values=0)
         | 
| 397 | 
             
                    else:
         | 
| 398 | 
             
                        input_wav = ori_wav[:input_len]
         | 
| 399 | 
            -
             | 
| 400 | 
             
                    mel = TRANSFORMS_16000(input_wav)
         | 
| 401 | 
             
                    return mel
         | 
|  | |
| 402 | 
             
                def gen_mel_audio(self, input_audio):
         | 
| 403 | 
             
                    SAMPLE_RATE = 16000
         | 
| 404 | 
            -
                    sr,ori_wav = input_audio
         | 
| 405 | 
             
                    print("gen_mel_audio")
         | 
| 406 | 
            -
                    print(sr,ori_wav.shape,ori_wav)
         | 
| 407 |  | 
| 408 | 
             
                    ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0
         | 
| 409 | 
            -
                    if len(ori_wav.shape)==2 | 
| 410 | 
            -
                        ori_wav = librosa.to_mono( | 
| 411 | 
            -
             | 
| 412 | 
            -
                     | 
|  | |
| 413 |  | 
| 414 | 
            -
                    mel_len,hop_size = 848,256
         | 
| 415 | 
             
                    input_len = mel_len * hop_size
         | 
| 416 | 
             
                    if len(ori_wav) < input_len:
         | 
| 417 | 
            -
                        input_wav = np.pad(ori_wav,(0,mel_len*hop_size),constant_values=0)
         | 
| 418 | 
             
                    else:
         | 
| 419 | 
             
                        input_wav = ori_wav[:input_len]
         | 
| 420 | 
             
                    mel = TRANSFORMS_16000(input_wav)
         | 
| 421 | 
             
                    return mel
         | 
|  | |
| 422 | 
             
                def inpaint(self, batch, seed, ddim_steps, num_samples=1, W=512, H=512):
         | 
| 423 | 
             
                    model = self.sampler.model
         | 
| 424 | 
            -
             | 
| 425 | 
             
                    prng = np.random.RandomState(seed)
         | 
| 426 | 
             
                    start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8)
         | 
| 427 | 
             
                    start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
         | 
| 428 |  | 
| 429 | 
             
                    c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"]))
         | 
| 430 | 
             
                    cc = torch.nn.functional.interpolate(batch["mask"],
         | 
| 431 | 
            -
             | 
| 432 | 
            -
                    c = torch.cat((c, cc), dim=1) | 
| 433 |  | 
| 434 | 
            -
                    shape = (c.shape[1]-1,)+c.shape[2:]
         | 
| 435 | 
             
                    samples_ddim, _ = self.sampler.sample(S=ddim_steps,
         | 
| 436 | 
            -
             | 
| 437 | 
            -
             | 
| 438 | 
            -
             | 
| 439 | 
            -
             | 
| 440 | 
             
                    x_samples_ddim = model.decode_first_stage(samples_ddim)
         | 
| 441 |  | 
| 442 | 
            -
             | 
| 443 | 
            -
                     | 
| 444 | 
            -
                     | 
| 445 | 
            -
                     | 
| 446 | 
            -
                     | 
| 447 | 
            -
                    inpainted = (1-mask)*mel+mask*predicted_mel
         | 
| 448 | 
             
                    inpainted = inpainted.cpu().numpy().squeeze()
         | 
| 449 | 
             
                    inapint_wav = self.vocoder.vocode(inpainted)
         | 
| 450 |  | 
| 451 | 
             
                    return inpainted, inapint_wav
         | 
| 452 | 
            -
             | 
|  | |
| 453 | 
             
                    SAMPLE_RATE = 16000
         | 
| 454 | 
             
                    torch.set_grad_enabled(False)
         | 
| 455 | 
             
                    mel_img = Image.open(mel_and_mask['image'])
         | 
| 456 | 
             
                    mask_img = Image.open(mel_and_mask["mask"])
         | 
| 457 | 
            -
                    show_mel = np.array(mel_img.convert("L"))/255
         | 
| 458 | 
            -
                    mask = np.array(mask_img.convert("L"))/255
         | 
| 459 | 
            -
                    mel_bins,mel_len = 80,848
         | 
| 460 | 
            -
                    input_mel = self.gen_mel_audio(input_audio)[ | 
| 461 | 
            -
                    mask = np.pad(mask,((0,0),(0,mel_len-mask.shape[1])),mode='constant',constant_values=0)
         | 
| 462 | 
            -
                    print(mask.shape,input_mel.shape)
         | 
| 463 | 
             
                    with torch.no_grad():
         | 
| 464 | 
            -
                        batch = self.make_batch_sd(input_mel,mask,num_samples=1)
         | 
| 465 | 
            -
                        inpainted,gen_wav = self.inpaint(
         | 
| 466 | 
             
                            batch=batch,
         | 
| 467 | 
             
                            seed=seed,
         | 
| 468 | 
             
                            ddim_steps=ddim_steps,
         | 
| 469 | 
             
                            num_samples=1,
         | 
| 470 | 
             
                            H=mel_bins, W=mel_len
         | 
| 471 | 
             
                        )
         | 
| 472 | 
            -
                    inpainted = inpainted[ | 
| 473 | 
             
                    color_mel = self.cmap_transform(inpainted)
         | 
| 474 | 
             
                    input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0])
         | 
| 475 | 
             
                    gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len]
         | 
| 476 | 
            -
                    image = Image.fromarray((color_mel*255).astype(np.uint8))
         | 
| 477 | 
             
                    image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
         | 
| 478 | 
             
                    image.save(image_filename)
         | 
| 479 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 480 | 
            -
                    soundfile.write(audio_filename, gen_wav, samplerate | 
| 481 | 
             
                    return image_filename, audio_filename
         | 
| 482 |  | 
| 483 | 
             
                @prompts(name="Audio Inpainting",
         | 
| 484 | 
             
                         description="useful for when you want to inpaint a mel spectrum of an audio and predict this audio, "
         | 
| 485 | 
             
                                     "this tool will generate a mel spectrum and you can inpaint it, receives audio_path as input. "
         | 
| 486 | 
             
                                     "The input to this tool should be a string, "
         | 
| 487 | 
            -
                                     "representing the audio_path. " | 
| 488 | 
            -
                
         | 
| 489 | 
             
                def inference(self, input_audio_path):
         | 
| 490 | 
             
                    crop_len = 500
         | 
| 491 | 
            -
                    crop_mel = self.gen_mel(input_audio_path)[ | 
| 492 | 
             
                    color_mel = self.cmap_transform(crop_mel)
         | 
| 493 | 
            -
                    image = Image.fromarray((color_mel*255).astype(np.uint8))
         | 
| 494 | 
             
                    image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
         | 
| 495 | 
             
                    image.save(image_filename)
         | 
| 496 | 
             
                    return image_filename
         | 
| 497 | 
            -
             | 
|  | |
| 498 | 
             
            class ASR:
         | 
| 499 | 
             
                def __init__(self, device):
         | 
| 500 | 
             
                    print("Initializing Whisper to %s" % device)
         | 
| @@ -505,8 +535,7 @@ class ASR: | |
| 505 | 
             
                         description="useful for when you want to know the text corresponding to a human speech, "
         | 
| 506 | 
             
                                     "receives audio_path as input. "
         | 
| 507 | 
             
                                     "The input to this tool should be a string, "
         | 
| 508 | 
            -
                                     "representing the audio_path. " | 
| 509 | 
            -
             | 
| 510 | 
             
                def inference(self, audio_path):
         | 
| 511 | 
             
                    audio = whisper.load_audio(audio_path)
         | 
| 512 | 
             
                    audio = whisper.pad_or_trim(audio)
         | 
| @@ -516,6 +545,11 @@ class ASR: | |
| 516 | 
             
                    result = whisper.decode(self.model, mel, options)
         | 
| 517 | 
             
                    return result.text
         | 
| 518 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 519 | 
             
            class A2T:
         | 
| 520 | 
             
                def __init__(self, device):
         | 
| 521 | 
             
                    print("Initializing Audio-To-Text Model to %s" % device)
         | 
| @@ -526,13 +560,13 @@ class A2T: | |
| 526 | 
             
                         description="useful for when you want to describe an audio in text, "
         | 
| 527 | 
             
                                     "receives audio_path as input. "
         | 
| 528 | 
             
                                     "The input to this tool should be a string, "
         | 
| 529 | 
            -
                                     "representing the audio_path. " | 
| 530 | 
            -
             | 
| 531 | 
             
                def inference(self, audio_path):
         | 
| 532 | 
             
                    audio = whisper.load_audio(audio_path)
         | 
| 533 | 
             
                    caption_text = self.model(audio)
         | 
| 534 | 
             
                    return caption_text[0]
         | 
| 535 |  | 
|  | |
| 536 | 
             
            class SoundDetection:
         | 
| 537 | 
             
                def __init__(self, device):
         | 
| 538 | 
             
                    self.device = device
         | 
| @@ -548,9 +582,9 @@ class SoundDetection: | |
| 548 | 
             
                    self.labels = detection_config.labels
         | 
| 549 | 
             
                    self.frames_per_second = self.sample_rate // self.hop_size
         | 
| 550 | 
             
                    # Model = eval(self.model_type)
         | 
| 551 | 
            -
                    self.model = PVT(sample_rate=self.sample_rate, window_size=self.window_size, | 
| 552 | 
            -
             | 
| 553 | 
            -
             | 
| 554 | 
             
                    checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
         | 
| 555 | 
             
                    self.model.load_state_dict(checkpoint['model'])
         | 
| 556 | 
             
                    self.model.to(device)
         | 
| @@ -559,12 +593,11 @@ class SoundDetection: | |
| 559 | 
             
                         description="useful for when you want to know what event in the audio and the sound event start or end time, it will return an image "
         | 
| 560 | 
             
                                     "receives audio_path as input. "
         | 
| 561 | 
             
                                     "The input to this tool should be a string, "
         | 
| 562 | 
            -
                                     "representing the audio_path. " | 
| 563 | 
            -
                
         | 
| 564 | 
             
                def inference(self, audio_path):
         | 
| 565 | 
             
                    # Forward
         | 
| 566 | 
             
                    (waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True)
         | 
| 567 | 
            -
                    waveform = waveform[None, :] | 
| 568 | 
             
                    waveform = torch.from_numpy(waveform)
         | 
| 569 | 
             
                    waveform = waveform.to(self.device)
         | 
| 570 | 
             
                    # Forward
         | 
| @@ -579,11 +612,11 @@ class SoundDetection: | |
| 579 | 
             
                    import matplotlib.pyplot as plt
         | 
| 580 | 
             
                    sorted_indexes = np.argsort(np.max(framewise_output, axis=0))[::-1]
         | 
| 581 | 
             
                    top_k = 10  # Show top results
         | 
| 582 | 
            -
                    top_result_mat = framewise_output[:, sorted_indexes[0 | 
| 583 | 
             
                    """(time_steps, top_k)"""
         | 
| 584 | 
            -
                    # Plot result | 
| 585 | 
            -
                    stft = librosa.core.stft(y=waveform[0].data.cpu().numpy(), n_fft=self.window_size, | 
| 586 | 
            -
             | 
| 587 | 
             
                    frames_num = stft.shape[-1]
         | 
| 588 | 
             
                    fig, axs = plt.subplots(2, 1, sharex=True, figsize=(10, 4))
         | 
| 589 | 
             
                    axs[0].matshow(np.log(np.abs(stft)), origin='lower', aspect='auto', cmap='jet')
         | 
| @@ -593,7 +626,7 @@ class SoundDetection: | |
| 593 | 
             
                    axs[1].xaxis.set_ticks(np.arange(0, frames_num, self.frames_per_second))
         | 
| 594 | 
             
                    axs[1].xaxis.set_ticklabels(np.arange(0, frames_num / self.frames_per_second))
         | 
| 595 | 
             
                    axs[1].yaxis.set_ticks(np.arange(0, top_k))
         | 
| 596 | 
            -
                    axs[1].yaxis.set_ticklabels(np.array(self.labels)[sorted_indexes[0 | 
| 597 | 
             
                    axs[1].yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3)
         | 
| 598 | 
             
                    axs[1].set_xlabel('Seconds')
         | 
| 599 | 
             
                    axs[1].xaxis.set_ticks_position('bottom')
         | 
| @@ -602,6 +635,7 @@ class SoundDetection: | |
| 602 | 
             
                    plt.savefig(image_filename)
         | 
| 603 | 
             
                    return image_filename
         | 
| 604 |  | 
|  | |
| 605 | 
             
            class SoundExtraction:
         | 
| 606 | 
             
                def __init__(self, device):
         | 
| 607 | 
             
                    self.device = device
         | 
| @@ -617,25 +651,24 @@ class SoundExtraction: | |
| 617 | 
             
                         description="useful for when you extract target sound from a mixture audio, you can describe the target sound by text, "
         | 
| 618 | 
             
                                     "receives audio_path and text as input. "
         | 
| 619 | 
             
                                     "The input to this tool should be a comma seperated string of two, "
         | 
| 620 | 
            -
                                     "representing mixture audio path and input text." | 
| 621 | 
            -
                
         | 
| 622 | 
             
                def inference(self, inputs):
         | 
| 623 | 
            -
                    #key = ['ref_audio', 'text']
         | 
| 624 | 
             
                    val = inputs.split(",")
         | 
| 625 | 
            -
                    audio_path = val[0] | 
| 626 | 
             
                    text = val[1]
         | 
| 627 | 
             
                    waveform = load_wav(audio_path)
         | 
| 628 | 
            -
                    waveform = torch.tensor(waveform).transpose(1,0)
         | 
| 629 | 
             
                    mixed_mag, mixed_phase = self.stft.transform(waveform)
         | 
| 630 | 
             
                    text_query = ['[CLS] ' + text]
         | 
| 631 | 
            -
                    mixed_mag = mixed_mag.transpose(2,1).unsqueeze(0).to(self.device)
         | 
| 632 | 
             
                    est_mask = self.model(mixed_mag, text_query)
         | 
| 633 | 
            -
                    est_mag = est_mask * mixed_mag | 
| 634 | 
            -
                    est_mag = est_mag.squeeze(1) | 
| 635 | 
            -
                    est_mag = est_mag.permute(0, 2, 1) | 
| 636 | 
             
                    est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
         | 
| 637 | 
            -
                    est_wav = est_wav.squeeze(0).squeeze(0).numpy() | 
| 638 | 
            -
                    #est_path = f'output/est{i}.wav'
         | 
| 639 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 640 | 
             
                    print('audio_filename ', audio_filename)
         | 
| 641 | 
             
                    save_wav(est_wav, audio_filename)
         | 
| @@ -652,9 +685,9 @@ class Binaural: | |
| 652 | 
             
                                          'mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
         | 
| 653 | 
             
                                          'mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
         | 
| 654 | 
             
                    self.net = BinauralNetwork(view_dim=7,
         | 
| 655 | 
            -
             | 
| 656 | 
            -
             | 
| 657 | 
            -
             | 
| 658 | 
             
                    self.net.load_from_file(self.model_file)
         | 
| 659 | 
             
                    self.sr = 48000
         | 
| 660 |  | 
| @@ -662,33 +695,32 @@ class Binaural: | |
| 662 | 
             
                         description="useful for when you want to transfer your mono audio into binaural audio, "
         | 
| 663 | 
             
                                     "receives audio_path as input. "
         | 
| 664 | 
             
                                     "The input to this tool should be a string, "
         | 
| 665 | 
            -
                                     "representing the audio_path. " | 
| 666 | 
            -
                
         | 
| 667 | 
             
                def inference(self, audio_path):
         | 
| 668 | 
            -
                    mono, sr | 
| 669 | 
             
                    mono = torch.from_numpy(mono)
         | 
| 670 | 
             
                    mono = mono.unsqueeze(0)
         | 
| 671 | 
             
                    import numpy as np
         | 
| 672 | 
             
                    import random
         | 
| 673 | 
            -
                    rand_int = random.randint(0,4)
         | 
| 674 | 
             
                    view = np.loadtxt(self.position_file[rand_int]).transpose().astype(np.float32)
         | 
| 675 | 
             
                    view = torch.from_numpy(view)
         | 
| 676 | 
             
                    if not view.shape[-1] * 400 == mono.shape[-1]:
         | 
| 677 | 
            -
                        mono = mono[ | 
| 678 | 
            -
                        if view.shape[1]*400 > mono.shape[1]:
         | 
| 679 | 
            -
                            m_a = view.shape[1] - mono.shape[-1]//400 | 
| 680 | 
            -
                            rand_st = random.randint(0,m_a)
         | 
| 681 | 
            -
                            view = view[:,m_a:m_a+(mono.shape[-1]//400)] | 
| 682 | 
             
                    # binauralize and save output
         | 
| 683 | 
             
                    self.net.eval().to(self.device)
         | 
| 684 | 
             
                    mono, view = mono.to(self.device), view.to(self.device)
         | 
| 685 | 
             
                    chunk_size = 48000  # forward in chunks of 1s
         | 
| 686 | 
            -
                    rec_field = | 
| 687 | 
             
                    rec_field -= rec_field % 400  # make sure rec_field is a multiple of 400 to match audio and view frequencies
         | 
| 688 | 
             
                    chunks = [
         | 
| 689 | 
             
                        {
         | 
| 690 | 
            -
                            "mono": mono[:, max(0, i-rec_field):i+chunk_size],
         | 
| 691 | 
            -
                            "view": view[:, max(0, i-rec_field)//400:(i+chunk_size)//400]
         | 
| 692 | 
             
                        }
         | 
| 693 | 
             
                        for i in range(0, mono.shape[-1], chunk_size)
         | 
| 694 | 
             
                    ]
         | 
| @@ -698,18 +730,19 @@ class Binaural: | |
| 698 | 
             
                            view = chunk["view"].unsqueeze(0)
         | 
| 699 | 
             
                            binaural = self.net(mono, view).squeeze(0)
         | 
| 700 | 
             
                            if i > 0:
         | 
| 701 | 
            -
                                binaural = binaural[:, -(mono.shape[-1]-rec_field):]
         | 
| 702 | 
             
                            chunk["binaural"] = binaural
         | 
| 703 | 
             
                    binaural = torch.cat([chunk["binaural"] for chunk in chunks], dim=-1)
         | 
| 704 | 
             
                    binaural = torch.clamp(binaural, min=-1, max=1).cpu()
         | 
| 705 | 
            -
                    #binaural = chunked_forwarding(net, mono, view)
         | 
| 706 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 707 | 
             
                    import torchaudio
         | 
| 708 | 
             
                    torchaudio.save(audio_filename, binaural, sr)
         | 
| 709 | 
            -
                    #soundfile.write(audio_filename, binaural, samplerate = 48000)
         | 
| 710 | 
             
                    print(f"Processed Binaural.run, audio_filename: {audio_filename}")
         | 
| 711 | 
             
                    return audio_filename
         | 
| 712 |  | 
|  | |
| 713 | 
             
            class TargetSoundDetection:
         | 
| 714 | 
             
                def __init__(self, device):
         | 
| 715 | 
             
                    self.device = device
         | 
| @@ -722,18 +755,23 @@ class TargetSoundDetection: | |
| 722 | 
             
                    self.EPS = np.spacing(1)
         | 
| 723 | 
             
                    self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
         | 
| 724 | 
             
                    self.event_labels = event_labels
         | 
| 725 | 
            -
                    self.id_to_event = | 
| 726 | 
            -
                    config = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth', | 
|  | |
| 727 | 
             
                    config_parameters = dict(config)
         | 
| 728 | 
             
                    config_parameters['tao'] = 0.6
         | 
| 729 | 
             
                    if 'thres' not in config_parameters.keys():
         | 
| 730 | 
             
                        config_parameters['thres'] = 0.5
         | 
| 731 | 
             
                    if 'time_resolution' not in config_parameters.keys():
         | 
| 732 | 
             
                        config_parameters['time_resolution'] = 125
         | 
| 733 | 
            -
                    model_parameters = torch.load( | 
| 734 | 
            -
             | 
|  | |
| 735 | 
             
                    self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
         | 
| 736 | 
            -
             | 
|  | |
|  | |
|  | |
| 737 | 
             
                    self.model.load_state_dict(model_parameters)
         | 
| 738 | 
             
                    self.model = self.model.to(self.device).eval()
         | 
| 739 | 
             
                    self.re_embeds = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
         | 
| @@ -743,18 +781,18 @@ class TargetSoundDetection: | |
| 743 | 
             
                    import soundfile as sf
         | 
| 744 | 
             
                    y, sr = sf.read(fname, dtype='float32')
         | 
| 745 | 
             
                    print('y ', y.shape)
         | 
| 746 | 
            -
                    ti = y.shape[0]/sr
         | 
| 747 | 
             
                    if y.ndim > 1:
         | 
| 748 | 
             
                        y = y.mean(1)
         | 
| 749 | 
             
                    y = librosa.resample(y, sr, 22050)
         | 
| 750 | 
             
                    lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
         | 
| 751 | 
            -
                    return lms_feature,ti
         | 
| 752 | 
            -
             | 
| 753 | 
             
                def build_clip(self, text):
         | 
| 754 | 
            -
                    text = clip.tokenize(text).to(self.device) | 
| 755 | 
             
                    text_features = self.clip_model.encode_text(text)
         | 
| 756 | 
             
                    return text_features
         | 
| 757 | 
            -
             | 
| 758 | 
             
                def cal_similarity(self, target, retrievals):
         | 
| 759 | 
             
                    ans = []
         | 
| 760 | 
             
                    for name in retrievals.keys():
         | 
| @@ -767,41 +805,229 @@ class TargetSoundDetection: | |
| 767 | 
             
                         description="useful for when you want to know when the target sound event in the audio happens. You can use language descriptions to instruct the model, "
         | 
| 768 | 
             
                                     "receives text description and audio_path as input. "
         | 
| 769 | 
             
                                     "The input to this tool should be a comma seperated string of two, "
         | 
| 770 | 
            -
                                     "representing audio path and the text description. " | 
| 771 | 
            -
                
         | 
| 772 | 
             
                def inference(self, inputs):
         | 
| 773 | 
             
                    audio_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
         | 
| 774 | 
            -
                    target_emb = self.build_clip(text) | 
| 775 | 
             
                    idx = self.cal_similarity(target_emb, self.re_embeds)
         | 
| 776 | 
             
                    target_event = self.id_to_event[idx]
         | 
| 777 | 
             
                    embedding = self.ref_mel[target_event]
         | 
| 778 | 
             
                    embedding = torch.from_numpy(embedding)
         | 
| 779 | 
             
                    embedding = embedding.unsqueeze(0).to(self.device).float()
         | 
| 780 | 
            -
                    inputs,ti = self.extract_feature(audio_path)
         | 
| 781 | 
             
                    inputs = torch.from_numpy(inputs)
         | 
| 782 | 
             
                    inputs = inputs.unsqueeze(0).to(self.device).float()
         | 
| 783 | 
             
                    decision, decision_up, logit = self.model(inputs, embedding)
         | 
| 784 | 
             
                    pred = decision_up.detach().cpu().numpy()
         | 
| 785 | 
            -
                    pred = pred[ | 
| 786 | 
             
                    frame_num = decision_up.shape[1]
         | 
| 787 | 
             
                    time_ratio = ti / frame_num
         | 
| 788 | 
             
                    filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
         | 
| 789 | 
             
                    time_predictions = []
         | 
| 790 | 
             
                    for index_k in range(filtered_pred.shape[0]):
         | 
| 791 | 
             
                        decoded_pred = []
         | 
| 792 | 
            -
                        decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k | 
| 793 | 
            -
                        if len(decoded_pred_) == 0: | 
| 794 | 
             
                            decoded_pred_.append((target_event, 0, 0))
         | 
| 795 | 
             
                        decoded_pred.append(decoded_pred_)
         | 
| 796 | 
            -
                        for num_batch in range(len(decoded_pred)): | 
| 797 | 
             
                            cur_pred = pred[num_batch]
         | 
| 798 | 
             
                            # Save each frame output, for later visualization
         | 
| 799 | 
            -
                            label_prediction = decoded_pred[num_batch] | 
| 800 | 
             
                            for event_label, onset, offset in label_prediction:
         | 
| 801 | 
             
                                time_predictions.append({
         | 
| 802 | 
            -
                                    'onset': onset*time_ratio,
         | 
| 803 | 
            -
                                    'offset': offset*time_ratio,})
         | 
| 804 | 
             
                    ans = ''
         | 
| 805 | 
            -
                    for i,item in enumerate(time_predictions):
         | 
| 806 | 
            -
                        ans = ans + 'segment' + str(i+1) + ' start_time: ' + str(item['onset']) + '  end_time: ' + str( | 
| 807 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
             
            import sys
         | 
| 2 | 
             
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
             
            sys.path.append(os.path.dirname(os.path.realpath(__file__)))
         | 
| 5 | 
             
            sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
         | 
| 6 | 
             
            sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NeuralSeq'))
         | 
|  | |
| 54 | 
             
            from target_sound_detection.src.utils import median_filter, decode_with_timestamps
         | 
| 55 | 
             
            import clip
         | 
| 56 |  | 
| 57 | 
            +
             | 
| 58 | 
             
            def prompts(name, description):
         | 
| 59 | 
             
                def decorator(func):
         | 
| 60 | 
             
                    func.name = name
         | 
|  | |
| 63 |  | 
| 64 | 
             
                return decorator
         | 
| 65 |  | 
| 66 | 
            +
             | 
| 67 | 
             
            def initialize_model(config, ckpt, device):
         | 
| 68 | 
             
                config = OmegaConf.load(config)
         | 
| 69 | 
             
                model = instantiate_from_config(config.model)
         | 
| 70 | 
            +
                model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
         | 
| 71 |  | 
| 72 | 
             
                model = model.to(device)
         | 
| 73 | 
             
                model.cond_stage_model.to(model.device)
         | 
|  | |
| 75 | 
             
                sampler = DDIMSampler(model)
         | 
| 76 | 
             
                return sampler
         | 
| 77 |  | 
| 78 | 
            +
             | 
| 79 | 
             
            def initialize_model_inpaint(config, ckpt):
         | 
| 80 | 
             
                config = OmegaConf.load(config)
         | 
| 81 | 
             
                model = instantiate_from_config(config.model)
         | 
| 82 | 
            +
                model.load_state_dict(torch.load(ckpt, map_location='cpu')["state_dict"], strict=False)
         | 
| 83 | 
             
                device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         | 
| 84 | 
             
                model = model.to(device)
         | 
| 85 | 
            +
                print(model.device, device, model.cond_stage_model.device)
         | 
| 86 | 
             
                sampler = DDIMSampler(model)
         | 
| 87 | 
             
                return sampler
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def select_best_audio(prompt, wav_list):
         | 
| 91 | 
            +
                clap_model = CLAPWrapper('text_to_audio/Make_An_Audio/useful_ckpts/CLAP/CLAP_weights_2022.pth',
         | 
| 92 | 
            +
                                         'text_to_audio/Make_An_Audio/useful_ckpts/CLAP/config.yml',
         | 
| 93 | 
            +
                                         use_cuda=torch.cuda.is_available())
         | 
| 94 | 
             
                text_embeddings = clap_model.get_text_embeddings([prompt])
         | 
| 95 | 
             
                score_list = []
         | 
| 96 | 
             
                for data in wav_list:
         | 
| 97 | 
            +
                    sr, wav = data
         | 
| 98 | 
            +
                    audio_embeddings = clap_model.get_audio_embeddings([(torch.FloatTensor(wav), sr)], resample=True)
         | 
| 99 | 
            +
                    score = clap_model.compute_similarity(audio_embeddings, text_embeddings,
         | 
| 100 | 
            +
                                                          use_logit_scale=False).squeeze().cpu().numpy()
         | 
| 101 | 
             
                    score_list.append(score)
         | 
| 102 | 
             
                max_index = np.array(score_list).argmax()
         | 
| 103 | 
            +
                print(score_list, max_index)
         | 
| 104 | 
             
                return wav_list[max_index]
         | 
| 105 |  | 
| 106 |  | 
| 107 | 
            +
            def merge_audio(audio_path_1, audio_path_2):
         | 
| 108 | 
            +
                merged_signal = []
         | 
| 109 | 
            +
                sr_1, signal_1 = wavfile.read(audio_path_1)
         | 
| 110 | 
            +
                sr_2, signal_2 = wavfile.read(audio_path_2)
         | 
| 111 | 
            +
                merged_signal.append(signal_1)
         | 
| 112 | 
            +
                merged_signal.append(signal_2)
         | 
| 113 | 
            +
                merged_signal = np.hstack(merged_signal)
         | 
| 114 | 
            +
                merged_signal = np.asarray(merged_signal, dtype=np.int16)
         | 
| 115 | 
            +
                audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 116 | 
            +
                wavfile.write(audio_filename, sr_1, merged_signal)
         | 
| 117 | 
            +
                return audio_filename
         | 
| 118 | 
            +
             | 
| 119 | 
            +
             | 
| 120 | 
             
            class T2I:
         | 
| 121 | 
             
                def __init__(self, device):
         | 
| 122 | 
             
                    print("Initializing T2I to %s" % device)
         | 
|  | |
| 124 | 
             
                    self.pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
         | 
| 125 | 
             
                    self.text_refine_tokenizer = AutoTokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
         | 
| 126 | 
             
                    self.text_refine_model = AutoModelForCausalLM.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
         | 
| 127 | 
            +
                    self.text_refine_gpt2_pipe = pipeline("text-generation", model=self.text_refine_model,
         | 
| 128 | 
            +
                                                          tokenizer=self.text_refine_tokenizer, device=self.device)
         | 
| 129 | 
             
                    self.pipe.to(device)
         | 
| 130 |  | 
| 131 | 
             
                @prompts(name="Generate Image From User Input Text",
         | 
| 132 | 
             
                         description="useful when you want to generate an image from a user input text and save it to a file. "
         | 
| 133 | 
             
                                     "like: generate an image of an object or something, or generate an image that includes some objects. "
         | 
| 134 | 
             
                                     "The input to this tool should be a string, representing the text used to generate image. ")
         | 
|  | |
| 135 | 
             
                def inference(self, text):
         | 
| 136 | 
             
                    image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
         | 
| 137 | 
             
                    refined_text = self.text_refine_gpt2_pipe(text)[0]["generated_text"]
         | 
|  | |
| 141 | 
             
                    print(f"Processed T2I.run, text: {text}, image_filename: {image_filename}")
         | 
| 142 | 
             
                    return image_filename
         | 
| 143 |  | 
| 144 | 
            +
             | 
| 145 | 
             
            class ImageCaptioning:
         | 
| 146 | 
             
                def __init__(self, device):
         | 
| 147 | 
             
                    print("Initializing ImageCaptioning to %s" % device)
         | 
| 148 | 
             
                    self.device = device
         | 
| 149 | 
             
                    self.processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
         | 
| 150 | 
            +
                    self.model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(
         | 
| 151 | 
            +
                        self.device)
         | 
| 152 |  | 
| 153 | 
             
                @prompts(name="Remove Something From The Photo",
         | 
| 154 | 
             
                         description="useful when you want to remove and object or something from the photo "
         | 
| 155 | 
             
                                     "from its description or location. "
         | 
| 156 | 
             
                                     "The input to this tool should be a comma separated string of two, "
         | 
| 157 | 
             
                                     "representing the image_path and the object need to be removed. ")
         | 
|  | |
| 158 | 
             
                def inference(self, image_path):
         | 
| 159 | 
             
                    inputs = self.processor(Image.open(image_path), return_tensors="pt").to(self.device)
         | 
| 160 | 
             
                    out = self.model.generate(**inputs)
         | 
| 161 | 
             
                    captions = self.processor.decode(out[0], skip_special_tokens=True)
         | 
| 162 | 
             
                    return captions
         | 
| 163 |  | 
| 164 | 
            +
             | 
| 165 | 
             
            class T2A:
         | 
| 166 | 
             
                def __init__(self, device):
         | 
| 167 | 
             
                    print("Initializing Make-An-Audio to %s" % device)
         | 
| 168 | 
             
                    self.device = device
         | 
| 169 | 
            +
                    self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/text-to-audio/txt2audio_args.yaml',
         | 
| 170 | 
            +
                                                    'text_to_audio/Make_An_Audio/useful_ckpts/ta40multi_epoch=000085.ckpt',
         | 
| 171 | 
            +
                                                    device=device)
         | 
| 172 | 
            +
                    self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w', device=device)
         | 
| 173 |  | 
| 174 | 
            +
                def txt2audio(self, text, seed=55, scale=1.5, ddim_steps=100, n_samples=3, W=624, H=80):
         | 
| 175 | 
             
                    SAMPLE_RATE = 16000
         | 
| 176 | 
             
                    prng = np.random.RandomState(seed)
         | 
| 177 | 
             
                    start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
         | 
| 178 | 
             
                    start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
         | 
| 179 | 
             
                    uc = self.sampler.model.get_learned_conditioning(n_samples * [""])
         | 
| 180 | 
             
                    c = self.sampler.model.get_learned_conditioning(n_samples * [text])
         | 
| 181 | 
            +
                    shape = [self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8]  # (z_dim, 80//2^x, 848//2^x)
         | 
| 182 | 
            +
                    samples_ddim, _ = self.sampler.sample(S=ddim_steps,
         | 
| 183 | 
            +
                                                          conditioning=c,
         | 
| 184 | 
            +
                                                          batch_size=n_samples,
         | 
| 185 | 
            +
                                                          shape=shape,
         | 
| 186 | 
            +
                                                          verbose=False,
         | 
| 187 | 
            +
                                                          unconditional_guidance_scale=scale,
         | 
| 188 | 
            +
                                                          unconditional_conditioning=uc,
         | 
| 189 | 
            +
                                                          x_T=start_code)
         | 
| 190 |  | 
| 191 | 
             
                    x_samples_ddim = self.sampler.model.decode_first_stage(samples_ddim)
         | 
| 192 | 
            +
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)  # [0, 1]
         | 
| 193 |  | 
| 194 | 
             
                    wav_list = []
         | 
| 195 | 
            +
                    for idx, spec in enumerate(x_samples_ddim):
         | 
| 196 | 
             
                        wav = self.vocoder.vocode(spec)
         | 
| 197 | 
            +
                        wav_list.append((SAMPLE_RATE, wav))
         | 
| 198 | 
             
                    best_wav = select_best_audio(text, wav_list)
         | 
| 199 | 
             
                    return best_wav
         | 
| 200 |  | 
|  | |
| 203 | 
             
                                     "from a user input text and it saved it to a file."
         | 
| 204 | 
             
                                     "The input to this tool should be a string, "
         | 
| 205 | 
             
                                     "representing the text used to generate audio.")
         | 
| 206 | 
            +
                def inference(self, text, seed=55, scale=1.5, ddim_steps=100, n_samples=3, W=624, H=80):
         | 
| 207 | 
            +
                    melbins, mel_len = 80, 624
         | 
|  | |
| 208 | 
             
                    with torch.no_grad():
         | 
| 209 | 
             
                        result = self.txt2audio(
         | 
| 210 | 
            +
                            text=text,
         | 
| 211 | 
            +
                            H=melbins,
         | 
| 212 | 
            +
                            W=mel_len
         | 
| 213 | 
             
                        )
         | 
| 214 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 215 | 
            +
                    soundfile.write(audio_filename, result[1], samplerate=16000)
         | 
| 216 | 
             
                    print(f"Processed T2I.run, text: {text}, audio_filename: {audio_filename}")
         | 
| 217 | 
             
                    return audio_filename
         | 
| 218 |  | 
| 219 | 
            +
             | 
| 220 | 
             
            class I2A:
         | 
| 221 | 
             
                def __init__(self, device):
         | 
| 222 | 
             
                    print("Initializing Make-An-Audio-Image to %s" % device)
         | 
| 223 | 
             
                    self.device = device
         | 
| 224 | 
            +
                    self.sampler = initialize_model('text_to_audio/Make_An_Audio/configs/img_to_audio/img2audio_args.yaml',
         | 
| 225 | 
            +
                                                    'text_to_audio/Make_An_Audio/useful_ckpts/ta54_epoch=000216.ckpt',
         | 
| 226 | 
            +
                                                    device=device)
         | 
| 227 | 
            +
                    self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w', device=device)
         | 
| 228 |  | 
| 229 | 
            +
                def img2audio(self, image, seed=55, scale=3, ddim_steps=100, W=624, H=80):
         | 
|  | |
| 230 | 
             
                    SAMPLE_RATE = 16000
         | 
| 231 | 
            +
                    n_samples = 1  # only support 1 sample
         | 
| 232 | 
             
                    prng = np.random.RandomState(seed)
         | 
| 233 | 
             
                    start_code = prng.randn(n_samples, self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8)
         | 
| 234 | 
             
                    start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
         | 
| 235 | 
             
                    uc = self.sampler.model.get_learned_conditioning(n_samples * [""])
         | 
| 236 | 
            +
                    # image = Image.fromarray(image)
         | 
| 237 | 
             
                    image = Image.open(image)
         | 
| 238 | 
             
                    image = self.sampler.model.cond_stage_model.preprocess(image).unsqueeze(0)
         | 
| 239 | 
             
                    image_embedding = self.sampler.model.cond_stage_model.forward_img(image)
         | 
| 240 | 
             
                    c = image_embedding.repeat(n_samples, 1, 1)
         | 
| 241 | 
            +
                    shape = [self.sampler.model.first_stage_model.embed_dim, H // 8, W // 8]  # (z_dim, 80//2^x, 848//2^x)
         | 
| 242 | 
             
                    samples_ddim, _ = self.sampler.sample(S=ddim_steps,
         | 
| 243 | 
            +
                                                          conditioning=c,
         | 
| 244 | 
            +
                                                          batch_size=n_samples,
         | 
| 245 | 
            +
                                                          shape=shape,
         | 
| 246 | 
            +
                                                          verbose=False,
         | 
| 247 | 
            +
                                                          unconditional_guidance_scale=scale,
         | 
| 248 | 
            +
                                                          unconditional_conditioning=uc,
         | 
| 249 | 
            +
                                                          x_T=start_code)
         | 
| 250 |  | 
| 251 | 
             
                    x_samples_ddim = self.sampler.model.decode_first_stage(samples_ddim)
         | 
| 252 | 
            +
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)  # [0, 1]
         | 
| 253 | 
             
                    wav_list = []
         | 
| 254 | 
            +
                    for idx, spec in enumerate(x_samples_ddim):
         | 
| 255 | 
             
                        wav = self.vocoder.vocode(spec)
         | 
| 256 | 
            +
                        wav_list.append((SAMPLE_RATE, wav))
         | 
| 257 | 
             
                    best_wav = wav_list[0]
         | 
| 258 | 
             
                    return best_wav
         | 
| 259 |  | 
|  | |
| 262 | 
             
                                     "based on an image. "
         | 
| 263 | 
             
                                     "The input to this tool should be a string, "
         | 
| 264 | 
             
                                     "representing the image_path. ")
         | 
| 265 | 
            +
                def inference(self, image, seed=55, scale=3, ddim_steps=100, W=624, H=80):
         | 
| 266 | 
            +
                    melbins, mel_len = 80, 624
         | 
|  | |
| 267 | 
             
                    with torch.no_grad():
         | 
| 268 | 
             
                        result = self.img2audio(
         | 
| 269 | 
             
                            image=image,
         | 
| 270 | 
            +
                            H=melbins,
         | 
| 271 | 
             
                            W=mel_len
         | 
| 272 | 
             
                        )
         | 
| 273 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 274 | 
            +
                    soundfile.write(audio_filename, result[1], samplerate=16000)
         | 
| 275 | 
             
                    print(f"Processed I2a.run, image_filename: {image}, audio_filename: {audio_filename}")
         | 
| 276 | 
             
                    return audio_filename
         | 
| 277 |  | 
| 278 | 
            +
             | 
| 279 | 
             
            class TTS:
         | 
| 280 | 
             
                def __init__(self, device=None):
         | 
| 281 | 
             
                    self.model = TTSInference(device)
         | 
| 282 | 
            +
             | 
| 283 | 
             
                @prompts(name="Synthesize Speech Given the User Input Text",
         | 
| 284 | 
             
                         description="useful for when you want to convert a user input text into speech audio it saved it to a file."
         | 
| 285 | 
             
                                     "The input to this tool should be a string, "
         | 
| 286 | 
             
                                     "representing the text used to be converted to speech.")
         | 
|  | |
| 287 | 
             
                def inference(self, text):
         | 
| 288 | 
             
                    inp = {"text": text}
         | 
| 289 | 
             
                    out = self.model.infer_once(inp)
         | 
| 290 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 291 | 
            +
                    soundfile.write(audio_filename, out, samplerate=22050)
         | 
| 292 | 
             
                    return audio_filename
         | 
| 293 |  | 
| 294 | 
            +
             | 
| 295 | 
             
            class T2S:
         | 
| 296 | 
            +
                def __init__(self, device=None):
         | 
| 297 | 
             
                    if device is None:
         | 
| 298 | 
             
                        device = 'cuda' if torch.cuda.is_available() else 'cpu'
         | 
| 299 | 
             
                    print("Initializing DiffSinger to %s" % device)
         | 
| 300 | 
             
                    self.device = device
         | 
| 301 | 
             
                    self.exp_name = 'checkpoints/0831_opencpop_ds1000'
         | 
| 302 | 
            +
                    self.config = 'NeuralSeq/egs/egs_bases/svs/midi/e2e/opencpop/ds1000.yaml'
         | 
| 303 | 
             
                    self.set_model_hparams()
         | 
| 304 | 
             
                    self.pipe = DiffSingerE2EInfer(self.hp, device)
         | 
| 305 | 
             
                    self.default_inp = {
         | 
|  | |
| 308 | 
             
                        'notes_duration': '0.113740 | 0.329060 | 0.287950 | 0.133480 | 0.150900 | 0.484730 | 0.242010 | 0.180820 | 0.343570 | 0.152050 | 0.266720 | 0.280310 | 0.633300 | 0.444590'
         | 
| 309 | 
             
                    }
         | 
| 310 |  | 
|  | |
| 311 | 
             
                def set_model_hparams(self):
         | 
| 312 | 
             
                    set_hparams(config=self.config, exp_name=self.exp_name, print_hparams=False)
         | 
| 313 | 
             
                    self.hp = hp
         | 
|  | |
| 320 | 
             
                                     "Or Like: Generate a piece of singing voice. Text is xxx, note is xxx, duration is xxx."
         | 
| 321 | 
             
                                     "The input to this tool should be a comma seperated string of three, "
         | 
| 322 | 
             
                                     "representing text, note and duration sequence since User Input Text, Note and Duration Sequence are all provided. ")
         | 
|  | |
| 323 | 
             
                def inference(self, inputs):
         | 
| 324 | 
             
                    self.set_model_hparams()
         | 
| 325 | 
             
                    val = inputs.split(",")
         | 
|  | |
| 337 | 
             
                    print(f"Processed T2S.run, audio_filename: {audio_filename}")
         | 
| 338 | 
             
                    return audio_filename
         | 
| 339 |  | 
| 340 | 
            +
             | 
| 341 | 
             
            class TTS_OOD:
         | 
| 342 | 
             
                def __init__(self, device):
         | 
| 343 | 
             
                    if device is None:
         | 
|  | |
| 364 | 
             
                                     "(e.g., timbre, emotion, and prosody) derived from a reference custom voice. "
         | 
| 365 | 
             
                                     "Like: Generate a speech with style transferred from this voice. The text is xxx., or speak using the voice of this audio. The text is xxx."
         | 
| 366 | 
             
                                     "The input to this tool should be a comma seperated string of two, "
         | 
| 367 | 
            +
                                     "representing reference audio path and input text. ")
         | 
|  | |
| 368 | 
             
                def inference(self, inputs):
         | 
| 369 | 
             
                    self.set_model_hparams()
         | 
| 370 | 
             
                    key = ['ref_audio', 'text']
         | 
|  | |
| 377 | 
             
                    print(
         | 
| 378 | 
             
                        f"Processed GenerSpeech.run. Input text:{val[1]}. Input reference audio: {val[0]}. Output Audio_filename: {audio_filename}")
         | 
| 379 | 
             
                    return audio_filename
         | 
| 380 | 
            +
             | 
| 381 | 
            +
             | 
| 382 | 
             
            class Inpaint:
         | 
| 383 | 
             
                def __init__(self, device):
         | 
| 384 | 
             
                    print("Initializing Make-An-Audio-inpaint to %s" % device)
         | 
| 385 | 
             
                    self.device = device
         | 
| 386 | 
            +
                    self.sampler = initialize_model_inpaint('text_to_audio/Make_An_Audio/configs/inpaint/txt2audio_args.yaml',
         | 
| 387 | 
            +
                                                            'text_to_audio/Make_An_Audio/useful_ckpts/inpaint7_epoch00047.ckpt')
         | 
| 388 | 
            +
                    self.vocoder = VocoderBigVGAN('text_to_audio/Make_An_Audio/vocoder/logs/bigv16k53w', device=device)
         | 
| 389 | 
             
                    self.cmap_transform = matplotlib.cm.viridis
         | 
| 390 |  | 
| 391 | 
             
                def make_batch_sd(self, mel, mask, num_samples=1):
         | 
| 392 |  | 
| 393 | 
            +
                    mel = torch.from_numpy(mel)[None, None, ...].to(dtype=torch.float32)
         | 
| 394 | 
            +
                    mask = torch.from_numpy(mask)[None, None, ...].to(dtype=torch.float32)
         | 
| 395 | 
             
                    masked_mel = (1 - mask) * mel
         | 
| 396 |  | 
| 397 | 
             
                    mel = mel * 2 - 1
         | 
| 398 | 
             
                    mask = mask * 2 - 1
         | 
| 399 | 
            +
                    masked_mel = masked_mel * 2 - 1
         | 
| 400 |  | 
| 401 | 
             
                    batch = {
         | 
| 402 | 
            +
                        "mel": repeat(mel.to(device=self.device), "1 ... -> n ...", n=num_samples),
         | 
| 403 | 
            +
                        "mask": repeat(mask.to(device=self.device), "1 ... -> n ...", n=num_samples),
         | 
| 404 | 
            +
                        "masked_mel": repeat(masked_mel.to(device=self.device), "1 ... -> n ...", n=num_samples),
         | 
| 405 | 
             
                    }
         | 
| 406 | 
             
                    return batch
         | 
| 407 | 
            +
             | 
| 408 | 
             
                def gen_mel(self, input_audio_path):
         | 
| 409 | 
             
                    SAMPLE_RATE = 16000
         | 
| 410 | 
             
                    sr, ori_wav = wavfile.read(input_audio_path)
         | 
| 411 | 
             
                    print("gen_mel")
         | 
| 412 | 
            +
                    print(sr, ori_wav.shape, ori_wav)
         | 
| 413 | 
             
                    ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0
         | 
| 414 | 
            +
                    if len(ori_wav.shape) == 2:  # stereo
         | 
| 415 | 
            +
                        ori_wav = librosa.to_mono(
         | 
| 416 | 
            +
                            ori_wav.T)  # gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
         | 
| 417 | 
            +
                    print(sr, ori_wav.shape, ori_wav)
         | 
| 418 | 
            +
                    ori_wav = librosa.resample(ori_wav, orig_sr=sr, target_sr=SAMPLE_RATE)
         | 
| 419 |  | 
| 420 | 
            +
                    mel_len, hop_size = 848, 256
         | 
| 421 | 
             
                    input_len = mel_len * hop_size
         | 
| 422 | 
             
                    if len(ori_wav) < input_len:
         | 
| 423 | 
            +
                        input_wav = np.pad(ori_wav, (0, mel_len * hop_size), constant_values=0)
         | 
| 424 | 
             
                    else:
         | 
| 425 | 
             
                        input_wav = ori_wav[:input_len]
         | 
| 426 | 
            +
             | 
| 427 | 
             
                    mel = TRANSFORMS_16000(input_wav)
         | 
| 428 | 
             
                    return mel
         | 
| 429 | 
            +
             | 
| 430 | 
             
                def gen_mel_audio(self, input_audio):
         | 
| 431 | 
             
                    SAMPLE_RATE = 16000
         | 
| 432 | 
            +
                    sr, ori_wav = input_audio
         | 
| 433 | 
             
                    print("gen_mel_audio")
         | 
| 434 | 
            +
                    print(sr, ori_wav.shape, ori_wav)
         | 
| 435 |  | 
| 436 | 
             
                    ori_wav = ori_wav.astype(np.float32, order='C') / 32768.0
         | 
| 437 | 
            +
                    if len(ori_wav.shape) == 2:  # stereo
         | 
| 438 | 
            +
                        ori_wav = librosa.to_mono(
         | 
| 439 | 
            +
                            ori_wav.T)  # gradio load wav shape could be (wav_len,2) but librosa expects (2,wav_len)
         | 
| 440 | 
            +
                    print(sr, ori_wav.shape, ori_wav)
         | 
| 441 | 
            +
                    ori_wav = librosa.resample(ori_wav, orig_sr=sr, target_sr=SAMPLE_RATE)
         | 
| 442 |  | 
| 443 | 
            +
                    mel_len, hop_size = 848, 256
         | 
| 444 | 
             
                    input_len = mel_len * hop_size
         | 
| 445 | 
             
                    if len(ori_wav) < input_len:
         | 
| 446 | 
            +
                        input_wav = np.pad(ori_wav, (0, mel_len * hop_size), constant_values=0)
         | 
| 447 | 
             
                    else:
         | 
| 448 | 
             
                        input_wav = ori_wav[:input_len]
         | 
| 449 | 
             
                    mel = TRANSFORMS_16000(input_wav)
         | 
| 450 | 
             
                    return mel
         | 
| 451 | 
            +
             | 
| 452 | 
             
                def inpaint(self, batch, seed, ddim_steps, num_samples=1, W=512, H=512):
         | 
| 453 | 
             
                    model = self.sampler.model
         | 
| 454 | 
            +
             | 
| 455 | 
             
                    prng = np.random.RandomState(seed)
         | 
| 456 | 
             
                    start_code = prng.randn(num_samples, model.first_stage_model.embed_dim, H // 8, W // 8)
         | 
| 457 | 
             
                    start_code = torch.from_numpy(start_code).to(device=self.device, dtype=torch.float32)
         | 
| 458 |  | 
| 459 | 
             
                    c = model.get_first_stage_encoding(model.encode_first_stage(batch["masked_mel"]))
         | 
| 460 | 
             
                    cc = torch.nn.functional.interpolate(batch["mask"],
         | 
| 461 | 
            +
                                                         size=c.shape[-2:])
         | 
| 462 | 
            +
                    c = torch.cat((c, cc), dim=1)  # (b,c+1,h,w) 1 is mask
         | 
| 463 |  | 
| 464 | 
            +
                    shape = (c.shape[1] - 1,) + c.shape[2:]
         | 
| 465 | 
             
                    samples_ddim, _ = self.sampler.sample(S=ddim_steps,
         | 
| 466 | 
            +
                                                          conditioning=c,
         | 
| 467 | 
            +
                                                          batch_size=c.shape[0],
         | 
| 468 | 
            +
                                                          shape=shape,
         | 
| 469 | 
            +
                                                          verbose=False)
         | 
| 470 | 
             
                    x_samples_ddim = model.decode_first_stage(samples_ddim)
         | 
| 471 |  | 
| 472 | 
            +
                    mask = batch["mask"]  # [-1,1]
         | 
| 473 | 
            +
                    mel = torch.clamp((batch["mel"] + 1.0) / 2.0, min=0.0, max=1.0)
         | 
| 474 | 
            +
                    mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
         | 
| 475 | 
            +
                    predicted_mel = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
         | 
| 476 | 
            +
                    inpainted = (1 - mask) * mel + mask * predicted_mel
         | 
|  | |
| 477 | 
             
                    inpainted = inpainted.cpu().numpy().squeeze()
         | 
| 478 | 
             
                    inapint_wav = self.vocoder.vocode(inpainted)
         | 
| 479 |  | 
| 480 | 
             
                    return inpainted, inapint_wav
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                def predict(self, input_audio, mel_and_mask, seed=55, ddim_steps=100):
         | 
| 483 | 
             
                    SAMPLE_RATE = 16000
         | 
| 484 | 
             
                    torch.set_grad_enabled(False)
         | 
| 485 | 
             
                    mel_img = Image.open(mel_and_mask['image'])
         | 
| 486 | 
             
                    mask_img = Image.open(mel_and_mask["mask"])
         | 
| 487 | 
            +
                    show_mel = np.array(mel_img.convert("L")) / 255
         | 
| 488 | 
            +
                    mask = np.array(mask_img.convert("L")) / 255
         | 
| 489 | 
            +
                    mel_bins, mel_len = 80, 848
         | 
| 490 | 
            +
                    input_mel = self.gen_mel_audio(input_audio)[:, :mel_len]
         | 
| 491 | 
            +
                    mask = np.pad(mask, ((0, 0), (0, mel_len - mask.shape[1])), mode='constant', constant_values=0)
         | 
| 492 | 
            +
                    print(mask.shape, input_mel.shape)
         | 
| 493 | 
             
                    with torch.no_grad():
         | 
| 494 | 
            +
                        batch = self.make_batch_sd(input_mel, mask, num_samples=1)
         | 
| 495 | 
            +
                        inpainted, gen_wav = self.inpaint(
         | 
| 496 | 
             
                            batch=batch,
         | 
| 497 | 
             
                            seed=seed,
         | 
| 498 | 
             
                            ddim_steps=ddim_steps,
         | 
| 499 | 
             
                            num_samples=1,
         | 
| 500 | 
             
                            H=mel_bins, W=mel_len
         | 
| 501 | 
             
                        )
         | 
| 502 | 
            +
                    inpainted = inpainted[:, :show_mel.shape[1]]
         | 
| 503 | 
             
                    color_mel = self.cmap_transform(inpainted)
         | 
| 504 | 
             
                    input_len = int(input_audio[1].shape[0] * SAMPLE_RATE / input_audio[0])
         | 
| 505 | 
             
                    gen_wav = (gen_wav * 32768).astype(np.int16)[:input_len]
         | 
| 506 | 
            +
                    image = Image.fromarray((color_mel * 255).astype(np.uint8))
         | 
| 507 | 
             
                    image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
         | 
| 508 | 
             
                    image.save(image_filename)
         | 
| 509 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 510 | 
            +
                    soundfile.write(audio_filename, gen_wav, samplerate=16000)
         | 
| 511 | 
             
                    return image_filename, audio_filename
         | 
| 512 |  | 
| 513 | 
             
                @prompts(name="Audio Inpainting",
         | 
| 514 | 
             
                         description="useful for when you want to inpaint a mel spectrum of an audio and predict this audio, "
         | 
| 515 | 
             
                                     "this tool will generate a mel spectrum and you can inpaint it, receives audio_path as input. "
         | 
| 516 | 
             
                                     "The input to this tool should be a string, "
         | 
| 517 | 
            +
                                     "representing the audio_path. ")
         | 
|  | |
| 518 | 
             
                def inference(self, input_audio_path):
         | 
| 519 | 
             
                    crop_len = 500
         | 
| 520 | 
            +
                    crop_mel = self.gen_mel(input_audio_path)[:, :crop_len]
         | 
| 521 | 
             
                    color_mel = self.cmap_transform(crop_mel)
         | 
| 522 | 
            +
                    image = Image.fromarray((color_mel * 255).astype(np.uint8))
         | 
| 523 | 
             
                    image_filename = os.path.join('image', str(uuid.uuid4())[0:8] + ".png")
         | 
| 524 | 
             
                    image.save(image_filename)
         | 
| 525 | 
             
                    return image_filename
         | 
| 526 | 
            +
             | 
| 527 | 
            +
             | 
| 528 | 
             
            class ASR:
         | 
| 529 | 
             
                def __init__(self, device):
         | 
| 530 | 
             
                    print("Initializing Whisper to %s" % device)
         | 
|  | |
| 535 | 
             
                         description="useful for when you want to know the text corresponding to a human speech, "
         | 
| 536 | 
             
                                     "receives audio_path as input. "
         | 
| 537 | 
             
                                     "The input to this tool should be a string, "
         | 
| 538 | 
            +
                                     "representing the audio_path. ")
         | 
|  | |
| 539 | 
             
                def inference(self, audio_path):
         | 
| 540 | 
             
                    audio = whisper.load_audio(audio_path)
         | 
| 541 | 
             
                    audio = whisper.pad_or_trim(audio)
         | 
|  | |
| 545 | 
             
                    result = whisper.decode(self.model, mel, options)
         | 
| 546 | 
             
                    return result.text
         | 
| 547 |  | 
| 548 | 
            +
                def translate_english(self, audio_path):
         | 
| 549 | 
            +
                    audio = self.model.transcribe(audio_path, language='English')
         | 
| 550 | 
            +
                    return audio['text']
         | 
| 551 | 
            +
             | 
| 552 | 
            +
             | 
| 553 | 
             
            class A2T:
         | 
| 554 | 
             
                def __init__(self, device):
         | 
| 555 | 
             
                    print("Initializing Audio-To-Text Model to %s" % device)
         | 
|  | |
| 560 | 
             
                         description="useful for when you want to describe an audio in text, "
         | 
| 561 | 
             
                                     "receives audio_path as input. "
         | 
| 562 | 
             
                                     "The input to this tool should be a string, "
         | 
| 563 | 
            +
                                     "representing the audio_path. ")
         | 
|  | |
| 564 | 
             
                def inference(self, audio_path):
         | 
| 565 | 
             
                    audio = whisper.load_audio(audio_path)
         | 
| 566 | 
             
                    caption_text = self.model(audio)
         | 
| 567 | 
             
                    return caption_text[0]
         | 
| 568 |  | 
| 569 | 
            +
             | 
| 570 | 
             
            class SoundDetection:
         | 
| 571 | 
             
                def __init__(self, device):
         | 
| 572 | 
             
                    self.device = device
         | 
|  | |
| 582 | 
             
                    self.labels = detection_config.labels
         | 
| 583 | 
             
                    self.frames_per_second = self.sample_rate // self.hop_size
         | 
| 584 | 
             
                    # Model = eval(self.model_type)
         | 
| 585 | 
            +
                    self.model = PVT(sample_rate=self.sample_rate, window_size=self.window_size,
         | 
| 586 | 
            +
                                     hop_size=self.hop_size, mel_bins=self.mel_bins, fmin=self.fmin, fmax=self.fmax,
         | 
| 587 | 
            +
                                     classes_num=self.classes_num)
         | 
| 588 | 
             
                    checkpoint = torch.load(self.checkpoint_path, map_location=self.device)
         | 
| 589 | 
             
                    self.model.load_state_dict(checkpoint['model'])
         | 
| 590 | 
             
                    self.model.to(device)
         | 
|  | |
| 593 | 
             
                         description="useful for when you want to know what event in the audio and the sound event start or end time, it will return an image "
         | 
| 594 | 
             
                                     "receives audio_path as input. "
         | 
| 595 | 
             
                                     "The input to this tool should be a string, "
         | 
| 596 | 
            +
                                     "representing the audio_path. ")
         | 
|  | |
| 597 | 
             
                def inference(self, audio_path):
         | 
| 598 | 
             
                    # Forward
         | 
| 599 | 
             
                    (waveform, _) = librosa.core.load(audio_path, sr=self.sample_rate, mono=True)
         | 
| 600 | 
            +
                    waveform = waveform[None, :]  # (1, audio_length)
         | 
| 601 | 
             
                    waveform = torch.from_numpy(waveform)
         | 
| 602 | 
             
                    waveform = waveform.to(self.device)
         | 
| 603 | 
             
                    # Forward
         | 
|  | |
| 612 | 
             
                    import matplotlib.pyplot as plt
         | 
| 613 | 
             
                    sorted_indexes = np.argsort(np.max(framewise_output, axis=0))[::-1]
         | 
| 614 | 
             
                    top_k = 10  # Show top results
         | 
| 615 | 
            +
                    top_result_mat = framewise_output[:, sorted_indexes[0: top_k]]
         | 
| 616 | 
             
                    """(time_steps, top_k)"""
         | 
| 617 | 
            +
                    # Plot result
         | 
| 618 | 
            +
                    stft = librosa.core.stft(y=waveform[0].data.cpu().numpy(), n_fft=self.window_size,
         | 
| 619 | 
            +
                                             hop_length=self.hop_size, window='hann', center=True)
         | 
| 620 | 
             
                    frames_num = stft.shape[-1]
         | 
| 621 | 
             
                    fig, axs = plt.subplots(2, 1, sharex=True, figsize=(10, 4))
         | 
| 622 | 
             
                    axs[0].matshow(np.log(np.abs(stft)), origin='lower', aspect='auto', cmap='jet')
         | 
|  | |
| 626 | 
             
                    axs[1].xaxis.set_ticks(np.arange(0, frames_num, self.frames_per_second))
         | 
| 627 | 
             
                    axs[1].xaxis.set_ticklabels(np.arange(0, frames_num / self.frames_per_second))
         | 
| 628 | 
             
                    axs[1].yaxis.set_ticks(np.arange(0, top_k))
         | 
| 629 | 
            +
                    axs[1].yaxis.set_ticklabels(np.array(self.labels)[sorted_indexes[0: top_k]])
         | 
| 630 | 
             
                    axs[1].yaxis.grid(color='k', linestyle='solid', linewidth=0.3, alpha=0.3)
         | 
| 631 | 
             
                    axs[1].set_xlabel('Seconds')
         | 
| 632 | 
             
                    axs[1].xaxis.set_ticks_position('bottom')
         | 
|  | |
| 635 | 
             
                    plt.savefig(image_filename)
         | 
| 636 | 
             
                    return image_filename
         | 
| 637 |  | 
| 638 | 
            +
             | 
| 639 | 
             
            class SoundExtraction:
         | 
| 640 | 
             
                def __init__(self, device):
         | 
| 641 | 
             
                    self.device = device
         | 
|  | |
| 651 | 
             
                         description="useful for when you extract target sound from a mixture audio, you can describe the target sound by text, "
         | 
| 652 | 
             
                                     "receives audio_path and text as input. "
         | 
| 653 | 
             
                                     "The input to this tool should be a comma seperated string of two, "
         | 
| 654 | 
            +
                                     "representing mixture audio path and input text.")
         | 
|  | |
| 655 | 
             
                def inference(self, inputs):
         | 
| 656 | 
            +
                    # key = ['ref_audio', 'text']
         | 
| 657 | 
             
                    val = inputs.split(",")
         | 
| 658 | 
            +
                    audio_path = val[0]  # audio_path, text
         | 
| 659 | 
             
                    text = val[1]
         | 
| 660 | 
             
                    waveform = load_wav(audio_path)
         | 
| 661 | 
            +
                    waveform = torch.tensor(waveform).transpose(1, 0)
         | 
| 662 | 
             
                    mixed_mag, mixed_phase = self.stft.transform(waveform)
         | 
| 663 | 
             
                    text_query = ['[CLS] ' + text]
         | 
| 664 | 
            +
                    mixed_mag = mixed_mag.transpose(2, 1).unsqueeze(0).to(self.device)
         | 
| 665 | 
             
                    est_mask = self.model(mixed_mag, text_query)
         | 
| 666 | 
            +
                    est_mag = est_mask * mixed_mag
         | 
| 667 | 
            +
                    est_mag = est_mag.squeeze(1)
         | 
| 668 | 
            +
                    est_mag = est_mag.permute(0, 2, 1)
         | 
| 669 | 
             
                    est_wav = self.stft.inverse(est_mag.cpu().detach(), mixed_phase)
         | 
| 670 | 
            +
                    est_wav = est_wav.squeeze(0).squeeze(0).numpy()
         | 
| 671 | 
            +
                    # est_path = f'output/est{i}.wav'
         | 
| 672 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 673 | 
             
                    print('audio_filename ', audio_filename)
         | 
| 674 | 
             
                    save_wav(est_wav, audio_filename)
         | 
|  | |
| 685 | 
             
                                          'mono2binaural/useful_ckpts/m2b/tx_positions4.txt',
         | 
| 686 | 
             
                                          'mono2binaural/useful_ckpts/m2b/tx_positions5.txt']
         | 
| 687 | 
             
                    self.net = BinauralNetwork(view_dim=7,
         | 
| 688 | 
            +
                                               warpnet_layers=4,
         | 
| 689 | 
            +
                                               warpnet_channels=64,
         | 
| 690 | 
            +
                                               )
         | 
| 691 | 
             
                    self.net.load_from_file(self.model_file)
         | 
| 692 | 
             
                    self.sr = 48000
         | 
| 693 |  | 
|  | |
| 695 | 
             
                         description="useful for when you want to transfer your mono audio into binaural audio, "
         | 
| 696 | 
             
                                     "receives audio_path as input. "
         | 
| 697 | 
             
                                     "The input to this tool should be a string, "
         | 
| 698 | 
            +
                                     "representing the audio_path. ")
         | 
|  | |
| 699 | 
             
                def inference(self, audio_path):
         | 
| 700 | 
            +
                    mono, sr = librosa.load(path=audio_path, sr=self.sr, mono=True)
         | 
| 701 | 
             
                    mono = torch.from_numpy(mono)
         | 
| 702 | 
             
                    mono = mono.unsqueeze(0)
         | 
| 703 | 
             
                    import numpy as np
         | 
| 704 | 
             
                    import random
         | 
| 705 | 
            +
                    rand_int = random.randint(0, 4)
         | 
| 706 | 
             
                    view = np.loadtxt(self.position_file[rand_int]).transpose().astype(np.float32)
         | 
| 707 | 
             
                    view = torch.from_numpy(view)
         | 
| 708 | 
             
                    if not view.shape[-1] * 400 == mono.shape[-1]:
         | 
| 709 | 
            +
                        mono = mono[:, :(mono.shape[-1] // 400) * 400]  #
         | 
| 710 | 
            +
                        if view.shape[1] * 400 > mono.shape[1]:
         | 
| 711 | 
            +
                            m_a = view.shape[1] - mono.shape[-1] // 400
         | 
| 712 | 
            +
                            rand_st = random.randint(0, m_a)
         | 
| 713 | 
            +
                            view = view[:, m_a:m_a + (mono.shape[-1] // 400)]  #
         | 
| 714 | 
             
                    # binauralize and save output
         | 
| 715 | 
             
                    self.net.eval().to(self.device)
         | 
| 716 | 
             
                    mono, view = mono.to(self.device), view.to(self.device)
         | 
| 717 | 
             
                    chunk_size = 48000  # forward in chunks of 1s
         | 
| 718 | 
            +
                    rec_field = 1000  # add 1000 samples as "safe bet" since warping has undefined rec. field
         | 
| 719 | 
             
                    rec_field -= rec_field % 400  # make sure rec_field is a multiple of 400 to match audio and view frequencies
         | 
| 720 | 
             
                    chunks = [
         | 
| 721 | 
             
                        {
         | 
| 722 | 
            +
                            "mono": mono[:, max(0, i - rec_field):i + chunk_size],
         | 
| 723 | 
            +
                            "view": view[:, max(0, i - rec_field) // 400:(i + chunk_size) // 400]
         | 
| 724 | 
             
                        }
         | 
| 725 | 
             
                        for i in range(0, mono.shape[-1], chunk_size)
         | 
| 726 | 
             
                    ]
         | 
|  | |
| 730 | 
             
                            view = chunk["view"].unsqueeze(0)
         | 
| 731 | 
             
                            binaural = self.net(mono, view).squeeze(0)
         | 
| 732 | 
             
                            if i > 0:
         | 
| 733 | 
            +
                                binaural = binaural[:, -(mono.shape[-1] - rec_field):]
         | 
| 734 | 
             
                            chunk["binaural"] = binaural
         | 
| 735 | 
             
                    binaural = torch.cat([chunk["binaural"] for chunk in chunks], dim=-1)
         | 
| 736 | 
             
                    binaural = torch.clamp(binaural, min=-1, max=1).cpu()
         | 
| 737 | 
            +
                    # binaural = chunked_forwarding(net, mono, view)
         | 
| 738 | 
             
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 739 | 
             
                    import torchaudio
         | 
| 740 | 
             
                    torchaudio.save(audio_filename, binaural, sr)
         | 
| 741 | 
            +
                    # soundfile.write(audio_filename, binaural, samplerate = 48000)
         | 
| 742 | 
             
                    print(f"Processed Binaural.run, audio_filename: {audio_filename}")
         | 
| 743 | 
             
                    return audio_filename
         | 
| 744 |  | 
| 745 | 
            +
             | 
| 746 | 
             
            class TargetSoundDetection:
         | 
| 747 | 
             
                def __init__(self, device):
         | 
| 748 | 
             
                    self.device = device
         | 
|  | |
| 755 | 
             
                    self.EPS = np.spacing(1)
         | 
| 756 | 
             
                    self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
         | 
| 757 | 
             
                    self.event_labels = event_labels
         | 
| 758 | 
            +
                    self.id_to_event = {i: label for i, label in enumerate(self.event_labels)}
         | 
| 759 | 
            +
                    config = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/run_config.pth',
         | 
| 760 | 
            +
                                        map_location='cpu')
         | 
| 761 | 
             
                    config_parameters = dict(config)
         | 
| 762 | 
             
                    config_parameters['tao'] = 0.6
         | 
| 763 | 
             
                    if 'thres' not in config_parameters.keys():
         | 
| 764 | 
             
                        config_parameters['thres'] = 0.5
         | 
| 765 | 
             
                    if 'time_resolution' not in config_parameters.keys():
         | 
| 766 | 
             
                        config_parameters['time_resolution'] = 125
         | 
| 767 | 
            +
                    model_parameters = torch.load(
         | 
| 768 | 
            +
                        'audio_detection/target_sound_detection/useful_ckpts/tsd/run_model_7_loss=-0.0724.pt'
         | 
| 769 | 
            +
                        , map_location=lambda storage, loc: storage)  # load parameter
         | 
| 770 | 
             
                    self.model = getattr(tsd_models, config_parameters['model'])(config_parameters,
         | 
| 771 | 
            +
                                                                                 inputdim=64, outputdim=2,
         | 
| 772 | 
            +
                                                                                 time_resolution=config_parameters[
         | 
| 773 | 
            +
                                                                                     'time_resolution'],
         | 
| 774 | 
            +
                                                                                 **config_parameters['model_args'])
         | 
| 775 | 
             
                    self.model.load_state_dict(model_parameters)
         | 
| 776 | 
             
                    self.model = self.model.to(self.device).eval()
         | 
| 777 | 
             
                    self.re_embeds = torch.load('audio_detection/target_sound_detection/useful_ckpts/tsd/text_emb.pth')
         | 
|  | |
| 781 | 
             
                    import soundfile as sf
         | 
| 782 | 
             
                    y, sr = sf.read(fname, dtype='float32')
         | 
| 783 | 
             
                    print('y ', y.shape)
         | 
| 784 | 
            +
                    ti = y.shape[0] / sr
         | 
| 785 | 
             
                    if y.ndim > 1:
         | 
| 786 | 
             
                        y = y.mean(1)
         | 
| 787 | 
             
                    y = librosa.resample(y, sr, 22050)
         | 
| 788 | 
             
                    lms_feature = np.log(librosa.feature.melspectrogram(y, **self.MEL_ARGS) + self.EPS).T
         | 
| 789 | 
            +
                    return lms_feature, ti
         | 
| 790 | 
            +
             | 
| 791 | 
             
                def build_clip(self, text):
         | 
| 792 | 
            +
                    text = clip.tokenize(text).to(self.device)  # ["a diagram with dog", "a dog", "a cat"]
         | 
| 793 | 
             
                    text_features = self.clip_model.encode_text(text)
         | 
| 794 | 
             
                    return text_features
         | 
| 795 | 
            +
             | 
| 796 | 
             
                def cal_similarity(self, target, retrievals):
         | 
| 797 | 
             
                    ans = []
         | 
| 798 | 
             
                    for name in retrievals.keys():
         | 
|  | |
| 805 | 
             
                         description="useful for when you want to know when the target sound event in the audio happens. You can use language descriptions to instruct the model, "
         | 
| 806 | 
             
                                     "receives text description and audio_path as input. "
         | 
| 807 | 
             
                                     "The input to this tool should be a comma seperated string of two, "
         | 
| 808 | 
            +
                                     "representing audio path and the text description. ")
         | 
|  | |
| 809 | 
             
                def inference(self, inputs):
         | 
| 810 | 
             
                    audio_path, text = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
         | 
| 811 | 
            +
                    target_emb = self.build_clip(text)  # torch type
         | 
| 812 | 
             
                    idx = self.cal_similarity(target_emb, self.re_embeds)
         | 
| 813 | 
             
                    target_event = self.id_to_event[idx]
         | 
| 814 | 
             
                    embedding = self.ref_mel[target_event]
         | 
| 815 | 
             
                    embedding = torch.from_numpy(embedding)
         | 
| 816 | 
             
                    embedding = embedding.unsqueeze(0).to(self.device).float()
         | 
| 817 | 
            +
                    inputs, ti = self.extract_feature(audio_path)
         | 
| 818 | 
             
                    inputs = torch.from_numpy(inputs)
         | 
| 819 | 
             
                    inputs = inputs.unsqueeze(0).to(self.device).float()
         | 
| 820 | 
             
                    decision, decision_up, logit = self.model(inputs, embedding)
         | 
| 821 | 
             
                    pred = decision_up.detach().cpu().numpy()
         | 
| 822 | 
            +
                    pred = pred[:, :, 0]
         | 
| 823 | 
             
                    frame_num = decision_up.shape[1]
         | 
| 824 | 
             
                    time_ratio = ti / frame_num
         | 
| 825 | 
             
                    filtered_pred = median_filter(pred, window_size=1, threshold=0.5)
         | 
| 826 | 
             
                    time_predictions = []
         | 
| 827 | 
             
                    for index_k in range(filtered_pred.shape[0]):
         | 
| 828 | 
             
                        decoded_pred = []
         | 
| 829 | 
            +
                        decoded_pred_ = decode_with_timestamps(target_event, filtered_pred[index_k, :])
         | 
| 830 | 
            +
                        if len(decoded_pred_) == 0:  # neg deal
         | 
| 831 | 
             
                            decoded_pred_.append((target_event, 0, 0))
         | 
| 832 | 
             
                        decoded_pred.append(decoded_pred_)
         | 
| 833 | 
            +
                        for num_batch in range(len(decoded_pred)):  # when we test our model,the batch_size is 1
         | 
| 834 | 
             
                            cur_pred = pred[num_batch]
         | 
| 835 | 
             
                            # Save each frame output, for later visualization
         | 
| 836 | 
            +
                            label_prediction = decoded_pred[num_batch]  # frame predict
         | 
| 837 | 
             
                            for event_label, onset, offset in label_prediction:
         | 
| 838 | 
             
                                time_predictions.append({
         | 
| 839 | 
            +
                                    'onset': onset * time_ratio,
         | 
| 840 | 
            +
                                    'offset': offset * time_ratio, })
         | 
| 841 | 
             
                    ans = ''
         | 
| 842 | 
            +
                    for i, item in enumerate(time_predictions):
         | 
| 843 | 
            +
                        ans = ans + 'segment' + str(i + 1) + ' start_time: ' + str(item['onset']) + '  end_time: ' + str(
         | 
| 844 | 
            +
                            item['offset']) + '\t'
         | 
| 845 | 
            +
                    return ans
         | 
| 846 | 
            +
             | 
| 847 | 
            +
             | 
| 848 | 
            +
            class Speech_Enh_SC:
         | 
| 849 | 
            +
                """Speech Enhancement or Separation in single-channel
         | 
| 850 | 
            +
                Example usage:
         | 
| 851 | 
            +
                    enh_model = Speech_Enh_SS("cuda")
         | 
| 852 | 
            +
                    enh_wav = enh_model.inference("./test_chime4_audio_M05_440C0213_PED_REAL.wav")
         | 
| 853 | 
            +
                """
         | 
| 854 | 
            +
             | 
| 855 | 
            +
                def __init__(self, device="cuda", model_name="espnet/Wangyou_Zhang_chime4_enh_train_enh_conv_tasnet_raw"):
         | 
| 856 | 
            +
                    self.model_name = model_name
         | 
| 857 | 
            +
                    self.device = device
         | 
| 858 | 
            +
                    print("Initializing ESPnet Enh to %s" % device)
         | 
| 859 | 
            +
                    self._initialize_model()
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                def _initialize_model(self):
         | 
| 862 | 
            +
                    from espnet_model_zoo.downloader import ModelDownloader
         | 
| 863 | 
            +
                    from espnet2.bin.enh_inference import SeparateSpeech
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                    d = ModelDownloader()
         | 
| 866 | 
            +
             | 
| 867 | 
            +
                    cfg = d.download_and_unpack(self.model_name)
         | 
| 868 | 
            +
                    self.separate_speech = SeparateSpeech(
         | 
| 869 | 
            +
                        train_config=cfg["train_config"],
         | 
| 870 | 
            +
                        model_file=cfg["model_file"],
         | 
| 871 | 
            +
                        # for segment-wise process on long speech
         | 
| 872 | 
            +
                        segment_size=2.4,
         | 
| 873 | 
            +
                        hop_size=0.8,
         | 
| 874 | 
            +
                        normalize_segment_scale=False,
         | 
| 875 | 
            +
                        show_progressbar=True,
         | 
| 876 | 
            +
                        ref_channel=None,
         | 
| 877 | 
            +
                        normalize_output_wav=True,
         | 
| 878 | 
            +
                        device=self.device,
         | 
| 879 | 
            +
                    )
         | 
| 880 | 
            +
             | 
| 881 | 
            +
                @prompts(name="Speech Enhancement In Single-Channel",
         | 
| 882 | 
            +
                         description="useful for when you want to enhance the quality of the speech signal by reducing background noise (single-channel), "
         | 
| 883 | 
            +
                                     "receives audio_path as input."
         | 
| 884 | 
            +
                                     "The input to this tool should be a string, "
         | 
| 885 | 
            +
                                     "representing the audio_path. ")
         | 
| 886 | 
            +
                def inference(self, speech_path, ref_channel=0):
         | 
| 887 | 
            +
                    speech, sr = soundfile.read(speech_path)
         | 
| 888 | 
            +
                    speech = speech[:, ref_channel]
         | 
| 889 | 
            +
                    enh_speech = self.separate_speech(speech[None, ...], fs=sr)
         | 
| 890 | 
            +
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 891 | 
            +
                    soundfile.write(audio_filename, enh_speech[0].squeeze(), samplerate=sr)
         | 
| 892 | 
            +
                    return audio_filename
         | 
| 893 | 
            +
             | 
| 894 | 
            +
             | 
| 895 | 
            +
            class Speech_SS:
         | 
| 896 | 
            +
                def __init__(self, device="cuda", model_name="lichenda/wsj0_2mix_skim_noncausal"):
         | 
| 897 | 
            +
                    self.model_name = model_name
         | 
| 898 | 
            +
                    self.device = device
         | 
| 899 | 
            +
                    print("Initializing ESPnet SS to %s" % device)
         | 
| 900 | 
            +
                    self._initialize_model()
         | 
| 901 | 
            +
             | 
| 902 | 
            +
                def _initialize_model(self):
         | 
| 903 | 
            +
                    from espnet_model_zoo.downloader import ModelDownloader
         | 
| 904 | 
            +
                    from espnet2.bin.enh_inference import SeparateSpeech
         | 
| 905 | 
            +
             | 
| 906 | 
            +
                    d = ModelDownloader()
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                    cfg = d.download_and_unpack(self.model_name)
         | 
| 909 | 
            +
                    self.separate_speech = SeparateSpeech(
         | 
| 910 | 
            +
                        train_config=cfg["train_config"],
         | 
| 911 | 
            +
                        model_file=cfg["model_file"],
         | 
| 912 | 
            +
                        # for segment-wise process on long speech
         | 
| 913 | 
            +
                        segment_size=2.4,
         | 
| 914 | 
            +
                        hop_size=0.8,
         | 
| 915 | 
            +
                        normalize_segment_scale=False,
         | 
| 916 | 
            +
                        show_progressbar=True,
         | 
| 917 | 
            +
                        ref_channel=None,
         | 
| 918 | 
            +
                        normalize_output_wav=True,
         | 
| 919 | 
            +
                        device=self.device,
         | 
| 920 | 
            +
                    )
         | 
| 921 | 
            +
             | 
| 922 | 
            +
                @prompts(name="Speech Separation",
         | 
| 923 | 
            +
                         description="useful for when you want to separate each speech from the speech mixture, "
         | 
| 924 | 
            +
                                     "receives audio_path as input."
         | 
| 925 | 
            +
                                     "The input to this tool should be a string, "
         | 
| 926 | 
            +
                                     "representing the audio_path. ")
         | 
| 927 | 
            +
                def inference(self, speech_path):
         | 
| 928 | 
            +
                    speech, sr = soundfile.read(speech_path)
         | 
| 929 | 
            +
                    enh_speech = self.separate_speech(speech[None, ...], fs=sr)
         | 
| 930 | 
            +
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 931 | 
            +
                    if len(enh_speech) == 1:
         | 
| 932 | 
            +
                        soundfile.write(audio_filename, enh_speech[0].squeeze(), samplerate=sr)
         | 
| 933 | 
            +
                    else:
         | 
| 934 | 
            +
                        audio_filename_1 = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 935 | 
            +
                        soundfile.write(audio_filename_1, enh_speech[0].squeeze(), samplerate=sr)
         | 
| 936 | 
            +
                        audio_filename_2 = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 937 | 
            +
                        soundfile.write(audio_filename_2, enh_speech[1].squeeze(), samplerate=sr)
         | 
| 938 | 
            +
                        audio_filename = merge_audio(audio_filename_1, audio_filename_2)
         | 
| 939 | 
            +
                    return audio_filename
         | 
| 940 | 
            +
                    
         | 
| 941 | 
            +
            class Speech_Enh_SC:
         | 
| 942 | 
            +
                """Speech Enhancement or Separation in single-channel
         | 
| 943 | 
            +
                Example usage:
         | 
| 944 | 
            +
                    enh_model = Speech_Enh_SS("cuda")
         | 
| 945 | 
            +
                    enh_wav = enh_model.inference("./test_chime4_audio_M05_440C0213_PED_REAL.wav")
         | 
| 946 | 
            +
                """
         | 
| 947 | 
            +
             | 
| 948 | 
            +
                def __init__(self, device="cuda", model_name="espnet/Wangyou_Zhang_chime4_enh_train_enh_conv_tasnet_raw"):
         | 
| 949 | 
            +
                    self.model_name = model_name
         | 
| 950 | 
            +
                    self.device = device
         | 
| 951 | 
            +
                    print("Initializing ESPnet Enh to %s" % device)
         | 
| 952 | 
            +
                    self._initialize_model()
         | 
| 953 | 
            +
             | 
| 954 | 
            +
                def _initialize_model(self):
         | 
| 955 | 
            +
                    from espnet_model_zoo.downloader import ModelDownloader
         | 
| 956 | 
            +
                    from espnet2.bin.enh_inference import SeparateSpeech
         | 
| 957 | 
            +
             | 
| 958 | 
            +
                    d = ModelDownloader()
         | 
| 959 | 
            +
             | 
| 960 | 
            +
                    cfg = d.download_and_unpack(self.model_name)
         | 
| 961 | 
            +
                    self.separate_speech = SeparateSpeech(
         | 
| 962 | 
            +
                        train_config=cfg["train_config"],
         | 
| 963 | 
            +
                        model_file=cfg["model_file"],
         | 
| 964 | 
            +
                        # for segment-wise process on long speech
         | 
| 965 | 
            +
                        segment_size=2.4,
         | 
| 966 | 
            +
                        hop_size=0.8,
         | 
| 967 | 
            +
                        normalize_segment_scale=False,
         | 
| 968 | 
            +
                        show_progressbar=True,
         | 
| 969 | 
            +
                        ref_channel=None,
         | 
| 970 | 
            +
                        normalize_output_wav=True,
         | 
| 971 | 
            +
                        device=self.device,
         | 
| 972 | 
            +
                    )
         | 
| 973 | 
            +
             | 
| 974 | 
            +
                @prompts(name="Speech Enhancement In Single-Channel",
         | 
| 975 | 
            +
                         description="useful for when you want to enhance the quality of the speech signal by reducing background noise (single-channel), "
         | 
| 976 | 
            +
                                     "receives audio_path as input."
         | 
| 977 | 
            +
                                     "The input to this tool should be a string, "
         | 
| 978 | 
            +
                                     "representing the audio_path. ")
         | 
| 979 | 
            +
                def inference(self, speech_path, ref_channel=0):
         | 
| 980 | 
            +
                    speech, sr = soundfile.read(speech_path)
         | 
| 981 | 
            +
                    if speech.ndim != 1:
         | 
| 982 | 
            +
                        speech = speech[:, ref_channel]
         | 
| 983 | 
            +
                    enh_speech = self.separate_speech(speech[None, ...], fs=sr)
         | 
| 984 | 
            +
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 985 | 
            +
                    soundfile.write(audio_filename, enh_speech[0].squeeze(), samplerate=sr)
         | 
| 986 | 
            +
                    return audio_filename
         | 
| 987 | 
            +
             | 
| 988 | 
            +
             | 
| 989 | 
            +
            class Speech_SS:
         | 
| 990 | 
            +
                def __init__(self, device="cuda", model_name="lichenda/wsj0_2mix_skim_noncausal"):
         | 
| 991 | 
            +
                    self.model_name = model_name
         | 
| 992 | 
            +
                    self.device = device
         | 
| 993 | 
            +
                    print("Initializing ESPnet SS to %s" % device)
         | 
| 994 | 
            +
                    self._initialize_model()
         | 
| 995 | 
            +
             | 
| 996 | 
            +
                def _initialize_model(self):
         | 
| 997 | 
            +
                    from espnet_model_zoo.downloader import ModelDownloader
         | 
| 998 | 
            +
                    from espnet2.bin.enh_inference import SeparateSpeech
         | 
| 999 | 
            +
             | 
| 1000 | 
            +
                    d = ModelDownloader()
         | 
| 1001 | 
            +
             | 
| 1002 | 
            +
                    cfg = d.download_and_unpack(self.model_name)
         | 
| 1003 | 
            +
                    self.separate_speech = SeparateSpeech(
         | 
| 1004 | 
            +
                        train_config=cfg["train_config"],
         | 
| 1005 | 
            +
                        model_file=cfg["model_file"],
         | 
| 1006 | 
            +
                        # for segment-wise process on long speech
         | 
| 1007 | 
            +
                        segment_size=2.4,
         | 
| 1008 | 
            +
                        hop_size=0.8,
         | 
| 1009 | 
            +
                        normalize_segment_scale=False,
         | 
| 1010 | 
            +
                        show_progressbar=True,
         | 
| 1011 | 
            +
                        ref_channel=None,
         | 
| 1012 | 
            +
                        normalize_output_wav=True,
         | 
| 1013 | 
            +
                        device=self.device,
         | 
| 1014 | 
            +
                    )
         | 
| 1015 | 
            +
             | 
| 1016 | 
            +
                @prompts(name="Speech Separation",
         | 
| 1017 | 
            +
                         description="useful for when you want to separate each speech from the speech mixture, "
         | 
| 1018 | 
            +
                                     "receives audio_path as input."
         | 
| 1019 | 
            +
                                     "The input to this tool should be a string, "
         | 
| 1020 | 
            +
                                     "representing the audio_path. ")
         | 
| 1021 | 
            +
                def inference(self, speech_path):
         | 
| 1022 | 
            +
                    speech, sr = soundfile.read(speech_path)
         | 
| 1023 | 
            +
                    enh_speech = self.separate_speech(speech[None, ...], fs=sr)
         | 
| 1024 | 
            +
                    audio_filename = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 1025 | 
            +
                    if len(enh_speech) == 1:
         | 
| 1026 | 
            +
                        soundfile.write(audio_filename, enh_speech[0].squeeze(), samplerate=sr)
         | 
| 1027 | 
            +
                    else:
         | 
| 1028 | 
            +
                        audio_filename_1 = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 1029 | 
            +
                        soundfile.write(audio_filename_1, enh_speech[0].squeeze(), samplerate=sr)
         | 
| 1030 | 
            +
                        audio_filename_2 = os.path.join('audio', str(uuid.uuid4())[0:8] + ".wav")
         | 
| 1031 | 
            +
                        soundfile.write(audio_filename_2, enh_speech[1].squeeze(), samplerate=sr)
         | 
| 1032 | 
            +
                        audio_filename = merge_audio(audio_filename_1, audio_filename_2)
         | 
| 1033 | 
            +
                    return audio_filename
         |