Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -208,27 +208,35 @@ def infer_batch(ref_audio, ref_text, gen_text_batches, exp_name, remove_silence, 
     | 
|
| 208 | 
         
             
                    print(f"Duration: {duration} seconds")
         
     | 
| 209 | 
         
             
                    # inference
         
     | 
| 210 | 
         
             
                    with torch.inference_mode():
         
     | 
| 211 | 
         
            -
             
     | 
| 212 | 
         
            -
             
     | 
| 213 | 
         
            -
             
     | 
| 214 | 
         
            -
             
     | 
| 215 | 
         
            -
             
     | 
| 216 | 
         
            -
             
     | 
| 217 | 
         
            -
             
     | 
| 218 | 
         
            -
             
     | 
| 219 | 
         
            -
             
     | 
| 220 | 
         
            -
                     
     | 
| 221 | 
         
            -
                     
     | 
| 222 | 
         
            -
             
     | 
| 223 | 
         
            -
             
     | 
| 224 | 
         
            -
             
     | 
| 225 | 
         
            -
             
     | 
| 226 | 
         
            -
             
     | 
| 227 | 
         
            -
             
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
             
     | 
| 230 | 
         
            -
             
     | 
| 231 | 
         
            -
              
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 232 | 
         
             
            # Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
         
     | 
| 233 | 
         
             
            #        generated_mel_spec = generated_mel_spec.to(dtype=torch.float32)  # Convert to float32 if it's in bfloat16
         
     | 
| 234 | 
         | 
| 
         | 
|
| 208 | 
         
             
                    print(f"Duration: {duration} seconds")
         
     | 
| 209 | 
         
             
                    # inference
         
     | 
| 210 | 
         
             
                    with torch.inference_mode():
         
     | 
| 211 | 
         
            +
                # Ensure all inputs are on the same device as ema_model
         
     | 
| 212 | 
         
            +
                    audio = audio.to(ema_model.device)  # Match ema_model's device
         
     | 
| 213 | 
         
            +
                    final_text_list = [t.to(ema_model.device) if isinstance(t, torch.Tensor) else t for t in final_text_list]
         
     | 
| 214 | 
         
            +
                generated, _ = ema_model.sample(
         
     | 
| 215 | 
         
            +
                    cond=audio,
         
     | 
| 216 | 
         
            +
                    text=final_text_list,
         
     | 
| 217 | 
         
            +
                    duration=duration,
         
     | 
| 218 | 
         
            +
                    steps=nfe_step,
         
     | 
| 219 | 
         
            +
                    cfg_strength=cfg_strength,
         
     | 
| 220 | 
         
            +
                    sway_sampling_coef=sway_sampling_coef,
         
     | 
| 221 | 
         
            +
                    )
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
            # Process generated tensor
         
     | 
| 224 | 
         
            +
                generated = generated[:, ref_audio_len:, :]
         
     | 
| 225 | 
         
            +
                generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
            # Convert to appropriate dtype and device
         
     | 
| 228 | 
         
            +
                generated_mel_spec = generated_mel_spec.to(dtype=torch.float16, device=vocos.device)  # Ensure device matches vocos
         
     | 
| 229 | 
         
            +
                generated_wave = vocos.decode(generated_mel_spec)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
            # Adjust wave RMS if needed
         
     | 
| 232 | 
         
            +
                if rms < target_rms:
         
     | 
| 233 | 
         
            +
                generated_wave = generated_wave * rms / target_rms
         
     | 
| 234 | 
         
            +
             
     | 
| 235 | 
         
            +
            # Convert to numpy
         
     | 
| 236 | 
         
            +
                   generated_wave = generated_wave.squeeze().cpu().numpy()
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
            # Append to list
         
     | 
| 239 | 
         
            +
                generated_waves.append(generated_wave)spectrograms.append(generated_mel_spec[0].cpu().numpy())
         
     | 
| 240 | 
         
             
            # Ensure generated_mel_spec is in a compatible dtype (e.g., float32) before passing it to numpy
         
     | 
| 241 | 
         
             
            #        generated_mel_spec = generated_mel_spec.to(dtype=torch.float32)  # Convert to float32 if it's in bfloat16
         
     | 
| 242 | 
         |