Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	v1.2
Browse files- app.py +25 -16
- diffrhythm/config/{diffrhythm-1b.json → config.json} +1 -1
- diffrhythm/infer/infer.py +169 -102
- diffrhythm/infer/infer_utils.py +345 -58
- diffrhythm/model/__pycache__/__init__.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/__init__.cpython-312.pyc +0 -0
- diffrhythm/model/__pycache__/cfm.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/cfm.cpython-312.pyc +0 -0
- diffrhythm/model/__pycache__/custom_dataset.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/custom_dataset_lrc_emb.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/dataset.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/dit.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/modules.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/trainer.cpython-310.pyc +0 -0
- diffrhythm/model/__pycache__/utils.cpython-310.pyc +0 -0
- diffrhythm/model/cfm.py +62 -43
- diffrhythm/model/dit.py +40 -25
- diffrhythm/model/modules.py +41 -0
- diffrhythm/model/utils.py +4 -4
- pretrained/eval.py +66 -0
- pretrained/eval.safetensors +3 -0
- pretrained/eval.yaml +6 -0
    	
        app.py
    CHANGED
    
    | @@ -27,22 +27,16 @@ from diffrhythm.infer.infer import inference | |
| 27 |  | 
| 28 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 29 | 
             
            device='cuda'
         | 
| 30 | 
            -
            cfm,  | 
| 31 | 
             
            cfm = torch.compile(cfm)
         | 
| 32 | 
            -
            cfm_full = torch.compile(cfm_full)
         | 
| 33 |  | 
| 34 | 
            -
            @spaces.GPU | 
| 35 | 
            -
            def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler',  | 
| 36 | 
            -
                 | 
| 37 | 
            -
             | 
| 38 | 
            -
                    cfm_model = cfm
         | 
| 39 | 
            -
                else:
         | 
| 40 | 
            -
                    max_frames = 6144
         | 
| 41 | 
            -
                    cfm_model = cfm_full
         | 
| 42 | 
             
                if randomize_seed:
         | 
| 43 | 
             
                    seed = random.randint(0, MAX_SEED)
         | 
| 44 | 
             
                torch.manual_seed(seed)
         | 
| 45 | 
            -
                sway_sampling_coef = -1 if steps < 32 else None
         | 
| 46 | 
             
                vocal_flag = False
         | 
| 47 | 
             
                try:
         | 
| 48 | 
             
                    lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
         | 
| @@ -53,9 +47,16 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, | |
| 53 | 
             
                except Exception as e:
         | 
| 54 | 
             
                    raise gr.Error(f"Error: {str(e)}")
         | 
| 55 | 
             
                negative_style_prompt = get_negative_style_prompt(device)
         | 
| 56 | 
            -
                latent_prompt = get_reference_latent(device, max_frames)
         | 
| 57 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 58 | 
             
                                           vae_model=vae, 
         | 
|  | |
|  | |
| 59 | 
             
                                           cond=latent_prompt, 
         | 
| 60 | 
             
                                           text=lrc_prompt, 
         | 
| 61 | 
             
                                           duration=max_frames, 
         | 
| @@ -68,6 +69,8 @@ def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, | |
| 68 | 
             
                                           file_type=file_type,
         | 
| 69 | 
             
                                           vocal_flag=vocal_flag,
         | 
| 70 | 
             
                                           odeint_method=odeint_method,
         | 
|  | |
|  | |
| 71 | 
             
                                           )
         | 
| 72 | 
             
                return generated_song
         | 
| 73 |  | 
| @@ -234,8 +237,8 @@ with gr.Blocks(css=css) as demo: | |
| 234 | 
             
                - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.                                
         | 
| 235 |  | 
| 236 | 
             
                                    """)
         | 
| 237 | 
            -
                                Music_Duration = gr.Radio(["95s", "285s"], label="Music Duration", value="95s")
         | 
| 238 | 
            -
                                
         | 
| 239 | 
             
                                lyrics_btn = gr.Button("Generate", variant="primary")
         | 
| 240 | 
             
                                audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
         | 
| 241 | 
             
                                with gr.Accordion("Advanced Settings", open=False):
         | 
| @@ -266,6 +269,12 @@ with gr.Blocks(css=css) as demo: | |
| 266 | 
             
                                                interactive=True,
         | 
| 267 | 
             
                                                elem_id="step_slider"
         | 
| 268 | 
             
                                            )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 269 | 
             
                                    odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")                        
         | 
| 270 | 
             
                                    file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
         | 
| 271 |  | 
| @@ -409,7 +418,7 @@ with gr.Blocks(css=css) as demo: | |
| 409 |  | 
| 410 | 
             
                lyrics_btn.click(
         | 
| 411 | 
             
                    fn=infer_music,
         | 
| 412 | 
            -
                    inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method,  | 
| 413 | 
             
                    outputs=audio_output
         | 
| 414 | 
             
                )
         | 
| 415 |  | 
|  | |
| 27 |  | 
| 28 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 29 | 
             
            device='cuda'
         | 
| 30 | 
            +
            cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(device)
         | 
| 31 | 
             
            cfm = torch.compile(cfm)
         | 
|  | |
| 32 |  | 
| 33 | 
            +
            @spaces.GPU
         | 
| 34 | 
            +
            def infer_music(lrc, ref_audio_path, text_prompt, current_prompt_type, seed=42, randomize_seed=False, steps=32, cfg_strength=4.0, file_type='wav', odeint_method='euler', preference_infer="quality first", edit=False, edit_segments=None, device='cuda'):
         | 
| 35 | 
            +
                max_frames = 2048
         | 
| 36 | 
            +
                sway_sampling_coef = -1 if steps < 32 else None
         | 
|  | |
|  | |
|  | |
|  | |
| 37 | 
             
                if randomize_seed:
         | 
| 38 | 
             
                    seed = random.randint(0, MAX_SEED)
         | 
| 39 | 
             
                torch.manual_seed(seed)
         | 
|  | |
| 40 | 
             
                vocal_flag = False
         | 
| 41 | 
             
                try:
         | 
| 42 | 
             
                    lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
         | 
|  | |
| 47 | 
             
                except Exception as e:
         | 
| 48 | 
             
                    raise gr.Error(f"Error: {str(e)}")
         | 
| 49 | 
             
                negative_style_prompt = get_negative_style_prompt(device)
         | 
| 50 | 
            +
                latent_prompt, pred_frames = get_reference_latent(device, max_frames, edit, edit_segments, ref_audio_path, vae)
         | 
| 51 | 
            +
                
         | 
| 52 | 
            +
                if preference_infer == "quality first":
         | 
| 53 | 
            +
                    batch_infer_num = 5
         | 
| 54 | 
            +
                else:
         | 
| 55 | 
            +
                    batch_infer_num = 1
         | 
| 56 | 
            +
                generated_song = inference(cfm_model=cfm, 
         | 
| 57 | 
             
                                           vae_model=vae, 
         | 
| 58 | 
            +
                                           eval_model=eval_model,
         | 
| 59 | 
            +
                                           eval_muq=eval_muq,
         | 
| 60 | 
             
                                           cond=latent_prompt, 
         | 
| 61 | 
             
                                           text=lrc_prompt, 
         | 
| 62 | 
             
                                           duration=max_frames, 
         | 
|  | |
| 69 | 
             
                                           file_type=file_type,
         | 
| 70 | 
             
                                           vocal_flag=vocal_flag,
         | 
| 71 | 
             
                                           odeint_method=odeint_method,
         | 
| 72 | 
            +
                                           pred_frames=pred_frames,
         | 
| 73 | 
            +
                                           batch_infer_num=batch_infer_num,
         | 
| 74 | 
             
                                           )
         | 
| 75 | 
             
                return generated_song
         | 
| 76 |  | 
|  | |
| 237 | 
             
                - If loading audio result is slow, you can select Output Format as mp3 in Advanced Settings.                                
         | 
| 238 |  | 
| 239 | 
             
                                    """)
         | 
| 240 | 
            +
                                # Music_Duration = gr.Radio(["95s", "285s"], label="Music Duration", value="95s")
         | 
| 241 | 
            +
                                preference_infer = gr.Radio(["quality first", "speed first"], label="Preference", value="quality first")
         | 
| 242 | 
             
                                lyrics_btn = gr.Button("Generate", variant="primary")
         | 
| 243 | 
             
                                audio_output = gr.Audio(label="Audio Result", type="filepath", elem_id="audio_output")
         | 
| 244 | 
             
                                with gr.Accordion("Advanced Settings", open=False):
         | 
|  | |
| 269 | 
             
                                                interactive=True,
         | 
| 270 | 
             
                                                elem_id="step_slider"
         | 
| 271 | 
             
                                            )
         | 
| 272 | 
            +
                                    edit = gr.Checkbox(label="edit", value=False)
         | 
| 273 | 
            +
                                    edit_segeditments = gr.Textbox(
         | 
| 274 | 
            +
                                        label="Edit Segments",
         | 
| 275 | 
            +
                                        placeholder="Time segments to edit (in seconds). Format: `[[start1,end1],...]",
         | 
| 276 | 
            +
                                        )
         | 
| 277 | 
            +
                                    
         | 
| 278 | 
             
                                    odeint_method = gr.Radio(["euler", "midpoint", "rk4","implicit_adams"], label="ODE Solver", value="euler")                        
         | 
| 279 | 
             
                                    file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="wav")
         | 
| 280 |  | 
|  | |
| 418 |  | 
| 419 | 
             
                lyrics_btn.click(
         | 
| 420 | 
             
                    fn=infer_music,
         | 
| 421 | 
            +
                    inputs=[lrc, audio_prompt, text_prompt, current_prompt_type, seed, randomize_seed, steps, cfg_strength, file_type, odeint_method, preference_infer, edit, edit_segments],
         | 
| 422 | 
             
                    outputs=audio_output
         | 
| 423 | 
             
                )
         | 
| 424 |  | 
    	
        diffrhythm/config/{diffrhythm-1b.json → config.json}
    RENAMED
    
    | @@ -2,7 +2,7 @@ | |
| 2 | 
             
                "model_type": "diffrhythm",
         | 
| 3 | 
             
                "model": {
         | 
| 4 | 
             
                    "dim": 2048,
         | 
| 5 | 
            -
                    "depth": 16,
         | 
| 6 | 
             
                    "heads": 32,
         | 
| 7 | 
             
                    "ff_mult": 4,
         | 
| 8 | 
             
                    "text_dim": 512,
         | 
|  | |
| 2 | 
             
                "model_type": "diffrhythm",
         | 
| 3 | 
             
                "model": {
         | 
| 4 | 
             
                    "dim": 2048,
         | 
| 5 | 
            +
                    "depth": 16, 
         | 
| 6 | 
             
                    "heads": 32,
         | 
| 7 | 
             
                    "ff_mult": 4,
         | 
| 8 | 
             
                    "text_dim": 512,
         | 
    	
        diffrhythm/infer/infer.py
    CHANGED
    
    | @@ -2,82 +2,51 @@ import torch | |
| 2 | 
             
            import torchaudio
         | 
| 3 | 
             
            from einops import rearrange
         | 
| 4 | 
             
            import argparse
         | 
| 5 | 
            -
            import json
         | 
| 6 | 
             
            import os
         | 
| 7 | 
            -
             | 
| 8 | 
             
            import random
         | 
|  | |
|  | |
|  | |
| 9 | 
             
            import numpy as np
         | 
| 10 | 
            -
            import  | 
| 11 | 
             
            import io
         | 
| 12 | 
             
            import pydub
         | 
| 13 |  | 
| 14 | 
             
            from diffrhythm.infer.infer_utils import (
         | 
| 15 | 
            -
                 | 
| 16 | 
             
                get_lrc_token,
         | 
| 17 | 
            -
                 | 
|  | |
|  | |
| 18 | 
             
                prepare_model,
         | 
| 19 | 
            -
                 | 
| 20 | 
             
            )
         | 
| 21 |  | 
| 22 | 
            -
            def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
         | 
| 23 | 
            -
                downsampling_ratio = 2048
         | 
| 24 | 
            -
                io_channels = 2
         | 
| 25 | 
            -
                if not chunked:
         | 
| 26 | 
            -
                    # default behavior. Decode the entire latent in parallel
         | 
| 27 | 
            -
                    return vae_model.decode_export(latents)
         | 
| 28 | 
            -
                else:
         | 
| 29 | 
            -
                    # chunked decoding
         | 
| 30 | 
            -
                    hop_size = chunk_size - overlap
         | 
| 31 | 
            -
                    total_size = latents.shape[2]
         | 
| 32 | 
            -
                    batch_size = latents.shape[0]
         | 
| 33 | 
            -
                    chunks = []
         | 
| 34 | 
            -
                    i = 0
         | 
| 35 | 
            -
                    for i in range(0, total_size - chunk_size + 1, hop_size):
         | 
| 36 | 
            -
                        chunk = latents[:,:,i:i+chunk_size]
         | 
| 37 | 
            -
                        chunks.append(chunk)
         | 
| 38 | 
            -
                    if i+chunk_size != total_size:
         | 
| 39 | 
            -
                        # Final chunk
         | 
| 40 | 
            -
                        chunk = latents[:,:,-chunk_size:]
         | 
| 41 | 
            -
                        chunks.append(chunk)
         | 
| 42 | 
            -
                    chunks = torch.stack(chunks)
         | 
| 43 | 
            -
                    num_chunks = chunks.shape[0]
         | 
| 44 | 
            -
                    # samples_per_latent is just the downsampling ratio
         | 
| 45 | 
            -
                    samples_per_latent = downsampling_ratio
         | 
| 46 | 
            -
                    # Create an empty waveform, we will populate it with chunks as decode them
         | 
| 47 | 
            -
                    y_size = total_size * samples_per_latent
         | 
| 48 | 
            -
                    y_final = torch.zeros((batch_size,io_channels,y_size)).to(latents.device)
         | 
| 49 | 
            -
                    for i in range(num_chunks):
         | 
| 50 | 
            -
                        x_chunk = chunks[i,:]
         | 
| 51 | 
            -
                        # decode the chunk
         | 
| 52 | 
            -
                        y_chunk = vae_model.decode_export(x_chunk)
         | 
| 53 | 
            -
                        # figure out where to put the audio along the time domain
         | 
| 54 | 
            -
                        if i == num_chunks-1:
         | 
| 55 | 
            -
                            # final chunk always goes at the end
         | 
| 56 | 
            -
                            t_end = y_size
         | 
| 57 | 
            -
                            t_start = t_end - y_chunk.shape[2]
         | 
| 58 | 
            -
                        else:
         | 
| 59 | 
            -
                            t_start = i * hop_size * samples_per_latent
         | 
| 60 | 
            -
                            t_end = t_start + chunk_size * samples_per_latent
         | 
| 61 | 
            -
                        #  remove the edges of the overlaps
         | 
| 62 | 
            -
                        ol = (overlap//2) * samples_per_latent
         | 
| 63 | 
            -
                        chunk_start = 0
         | 
| 64 | 
            -
                        chunk_end = y_chunk.shape[2]
         | 
| 65 | 
            -
                        if i > 0:
         | 
| 66 | 
            -
                            # no overlap for the start of the first chunk
         | 
| 67 | 
            -
                            t_start += ol
         | 
| 68 | 
            -
                            chunk_start += ol
         | 
| 69 | 
            -
                        if i < num_chunks-1:
         | 
| 70 | 
            -
                            # no overlap for the end of the last chunk
         | 
| 71 | 
            -
                            t_end -= ol
         | 
| 72 | 
            -
                            chunk_end -= ol
         | 
| 73 | 
            -
                        # paste the chunked audio into our y_final output audio
         | 
| 74 | 
            -
                        y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
         | 
| 75 | 
            -
                    return y_final
         | 
| 76 | 
            -
             | 
| 77 | 
            -
            def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative_style_prompt, steps, cfg_strength, sway_sampling_coef, start_time, file_type, vocal_flag, odeint_method):
         | 
| 78 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 79 | 
             
                with torch.inference_mode():
         | 
| 80 | 
            -
                     | 
| 81 | 
             
                        cond=cond,
         | 
| 82 | 
             
                        text=text,
         | 
| 83 | 
             
                        duration=duration,
         | 
| @@ -89,17 +58,27 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative | |
| 89 | 
             
                        start_time=start_time,
         | 
| 90 | 
             
                        vocal_flag=vocal_flag,
         | 
| 91 | 
             
                        odeint_method=odeint_method,
         | 
|  | |
|  | |
| 92 | 
             
                    )
         | 
| 93 | 
            -
                    
         | 
| 94 | 
            -
                    generated = generated.to(torch.float32)
         | 
| 95 | 
            -
                    latent = generated.transpose(1, 2) # [b d t]
         | 
| 96 | 
            -
                    output = decode_audio(latent, vae_model, chunked=False)
         | 
| 97 |  | 
| 98 | 
            -
                     | 
| 99 | 
            -
                     | 
| 100 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 101 | 
             
                    output_np = output_tensor.numpy().T.astype(np.float32)
         | 
| 102 | 
            -
                    
         | 
| 103 | 
             
                    if file_type == 'wav':
         | 
| 104 | 
             
                        return (44100, output_np)
         | 
| 105 | 
             
                    else:
         | 
| @@ -111,52 +90,140 @@ def inference(cfm_model, vae_model, cond, text, duration, style_prompt, negative | |
| 111 | 
             
                        else:
         | 
| 112 | 
             
                            song.export(buffer, format="ogg", bitrate="320k")
         | 
| 113 | 
             
                        return buffer.getvalue()
         | 
| 114 | 
            -
                
         | 
| 115 |  | 
|  | |
|  | |
| 116 | 
             
            if __name__ == "__main__":
         | 
| 117 | 
             
                parser = argparse.ArgumentParser()
         | 
| 118 | 
            -
                parser.add_argument( | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 122 | 
             
                args = parser.parse_args()
         | 
| 123 | 
            -
             | 
| 124 | 
            -
                 | 
| 125 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 126 | 
             
                audio_length = args.audio_length
         | 
| 127 | 
             
                if audio_length == 95:
         | 
| 128 | 
             
                    max_frames = 2048
         | 
| 129 | 
             
                elif audio_length == 285:
         | 
| 130 | 
             
                    max_frames = 6144
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 131 |  | 
| 132 | 
            -
                cfm, tokenizer, muq, vae = prepare_model(device)
         | 
| 133 | 
            -
                
         | 
| 134 | 
            -
                with open(args.lrc_path, 'r') as f:
         | 
| 135 | 
            -
                    lrc = f.read()
         | 
| 136 | 
            -
                lrc_prompt, start_time = get_lrc_token(lrc, tokenizer, device)
         | 
| 137 | 
            -
                
         | 
| 138 | 
            -
                style_prompt = get_audio_style_prompt(muq, args.ref_audio_path)
         | 
| 139 |  | 
| 140 | 
            -
                negative_style_prompt = get_negative_style_prompt(device)
         | 
| 141 |  | 
| 142 | 
            -
                 | 
| 143 |  | 
| 144 | 
            -
                 | 
| 145 | 
            -
                generated_song =  | 
| 146 | 
            -
             | 
| 147 | 
            -
             | 
| 148 | 
            -
             | 
| 149 | 
            -
             | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
| 153 | 
            -
                                           )
         | 
| 154 | 
             
                e_t = time.time() - s_t
         | 
| 155 | 
            -
                print(f"inference cost {e_t} seconds")
         | 
| 156 | 
            -
                
         | 
| 157 | 
             
                output_dir = args.output_dir
         | 
| 158 | 
             
                os.makedirs(output_dir, exist_ok=True)
         | 
| 159 | 
            -
             | 
| 160 | 
             
                output_path = os.path.join(output_dir, "output.wav")
         | 
| 161 | 
             
                torchaudio.save(output_path, generated_song, sample_rate=44100)
         | 
| 162 | 
            -
                
         | 
|  | |
| 2 | 
             
            import torchaudio
         | 
| 3 | 
             
            from einops import rearrange
         | 
| 4 | 
             
            import argparse
         | 
|  | |
| 5 | 
             
            import os
         | 
| 6 | 
            +
            import time
         | 
| 7 | 
             
            import random
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torchaudio
         | 
| 11 | 
             
            import numpy as np
         | 
| 12 | 
            +
            from einops import rearrange
         | 
| 13 | 
             
            import io
         | 
| 14 | 
             
            import pydub
         | 
| 15 |  | 
| 16 | 
             
            from diffrhythm.infer.infer_utils import (
         | 
| 17 | 
            +
                decode_audio,
         | 
| 18 | 
             
                get_lrc_token,
         | 
| 19 | 
            +
                get_negative_style_prompt,
         | 
| 20 | 
            +
                get_reference_latent,
         | 
| 21 | 
            +
                get_style_prompt,
         | 
| 22 | 
             
                prepare_model,
         | 
| 23 | 
            +
                eval_song,
         | 
| 24 | 
             
            )
         | 
| 25 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 26 |  | 
| 27 | 
            +
            def inference(
         | 
| 28 | 
            +
                cfm_model,
         | 
| 29 | 
            +
                vae_model,
         | 
| 30 | 
            +
                eval_model,
         | 
| 31 | 
            +
                eval_muq,
         | 
| 32 | 
            +
                cond,
         | 
| 33 | 
            +
                text,
         | 
| 34 | 
            +
                duration,
         | 
| 35 | 
            +
                style_prompt,
         | 
| 36 | 
            +
                negative_style_prompt,
         | 
| 37 | 
            +
                steps,
         | 
| 38 | 
            +
                cfg_strength,
         | 
| 39 | 
            +
                sway_sampling_coef,
         | 
| 40 | 
            +
                start_time,
         | 
| 41 | 
            +
                file_type,
         | 
| 42 | 
            +
                vocal_flag,
         | 
| 43 | 
            +
                odeint_method,
         | 
| 44 | 
            +
                pred_frames,
         | 
| 45 | 
            +
                batch_infer_num,
         | 
| 46 | 
            +
                chunked=True,
         | 
| 47 | 
            +
            ):
         | 
| 48 | 
             
                with torch.inference_mode():
         | 
| 49 | 
            +
                    latents, _ = cfm_model.sample(
         | 
| 50 | 
             
                        cond=cond,
         | 
| 51 | 
             
                        text=text,
         | 
| 52 | 
             
                        duration=duration,
         | 
|  | |
| 58 | 
             
                        start_time=start_time,
         | 
| 59 | 
             
                        vocal_flag=vocal_flag,
         | 
| 60 | 
             
                        odeint_method=odeint_method,
         | 
| 61 | 
            +
                        latent_pred_segments=pred_frames,
         | 
| 62 | 
            +
                        batch_infer_num=batch_infer_num
         | 
| 63 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
| 64 |  | 
| 65 | 
            +
                    outputs = []
         | 
| 66 | 
            +
                    for latent in latents:
         | 
| 67 | 
            +
                        latent = latent.to(torch.float32)
         | 
| 68 | 
            +
                        latent = latent.transpose(1, 2)  # [b d t]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                        output = decode_audio(latent, vae_model, chunked=chunked)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                        # Rearrange audio batch to a single sequence
         | 
| 73 | 
            +
                        output = rearrange(output, "b d n -> d (b n)")
         | 
| 74 | 
            +
                        
         | 
| 75 | 
            +
                        outputs.append(output)
         | 
| 76 | 
            +
                    if batch_infer_num > 1:
         | 
| 77 | 
            +
                        generated_song = eval_song(eval_model, eval_muq, outputs)
         | 
| 78 | 
            +
                    else:
         | 
| 79 | 
            +
                        generated_song = outputs[0]
         | 
| 80 | 
            +
                    output_tensor = generated_song.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).cpu()
         | 
| 81 | 
             
                    output_np = output_tensor.numpy().T.astype(np.float32)
         | 
|  | |
| 82 | 
             
                    if file_type == 'wav':
         | 
| 83 | 
             
                        return (44100, output_np)
         | 
| 84 | 
             
                    else:
         | 
|  | |
| 90 | 
             
                        else:
         | 
| 91 | 
             
                            song.export(buffer, format="ogg", bitrate="320k")
         | 
| 92 | 
             
                        return buffer.getvalue()
         | 
|  | |
| 93 |  | 
| 94 | 
            +
             | 
| 95 | 
            +
             | 
| 96 | 
             
            if __name__ == "__main__":
         | 
| 97 | 
             
                parser = argparse.ArgumentParser()
         | 
| 98 | 
            +
                parser.add_argument(
         | 
| 99 | 
            +
                    "--lrc-path",
         | 
| 100 | 
            +
                    type=str,
         | 
| 101 | 
            +
                    help="lyrics of target song",
         | 
| 102 | 
            +
                )  # lyrics of target song
         | 
| 103 | 
            +
                parser.add_argument(
         | 
| 104 | 
            +
                    "--ref-prompt",
         | 
| 105 | 
            +
                    type=str,
         | 
| 106 | 
            +
                    help="reference prompt as style prompt for target song",
         | 
| 107 | 
            +
                    required=False,
         | 
| 108 | 
            +
                )  # reference prompt as style prompt for target song
         | 
| 109 | 
            +
                parser.add_argument(
         | 
| 110 | 
            +
                    "--ref-audio-path",
         | 
| 111 | 
            +
                    type=str,
         | 
| 112 | 
            +
                    help="reference audio as style prompt for target song",
         | 
| 113 | 
            +
                    required=False,
         | 
| 114 | 
            +
                )  # reference audio as style prompt for target song
         | 
| 115 | 
            +
                parser.add_argument(
         | 
| 116 | 
            +
                    "--chunked",
         | 
| 117 | 
            +
                    action="store_true",
         | 
| 118 | 
            +
                    help="whether to use chunked decoding",
         | 
| 119 | 
            +
                )  # whether to use chunked decoding
         | 
| 120 | 
            +
                parser.add_argument(
         | 
| 121 | 
            +
                    "--audio-length",
         | 
| 122 | 
            +
                    type=int,
         | 
| 123 | 
            +
                    default=95,
         | 
| 124 | 
            +
                    choices=[95, 285],
         | 
| 125 | 
            +
                    help="length of generated song",
         | 
| 126 | 
            +
                )  # length of target song
         | 
| 127 | 
            +
                parser.add_argument(
         | 
| 128 | 
            +
                    "--repo-id", type=str, default="ASLP-lab/DiffRhythm-base", help="target model"
         | 
| 129 | 
            +
                )
         | 
| 130 | 
            +
                parser.add_argument(
         | 
| 131 | 
            +
                    "--output-dir",
         | 
| 132 | 
            +
                    type=str,
         | 
| 133 | 
            +
                    default="infer/example/output",
         | 
| 134 | 
            +
                    help="output directory fo generated song",
         | 
| 135 | 
            +
                )  # output directory of target song
         | 
| 136 | 
            +
                parser.add_argument(
         | 
| 137 | 
            +
                    "--edit",
         | 
| 138 | 
            +
                    action="store_true",
         | 
| 139 | 
            +
                    help="whether to open edit mode",
         | 
| 140 | 
            +
                )  # edit flag
         | 
| 141 | 
            +
                parser.add_argument(
         | 
| 142 | 
            +
                    "--ref-song",
         | 
| 143 | 
            +
                    type=str,
         | 
| 144 | 
            +
                    required=False,
         | 
| 145 | 
            +
                    help="reference prompt as latent prompt for editing",
         | 
| 146 | 
            +
                )  # reference prompt as latent prompt for editing
         | 
| 147 | 
            +
                parser.add_argument(
         | 
| 148 | 
            +
                    "--edit-segments",
         | 
| 149 | 
            +
                    type=str,
         | 
| 150 | 
            +
                    required=False,
         | 
| 151 | 
            +
                    help="edit segments o target song",
         | 
| 152 | 
            +
                )  # edit segments o target song
         | 
| 153 | 
             
                args = parser.parse_args()
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                assert (
         | 
| 156 | 
            +
                    args.ref_prompt or args.ref_audio_path
         | 
| 157 | 
            +
                ), "either ref_prompt or ref_audio_path should be provided"
         | 
| 158 | 
            +
                assert not (
         | 
| 159 | 
            +
                    args.ref_prompt and args.ref_audio_path
         | 
| 160 | 
            +
                ), "only one of them should be provided"
         | 
| 161 | 
            +
                if args.edit:
         | 
| 162 | 
            +
                    assert (
         | 
| 163 | 
            +
                        args.ref_song and args.edit_segments
         | 
| 164 | 
            +
                    ), "reference song and edit segments should be provided for editing"
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                device = "cpu"
         | 
| 167 | 
            +
                if torch.cuda.is_available():
         | 
| 168 | 
            +
                    device = "cuda"
         | 
| 169 | 
            +
                elif torch.mps.is_available():
         | 
| 170 | 
            +
                    device = "mps"
         | 
| 171 | 
            +
             | 
| 172 | 
             
                audio_length = args.audio_length
         | 
| 173 | 
             
                if audio_length == 95:
         | 
| 174 | 
             
                    max_frames = 2048
         | 
| 175 | 
             
                elif audio_length == 285:
         | 
| 176 | 
             
                    max_frames = 6144
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                cfm, tokenizer, muq, vae, eval_model, eval_muq = prepare_model(max_frames, device, repo_id=args.repo_id)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                if args.lrc_path:
         | 
| 181 | 
            +
                    with open(args.lrc_path, "r", encoding='utf-8') as f:
         | 
| 182 | 
            +
                        lrc = f.read()
         | 
| 183 | 
            +
                else:
         | 
| 184 | 
            +
                    lrc = ""
         | 
| 185 | 
            +
                lrc_prompt, start_time = get_lrc_token(max_frames, lrc, tokenizer, device)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                if args.ref_audio_path:
         | 
| 188 | 
            +
                    style_prompt = get_style_prompt(muq, args.ref_audio_path)
         | 
| 189 | 
            +
                else:
         | 
| 190 | 
            +
                    style_prompt = get_style_prompt(muq, prompt=args.ref_prompt)
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                negative_style_prompt = get_negative_style_prompt(device)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                latent_prompt, pred_frames = get_reference_latent(device, max_frames, args.edit, args.edit_segments, args.ref_song, vae)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                s_t = time.time()
         | 
| 197 | 
            +
                generated_songs = inference(
         | 
| 198 | 
            +
                    cfm_model=cfm,
         | 
| 199 | 
            +
                    vae_model=vae,
         | 
| 200 | 
            +
                    cond=latent_prompt,
         | 
| 201 | 
            +
                    text=lrc_prompt,
         | 
| 202 | 
            +
                    duration=max_frames,
         | 
| 203 | 
            +
                    style_prompt=style_prompt,
         | 
| 204 | 
            +
                    negative_style_prompt=negative_style_prompt,
         | 
| 205 | 
            +
                    start_time=start_time,
         | 
| 206 | 
            +
                    pred_frames=pred_frames,
         | 
| 207 | 
            +
                    chunked=args.chunked,
         | 
| 208 | 
            +
                )
         | 
| 209 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 210 |  | 
|  | |
| 211 |  | 
| 212 | 
            +
                generated_song = eval_song(eval_model, eval_muq, generated_songs)
         | 
| 213 |  | 
| 214 | 
            +
                # Peak normalize, clip, convert to int16, and save to file
         | 
| 215 | 
            +
                generated_song = (
         | 
| 216 | 
            +
                    generated_song.to(torch.float32)
         | 
| 217 | 
            +
                    .div(torch.max(torch.abs(generated_song)))
         | 
| 218 | 
            +
                    .clamp(-1, 1)
         | 
| 219 | 
            +
                    .mul(32767)
         | 
| 220 | 
            +
                    .to(torch.int16)
         | 
| 221 | 
            +
                    .cpu()
         | 
| 222 | 
            +
                )
         | 
|  | |
| 223 | 
             
                e_t = time.time() - s_t
         | 
| 224 | 
            +
                print(f"inference cost {e_t:.2f} seconds")
         | 
|  | |
| 225 | 
             
                output_dir = args.output_dir
         | 
| 226 | 
             
                os.makedirs(output_dir, exist_ok=True)
         | 
| 227 | 
            +
             | 
| 228 | 
             
                output_path = os.path.join(output_dir, "output.wav")
         | 
| 229 | 
             
                torchaudio.save(output_path, generated_song, sample_rate=44100)
         | 
|  | 
    	
        diffrhythm/infer/infer_utils.py
    CHANGED
    
    | @@ -1,66 +1,308 @@ | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            import librosa
         | 
|  | |
| 3 | 
             
            import random
         | 
| 4 | 
             
            import json
         | 
| 5 | 
            -
            from muq import MuQMuLan
         | 
| 6 | 
             
            from mutagen.mp3 import MP3
         | 
| 7 | 
             
            import os
         | 
| 8 | 
             
            import numpy as np
         | 
| 9 | 
             
            from huggingface_hub import hf_hub_download
         | 
|  | |
|  | |
|  | |
|  | |
| 10 | 
             
            from diffrhythm.model import DiT, CFM
         | 
| 11 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 12 |  | 
| 13 | 
             
            def prepare_model(device):
         | 
| 14 | 
             
                # prepare cfm model
         | 
| 15 | 
            -
                 | 
| 16 | 
            -
                 | 
| 17 | 
            -
                dit_config_path = "./diffrhythm/config/ | 
| 18 | 
             
                with open(dit_config_path) as f:
         | 
| 19 | 
             
                    model_config = json.load(f)
         | 
| 20 | 
             
                dit_model_cls = DiT
         | 
| 21 | 
             
                cfm = CFM(
         | 
| 22 | 
            -
                            transformer=dit_model_cls(**model_config["model"],  | 
| 23 | 
             
                            num_channels=model_config["model"]['mel_dim'],
         | 
| 24 | 
            -
                            use_style_prompt=True
         | 
| 25 | 
             
                         )
         | 
| 26 | 
             
                cfm = cfm.to(device)
         | 
| 27 | 
             
                cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
         | 
| 28 | 
            -
             | 
| 29 | 
            -
                cfm_full = CFM(
         | 
| 30 | 
            -
                            transformer=dit_model_cls(**model_config["model"], use_style_prompt=True, max_pos=6144),
         | 
| 31 | 
            -
                            num_channels=model_config["model"]['mel_dim'],
         | 
| 32 | 
            -
                            use_style_prompt=True
         | 
| 33 | 
            -
                         )
         | 
| 34 | 
            -
                cfm_full = cfm_full.to(device)
         | 
| 35 | 
            -
                cfm_full = load_checkpoint(cfm_full, dit_full_ckpt_path, device=device, use_ema=False)
         | 
| 36 | 
            -
                
         | 
| 37 | 
             
                # prepare tokenizer
         | 
| 38 | 
             
                tokenizer = CNENTokenizer()
         | 
| 39 | 
            -
             | 
| 40 | 
             
                # prepare muq
         | 
| 41 | 
            -
                muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large")
         | 
| 42 | 
             
                muq = muq.to(device).eval()
         | 
| 43 | 
            -
             | 
| 44 | 
             
                # prepare vae
         | 
| 45 | 
             
                vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
         | 
| 46 | 
            -
                vae = torch.jit.load(vae_ckpt_path, map_location= | 
|  | |
| 47 |  | 
| 48 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 49 |  | 
| 50 |  | 
| 51 | 
             
            # for song edit, will be added in the future
         | 
| 52 | 
            -
            def get_reference_latent(device, max_frames):
         | 
| 53 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 54 |  | 
| 55 | 
             
            def get_negative_style_prompt(device):
         | 
| 56 | 
             
                file_path = "./src/negative_prompt.npy"
         | 
| 57 | 
             
                vocal_stlye = np.load(file_path)
         | 
| 58 | 
            -
             | 
| 59 | 
            -
                vocal_stlye = torch.from_numpy(vocal_stlye).to(device) | 
| 60 | 
             
                vocal_stlye = vocal_stlye.half()
         | 
| 61 | 
            -
             | 
| 62 | 
             
                return vocal_stlye
         | 
| 63 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 64 | 
             
            def get_audio_style_prompt(model, wav_path):
         | 
| 65 | 
             
                vocal_flag = False
         | 
| 66 | 
             
                mulan = model
         | 
| @@ -85,6 +327,8 @@ def get_audio_style_prompt(model, wav_path): | |
| 85 |  | 
| 86 | 
             
                return audio_emb, vocal_flag
         | 
| 87 |  | 
|  | |
|  | |
| 88 | 
             
            def get_text_style_prompt(model, text_prompt):
         | 
| 89 | 
             
                mulan = model
         | 
| 90 |  | 
| @@ -95,50 +339,88 @@ def get_text_style_prompt(model, text_prompt): | |
| 95 | 
             
                return text_emb
         | 
| 96 |  | 
| 97 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 98 |  | 
| 99 | 
             
            def parse_lyrics(lyrics: str):
         | 
| 100 | 
             
                lyrics_with_time = []
         | 
| 101 | 
             
                lyrics = lyrics.strip()
         | 
| 102 | 
            -
                for line in lyrics.split( | 
| 103 | 
             
                    try:
         | 
| 104 | 
             
                        time, lyric = line[1:9], line[10:]
         | 
| 105 | 
             
                        lyric = lyric.strip()
         | 
| 106 | 
            -
                        mins, secs = time.split( | 
| 107 | 
             
                        secs = int(mins) * 60 + float(secs)
         | 
| 108 | 
             
                        lyrics_with_time.append((secs, lyric))
         | 
| 109 | 
             
                    except:
         | 
| 110 | 
             
                        continue
         | 
| 111 | 
             
                return lyrics_with_time
         | 
| 112 |  | 
| 113 | 
            -
             | 
|  | |
| 114 | 
             
                def __init__(self):
         | 
| 115 | 
            -
                    with open( | 
| 116 | 
            -
                        self.phone2id:dict = json.load(file)[ | 
| 117 | 
            -
                    self.id2phone = {v:k for (k, v) in self.phone2id.items()}
         | 
| 118 | 
             
                    from diffrhythm.g2p.g2p_generation import chn_eng_g2p
         | 
|  | |
| 119 | 
             
                    self.tokenizer = chn_eng_g2p
         | 
|  | |
| 120 | 
             
                def encode(self, text):
         | 
| 121 | 
             
                    phone, token = self.tokenizer(text)
         | 
| 122 | 
            -
                    token = [x+1 for x in token]
         | 
| 123 | 
             
                    return token
         | 
|  | |
| 124 | 
             
                def decode(self, token):
         | 
| 125 | 
            -
                    return "|".join([self.id2phone[x-1] for x in token])
         | 
| 126 | 
            -
             | 
|  | |
| 127 | 
             
            def get_lrc_token(max_frames, text, tokenizer, device):
         | 
| 128 |  | 
| 129 | 
             
                lyrics_shift = 0
         | 
| 130 | 
             
                sampling_rate = 44100
         | 
| 131 | 
             
                downsample_rate = 2048
         | 
| 132 | 
             
                max_secs = max_frames / (sampling_rate / downsample_rate)
         | 
| 133 | 
            -
             | 
| 134 | 
            -
                pad_token_id = 0
         | 
| 135 | 
             
                comma_token_id = 1
         | 
| 136 | 
            -
                period_token_id = 2 | 
| 137 | 
            -
                if text == "":
         | 
| 138 | 
            -
                    return torch.zeros((max_frames,), dtype=torch.long).unsqueeze(0).to(device), torch.tensor(0.).unsqueeze(0).to(device).half()
         | 
| 139 |  | 
| 140 | 
             
                lrc_with_time = parse_lyrics(text)
         | 
| 141 | 
            -
             | 
| 142 | 
             
                modified_lrc_with_time = []
         | 
| 143 | 
             
                for i in range(len(lrc_with_time)):
         | 
| 144 | 
             
                    time, line = lrc_with_time[i]
         | 
| @@ -146,44 +428,49 @@ def get_lrc_token(max_frames, text, tokenizer, device): | |
| 146 | 
             
                    modified_lrc_with_time.append((time, line_token))
         | 
| 147 | 
             
                lrc_with_time = modified_lrc_with_time
         | 
| 148 |  | 
| 149 | 
            -
                lrc_with_time = [ | 
| 150 | 
            -
             | 
| 151 | 
            -
             | 
| 152 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 153 |  | 
| 154 | 
             
                lrc = torch.zeros((max_frames,), dtype=torch.long)
         | 
| 155 |  | 
| 156 | 
             
                tokens_count = 0
         | 
| 157 | 
             
                last_end_pos = 0
         | 
| 158 | 
             
                for time_start, line in lrc_with_time:
         | 
| 159 | 
            -
                    tokens = [ | 
|  | |
|  | |
| 160 | 
             
                    tokens = torch.tensor(tokens, dtype=torch.long)
         | 
| 161 | 
             
                    num_tokens = tokens.shape[0]
         | 
| 162 |  | 
| 163 | 
             
                    gt_frame_start = int(time_start * sampling_rate / downsample_rate)
         | 
| 164 | 
            -
             | 
| 165 | 
            -
                    frame_shift = random.randint(int(lyrics_shift), int(lyrics_shift))
         | 
| 166 | 
            -
             | 
| 167 | 
             
                    frame_start = max(gt_frame_start - frame_shift, last_end_pos)
         | 
| 168 | 
             
                    frame_len = min(num_tokens, max_frames - frame_start)
         | 
| 169 |  | 
| 170 | 
            -
             | 
| 171 | 
            -
             | 
| 172 | 
            -
                    lrc[frame_start:frame_start + frame_len] = tokens[:frame_len]
         | 
| 173 |  | 
| 174 | 
             
                    tokens_count += num_tokens
         | 
| 175 | 
            -
                    last_end_pos = frame_start + frame_len | 
| 176 | 
            -
             | 
| 177 | 
             
                lrc_emb = lrc.unsqueeze(0).to(device)
         | 
| 178 | 
            -
             | 
| 179 | 
             
                normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
         | 
| 180 | 
             
                normalized_start_time = normalized_start_time.half()
         | 
| 181 | 
            -
             | 
| 182 | 
             
                return lrc_emb, normalized_start_time
         | 
| 183 |  | 
|  | |
| 184 | 
             
            def load_checkpoint(model, ckpt_path, device, use_ema=True):
         | 
| 185 | 
            -
                 | 
| 186 | 
            -
                    model = model.half()
         | 
| 187 |  | 
| 188 | 
             
                ckpt_type = ckpt_path.split(".")[-1]
         | 
| 189 | 
             
                if ckpt_type == "safetensors":
         | 
| @@ -207,4 +494,4 @@ def load_checkpoint(model, ckpt_path, device, use_ema=True): | |
| 207 | 
             
                        checkpoint = {"model_state_dict": checkpoint}
         | 
| 208 | 
             
                    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
         | 
| 209 |  | 
| 210 | 
            -
                return model.to(device)
         | 
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
             
            import librosa
         | 
| 3 | 
            +
            import torchaudio
         | 
| 4 | 
             
            import random
         | 
| 5 | 
             
            import json
         | 
| 6 | 
            +
            from muq import MuQMuLan, MuQ
         | 
| 7 | 
             
            from mutagen.mp3 import MP3
         | 
| 8 | 
             
            import os
         | 
| 9 | 
             
            import numpy as np
         | 
| 10 | 
             
            from huggingface_hub import hf_hub_download
         | 
| 11 | 
            +
            from hydra.utils import instantiate
         | 
| 12 | 
            +
            from omegaconf import OmegaConf
         | 
| 13 | 
            +
            from safetensors.torch import load_file
         | 
| 14 | 
            +
             | 
| 15 | 
             
            from diffrhythm.model import DiT, CFM
         | 
| 16 |  | 
| 17 | 
            +
            def vae_sample(mean, scale):
         | 
| 18 | 
            +
                stdev = torch.nn.functional.softplus(scale) + 1e-4
         | 
| 19 | 
            +
                var = stdev * stdev
         | 
| 20 | 
            +
                logvar = torch.log(var)
         | 
| 21 | 
            +
                latents = torch.randn_like(mean) * stdev + mean
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                kl = (mean * mean + var - logvar - 1).sum(1).mean()
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                return latents, kl
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def normalize_audio(y, target_dbfs=0):
         | 
| 28 | 
            +
                max_amplitude = torch.max(torch.abs(y))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                target_amplitude = 10.0**(target_dbfs / 20.0)
         | 
| 31 | 
            +
                scale_factor = target_amplitude / max_amplitude
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                normalized_audio = y * scale_factor
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                return normalized_audio
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def set_audio_channels(audio, target_channels):
         | 
| 38 | 
            +
                if target_channels == 1:
         | 
| 39 | 
            +
                    # Convert to mono
         | 
| 40 | 
            +
                    audio = audio.mean(1, keepdim=True)
         | 
| 41 | 
            +
                elif target_channels == 2:
         | 
| 42 | 
            +
                    # Convert to stereo
         | 
| 43 | 
            +
                    if audio.shape[1] == 1:
         | 
| 44 | 
            +
                        audio = audio.repeat(1, 2, 1)
         | 
| 45 | 
            +
                    elif audio.shape[1] > 2:
         | 
| 46 | 
            +
                        audio = audio[:, :2, :]
         | 
| 47 | 
            +
                return audio
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            class PadCrop(torch.nn.Module):
         | 
| 50 | 
            +
                def __init__(self, n_samples, randomize=True):
         | 
| 51 | 
            +
                    super().__init__()
         | 
| 52 | 
            +
                    self.n_samples = n_samples
         | 
| 53 | 
            +
                    self.randomize = randomize
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def __call__(self, signal):
         | 
| 56 | 
            +
                    n, s = signal.shape
         | 
| 57 | 
            +
                    start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
         | 
| 58 | 
            +
                    end = start + self.n_samples
         | 
| 59 | 
            +
                    output = signal.new_zeros([n, self.n_samples])
         | 
| 60 | 
            +
                    output[:, :min(s, self.n_samples)] = signal[:, start:end]
         | 
| 61 | 
            +
                    return output
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
         | 
| 64 | 
            +
                
         | 
| 65 | 
            +
                audio = audio.to(device)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                if in_sr != target_sr:
         | 
| 68 | 
            +
                    resample_tf = T.Resample(in_sr, target_sr).to(device)
         | 
| 69 | 
            +
                    audio = resample_tf(audio)
         | 
| 70 | 
            +
                if target_length is None:
         | 
| 71 | 
            +
                    target_length = audio.shape[-1]
         | 
| 72 | 
            +
                audio = PadCrop(target_length, randomize=False)(audio)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                # Add batch dimension
         | 
| 75 | 
            +
                if audio.dim() == 1:
         | 
| 76 | 
            +
                    audio = audio.unsqueeze(0).unsqueeze(0)
         | 
| 77 | 
            +
                elif audio.dim() == 2:
         | 
| 78 | 
            +
                    audio = audio.unsqueeze(0)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                audio = set_audio_channels(audio, target_channels)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return audio
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            def decode_audio(latents, vae_model, chunked=False, overlap=32, chunk_size=128):
         | 
| 85 | 
            +
                downsampling_ratio = 2048
         | 
| 86 | 
            +
                io_channels = 2
         | 
| 87 | 
            +
                if not chunked:
         | 
| 88 | 
            +
                    return vae_model.decode_export(latents)
         | 
| 89 | 
            +
                else:
         | 
| 90 | 
            +
                    # chunked decoding
         | 
| 91 | 
            +
                    hop_size = chunk_size - overlap
         | 
| 92 | 
            +
                    total_size = latents.shape[2]
         | 
| 93 | 
            +
                    batch_size = latents.shape[0]
         | 
| 94 | 
            +
                    chunks = []
         | 
| 95 | 
            +
                    i = 0
         | 
| 96 | 
            +
                    for i in range(0, total_size - chunk_size + 1, hop_size):
         | 
| 97 | 
            +
                        chunk = latents[:, :, i : i + chunk_size]
         | 
| 98 | 
            +
                        chunks.append(chunk)
         | 
| 99 | 
            +
                    if i + chunk_size != total_size:
         | 
| 100 | 
            +
                        # Final chunk
         | 
| 101 | 
            +
                        chunk = latents[:, :, -chunk_size:]
         | 
| 102 | 
            +
                        chunks.append(chunk)
         | 
| 103 | 
            +
                    chunks = torch.stack(chunks)
         | 
| 104 | 
            +
                    num_chunks = chunks.shape[0]
         | 
| 105 | 
            +
                    # samples_per_latent is just the downsampling ratio
         | 
| 106 | 
            +
                    samples_per_latent = downsampling_ratio
         | 
| 107 | 
            +
                    # Create an empty waveform, we will populate it with chunks as decode them
         | 
| 108 | 
            +
                    y_size = total_size * samples_per_latent
         | 
| 109 | 
            +
                    y_final = torch.zeros((batch_size, io_channels, y_size)).to(latents.device)
         | 
| 110 | 
            +
                    for i in range(num_chunks):
         | 
| 111 | 
            +
                        x_chunk = chunks[i, :]
         | 
| 112 | 
            +
                        # decode the chunk
         | 
| 113 | 
            +
                        y_chunk = vae_model.decode_export(x_chunk)
         | 
| 114 | 
            +
                        # figure out where to put the audio along the time domain
         | 
| 115 | 
            +
                        if i == num_chunks - 1:
         | 
| 116 | 
            +
                            # final chunk always goes at the end
         | 
| 117 | 
            +
                            t_end = y_size
         | 
| 118 | 
            +
                            t_start = t_end - y_chunk.shape[2]
         | 
| 119 | 
            +
                        else:
         | 
| 120 | 
            +
                            t_start = i * hop_size * samples_per_latent
         | 
| 121 | 
            +
                            t_end = t_start + chunk_size * samples_per_latent
         | 
| 122 | 
            +
                        #  remove the edges of the overlaps
         | 
| 123 | 
            +
                        ol = (overlap // 2) * samples_per_latent
         | 
| 124 | 
            +
                        chunk_start = 0
         | 
| 125 | 
            +
                        chunk_end = y_chunk.shape[2]
         | 
| 126 | 
            +
                        if i > 0:
         | 
| 127 | 
            +
                            # no overlap for the start of the first chunk
         | 
| 128 | 
            +
                            t_start += ol
         | 
| 129 | 
            +
                            chunk_start += ol
         | 
| 130 | 
            +
                        if i < num_chunks - 1:
         | 
| 131 | 
            +
                            # no overlap for the end of the last chunk
         | 
| 132 | 
            +
                            t_end -= ol
         | 
| 133 | 
            +
                            chunk_end -= ol
         | 
| 134 | 
            +
                        # paste the chunked audio into our y_final output audio
         | 
| 135 | 
            +
                        y_final[:, :, t_start:t_end] = y_chunk[:, :, chunk_start:chunk_end]
         | 
| 136 | 
            +
                    return y_final
         | 
| 137 | 
            +
             | 
| 138 | 
            +
            def encode_audio(audio, vae_model, chunked=False, overlap=32, chunk_size=128):
         | 
| 139 | 
            +
                downsampling_ratio = 2048
         | 
| 140 | 
            +
                latent_dim = 128
         | 
| 141 | 
            +
                if not chunked:
         | 
| 142 | 
            +
                    # default behavior. Encode the entire audio in parallel
         | 
| 143 | 
            +
                    return vae_model.encode_export(audio)
         | 
| 144 | 
            +
                else:
         | 
| 145 | 
            +
                    # CHUNKED ENCODING
         | 
| 146 | 
            +
                    # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
         | 
| 147 | 
            +
                    samples_per_latent = downsampling_ratio
         | 
| 148 | 
            +
                    total_size = audio.shape[2] # in samples
         | 
| 149 | 
            +
                    batch_size = audio.shape[0]
         | 
| 150 | 
            +
                    chunk_size *= samples_per_latent # converting metric in latents to samples
         | 
| 151 | 
            +
                    overlap *= samples_per_latent # converting metric in latents to samples
         | 
| 152 | 
            +
                    hop_size = chunk_size - overlap
         | 
| 153 | 
            +
                    chunks = []
         | 
| 154 | 
            +
                    for i in range(0, total_size - chunk_size + 1, hop_size):
         | 
| 155 | 
            +
                        chunk = audio[:,:,i:i+chunk_size]
         | 
| 156 | 
            +
                        chunks.append(chunk)
         | 
| 157 | 
            +
                    if i+chunk_size != total_size:
         | 
| 158 | 
            +
                        # Final chunk
         | 
| 159 | 
            +
                        chunk = audio[:,:,-chunk_size:]
         | 
| 160 | 
            +
                        chunks.append(chunk)
         | 
| 161 | 
            +
                    chunks = torch.stack(chunks)
         | 
| 162 | 
            +
                    num_chunks = chunks.shape[0]
         | 
| 163 | 
            +
                    # Note: y_size might be a different value from the latent length used in diffusion training
         | 
| 164 | 
            +
                    # because we can encode audio of varying lengths
         | 
| 165 | 
            +
                    # However, the audio should've been padded to a multiple of samples_per_latent by now.
         | 
| 166 | 
            +
                    y_size = total_size // samples_per_latent
         | 
| 167 | 
            +
                    # Create an empty latent, we will populate it with chunks as we encode them
         | 
| 168 | 
            +
                    y_final = torch.zeros((batch_size,latent_dim,y_size)).to(audio.device)
         | 
| 169 | 
            +
                    for i in range(num_chunks):
         | 
| 170 | 
            +
                        x_chunk = chunks[i,:]
         | 
| 171 | 
            +
                        # encode the chunk
         | 
| 172 | 
            +
                        y_chunk = vae_model.encode_export(x_chunk)
         | 
| 173 | 
            +
                        # figure out where to put the audio along the time domain
         | 
| 174 | 
            +
                        if i == num_chunks-1:
         | 
| 175 | 
            +
                            # final chunk always goes at the end
         | 
| 176 | 
            +
                            t_end = y_size
         | 
| 177 | 
            +
                            t_start = t_end - y_chunk.shape[2]
         | 
| 178 | 
            +
                        else:
         | 
| 179 | 
            +
                            t_start = i * hop_size // samples_per_latent
         | 
| 180 | 
            +
                            t_end = t_start + chunk_size // samples_per_latent
         | 
| 181 | 
            +
                        #  remove the edges of the overlaps
         | 
| 182 | 
            +
                        ol = overlap//samples_per_latent//2
         | 
| 183 | 
            +
                        chunk_start = 0
         | 
| 184 | 
            +
                        chunk_end = y_chunk.shape[2]
         | 
| 185 | 
            +
                        if i > 0:
         | 
| 186 | 
            +
                            # no overlap for the start of the first chunk
         | 
| 187 | 
            +
                            t_start += ol
         | 
| 188 | 
            +
                            chunk_start += ol
         | 
| 189 | 
            +
                        if i < num_chunks-1:
         | 
| 190 | 
            +
                            # no overlap for the end of the last chunk
         | 
| 191 | 
            +
                            t_end -= ol
         | 
| 192 | 
            +
                            chunk_end -= ol
         | 
| 193 | 
            +
                        # paste the chunked audio into our y_final output audio
         | 
| 194 | 
            +
                        y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
         | 
| 195 | 
            +
                    return y_final
         | 
| 196 |  | 
| 197 | 
             
            def prepare_model(device):
         | 
| 198 | 
             
                # prepare cfm model
         | 
| 199 | 
            +
                
         | 
| 200 | 
            +
                dit_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-1_2", filename="cfm_model.pt")
         | 
| 201 | 
            +
                dit_config_path = "./diffrhythm/config/config.json"
         | 
| 202 | 
             
                with open(dit_config_path) as f:
         | 
| 203 | 
             
                    model_config = json.load(f)
         | 
| 204 | 
             
                dit_model_cls = DiT
         | 
| 205 | 
             
                cfm = CFM(
         | 
| 206 | 
            +
                            transformer=dit_model_cls(**model_config["model"], max_frames=2048),
         | 
| 207 | 
             
                            num_channels=model_config["model"]['mel_dim'],
         | 
|  | |
| 208 | 
             
                         )
         | 
| 209 | 
             
                cfm = cfm.to(device)
         | 
| 210 | 
             
                cfm = load_checkpoint(cfm, dit_ckpt_path, device=device, use_ema=False)
         | 
| 211 | 
            +
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 212 | 
             
                # prepare tokenizer
         | 
| 213 | 
             
                tokenizer = CNENTokenizer()
         | 
| 214 | 
            +
             | 
| 215 | 
             
                # prepare muq
         | 
| 216 | 
            +
                muq = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large", cache_dir="./pretrained")
         | 
| 217 | 
             
                muq = muq.to(device).eval()
         | 
| 218 | 
            +
             | 
| 219 | 
             
                # prepare vae
         | 
| 220 | 
             
                vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
         | 
| 221 | 
            +
                vae = torch.jit.load(vae_ckpt_path, map_location="cpu").to(device)
         | 
| 222 | 
            +
                
         | 
| 223 |  | 
| 224 | 
            +
                # prepare eval model
         | 
| 225 | 
            +
                train_config = OmegaConf.load("./pretrained/eval.yaml")
         | 
| 226 | 
            +
                checkpoint_path = "./pretrained/eval.safetensors"
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                eval_model = instantiate(train_config.generator).to(device).eval()
         | 
| 229 | 
            +
                state_dict = load_file(checkpoint_path, device="cpu")
         | 
| 230 | 
            +
                eval_model.load_state_dict(state_dict)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                eval_muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
         | 
| 233 | 
            +
                eval_muq = eval_muq.to(device).eval()
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                return cfm, tokenizer, muq, vae, eval_model, eval_muq
         | 
| 236 |  | 
| 237 |  | 
| 238 | 
             
            # for song edit, will be added in the future
         | 
| 239 | 
            +
            def get_reference_latent(device, max_frames, edit, pred_segments, ref_song, vae_model):
         | 
| 240 | 
            +
                sampling_rate = 44100
         | 
| 241 | 
            +
                downsample_rate = 2048
         | 
| 242 | 
            +
                io_channels = 2
         | 
| 243 | 
            +
                if edit:
         | 
| 244 | 
            +
                    input_audio, in_sr = torchaudio.load(ref_song)
         | 
| 245 | 
            +
                    input_audio = prepare_audio(input_audio, in_sr=in_sr, target_sr=sampling_rate, target_length=None, target_channels=io_channels, device=device)
         | 
| 246 | 
            +
                    input_audio = normalize_audio(input_audio, -6)
         | 
| 247 | 
            +
                    
         | 
| 248 | 
            +
                    with torch.no_grad():
         | 
| 249 | 
            +
                        latent = encode_audio(input_audio, vae_model, chunked=True) # [b d t]
         | 
| 250 | 
            +
                        mean, scale = latent.chunk(2, dim=1)
         | 
| 251 | 
            +
                        prompt, _ = vae_sample(mean, scale)
         | 
| 252 | 
            +
                        prompt = prompt.transpose(1, 2) # [b t d]
         | 
| 253 | 
            +
                    
         | 
| 254 | 
            +
                    pred_segments = json.loads(pred_segments)
         | 
| 255 | 
            +
                    # import pdb; pdb.set_trace()
         | 
| 256 | 
            +
                    pred_frames = []
         | 
| 257 | 
            +
                    for st, et in pred_segments:
         | 
| 258 | 
            +
                        sf = 0 if st == -1 else int(st * sampling_rate / downsample_rate)
         | 
| 259 | 
            +
                        # if st == -1:
         | 
| 260 | 
            +
                        #     sf = 0
         | 
| 261 | 
            +
                        # else:
         | 
| 262 | 
            +
                        #     sf = int(st * sampling_rate / downsample_rate )
         | 
| 263 | 
            +
                        
         | 
| 264 | 
            +
                        ef = max_frames if et == -1 else int(et * sampling_rate / downsample_rate)
         | 
| 265 | 
            +
                        # if et == -1:
         | 
| 266 | 
            +
                        #     ef = max_frames
         | 
| 267 | 
            +
                        # else:
         | 
| 268 | 
            +
                        #     ef = int(et * sampling_rate / downsample_rate )
         | 
| 269 | 
            +
                        pred_frames.append((sf, ef))
         | 
| 270 | 
            +
                    # import pdb; pdb.set_trace()
         | 
| 271 | 
            +
                    return prompt, pred_frames
         | 
| 272 | 
            +
                else:
         | 
| 273 | 
            +
                    prompt = torch.zeros(1, max_frames, 64).to(device)
         | 
| 274 | 
            +
                    pred_frames = [(0, max_frames)]
         | 
| 275 | 
            +
                    return prompt, pred_frames
         | 
| 276 | 
            +
             | 
| 277 |  | 
| 278 | 
             
            def get_negative_style_prompt(device):
         | 
| 279 | 
             
                file_path = "./src/negative_prompt.npy"
         | 
| 280 | 
             
                vocal_stlye = np.load(file_path)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                vocal_stlye = torch.from_numpy(vocal_stlye).to(device)  # [1, 512]
         | 
| 283 | 
             
                vocal_stlye = vocal_stlye.half()
         | 
| 284 | 
            +
             | 
| 285 | 
             
                return vocal_stlye
         | 
| 286 |  | 
| 287 | 
            +
            @torch.no_grad()
         | 
| 288 | 
            +
            def eval_song(eval_model, eval_muq, songs):
         | 
| 289 | 
            +
                
         | 
| 290 | 
            +
                resampled_songs = [torchaudio.functional.resample(song.mean(dim=0, keepdim=True), 44100, 24000) for song in songs]
         | 
| 291 | 
            +
                ssl_list = []
         | 
| 292 | 
            +
                for i in range(len(resampled_songs)):
         | 
| 293 | 
            +
                    output = eval_muq(resampled_songs[i], output_hidden_states=True)
         | 
| 294 | 
            +
                    muq_ssl = output["hidden_states"][6]
         | 
| 295 | 
            +
                    ssl_list.append(muq_ssl.squeeze(0))
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                ssl = torch.stack(ssl_list)
         | 
| 298 | 
            +
                scores_g = eval_model(ssl)
         | 
| 299 | 
            +
                score = torch.mean(scores_g, dim=1)
         | 
| 300 | 
            +
                idx = score.argmax(dim=0)
         | 
| 301 | 
            +
                
         | 
| 302 | 
            +
                return songs[idx]
         | 
| 303 | 
            +
                
         | 
| 304 | 
            +
             | 
| 305 | 
            +
            @torch.no_grad()
         | 
| 306 | 
             
            def get_audio_style_prompt(model, wav_path):
         | 
| 307 | 
             
                vocal_flag = False
         | 
| 308 | 
             
                mulan = model
         | 
|  | |
| 327 |  | 
| 328 | 
             
                return audio_emb, vocal_flag
         | 
| 329 |  | 
| 330 | 
            +
             | 
| 331 | 
            +
            @torch.no_grad()
         | 
| 332 | 
             
            def get_text_style_prompt(model, text_prompt):
         | 
| 333 | 
             
                mulan = model
         | 
| 334 |  | 
|  | |
| 339 | 
             
                return text_emb
         | 
| 340 |  | 
| 341 |  | 
| 342 | 
            +
            @torch.no_grad()
         | 
| 343 | 
            +
            def get_style_prompt(model, wav_path=None, prompt=None):
         | 
| 344 | 
            +
                mulan = model
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                if prompt is not None:
         | 
| 347 | 
            +
                    return mulan(texts=prompt).half()
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                ext = os.path.splitext(wav_path)[-1].lower()
         | 
| 350 | 
            +
                if ext == ".mp3":
         | 
| 351 | 
            +
                    meta = MP3(wav_path)
         | 
| 352 | 
            +
                    audio_len = meta.info.length
         | 
| 353 | 
            +
                elif ext in [".wav", ".flac"]:
         | 
| 354 | 
            +
                    audio_len = librosa.get_duration(path=wav_path)
         | 
| 355 | 
            +
                else:
         | 
| 356 | 
            +
                    raise ValueError("Unsupported file format: {}".format(ext))
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                if audio_len < 10:
         | 
| 359 | 
            +
                    print(
         | 
| 360 | 
            +
                        f"Warning: The audio file {wav_path} is too short ({audio_len:.2f} seconds). Expected at least 10 seconds."
         | 
| 361 | 
            +
                    )
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                assert audio_len >= 10
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                mid_time = audio_len // 2
         | 
| 366 | 
            +
                start_time = mid_time - 5
         | 
| 367 | 
            +
                wav, _ = librosa.load(wav_path, sr=24000, offset=start_time, duration=10)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                wav = torch.tensor(wav).unsqueeze(0).to(model.device)
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                with torch.no_grad():
         | 
| 372 | 
            +
                    audio_emb = mulan(wavs=wav)  # [1, 512]
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                audio_emb = audio_emb
         | 
| 375 | 
            +
                audio_emb = audio_emb.half()
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                return audio_emb
         | 
| 378 |  | 
| 379 | 
             
            def parse_lyrics(lyrics: str):
         | 
| 380 | 
             
                lyrics_with_time = []
         | 
| 381 | 
             
                lyrics = lyrics.strip()
         | 
| 382 | 
            +
                for line in lyrics.split("\n"):
         | 
| 383 | 
             
                    try:
         | 
| 384 | 
             
                        time, lyric = line[1:9], line[10:]
         | 
| 385 | 
             
                        lyric = lyric.strip()
         | 
| 386 | 
            +
                        mins, secs = time.split(":")
         | 
| 387 | 
             
                        secs = int(mins) * 60 + float(secs)
         | 
| 388 | 
             
                        lyrics_with_time.append((secs, lyric))
         | 
| 389 | 
             
                    except:
         | 
| 390 | 
             
                        continue
         | 
| 391 | 
             
                return lyrics_with_time
         | 
| 392 |  | 
| 393 | 
            +
             | 
| 394 | 
            +
            class CNENTokenizer:
         | 
| 395 | 
             
                def __init__(self):
         | 
| 396 | 
            +
                    with open("./diffrhythm/g2p/g2p/vocab.json", "r", encoding='utf-8') as file:
         | 
| 397 | 
            +
                        self.phone2id: dict = json.load(file)["vocab"]
         | 
| 398 | 
            +
                    self.id2phone = {v: k for (k, v) in self.phone2id.items()}
         | 
| 399 | 
             
                    from diffrhythm.g2p.g2p_generation import chn_eng_g2p
         | 
| 400 | 
            +
             | 
| 401 | 
             
                    self.tokenizer = chn_eng_g2p
         | 
| 402 | 
            +
             | 
| 403 | 
             
                def encode(self, text):
         | 
| 404 | 
             
                    phone, token = self.tokenizer(text)
         | 
| 405 | 
            +
                    token = [x + 1 for x in token]
         | 
| 406 | 
             
                    return token
         | 
| 407 | 
            +
             | 
| 408 | 
             
                def decode(self, token):
         | 
| 409 | 
            +
                    return "|".join([self.id2phone[x - 1] for x in token])
         | 
| 410 | 
            +
             | 
| 411 | 
            +
             | 
| 412 | 
             
            def get_lrc_token(max_frames, text, tokenizer, device):
         | 
| 413 |  | 
| 414 | 
             
                lyrics_shift = 0
         | 
| 415 | 
             
                sampling_rate = 44100
         | 
| 416 | 
             
                downsample_rate = 2048
         | 
| 417 | 
             
                max_secs = max_frames / (sampling_rate / downsample_rate)
         | 
| 418 | 
            +
             | 
|  | |
| 419 | 
             
                comma_token_id = 1
         | 
| 420 | 
            +
                period_token_id = 2
         | 
|  | |
|  | |
| 421 |  | 
| 422 | 
             
                lrc_with_time = parse_lyrics(text)
         | 
| 423 | 
            +
             | 
| 424 | 
             
                modified_lrc_with_time = []
         | 
| 425 | 
             
                for i in range(len(lrc_with_time)):
         | 
| 426 | 
             
                    time, line = lrc_with_time[i]
         | 
|  | |
| 428 | 
             
                    modified_lrc_with_time.append((time, line_token))
         | 
| 429 | 
             
                lrc_with_time = modified_lrc_with_time
         | 
| 430 |  | 
| 431 | 
            +
                lrc_with_time = [
         | 
| 432 | 
            +
                    (time_start, line)
         | 
| 433 | 
            +
                    for (time_start, line) in lrc_with_time
         | 
| 434 | 
            +
                    if time_start < max_secs
         | 
| 435 | 
            +
                ]
         | 
| 436 | 
            +
                if max_frames == 2048:
         | 
| 437 | 
            +
                    lrc_with_time = lrc_with_time[:-1] if len(lrc_with_time) >= 1 else lrc_with_time
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                normalized_start_time = 0.0
         | 
| 440 |  | 
| 441 | 
             
                lrc = torch.zeros((max_frames,), dtype=torch.long)
         | 
| 442 |  | 
| 443 | 
             
                tokens_count = 0
         | 
| 444 | 
             
                last_end_pos = 0
         | 
| 445 | 
             
                for time_start, line in lrc_with_time:
         | 
| 446 | 
            +
                    tokens = [
         | 
| 447 | 
            +
                        token if token != period_token_id else comma_token_id for token in line
         | 
| 448 | 
            +
                    ] + [period_token_id]
         | 
| 449 | 
             
                    tokens = torch.tensor(tokens, dtype=torch.long)
         | 
| 450 | 
             
                    num_tokens = tokens.shape[0]
         | 
| 451 |  | 
| 452 | 
             
                    gt_frame_start = int(time_start * sampling_rate / downsample_rate)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    frame_shift = random.randint(int(-lyrics_shift), int(lyrics_shift))
         | 
| 455 | 
            +
             | 
| 456 | 
             
                    frame_start = max(gt_frame_start - frame_shift, last_end_pos)
         | 
| 457 | 
             
                    frame_len = min(num_tokens, max_frames - frame_start)
         | 
| 458 |  | 
| 459 | 
            +
                    lrc[frame_start : frame_start + frame_len] = tokens[:frame_len]
         | 
|  | |
|  | |
| 460 |  | 
| 461 | 
             
                    tokens_count += num_tokens
         | 
| 462 | 
            +
                    last_end_pos = frame_start + frame_len
         | 
| 463 | 
            +
             | 
| 464 | 
             
                lrc_emb = lrc.unsqueeze(0).to(device)
         | 
| 465 | 
            +
             | 
| 466 | 
             
                normalized_start_time = torch.tensor(normalized_start_time).unsqueeze(0).to(device)
         | 
| 467 | 
             
                normalized_start_time = normalized_start_time.half()
         | 
| 468 | 
            +
             | 
| 469 | 
             
                return lrc_emb, normalized_start_time
         | 
| 470 |  | 
| 471 | 
            +
             | 
| 472 | 
             
            def load_checkpoint(model, ckpt_path, device, use_ema=True):
         | 
| 473 | 
            +
                model = model.half()
         | 
|  | |
| 474 |  | 
| 475 | 
             
                ckpt_type = ckpt_path.split(".")[-1]
         | 
| 476 | 
             
                if ckpt_type == "safetensors":
         | 
|  | |
| 494 | 
             
                        checkpoint = {"model_state_dict": checkpoint}
         | 
| 495 | 
             
                    model.load_state_dict(checkpoint["model_state_dict"], strict=False)
         | 
| 496 |  | 
| 497 | 
            +
                return model.to(device)
         | 
    	
        diffrhythm/model/__pycache__/__init__.cpython-310.pyc
    DELETED
    
    | Binary file (290 Bytes) | 
|  | 
    	
        diffrhythm/model/__pycache__/__init__.cpython-312.pyc
    DELETED
    
    | Binary file (508 Bytes) | 
|  | 
    	
        diffrhythm/model/__pycache__/cfm.cpython-310.pyc
    DELETED
    
    | Binary file (6.28 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/cfm.cpython-312.pyc
    DELETED
    
    | Binary file (10.7 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/custom_dataset.cpython-310.pyc
    DELETED
    
    | Binary file (11.5 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/custom_dataset_lrc_emb.cpython-310.pyc
    DELETED
    
    | Binary file (10.5 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/dataset.cpython-310.pyc
    DELETED
    
    | Binary file (8.04 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/dit.cpython-310.pyc
    DELETED
    
    | Binary file (5.61 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/modules.cpython-310.pyc
    DELETED
    
    | Binary file (15.9 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/trainer.cpython-310.pyc
    DELETED
    
    | Binary file (9.13 kB) | 
|  | 
    	
        diffrhythm/model/__pycache__/utils.cpython-310.pyc
    DELETED
    
    | Binary file (6.03 kB) | 
|  | 
    	
        diffrhythm/model/cfm.py
    CHANGED
    
    | @@ -1,10 +1,22 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 | 
             
            """
         | 
| 9 |  | 
| 10 | 
             
            from __future__ import annotations
         | 
| @@ -19,9 +31,7 @@ from torch.nn.utils.rnn import pad_sequence | |
| 19 |  | 
| 20 | 
             
            from torchdiffeq import odeint
         | 
| 21 |  | 
| 22 | 
            -
            from diffrhythm.model.modules import MelSpec
         | 
| 23 | 
             
            from diffrhythm.model.utils import (
         | 
| 24 | 
            -
                default,
         | 
| 25 | 
             
                exists,
         | 
| 26 | 
             
                list_str_to_idx,
         | 
| 27 | 
             
                list_str_to_tensor,
         | 
| @@ -29,12 +39,25 @@ from diffrhythm.model.utils import ( | |
| 29 | 
             
                mask_from_frac_lengths,
         | 
| 30 | 
             
            )
         | 
| 31 |  | 
| 32 | 
            -
            def custom_mask_from_start_end_indices( | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 33 | 
             
                max_seq_len = max_seq_len
         | 
| 34 | 
             
                seq = torch.arange(max_seq_len, device=device).long()
         | 
| 35 | 
            -
             | 
| 36 | 
            -
                 | 
| 37 | 
            -
                 | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 38 |  | 
| 39 | 
             
            class CFM(nn.Module):
         | 
| 40 | 
             
                def __init__(
         | 
| @@ -42,7 +65,7 @@ class CFM(nn.Module): | |
| 42 | 
             
                    transformer: nn.Module,
         | 
| 43 | 
             
                    sigma=0.0,
         | 
| 44 | 
             
                    odeint_kwargs: dict = dict(
         | 
| 45 | 
            -
                        method="euler" | 
| 46 | 
             
                    ),
         | 
| 47 | 
             
                    odeint_options: dict = dict(
         | 
| 48 | 
             
                        min_step=0.05
         | 
| @@ -54,7 +77,7 @@ class CFM(nn.Module): | |
| 54 | 
             
                    num_channels=None,
         | 
| 55 | 
             
                    frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
         | 
| 56 | 
             
                    vocab_char_map: dict[str:int] | None = None,
         | 
| 57 | 
            -
                     | 
| 58 | 
             
                ):
         | 
| 59 | 
             
                    super().__init__()
         | 
| 60 |  | 
| @@ -83,8 +106,8 @@ class CFM(nn.Module): | |
| 83 |  | 
| 84 | 
             
                    # vocab map for tokenization
         | 
| 85 | 
             
                    self.vocab_char_map = vocab_char_map
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                    self. | 
| 88 |  | 
| 89 | 
             
                @property
         | 
| 90 | 
             
                def device(self):
         | 
| @@ -112,10 +135,10 @@ class CFM(nn.Module): | |
| 112 | 
             
                    t_inter=0.1,
         | 
| 113 | 
             
                    edit_mask=None,
         | 
| 114 | 
             
                    start_time=None,
         | 
| 115 | 
            -
                     | 
| 116 | 
            -
                    latent_pred_end_frame=2048,
         | 
| 117 | 
             
                    vocal_flag=False,
         | 
| 118 | 
            -
                    odeint_method="euler"
         | 
|  | |
| 119 | 
             
                ):
         | 
| 120 | 
             
                    self.eval()
         | 
| 121 |  | 
| @@ -125,7 +148,6 @@ class CFM(nn.Module): | |
| 125 | 
             
                        cond = cond.half()
         | 
| 126 |  | 
| 127 | 
             
                    # raw wave
         | 
| 128 | 
            -
                    
         | 
| 129 | 
             
                    if cond.shape[1] > duration:
         | 
| 130 | 
             
                        cond = cond[:, :duration, :]
         | 
| 131 |  | 
| @@ -139,7 +161,6 @@ class CFM(nn.Module): | |
| 139 | 
             
                        lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
         | 
| 140 |  | 
| 141 | 
             
                    # text
         | 
| 142 | 
            -
             | 
| 143 | 
             
                    if isinstance(text, list):
         | 
| 144 | 
             
                        if exists(self.vocab_char_map):
         | 
| 145 | 
             
                            text = list_str_to_idx(text, self.vocab_char_map).to(device)
         | 
| @@ -147,26 +168,18 @@ class CFM(nn.Module): | |
| 147 | 
             
                            text = list_str_to_tensor(text).to(device)
         | 
| 148 | 
             
                        assert text.shape[0] == batch
         | 
| 149 |  | 
| 150 | 
            -
                    if exists(text):
         | 
| 151 | 
            -
                        text_lens = (text != -1).sum(dim=-1)
         | 
| 152 | 
            -
             | 
| 153 | 
            -
             | 
| 154 | 
             
                    # duration
         | 
| 155 | 
             
                    cond_mask = lens_to_mask(lens)
         | 
| 156 | 
             
                    if edit_mask is not None:
         | 
| 157 | 
             
                        cond_mask = cond_mask & edit_mask
         | 
| 158 |  | 
| 159 | 
            -
                     | 
| 160 | 
            -
                     | 
| 161 | 
            -
                    latent_pred_end_frame = torch.tensor([latent_pred_end_frame]).to(cond.device)
         | 
| 162 | 
            -
                    fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_start_frame, latent_pred_end_frame, device=cond.device, max_seq_len=duration)
         | 
| 163 | 
            -
             | 
| 164 | 
             
                    fixed_span_mask = fixed_span_mask.unsqueeze(-1)
         | 
| 165 | 
             
                    step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond)
         | 
| 166 |  | 
| 167 | 
             
                    if isinstance(duration, int):
         | 
| 168 | 
            -
                        duration = torch.full(( | 
| 169 | 
            -
             | 
| 170 |  | 
| 171 | 
             
                    duration = duration.clamp(max=max_duration)
         | 
| 172 | 
             
                    max_duration = duration.amax()
         | 
| @@ -175,7 +188,6 @@ class CFM(nn.Module): | |
| 175 | 
             
                    if duplicate_test:
         | 
| 176 | 
             
                        test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
         | 
| 177 |  | 
| 178 | 
            -
             | 
| 179 | 
             
                    if batch > 1:
         | 
| 180 | 
             
                        mask = lens_to_mask(duration)
         | 
| 181 | 
             
                    else:  # save memory and speed up, as single inference need no mask currently
         | 
| @@ -184,20 +196,27 @@ class CFM(nn.Module): | |
| 184 | 
             
                    # test for no ref audio
         | 
| 185 | 
             
                    if no_ref_audio:
         | 
| 186 | 
             
                        cond = torch.zeros_like(cond)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 187 |  | 
| 188 | 
             
                    start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
         | 
| 189 | 
             
                    _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
         | 
| 190 |  | 
| 191 | 
            -
                    if vocal_flag:
         | 
| 192 | 
            -
                        style_prompt = negative_style_prompt
         | 
| 193 | 
            -
                        negative_style_prompt = torch.zeros_like(style_prompt)
         | 
| 194 | 
            -
                        
         | 
| 195 | 
             
                    text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
         | 
| 196 | 
             
                    text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
         | 
| 197 | 
             
                    step_cond = torch.cat([step_cond, step_cond], 0)
         | 
| 198 | 
             
                    style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
         | 
| 199 | 
             
                    start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
         | 
| 200 | 
            -
                        
         | 
| 201 |  | 
| 202 | 
             
                    def fn(t, x):
         | 
| 203 | 
             
                        x = torch.cat([x, x], 0)
         | 
| @@ -228,7 +247,7 @@ class CFM(nn.Module): | |
| 228 | 
             
                        t_start = t_inter
         | 
| 229 | 
             
                        y0 = (1 - t_start) * y0 + t_start * test_cond
         | 
| 230 | 
             
                        steps = int(steps * (1 - t_start))
         | 
| 231 | 
            -
             | 
| 232 | 
             
                    t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
         | 
| 233 | 
             
                    if sway_sampling_coef is not None:
         | 
| 234 | 
             
                        t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
         | 
| @@ -243,6 +262,7 @@ class CFM(nn.Module): | |
| 243 | 
             
                        out = out.permute(0, 2, 1)
         | 
| 244 | 
             
                        out = vocoder(out)
         | 
| 245 |  | 
|  | |
| 246 | 
             
                    return out, trajectory
         | 
| 247 |  | 
| 248 | 
             
                def forward(
         | 
| @@ -267,11 +287,10 @@ class CFM(nn.Module): | |
| 267 |  | 
| 268 | 
             
                    # get a random span to mask out for training conditionally
         | 
| 269 | 
             
                    frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
         | 
| 270 | 
            -
                    rand_span_mask = mask_from_frac_lengths(lens, frac_lengths)
         | 
| 271 |  | 
| 272 | 
             
                    if exists(mask):
         | 
| 273 | 
             
                        rand_span_mask = mask
         | 
| 274 | 
            -
                        # rand_span_mask &= mask
         | 
| 275 |  | 
| 276 | 
             
                    # mel is x1
         | 
| 277 | 
             
                    x1 = inp
         | 
| @@ -301,7 +320,7 @@ class CFM(nn.Module): | |
| 301 | 
             
                    # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
         | 
| 302 | 
             
                    pred = self.transformer(
         | 
| 303 | 
             
                        x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt,
         | 
| 304 | 
            -
                        style_prompt=style_prompt,  | 
| 305 | 
             
                    )
         | 
| 306 |  | 
| 307 | 
             
                    # flow matching loss
         | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 ASLP-LAB
         | 
| 2 | 
            +
            #               2025 Ziqian Ning   (ningziqian@mail.nwpu.edu.cn)
         | 
| 3 | 
            +
            #               2025 Huakang Chen  (huakang@mail.nwpu.edu.cn)
         | 
| 4 | 
            +
            #               2025 Guobin Ma     (guobin.ma@mail.nwpu.edu.cn)
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 7 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 8 | 
            +
            # You may obtain a copy of the License at
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 13 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 14 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 15 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 16 | 
            +
            # limitations under the License.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            """ This implementation is adapted from github repo:
         | 
| 19 | 
            +
                https://github.com/SWivid/F5-TTS.
         | 
| 20 | 
             
            """
         | 
| 21 |  | 
| 22 | 
             
            from __future__ import annotations
         | 
|  | |
| 31 |  | 
| 32 | 
             
            from torchdiffeq import odeint
         | 
| 33 |  | 
|  | |
| 34 | 
             
            from diffrhythm.model.utils import (
         | 
|  | |
| 35 | 
             
                exists,
         | 
| 36 | 
             
                list_str_to_idx,
         | 
| 37 | 
             
                list_str_to_tensor,
         | 
|  | |
| 39 | 
             
                mask_from_frac_lengths,
         | 
| 40 | 
             
            )
         | 
| 41 |  | 
| 42 | 
            +
            def custom_mask_from_start_end_indices(
         | 
| 43 | 
            +
                seq_len: int["b"],  # noqa: F821
         | 
| 44 | 
            +
                latent_pred_segments,
         | 
| 45 | 
            +
                device,
         | 
| 46 | 
            +
                max_seq_len
         | 
| 47 | 
            +
            ):
         | 
| 48 | 
             
                max_seq_len = max_seq_len
         | 
| 49 | 
             
                seq = torch.arange(max_seq_len, device=device).long()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                res_mask = torch.zeros(max_seq_len, device=device, dtype=torch.bool)
         | 
| 52 | 
            +
                
         | 
| 53 | 
            +
                for start, end in latent_pred_segments:
         | 
| 54 | 
            +
                    start = start.unsqueeze(0)
         | 
| 55 | 
            +
                    end = end.unsqueeze(0)
         | 
| 56 | 
            +
                    start_mask = seq[None, :] >= start[:, None]
         | 
| 57 | 
            +
                    end_mask = seq[None, :] < end[:, None]
         | 
| 58 | 
            +
                    res_mask = res_mask | (start_mask & end_mask)
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                return res_mask
         | 
| 61 |  | 
| 62 | 
             
            class CFM(nn.Module):
         | 
| 63 | 
             
                def __init__(
         | 
|  | |
| 65 | 
             
                    transformer: nn.Module,
         | 
| 66 | 
             
                    sigma=0.0,
         | 
| 67 | 
             
                    odeint_kwargs: dict = dict(
         | 
| 68 | 
            +
                        method="euler"
         | 
| 69 | 
             
                    ),
         | 
| 70 | 
             
                    odeint_options: dict = dict(
         | 
| 71 | 
             
                        min_step=0.05
         | 
|  | |
| 77 | 
             
                    num_channels=None,
         | 
| 78 | 
             
                    frac_lengths_mask: tuple[float, float] = (0.7, 1.0),
         | 
| 79 | 
             
                    vocab_char_map: dict[str:int] | None = None,
         | 
| 80 | 
            +
                    max_frames=2048
         | 
| 81 | 
             
                ):
         | 
| 82 | 
             
                    super().__init__()
         | 
| 83 |  | 
|  | |
| 106 |  | 
| 107 | 
             
                    # vocab map for tokenization
         | 
| 108 | 
             
                    self.vocab_char_map = vocab_char_map
         | 
| 109 | 
            +
                    
         | 
| 110 | 
            +
                    self.max_frames = max_frames
         | 
| 111 |  | 
| 112 | 
             
                @property
         | 
| 113 | 
             
                def device(self):
         | 
|  | |
| 135 | 
             
                    t_inter=0.1,
         | 
| 136 | 
             
                    edit_mask=None,
         | 
| 137 | 
             
                    start_time=None,
         | 
| 138 | 
            +
                    latent_pred_segments=None,
         | 
|  | |
| 139 | 
             
                    vocal_flag=False,
         | 
| 140 | 
            +
                    odeint_method="euler",
         | 
| 141 | 
            +
                    batch_infer_num=5
         | 
| 142 | 
             
                ):
         | 
| 143 | 
             
                    self.eval()
         | 
| 144 |  | 
|  | |
| 148 | 
             
                        cond = cond.half()
         | 
| 149 |  | 
| 150 | 
             
                    # raw wave
         | 
|  | |
| 151 | 
             
                    if cond.shape[1] > duration:
         | 
| 152 | 
             
                        cond = cond[:, :duration, :]
         | 
| 153 |  | 
|  | |
| 161 | 
             
                        lens = torch.full((batch,), cond_seq_len, device=device, dtype=torch.long)
         | 
| 162 |  | 
| 163 | 
             
                    # text
         | 
|  | |
| 164 | 
             
                    if isinstance(text, list):
         | 
| 165 | 
             
                        if exists(self.vocab_char_map):
         | 
| 166 | 
             
                            text = list_str_to_idx(text, self.vocab_char_map).to(device)
         | 
|  | |
| 168 | 
             
                            text = list_str_to_tensor(text).to(device)
         | 
| 169 | 
             
                        assert text.shape[0] == batch
         | 
| 170 |  | 
|  | |
|  | |
|  | |
|  | |
| 171 | 
             
                    # duration
         | 
| 172 | 
             
                    cond_mask = lens_to_mask(lens)
         | 
| 173 | 
             
                    if edit_mask is not None:
         | 
| 174 | 
             
                        cond_mask = cond_mask & edit_mask
         | 
| 175 |  | 
| 176 | 
            +
                    latent_pred_segments = torch.tensor(latent_pred_segments).to(cond.device)
         | 
| 177 | 
            +
                    fixed_span_mask = custom_mask_from_start_end_indices(cond_seq_len, latent_pred_segments, device=cond.device, max_seq_len=duration)
         | 
|  | |
|  | |
|  | |
| 178 | 
             
                    fixed_span_mask = fixed_span_mask.unsqueeze(-1)
         | 
| 179 | 
             
                    step_cond = torch.where(fixed_span_mask, torch.zeros_like(cond), cond)
         | 
| 180 |  | 
| 181 | 
             
                    if isinstance(duration, int):
         | 
| 182 | 
            +
                        duration = torch.full((batch_infer_num,), duration, device=device, dtype=torch.long)
         | 
|  | |
| 183 |  | 
| 184 | 
             
                    duration = duration.clamp(max=max_duration)
         | 
| 185 | 
             
                    max_duration = duration.amax()
         | 
|  | |
| 188 | 
             
                    if duplicate_test:
         | 
| 189 | 
             
                        test_cond = F.pad(cond, (0, 0, cond_seq_len, max_duration - 2 * cond_seq_len), value=0.0)
         | 
| 190 |  | 
|  | |
| 191 | 
             
                    if batch > 1:
         | 
| 192 | 
             
                        mask = lens_to_mask(duration)
         | 
| 193 | 
             
                    else:  # save memory and speed up, as single inference need no mask currently
         | 
|  | |
| 196 | 
             
                    # test for no ref audio
         | 
| 197 | 
             
                    if no_ref_audio:
         | 
| 198 | 
             
                        cond = torch.zeros_like(cond)
         | 
| 199 | 
            +
                        
         | 
| 200 | 
            +
                    if vocal_flag:
         | 
| 201 | 
            +
                        style_prompt = negative_style_prompt
         | 
| 202 | 
            +
                        negative_style_prompt = torch.zeros_like(style_prompt)
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    cond = cond.repeat(batch_infer_num, 1, 1)
         | 
| 205 | 
            +
                    step_cond = step_cond.repeat(batch_infer_num, 1, 1)
         | 
| 206 | 
            +
                    text = text.repeat(batch_infer_num, 1)
         | 
| 207 | 
            +
                    style_prompt = style_prompt.repeat(batch_infer_num, 1)
         | 
| 208 | 
            +
                    negative_style_prompt = negative_style_prompt.repeat(batch_infer_num, 1)
         | 
| 209 | 
            +
                    start_time = start_time.repeat(batch_infer_num)
         | 
| 210 | 
            +
                    fixed_span_mask = fixed_span_mask.repeat(batch_infer_num, 1, 1)
         | 
| 211 |  | 
| 212 | 
             
                    start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
         | 
| 213 | 
             
                    _, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
         | 
| 214 |  | 
|  | |
|  | |
|  | |
|  | |
| 215 | 
             
                    text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
         | 
| 216 | 
             
                    text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
         | 
| 217 | 
             
                    step_cond = torch.cat([step_cond, step_cond], 0)
         | 
| 218 | 
             
                    style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
         | 
| 219 | 
             
                    start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
         | 
|  | |
| 220 |  | 
| 221 | 
             
                    def fn(t, x):
         | 
| 222 | 
             
                        x = torch.cat([x, x], 0)
         | 
|  | |
| 247 | 
             
                        t_start = t_inter
         | 
| 248 | 
             
                        y0 = (1 - t_start) * y0 + t_start * test_cond
         | 
| 249 | 
             
                        steps = int(steps * (1 - t_start))
         | 
| 250 | 
            +
                    
         | 
| 251 | 
             
                    t = torch.linspace(t_start, 1, steps, device=self.device, dtype=step_cond.dtype)
         | 
| 252 | 
             
                    if sway_sampling_coef is not None:
         | 
| 253 | 
             
                        t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)
         | 
|  | |
| 262 | 
             
                        out = out.permute(0, 2, 1)
         | 
| 263 | 
             
                        out = vocoder(out)
         | 
| 264 |  | 
| 265 | 
            +
                    out = torch.chunk(out, batch_infer_num, dim=0)
         | 
| 266 | 
             
                    return out, trajectory
         | 
| 267 |  | 
| 268 | 
             
                def forward(
         | 
|  | |
| 287 |  | 
| 288 | 
             
                    # get a random span to mask out for training conditionally
         | 
| 289 | 
             
                    frac_lengths = torch.zeros((batch,), device=self.device).float().uniform_(*self.frac_lengths_mask)
         | 
| 290 | 
            +
                    rand_span_mask = mask_from_frac_lengths(lens, frac_lengths, self.max_frames)
         | 
| 291 |  | 
| 292 | 
             
                    if exists(mask):
         | 
| 293 | 
             
                        rand_span_mask = mask
         | 
|  | |
| 294 |  | 
| 295 | 
             
                    # mel is x1
         | 
| 296 | 
             
                    x1 = inp
         | 
|  | |
| 320 | 
             
                    # adding mask will use more memory, thus also need to adjust batchsampler with scaled down threshold for long sequences
         | 
| 321 | 
             
                    pred = self.transformer(
         | 
| 322 | 
             
                        x=φ, cond=cond, text=text, time=time, drop_audio_cond=drop_audio_cond, drop_text=drop_text, drop_prompt=drop_prompt,
         | 
| 323 | 
            +
                        style_prompt=style_prompt, start_time=start_time
         | 
| 324 | 
             
                    )
         | 
| 325 |  | 
| 326 | 
             
                    # flow matching loss
         | 
    	
        diffrhythm/model/dit.py
    CHANGED
    
    | @@ -1,10 +1,22 @@ | |
| 1 | 
            -
             | 
| 2 | 
            -
             | 
| 3 | 
            -
             | 
| 4 | 
            -
             | 
| 5 | 
            -
             | 
| 6 | 
            -
             | 
| 7 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 8 | 
             
            """
         | 
| 9 |  | 
| 10 | 
             
            from __future__ import annotations
         | 
| @@ -12,22 +24,19 @@ from __future__ import annotations | |
| 12 | 
             
            import torch
         | 
| 13 | 
             
            from torch import nn
         | 
| 14 | 
             
            import torch
         | 
| 15 | 
            -
             | 
| 16 | 
             
            from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
         | 
| 17 | 
             
            from transformers.models.llama import LlamaConfig
         | 
| 18 | 
            -
            from torch.utils.checkpoint import checkpoint
         | 
| 19 |  | 
| 20 | 
             
            from diffrhythm.model.modules import (
         | 
| 21 | 
             
                TimestepEmbedding,
         | 
| 22 | 
             
                ConvNeXtV2Block,
         | 
| 23 | 
             
                ConvPositionEmbedding,
         | 
| 24 | 
            -
                DiTBlock,
         | 
| 25 | 
             
                AdaLayerNormZero_Final,
         | 
| 26 | 
             
                precompute_freqs_cis,
         | 
| 27 | 
             
                get_pos_embed_indices,
         | 
|  | |
| 28 | 
             
            )
         | 
| 29 | 
            -
            # from liger_kernel.transformers import apply_liger_kernel_to_llama
         | 
| 30 | 
            -
            # apply_liger_kernel_to_llama()
         | 
| 31 |  | 
| 32 | 
             
            # Text embedding
         | 
| 33 | 
             
            class TextEmbedding(nn.Module):
         | 
| @@ -77,7 +86,6 @@ class InputEmbedding(nn.Module): | |
| 77 | 
             
                def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], style_emb, time_emb, drop_audio_cond=False):  # noqa: F722
         | 
| 78 | 
             
                    if drop_audio_cond:  # cfg for cond audio
         | 
| 79 | 
             
                        cond = torch.zeros_like(cond)
         | 
| 80 | 
            -
             | 
| 81 | 
             
                    style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
         | 
| 82 | 
             
                    time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
         | 
| 83 | 
             
                    x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
         | 
| @@ -85,9 +93,7 @@ class InputEmbedding(nn.Module): | |
| 85 | 
             
                    return x
         | 
| 86 |  | 
| 87 |  | 
| 88 | 
            -
            # Transformer backbone using  | 
| 89 | 
            -
             | 
| 90 | 
            -
             | 
| 91 | 
             
            class DiT(nn.Module):
         | 
| 92 | 
             
                def __init__(
         | 
| 93 | 
             
                    self,
         | 
| @@ -103,26 +109,25 @@ class DiT(nn.Module): | |
| 103 | 
             
                    text_dim=None,
         | 
| 104 | 
             
                    conv_layers=0,
         | 
| 105 | 
             
                    long_skip_connection=False,
         | 
| 106 | 
            -
                     | 
| 107 | 
            -
                    max_pos=2048,
         | 
| 108 | 
             
                ):
         | 
| 109 | 
             
                    super().__init__()
         | 
|  | |
|  | |
| 110 |  | 
| 111 | 
             
                    cond_dim = 512
         | 
| 112 | 
             
                    self.time_embed = TimestepEmbedding(cond_dim)
         | 
| 113 | 
             
                    self.start_time_embed = TimestepEmbedding(cond_dim)
         | 
| 114 | 
             
                    if text_dim is None:
         | 
| 115 | 
             
                        text_dim = mel_dim
         | 
| 116 | 
            -
                    self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos= | 
| 117 | 
             
                    self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
         | 
| 118 |  | 
| 119 | 
            -
             | 
| 120 | 
             
                    self.dim = dim
         | 
| 121 | 
             
                    self.depth = depth
         | 
| 122 |  | 
| 123 | 
            -
                    llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings= | 
| 124 | 
             
                    llama_config._attn_implementation = 'sdpa'
         | 
| 125 | 
            -
             | 
| 126 | 
             
                    self.transformer_blocks = nn.ModuleList(
         | 
| 127 | 
             
                        [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
         | 
| 128 | 
             
                    )
         | 
| @@ -144,7 +149,6 @@ class DiT(nn.Module): | |
| 144 | 
             
                    self.norm_out = AdaLayerNormZero_Final(dim, cond_dim)  # final modulation
         | 
| 145 | 
             
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 146 |  | 
| 147 | 
            -
             | 
| 148 | 
             
                def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
         | 
| 149 | 
             
                    s_t = self.start_time_embed(start_time)
         | 
| 150 | 
             
                    text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
         | 
| @@ -187,11 +191,22 @@ class DiT(nn.Module): | |
| 187 | 
             
                    pos_ids = torch.arange(x.shape[1], device=x.device)
         | 
| 188 | 
             
                    pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
         | 
| 189 | 
             
                    rotary_embed = self.rotary_emb(x, pos_ids)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 190 |  | 
| 191 | 
             
                    for i, block in enumerate(self.transformer_blocks):
         | 
| 192 | 
            -
                        x, *_ = block(x, position_embeddings=rotary_embed)
         | 
| 193 | 
             
                        if i < self.depth // 2:
         | 
| 194 | 
            -
                            x = x +  | 
| 195 |  | 
| 196 | 
             
                    if self.long_skip_connection is not None:
         | 
| 197 | 
             
                        x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
         | 
|  | |
| 1 | 
            +
            # Copyright (c) 2025 ASLP-LAB
         | 
| 2 | 
            +
            #               2025 Ziqian Ning   (ningziqian@mail.nwpu.edu.cn)
         | 
| 3 | 
            +
            #               2025 Huakang Chen  (huakang@mail.nwpu.edu.cn)
         | 
| 4 | 
            +
            #               2025 Yuepeng Jiang (Jiangyp@mail.nwpu.edu.cn)
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 7 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 8 | 
            +
            # You may obtain a copy of the License at
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 13 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 14 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 15 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 16 | 
            +
            # limitations under the License.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            """ This implementation is adapted from github repo:
         | 
| 19 | 
            +
                https://github.com/SWivid/F5-TTS.
         | 
| 20 | 
             
            """
         | 
| 21 |  | 
| 22 | 
             
            from __future__ import annotations
         | 
|  | |
| 24 | 
             
            import torch
         | 
| 25 | 
             
            from torch import nn
         | 
| 26 | 
             
            import torch
         | 
| 27 | 
            +
             | 
| 28 | 
             
            from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
         | 
| 29 | 
             
            from transformers.models.llama import LlamaConfig
         | 
|  | |
| 30 |  | 
| 31 | 
             
            from diffrhythm.model.modules import (
         | 
| 32 | 
             
                TimestepEmbedding,
         | 
| 33 | 
             
                ConvNeXtV2Block,
         | 
| 34 | 
             
                ConvPositionEmbedding,
         | 
|  | |
| 35 | 
             
                AdaLayerNormZero_Final,
         | 
| 36 | 
             
                precompute_freqs_cis,
         | 
| 37 | 
             
                get_pos_embed_indices,
         | 
| 38 | 
            +
                _prepare_decoder_attention_mask,
         | 
| 39 | 
             
            )
         | 
|  | |
|  | |
| 40 |  | 
| 41 | 
             
            # Text embedding
         | 
| 42 | 
             
            class TextEmbedding(nn.Module):
         | 
|  | |
| 86 | 
             
                def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], style_emb, time_emb, drop_audio_cond=False):  # noqa: F722
         | 
| 87 | 
             
                    if drop_audio_cond:  # cfg for cond audio
         | 
| 88 | 
             
                        cond = torch.zeros_like(cond)
         | 
|  | |
| 89 | 
             
                    style_emb = style_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
         | 
| 90 | 
             
                    time_emb = time_emb.unsqueeze(1).repeat(1, x.shape[1], 1)
         | 
| 91 | 
             
                    x = self.proj(torch.cat((x, cond, text_embed, style_emb, time_emb), dim=-1))
         | 
|  | |
| 93 | 
             
                    return x
         | 
| 94 |  | 
| 95 |  | 
| 96 | 
            +
            # Transformer backbone using Llama blocks
         | 
|  | |
|  | |
| 97 | 
             
            class DiT(nn.Module):
         | 
| 98 | 
             
                def __init__(
         | 
| 99 | 
             
                    self,
         | 
|  | |
| 109 | 
             
                    text_dim=None,
         | 
| 110 | 
             
                    conv_layers=0,
         | 
| 111 | 
             
                    long_skip_connection=False,
         | 
| 112 | 
            +
                    max_frames=2048
         | 
|  | |
| 113 | 
             
                ):
         | 
| 114 | 
             
                    super().__init__()
         | 
| 115 | 
            +
                    
         | 
| 116 | 
            +
                    self.max_frames = max_frames
         | 
| 117 |  | 
| 118 | 
             
                    cond_dim = 512
         | 
| 119 | 
             
                    self.time_embed = TimestepEmbedding(cond_dim)
         | 
| 120 | 
             
                    self.start_time_embed = TimestepEmbedding(cond_dim)
         | 
| 121 | 
             
                    if text_dim is None:
         | 
| 122 | 
             
                        text_dim = mel_dim
         | 
| 123 | 
            +
                    self.text_embed = TextEmbedding(text_num_embeds, text_dim, conv_layers=conv_layers, max_pos=self.max_frames)
         | 
| 124 | 
             
                    self.input_embed = InputEmbedding(mel_dim, text_dim, dim, cond_dim=cond_dim)
         | 
| 125 |  | 
|  | |
| 126 | 
             
                    self.dim = dim
         | 
| 127 | 
             
                    self.depth = depth
         | 
| 128 |  | 
| 129 | 
            +
                    llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu', max_position_embeddings=self.max_frames)
         | 
| 130 | 
             
                    llama_config._attn_implementation = 'sdpa'
         | 
|  | |
| 131 | 
             
                    self.transformer_blocks = nn.ModuleList(
         | 
| 132 | 
             
                        [LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
         | 
| 133 | 
             
                    )
         | 
|  | |
| 149 | 
             
                    self.norm_out = AdaLayerNormZero_Final(dim, cond_dim)  # final modulation
         | 
| 150 | 
             
                    self.proj_out = nn.Linear(dim, mel_dim)
         | 
| 151 |  | 
|  | |
| 152 | 
             
                def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
         | 
| 153 | 
             
                    s_t = self.start_time_embed(start_time)
         | 
| 154 | 
             
                    text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
         | 
|  | |
| 191 | 
             
                    pos_ids = torch.arange(x.shape[1], device=x.device)
         | 
| 192 | 
             
                    pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
         | 
| 193 | 
             
                    rotary_embed = self.rotary_emb(x, pos_ids)
         | 
| 194 | 
            +
                    
         | 
| 195 | 
            +
                    attention_mask = torch.ones(
         | 
| 196 | 
            +
                        (batch, seq_len),
         | 
| 197 | 
            +
                        dtype=torch.bool,
         | 
| 198 | 
            +
                        device=x.device,
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
                    attention_mask = _prepare_decoder_attention_mask(
         | 
| 201 | 
            +
                        attention_mask,
         | 
| 202 | 
            +
                        (batch, seq_len),
         | 
| 203 | 
            +
                        x,
         | 
| 204 | 
            +
                    )
         | 
| 205 |  | 
| 206 | 
             
                    for i, block in enumerate(self.transformer_blocks):
         | 
| 207 | 
            +
                        x, *_ = block(x, attention_mask=attention_mask, position_embeddings=rotary_embed)
         | 
| 208 | 
             
                        if i < self.depth // 2:
         | 
| 209 | 
            +
                            x = x + self.text_fusion_linears[i](text_embed)
         | 
| 210 |  | 
| 211 | 
             
                    if self.long_skip_connection is not None:
         | 
| 212 | 
             
                        x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
         | 
    	
        diffrhythm/model/modules.py
    CHANGED
    
    | @@ -609,3 +609,44 @@ class TimestepEmbedding(nn.Module): | |
| 609 | 
             
                    time_hidden = time_hidden.to(timestep.dtype)
         | 
| 610 | 
             
                    time = self.time_mlp(time_hidden)  # b d
         | 
| 611 | 
             
                    return time
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 609 | 
             
                    time_hidden = time_hidden.to(timestep.dtype)
         | 
| 610 | 
             
                    time = self.time_mlp(time_hidden)  # b d
         | 
| 611 | 
             
                    return time
         | 
| 612 | 
            +
             | 
| 613 | 
            +
             | 
| 614 | 
            +
            # attention mask realated
         | 
| 615 | 
            +
             | 
| 616 | 
            +
             | 
| 617 | 
            +
            def _prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds):
         | 
| 618 | 
            +
                # create noncausal mask
         | 
| 619 | 
            +
                # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
         | 
| 620 | 
            +
                combined_attention_mask = None
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                def _expand_mask(
         | 
| 623 | 
            +
                    mask: torch.Tensor, dtype: torch.dtype, tgt_len: int = None
         | 
| 624 | 
            +
                ):
         | 
| 625 | 
            +
                    """
         | 
| 626 | 
            +
                    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
         | 
| 627 | 
            +
                    """
         | 
| 628 | 
            +
                    bsz, src_len = mask.size()
         | 
| 629 | 
            +
                    tgt_len = tgt_len if tgt_len is not None else src_len
         | 
| 630 | 
            +
             | 
| 631 | 
            +
                    expanded_mask = (
         | 
| 632 | 
            +
                        mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
         | 
| 633 | 
            +
                    )
         | 
| 634 | 
            +
             | 
| 635 | 
            +
                    inverted_mask = 1.0 - expanded_mask
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                    return inverted_mask.masked_fill(
         | 
| 638 | 
            +
                        inverted_mask.to(torch.bool), torch.finfo(dtype).min
         | 
| 639 | 
            +
                    )
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                if attention_mask is not None:
         | 
| 642 | 
            +
                    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
         | 
| 643 | 
            +
                    expanded_attn_mask = _expand_mask(
         | 
| 644 | 
            +
                        attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
         | 
| 645 | 
            +
                    ).to(inputs_embeds.device)
         | 
| 646 | 
            +
                    combined_attention_mask = (
         | 
| 647 | 
            +
                        expanded_attn_mask
         | 
| 648 | 
            +
                        if combined_attention_mask is None
         | 
| 649 | 
            +
                        else expanded_attn_mask + combined_attention_mask
         | 
| 650 | 
            +
                    )
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                return combined_attention_mask
         | 
    	
        diffrhythm/model/utils.py
    CHANGED
    
    | @@ -44,15 +44,15 @@ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]:  # noqa | |
| 44 | 
             
                return seq[None, :] < t[:, None]
         | 
| 45 |  | 
| 46 |  | 
| 47 | 
            -
            def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]):  # noqa: F722 F821
         | 
| 48 | 
            -
                max_seq_len =  | 
| 49 | 
             
                seq = torch.arange(max_seq_len, device=start.device).long()
         | 
| 50 | 
             
                start_mask = seq[None, :] >= start[:, None]
         | 
| 51 | 
             
                end_mask = seq[None, :] < end[:, None]
         | 
| 52 | 
             
                return start_mask & end_mask
         | 
| 53 |  | 
| 54 |  | 
| 55 | 
            -
            def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]):  # noqa: F722 F821
         | 
| 56 | 
             
                lengths = (frac_lengths * seq_len).long()
         | 
| 57 | 
             
                max_start = seq_len - lengths
         | 
| 58 |  | 
| @@ -60,7 +60,7 @@ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]):  # noqa | |
| 60 | 
             
                start = (max_start * rand).long().clamp(min=0)
         | 
| 61 | 
             
                end = start + lengths
         | 
| 62 |  | 
| 63 | 
            -
                return mask_from_start_end_indices(seq_len, start, end)
         | 
| 64 |  | 
| 65 |  | 
| 66 | 
             
            def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]:  # noqa: F722
         | 
|  | |
| 44 | 
             
                return seq[None, :] < t[:, None]
         | 
| 45 |  | 
| 46 |  | 
| 47 | 
            +
            def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"], max_frames):  # noqa: F722 F821
         | 
| 48 | 
            +
                max_seq_len = max_frames
         | 
| 49 | 
             
                seq = torch.arange(max_seq_len, device=start.device).long()
         | 
| 50 | 
             
                start_mask = seq[None, :] >= start[:, None]
         | 
| 51 | 
             
                end_mask = seq[None, :] < end[:, None]
         | 
| 52 | 
             
                return start_mask & end_mask
         | 
| 53 |  | 
| 54 |  | 
| 55 | 
            +
            def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"], max_frames):  # noqa: F722 F821
         | 
| 56 | 
             
                lengths = (frac_lengths * seq_len).long()
         | 
| 57 | 
             
                max_start = seq_len - lengths
         | 
| 58 |  | 
|  | |
| 60 | 
             
                start = (max_start * rand).long().clamp(min=0)
         | 
| 61 | 
             
                end = start + lengths
         | 
| 62 |  | 
| 63 | 
            +
                return mask_from_start_end_indices(seq_len, start, end, max_frames)
         | 
| 64 |  | 
| 65 |  | 
| 66 | 
             
            def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]:  # noqa: F722
         | 
    	
        pretrained/eval.py
    ADDED
    
    | @@ -0,0 +1,66 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from einops import rearrange
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class Generator(nn.Module):
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                def __init__(self,
         | 
| 10 | 
            +
                             in_features,
         | 
| 11 | 
            +
                             ffd_hidden_size,
         | 
| 12 | 
            +
                             num_classes,
         | 
| 13 | 
            +
                             attn_layer_num,
         | 
| 14 | 
            +
                             
         | 
| 15 | 
            +
                             ):
         | 
| 16 | 
            +
                    super(Generator, self).__init__()
         | 
| 17 | 
            +
                    
         | 
| 18 | 
            +
                    self.attn = nn.ModuleList(
         | 
| 19 | 
            +
                        [
         | 
| 20 | 
            +
                            nn.MultiheadAttention(
         | 
| 21 | 
            +
                                embed_dim=in_features,
         | 
| 22 | 
            +
                                num_heads=8,
         | 
| 23 | 
            +
                                dropout=0.2,
         | 
| 24 | 
            +
                                batch_first=True,
         | 
| 25 | 
            +
                            )
         | 
| 26 | 
            +
                            for _ in range(attn_layer_num)
         | 
| 27 | 
            +
                        ]
         | 
| 28 | 
            +
                    )
         | 
| 29 | 
            +
                    
         | 
| 30 | 
            +
                    self.ffd = nn.Sequential(
         | 
| 31 | 
            +
                        nn.Linear(in_features, ffd_hidden_size),
         | 
| 32 | 
            +
                        nn.ReLU(),
         | 
| 33 | 
            +
                        nn.Linear(ffd_hidden_size, in_features)
         | 
| 34 | 
            +
                    )
         | 
| 35 | 
            +
                    
         | 
| 36 | 
            +
                    self.dropout = nn.Dropout(0.2)
         | 
| 37 | 
            +
                    
         | 
| 38 | 
            +
                    self.fc =  nn.Linear(in_features * 2, num_classes)
         | 
| 39 | 
            +
                    
         | 
| 40 | 
            +
                    self.proj = nn.Tanh()
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def forward(self, ssl_feature, judge_id=None):
         | 
| 44 | 
            +
                    '''
         | 
| 45 | 
            +
                    ssl_feature: [B, T, D]   
         | 
| 46 | 
            +
                    output: [B, num_classes]
         | 
| 47 | 
            +
                    '''
         | 
| 48 | 
            +
                    
         | 
| 49 | 
            +
                    B, T, D = ssl_feature.shape
         | 
| 50 | 
            +
                    
         | 
| 51 | 
            +
                    ssl_feature = self.ffd(ssl_feature)
         | 
| 52 | 
            +
                    
         | 
| 53 | 
            +
                    tmp_ssl_feature = ssl_feature
         | 
| 54 | 
            +
                    
         | 
| 55 | 
            +
                    for attn in self.attn:
         | 
| 56 | 
            +
                        tmp_ssl_feature, _ = attn(tmp_ssl_feature, tmp_ssl_feature, tmp_ssl_feature)
         | 
| 57 | 
            +
                
         | 
| 58 | 
            +
                    ssl_feature = self.dropout(torch.concat([torch.mean(tmp_ssl_feature, dim=1), torch.max(ssl_feature, dim=1)[0]], dim=1))  # B, 2D
         | 
| 59 | 
            +
                    
         | 
| 60 | 
            +
                    x = self.fc(ssl_feature)  # B, num_classes
         | 
| 61 | 
            +
                    
         | 
| 62 | 
            +
                    x = self.proj(x) * 2.0 + 3
         | 
| 63 | 
            +
                    
         | 
| 64 | 
            +
                    return x
         | 
| 65 | 
            +
                
         | 
| 66 | 
            +
                
         | 
    	
        pretrained/eval.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:81cbd54af8b103251e425fcbd8f5313975cb742e760c3dae1e10f99969933fd6
         | 
| 3 | 
            +
            size 100792276
         | 
    	
        pretrained/eval.yaml
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            generator:
         | 
| 2 | 
            +
              _target_: pretrained.eval.Generator
         | 
| 3 | 
            +
              in_features: 1024
         | 
| 4 | 
            +
              ffd_hidden_size: 4096
         | 
| 5 | 
            +
              num_classes: 5
         | 
| 6 | 
            +
              attn_layer_num: 4
         | 
