Khushi Dahiya commited on
Commit
f568365
·
1 Parent(s): af4f55f

updating predict to handle batch processing oof concurrent requests

Browse files
Files changed (2) hide show
  1. demos/melodyflow_app.py +314 -50
  2. requirements.txt +1 -0
demos/melodyflow_app.py CHANGED
@@ -16,6 +16,11 @@ import time
16
  import typing as tp
17
  import warnings
18
  import base64
 
 
 
 
 
19
 
20
  import torch
21
  import gradio as gr
@@ -26,6 +31,11 @@ from audiocraft.models import MelodyFlow
26
 
27
 
28
  MODEL = None # Last used model
 
 
 
 
 
29
  SPACE_ID = os.environ.get('SPACE_ID', '')
30
  MODEL_PREFIX = os.environ.get('MODEL_PREFIX', 'facebook/')
31
  IS_HF_SPACE = (MODEL_PREFIX + "MelodyFlow") in SPACE_ID
@@ -68,6 +78,220 @@ class FileCleaner:
68
  file_cleaner = FileCleaner()
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def make_waveform(*args, **kwargs):
72
  # Further remove some warnings.
73
  be = time.time()
@@ -80,14 +304,16 @@ def make_waveform(*args, **kwargs):
80
 
81
  def load_model(version=(MODEL_PREFIX + "melodyflow-t24-30secs")):
82
  global MODEL
83
- print("Loading model", version)
84
- if MODEL is None or MODEL.name != version:
85
- # Clear PyTorch CUDA cache and delete model
86
- del MODEL
87
- if torch.cuda.is_available():
88
- torch.cuda.empty_cache()
89
- MODEL = None # in case loading would crash
90
- MODEL = MelodyFlow.get_pretrained(version)
 
 
91
 
92
 
93
  def _do_predictions(texts,
@@ -153,24 +379,32 @@ def _do_predictions(texts,
153
  return out_wavs
154
 
155
 
156
- @spaces.GPU(duration=30)
157
  def predict(model, text,
158
- solver, steps, target_flowstep,
159
- regularize,
160
- regularization_strength,
161
- duration,
162
- melody=None,
163
- model_path=None,
164
- progress=gr.Progress()):
 
 
165
  if melody is not None:
166
  if solver == MIDPOINT:
167
  steps = steps//2
168
  else:
169
  steps = steps//5
170
 
171
- global INTERRUPTING
172
  INTERRUPTING = False
173
- progress(0, desc="Loading model...")
 
 
 
 
 
 
 
174
  if model_path:
175
  model_path = model_path.strip()
176
  if not Path(model_path).exists():
@@ -180,40 +414,51 @@ def predict(model, text,
180
  "state_dict.bin and compression_state_dict_.bin.")
181
  model = model_path
182
 
183
- load_model(model)
184
-
185
- max_generated = 0
186
-
187
- def _progress(generated, to_generate):
188
- nonlocal max_generated
189
- max_generated = max(generated, max_generated)
190
- progress((min(max_generated, to_generate), to_generate))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  if INTERRUPTING:
192
  raise gr.Error("Interrupted.")
193
- MODEL.set_custom_progress_callback(_progress)
194
-
195
- wavs = _do_predictions(
196
- [text] * N_REPEATS, [melody],
197
- solver=solver,
198
- steps=steps,
199
- target_flowstep=target_flowstep,
200
- regularize=regularize,
201
- regularization_strength=regularization_strength,
202
- duration=duration,
203
- progress=True,)
204
-
205
- # Read the audio file and convert to base64
206
- wav_path = wavs[0]
207
- with open(wav_path, 'rb') as f:
208
- audio_bytes = f.read()
209
 
210
- audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
211
 
212
- # Return as a dictionary with base64 data
213
- return {
214
- "audio": audio_base64,
215
- "format": "wav"
216
- }
 
217
 
218
 
219
  def toggle_audio_src(choice):
@@ -353,7 +598,11 @@ def ui_local(launch_kwargs):
353
  """
354
  )
355
 
356
- interface.queue().launch(**launch_kwargs)
 
 
 
 
357
 
358
  def ui_hf(launch_kwargs):
359
  with gr.Blocks() as interface:
@@ -470,7 +719,19 @@ def ui_hf(launch_kwargs):
470
  for more details.
471
  """)
472
 
473
- interface.queue().launch(**launch_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
 
476
  if __name__ == "__main__":
@@ -514,6 +775,9 @@ if __name__ == "__main__":
514
  if args.share:
515
  launch_kwargs['share'] = args.share
516
 
 
 
 
517
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
518
 
519
  # Show the interface
 
16
  import typing as tp
17
  import warnings
18
  import base64
19
+ import asyncio
20
+ import threading
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ from queue import Queue, Empty
23
+ import uuid
24
 
25
  import torch
26
  import gradio as gr
 
31
 
32
 
33
  MODEL = None # Last used model
34
+ MODEL_LOCK = threading.Lock() # Thread lock for model access
35
+ REQUEST_QUEUE = Queue() # Queue for batch processing
36
+ BATCH_PROCESSOR = None # Background batch processor
37
+ BATCH_SIZE = 4 # Maximum batch size for concurrent processing
38
+ BATCH_TIMEOUT = 2.0 # Maximum wait time to form a batch (seconds)
39
  SPACE_ID = os.environ.get('SPACE_ID', '')
40
  MODEL_PREFIX = os.environ.get('MODEL_PREFIX', 'facebook/')
41
  IS_HF_SPACE = (MODEL_PREFIX + "MelodyFlow") in SPACE_ID
 
78
  file_cleaner = FileCleaner()
79
 
80
 
81
+ class RequestBatch:
82
+ """Represents a batch of requests to process together"""
83
+ def __init__(self):
84
+ self.requests = []
85
+ self.futures = []
86
+ self.created_at = time.time()
87
+
88
+ def add_request(self, request_data, future):
89
+ self.requests.append(request_data)
90
+ self.futures.append(future)
91
+
92
+ def is_full(self):
93
+ return len(self.requests) >= BATCH_SIZE
94
+
95
+ def is_expired(self):
96
+ return time.time() - self.created_at > BATCH_TIMEOUT
97
+
98
+ def should_process(self):
99
+ return self.is_full() or self.is_expired() or len(self.requests) > 0
100
+
101
+
102
+ class BatchProcessor:
103
+ """Handles batched processing of requests"""
104
+ def __init__(self):
105
+ self.current_batch = RequestBatch()
106
+ self.processing = False
107
+ self.stop_event = threading.Event()
108
+
109
+ def start(self):
110
+ """Start the background batch processing thread"""
111
+ self.thread = threading.Thread(target=self._process_loop, daemon=True)
112
+ self.thread.start()
113
+
114
+ def stop(self):
115
+ """Stop the background batch processing"""
116
+ self.stop_event.set()
117
+
118
+ def add_request(self, request_data):
119
+ """Add a request to the batch and return a future for the result"""
120
+ from concurrent.futures import Future
121
+ future = Future()
122
+
123
+ # Add to current batch
124
+ self.current_batch.add_request(request_data, future)
125
+
126
+ # Signal that we have a new request
127
+ REQUEST_QUEUE.put("new_request")
128
+
129
+ return future
130
+
131
+ def _process_loop(self):
132
+ """Main processing loop that runs in background thread"""
133
+ while not self.stop_event.is_set():
134
+ try:
135
+ # Wait for a signal or timeout
136
+ REQUEST_QUEUE.get(timeout=0.5)
137
+
138
+ # Check if we should process current batch
139
+ if self.current_batch.should_process() and not self.processing:
140
+ self._process_current_batch()
141
+
142
+ except Empty:
143
+ # Timeout - check if we have an expired batch
144
+ if self.current_batch.should_process() and not self.processing:
145
+ self._process_current_batch()
146
+ continue
147
+ except Exception as e:
148
+ print(f"Error in batch processing loop: {e}")
149
+
150
+ @spaces.GPU(duration=45) # Increased duration for batch processing
151
+ def _process_current_batch(self):
152
+ """Process the current batch of requests"""
153
+ if len(self.current_batch.requests) == 0:
154
+ return
155
+
156
+ self.processing = True
157
+ batch = self.current_batch
158
+ self.current_batch = RequestBatch() # Start new batch
159
+
160
+ try:
161
+ # Extract batch data
162
+ texts = []
163
+ melodies = []
164
+ params_list = []
165
+
166
+ for request_data in batch.requests:
167
+ texts.append(request_data['text'])
168
+ melodies.append(request_data['melody'])
169
+ params_list.append({
170
+ 'solver': request_data['solver'],
171
+ 'steps': request_data['steps'],
172
+ 'target_flowstep': request_data['target_flowstep'],
173
+ 'regularize': request_data['regularize'],
174
+ 'regularization_strength': request_data['regularization_strength'],
175
+ 'duration': request_data['duration'],
176
+ 'model': request_data['model']
177
+ })
178
+
179
+ # Load model if needed (use the first request's model)
180
+ model_version = params_list[0]['model']
181
+ load_model(model_version)
182
+
183
+ # Process batch with unified parameters (use first request's params)
184
+ params = params_list[0]
185
+ results = _do_predictions_batch(
186
+ texts=texts,
187
+ melodies=melodies,
188
+ solver=params['solver'],
189
+ steps=params['steps'],
190
+ target_flowstep=params['target_flowstep'],
191
+ regularize=params['regularize'],
192
+ regularization_strength=params['regularization_strength'],
193
+ duration=params['duration'],
194
+ progress=False
195
+ )
196
+
197
+ # Set results for each future
198
+ for i, future in enumerate(batch.futures):
199
+ if i < len(results):
200
+ future.set_result(results[i])
201
+ else:
202
+ future.set_exception(Exception("Batch processing failed"))
203
+
204
+ except Exception as e:
205
+ # Set exception for all futures in batch
206
+ for future in batch.futures:
207
+ future.set_exception(e)
208
+ finally:
209
+ self.processing = False
210
+
211
+
212
+ def _do_predictions_batch(texts, melodies, solver, steps, target_flowstep,
213
+ regularize, regularization_strength, duration, progress=False):
214
+ """Process a batch of predictions efficiently"""
215
+ with MODEL_LOCK:
216
+ MODEL.set_generation_params(solver=solver, steps=steps, duration=duration)
217
+ MODEL.set_editing_params(
218
+ solver=solver,
219
+ steps=steps,
220
+ target_flowstep=target_flowstep,
221
+ regularize=regularize,
222
+ lambda_kl=regularization_strength
223
+ )
224
+
225
+ print(f"Processing batch: {len(texts)} requests")
226
+ be = time.time()
227
+
228
+ processed_melodies = []
229
+ target_sr = 48000
230
+ target_ac = 2
231
+
232
+ for melody in melodies:
233
+ if melody is None:
234
+ processed_melodies.append(None)
235
+ else:
236
+ melody, sr = audio_read(melody)
237
+ if melody.dim() == 2:
238
+ melody = melody[None]
239
+ if melody.shape[-1] > int(sr * MODEL.duration):
240
+ melody = melody[..., :int(sr * MODEL.duration)]
241
+ melody = convert_audio(melody, sr, target_sr, target_ac)
242
+ melody = MODEL.encode_audio(melody.to(MODEL.device))
243
+ processed_melodies.append(melody)
244
+
245
+ try:
246
+ # Process all requests in the batch together
247
+ if any(m is not None for m in processed_melodies):
248
+ # For editing mode, process each request individually due to melody constraints
249
+ outputs_list = []
250
+ for i, (text, melody) in enumerate(zip(texts, processed_melodies)):
251
+ if melody is not None:
252
+ output = MODEL.edit(
253
+ prompt_tokens=melody.repeat(1, 1, 1),
254
+ descriptions=[text],
255
+ src_descriptions=[""],
256
+ progress=progress,
257
+ return_tokens=False,
258
+ )
259
+ else:
260
+ output = MODEL.generate([text], progress=progress, return_tokens=False)
261
+ outputs_list.append(output[0])
262
+ outputs = torch.stack(outputs_list)
263
+ else:
264
+ # For generation mode, we can batch all requests
265
+ outputs = MODEL.generate(texts, progress=progress, return_tokens=False)
266
+
267
+ except RuntimeError as e:
268
+ raise gr.Error("Error while generating " + e.args[0])
269
+
270
+ outputs = outputs.detach().cpu().float()
271
+ results = []
272
+
273
+ for output in outputs:
274
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
275
+ audio_write(
276
+ file.name, output, MODEL.sample_rate, strategy="loudness",
277
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
278
+
279
+ # Read and encode audio
280
+ with open(file.name, 'rb') as f:
281
+ audio_bytes = f.read()
282
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
283
+
284
+ results.append({
285
+ "audio": audio_base64,
286
+ "format": "wav"
287
+ })
288
+
289
+ file_cleaner.add(file.name)
290
+
291
+ print(f"Batch finished: {len(texts)} requests in {time.time() - be:.2f}s")
292
+ return results
293
+
294
+
295
  def make_waveform(*args, **kwargs):
296
  # Further remove some warnings.
297
  be = time.time()
 
304
 
305
  def load_model(version=(MODEL_PREFIX + "melodyflow-t24-30secs")):
306
  global MODEL
307
+ with MODEL_LOCK:
308
+ print("Loading model", version)
309
+ if MODEL is None or MODEL.name != version:
310
+ # Clear PyTorch CUDA cache and delete model
311
+ del MODEL
312
+ if torch.cuda.is_available():
313
+ torch.cuda.empty_cache()
314
+ MODEL = None # in case loading would crash
315
+ MODEL = MelodyFlow.get_pretrained(version)
316
+ print(f"Model {version} loaded successfully")
317
 
318
 
319
  def _do_predictions(texts,
 
379
  return out_wavs
380
 
381
 
 
382
  def predict(model, text,
383
+ solver, steps, target_flowstep,
384
+ regularize,
385
+ regularization_strength,
386
+ duration,
387
+ melody=None,
388
+ model_path=None,
389
+ progress=gr.Progress()):
390
+ """Non-blocking predict function that uses batch processing"""
391
+
392
  if melody is not None:
393
  if solver == MIDPOINT:
394
  steps = steps//2
395
  else:
396
  steps = steps//5
397
 
398
+ global INTERRUPTING, BATCH_PROCESSOR
399
  INTERRUPTING = False
400
+
401
+ # Initialize batch processor if not already running
402
+ if BATCH_PROCESSOR is None:
403
+ BATCH_PROCESSOR = BatchProcessor()
404
+ BATCH_PROCESSOR.start()
405
+
406
+ progress(0, desc="Queuing request...")
407
+
408
  if model_path:
409
  model_path = model_path.strip()
410
  if not Path(model_path).exists():
 
414
  "state_dict.bin and compression_state_dict_.bin.")
415
  model = model_path
416
 
417
+ # Prepare request data
418
+ request_data = {
419
+ 'text': text,
420
+ 'melody': melody,
421
+ 'solver': solver,
422
+ 'steps': steps,
423
+ 'target_flowstep': target_flowstep,
424
+ 'regularize': regularize,
425
+ 'regularization_strength': regularization_strength,
426
+ 'duration': duration,
427
+ 'model': model,
428
+ 'request_id': str(uuid.uuid4())
429
+ }
430
+
431
+ # Add to batch processor
432
+ future = BATCH_PROCESSOR.add_request(request_data)
433
+
434
+ progress(0.3, desc="Waiting for GPU...")
435
+
436
+ # Wait for result with progress updates
437
+ max_wait = 60 # Maximum wait time in seconds
438
+ wait_start = time.time()
439
+
440
+ while not future.done():
441
+ elapsed = time.time() - wait_start
442
+ if elapsed > max_wait:
443
+ raise gr.Error("Request timeout")
444
+
445
+ # Update progress based on wait time
446
+ progress_val = min(0.9, 0.3 + (elapsed / max_wait) * 0.6)
447
+ progress(progress_val, desc="Processing...")
448
+
449
  if INTERRUPTING:
450
  raise gr.Error("Interrupted.")
451
+
452
+ time.sleep(0.1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
+ progress(1.0, desc="Complete!")
455
 
456
+ # Get result
457
+ try:
458
+ result = future.result()
459
+ return result
460
+ except Exception as e:
461
+ raise gr.Error(f"Generation failed: {str(e)}")
462
 
463
 
464
  def toggle_audio_src(choice):
 
598
  """
599
  )
600
 
601
+ interface.queue(
602
+ concurrency_count=8, # Allow up to 8 concurrent requests
603
+ max_size=50, # Queue up to 50 requests
604
+ api_open=True # Enable API access
605
+ ).launch(**launch_kwargs)
606
 
607
  def ui_hf(launch_kwargs):
608
  with gr.Blocks() as interface:
 
719
  for more details.
720
  """)
721
 
722
+ interface.queue(
723
+ concurrency_count=8, # Allow up to 8 concurrent requests
724
+ max_size=50, # Queue up to 50 requests
725
+ api_open=True # Enable API access
726
+ ).launch(**launch_kwargs)
727
+
728
+
729
+ def cleanup():
730
+ """Cleanup function for graceful shutdown"""
731
+ global BATCH_PROCESSOR
732
+ if BATCH_PROCESSOR:
733
+ BATCH_PROCESSOR.stop()
734
+ print("Cleanup completed")
735
 
736
 
737
  if __name__ == "__main__":
 
775
  if args.share:
776
  launch_kwargs['share'] = args.share
777
 
778
+ import atexit
779
+ atexit.register(cleanup)
780
+
781
  logging.basicConfig(level=logging.INFO, stream=sys.stderr)
782
 
783
  # Show the interface
requirements.txt CHANGED
@@ -26,3 +26,4 @@ torchvision
26
  torchtext
27
  pesq
28
  pystoi
 
 
26
  torchtext
27
  pesq
28
  pystoi
29
+ spaces