viewfinder-annn commited on
Commit
7bf848e
·
verified ·
1 Parent(s): b8a9aec

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ from anyaccomp.inference_utils import Sing2SongInferencePipeline
4
+ import os
5
+ import random
6
+ import librosa
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import gradio as gr
10
+ import time
11
+
12
+ base_dir = os.path.dirname(
13
+ os.path.abspath(__file__)
14
+ )
15
+
16
+ CFG_PATH = os.path.join(base_dir, "./config/flow_matching.json")
17
+ CHECKPOINT_PATH = os.path.join(
18
+ base_dir, "./pretrained/flow_matching"
19
+ )
20
+ VOCODER_CHECKPOINT_PATH = os.path.join(
21
+ base_dir, "./pretrained/vocoder"
22
+ )
23
+ VOCODER_CFG_PATH = os.path.join(base_dir, "./config/vocoder.json")
24
+ INFER_DST = os.path.join(base_dir, "./example/output_gradio")
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+
27
+ os.makedirs(INFER_DST, exist_ok=True)
28
+ acc_dst = os.path.join(INFER_DST, "accompaniment")
29
+ mixture_dst = os.path.join(INFER_DST, "mixture")
30
+ os.makedirs(acc_dst, exist_ok=True)
31
+ os.makedirs(mixture_dst, exist_ok=True)
32
+
33
+ print("Initializing AnyAccomp InferencePipeline...")
34
+ try:
35
+ inference_pipeline = Sing2SongInferencePipeline(
36
+ CHECKPOINT_PATH,
37
+ CFG_PATH,
38
+ VOCODER_CHECKPOINT_PATH,
39
+ VOCODER_CFG_PATH,
40
+ device=DEVICE,
41
+ )
42
+ inference_pipeline.sample_rate = 24000
43
+ print("Model loaded successfully.")
44
+ except Exception as e:
45
+ print(f"Error loading model: {e}")
46
+ inference_pipeline = None
47
+
48
+
49
+ def sing2song_inference(vocal_filepath, n_timesteps, cfg_scale, seed):
50
+ if inference_pipeline is None:
51
+ raise gr.Error(
52
+ "Model could not be loaded. Please check paths and environment configuration."
53
+ )
54
+
55
+ if vocal_filepath is None:
56
+ raise gr.Error("Please upload a vocal audio file.")
57
+
58
+ if seed == -1 or seed is None:
59
+ seed = random.randint(0, 2**32 - 1)
60
+
61
+ seed = int(seed)
62
+ print(f"Using seed: {seed}")
63
+
64
+ random.seed(seed)
65
+ np.random.seed(seed)
66
+ torch.manual_seed(seed)
67
+ if torch.cuda.is_available():
68
+ torch.cuda.manual_seed_all(seed)
69
+ torch.backends.cudnn.deterministic = True
70
+
71
+ try:
72
+ duration = librosa.get_duration(path=vocal_filepath)
73
+ if not (3 <= duration <= 30):
74
+ raise gr.Error("Audio duration must be between 3 and 30 seconds.")
75
+ except Exception as e:
76
+ raise gr.Error(f"Cannot read audio file or get duration: {e}")
77
+
78
+ try:
79
+ vocal_audio, _ = librosa.load(vocal_filepath, sr=24000, mono=True)
80
+ vocal_tensor = torch.tensor(vocal_audio).unsqueeze(0).to(DEVICE)
81
+
82
+ vocal_mel = inference_pipeline.encode_vocal(vocal_tensor)
83
+
84
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
85
+ mel = inference_pipeline.model.reverse_diffusion(
86
+ vocal_mel=vocal_mel,
87
+ n_timesteps=int(n_timesteps),
88
+ cfg=cfg_scale,
89
+ )
90
+
91
+ mel = mel.float()
92
+ wav = inference_pipeline._generate_audio(mel)
93
+ wav = wav.squeeze().detach().cpu().numpy()
94
+
95
+ wav = librosa.util.fix_length(data=wav, size=len(vocal_audio))
96
+ mixture_wav = wav + vocal_audio
97
+
98
+ timestamp = int(time.time())
99
+ original_filename = os.path.basename(vocal_filepath)
100
+ base_filename = f"{os.path.splitext(original_filename)[0]}_{timestamp}.wav"
101
+
102
+ accompaniment_path = os.path.join(acc_dst, base_filename)
103
+ mixture_path = os.path.join(mixture_dst, base_filename)
104
+
105
+ sf.write(accompaniment_path, wav, 24000)
106
+ sf.write(mixture_path, mixture_wav, 24000)
107
+
108
+ return accompaniment_path, mixture_path, "Status: Complete!"
109
+
110
+ except Exception as e:
111
+ import traceback
112
+
113
+ traceback.print_exc()
114
+ raise gr.Error(f"An error occurred during processing: {e}")
115
+
116
+
117
+ def randomize_seed():
118
+ return random.randint(0, 2**32 - 1)
119
+
120
+
121
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
122
+ gr.Markdown(
123
+ """
124
+ # AnyAccomp: GENERALIZABLE ACCOMPANIMENT GENERATION
125
+ Upload a 3-30 second vocal or instrument track (.wav or .mp3) and the model will generate an accompaniment for it.
126
+ """
127
+ )
128
+
129
+ with gr.Row():
130
+ with gr.Column(scale=1):
131
+ gr.Markdown("### 1. Upload or Select Audio")
132
+ vocal_input = gr.Audio(
133
+ type="filepath",
134
+ label="Upload Vocal or Instrument Audio",
135
+ sources=["upload", "microphone"],
136
+ )
137
+
138
+ example1_path = os.path.join(
139
+ base_dir, "./example/gradio/example1.mp3"
140
+ )
141
+ example2_path = os.path.join(
142
+ base_dir, "./example/gradio/example2.wav"
143
+ )
144
+ example3_path = os.path.join(
145
+ base_dir, "./example/gradio/example3.wav"
146
+ )
147
+ gr.Examples(
148
+ examples=[[example1_path], [example2_path], [example3_path]],
149
+ inputs=[vocal_input],
150
+ label="Or click an example to start",
151
+ )
152
+ gr.Markdown("### 2. Adjust Parameters (Optional)")
153
+ with gr.Accordion("Advanced Settings", open=True):
154
+ n_timesteps_slider = gr.Slider(
155
+ minimum=10,
156
+ maximum=100,
157
+ value=50,
158
+ step=1,
159
+ label="Inference Steps (n_timesteps)",
160
+ )
161
+ cfg_slider = gr.Slider(
162
+ minimum=1.0, maximum=10.0, value=3.0, step=0.1, label="CFG Scale"
163
+ )
164
+
165
+ with gr.Row():
166
+ seed_input = gr.Number(
167
+ value=-1, label="Seed (-1 for random)", precision=0
168
+ )
169
+ random_seed_btn = gr.Button("🎲")
170
+
171
+ with gr.Column(scale=1):
172
+ gr.Markdown("### 3. Generate and Listen to the Result")
173
+
174
+ status_text = gr.Markdown("Status: Ready")
175
+
176
+ accompaniment_output = gr.Audio(
177
+ label="Generated Accompaniment", type="filepath"
178
+ )
179
+ mixture_output = gr.Audio(
180
+ label="Mixture (Vocal + Accompaniment)", type="filepath"
181
+ )
182
+
183
+ submit_btn = gr.Button("Generate Accompaniment", variant="primary")
184
+
185
+ submit_btn.click(
186
+ fn=sing2song_inference,
187
+ inputs=[vocal_input, n_timesteps_slider, cfg_slider, seed_input],
188
+ # The function will now update the status text as its third output
189
+ outputs=[accompaniment_output, mixture_output, status_text],
190
+ )
191
+
192
+ random_seed_btn.click(fn=randomize_seed, inputs=None, outputs=seed_input)
193
+
194
+ demo.queue()
195
+
196
+ if __name__ == "__main__":
197
+ demo.launch(server_name="0.0.0.0", server_port=8091)