Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Tsinghua Univ. (authors: Xingchen Song) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ Example Usage: see README.md | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import datetime | |
| import numpy as np | |
| import onnxruntime | |
| import s3tokenizer | |
| import torch | |
| import torch.distributed as dist | |
| import torchaudio | |
| import torchaudio.compliance.kaldi as kaldi | |
| from torch.utils.data import DataLoader, Dataset, DistributedSampler | |
| from tqdm import tqdm | |
| from flashcosyvoice.config import Config, CosyVoice2LLMConfig, SamplingParams | |
| from flashcosyvoice.cosyvoice2 import CosyVoice2 | |
| from flashcosyvoice.utils.audio import mel_spectrogram | |
| def set_all_random_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def save_file_async( | |
| wav, prompt_speech_tokens, generated_speech_tokens, | |
| info, timing_stats | |
| ): | |
| """Save audio asynchronously.""" | |
| try: | |
| os.makedirs(os.path.dirname(info['wav']), exist_ok=True) | |
| if wav is not None: | |
| wav = wav.cpu() | |
| torchaudio.save(info['wav'], wav, 24000) | |
| duration = wav.shape[-1] / 24000.0 | |
| rtf = ((timing_stats['dataloader_time'] + timing_stats['model_inference_time']) / timing_stats['batch_size']) / duration | |
| timing_stats['rtf'] = rtf | |
| else: | |
| duration = 0.0 | |
| info['timing_stats'] = timing_stats | |
| info['prompt_speech_tokens'] = prompt_speech_tokens | |
| info['generated_speech_tokens'] = generated_speech_tokens | |
| with open(f"{info['wav'].replace('.wav', '.json')}", "w") as f: | |
| json.dump(info, f, ensure_ascii=False, indent=4) | |
| return duration | |
| except Exception as e: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [ERROR] - Error saving audio {info.get('key', 'unknown')}: {e}") | |
| return 0.0 | |
| class AudioDataset(Dataset): | |
| def __init__(self, text_norm, text_tokenizer, data_list, model_config: Config): | |
| self.datas = [] | |
| self.text_norm = text_norm | |
| self.model_config = model_config | |
| """Example data_list: | |
| ``` | |
| {"key": "uttid_1", "prompt_text": "你好,我是小明。", "text": "你好,我是小红。", "prompt_wav": "/mnt/data/audio/00000000.wav", "wav": "/mnt/data/audio_synthetic/uttid_1.wav"} | |
| {"key": "uttid_2", "prompt_text": "你好,我是小红。", "text": "你好,我是小明。", "prompt_wav": "/mnt/data/audio/00000001.wav", "wav": "/mnt/data/audio_synthetic/uttid_2.wav"} | |
| ``` | |
| Note: | |
| - `key` is the key of this sample. | |
| - `prompt_text` is the text used for prompt. | |
| - `text` is the text used for generating real audio. | |
| - `prompt_wav` is the audio used for prompt. | |
| - `wav` is the path to the generated audio to be saved (we highly recommend to pre-define the save path before running the script). | |
| """ | |
| missing = 0 | |
| with open(data_list, 'r', encoding='utf-8') as f: | |
| lines = f.readlines() | |
| total_lines = len(lines) | |
| if torch.distributed.get_node_local_rank() == 0: | |
| iterator = tqdm(lines, desc='Loading data') | |
| else: | |
| iterator = lines | |
| for line in iterator: | |
| data = json.loads(line.strip()) | |
| valid = True | |
| for k in ['key', 'prompt_text', 'text', 'prompt_wav']: | |
| if k not in data: | |
| valid = False | |
| break | |
| if data[k] is None: | |
| valid = False | |
| break | |
| if not os.path.exists(data['prompt_wav']): | |
| valid = False | |
| if valid: | |
| self.datas.append(data) | |
| else: | |
| missing += 1 | |
| if torch.distributed.get_node_local_rank() == 0: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f'[{timestamp}] - [INFO] - Loaded {total_lines} lines, found {missing} missing lines, total valid lines == {len(self.datas)}.') | |
| self.text_tokenizer = text_tokenizer | |
| option = onnxruntime.SessionOptions() | |
| option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| option.intra_op_num_threads = 1 | |
| self.spk_model = onnxruntime.InferenceSession(f"{self.model_config.model}/campplus.onnx", sess_options=option, | |
| providers=["CPUExecutionProvider"]) | |
| def __len__(self): | |
| return len(self.datas) | |
| def __getitem__(self, idx): | |
| data = self.datas[idx] | |
| try: | |
| # 1. feature for s3tokenizer | |
| audio = s3tokenizer.load_audio(data['prompt_wav'], sr=16000) # [T] | |
| log_mel = s3tokenizer.log_mel_spectrogram(audio) # [num_mels, T] | |
| # 2. feature for speaker embedding | |
| spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) | |
| spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) | |
| spk_emb = self.spk_model.run( | |
| None, {self.spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()} | |
| )[0].flatten().tolist() | |
| # 3. feature for flow | |
| audio, sample_rate = torchaudio.load(data['prompt_wav'], backend='soundfile') | |
| audio = audio.mean(dim=0, keepdim=True) # [1, T] | |
| if sample_rate != 24000: | |
| audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio) | |
| mel = mel_spectrogram(audio).transpose(1, 2).squeeze(0) # [T, num_mels] | |
| mel_len = mel.shape[0] | |
| # 4. feature for llm | |
| if self.text_norm is not None: | |
| prompt_texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['prompt_text'].strip()))["sentences"]] | |
| prompt_text = ''.join(prompt_texts) | |
| texts = [i["text"] for i in json.loads(self.text_norm.do_voicegen_frd(data['text'].strip()))["sentences"]] | |
| text = ''.join(texts) | |
| else: | |
| prompt_text = data['prompt_text'] | |
| text = data['text'] | |
| prompt_text_ids = self.text_tokenizer.encode(prompt_text) | |
| prompt_text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in prompt_text_ids] | |
| text_ids = self.text_tokenizer.encode(text) | |
| text_ids = [i + self.model_config.hf_config.speech_vocab_size + 2 for i in text_ids] | |
| item = { | |
| "prompt_text_tokens": prompt_text_ids, "text_tokens": text_ids, | |
| "spk_emb": spk_emb, "mel": mel, "mel_len": mel_len, "log_mel": log_mel, "info": data, | |
| "min_tokens": len(text_ids) * self.model_config.min_token_text_ratio, | |
| "max_tokens": len(text_ids) * self.model_config.max_token_text_ratio, | |
| } | |
| except Exception as e: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [WARNING] - Error processing data item {data.get('key', idx)}: {e}") | |
| return None | |
| return item | |
| def collate_fn(batch): | |
| prompt_mels_for_llm = [item["log_mel"] for item in batch if item is not None] | |
| prompt_mels_for_llm, prompt_mels_lens_for_llm = s3tokenizer.padding(prompt_mels_for_llm) # [B, num_mels=128, T] | |
| prompt_text_tokens_for_llm = [item["prompt_text_tokens"] for item in batch if item is not None] | |
| text_tokens_for_llm = [item["text_tokens"] for item in batch if item is not None] | |
| prompt_mels_for_flow = [item["mel"] for item in batch if item is not None] | |
| prompt_mels_for_flow = torch.nn.utils.rnn.pad_sequence(prompt_mels_for_flow, batch_first=True, padding_value=0) # [B, T', num_mels=80] | |
| prompt_mels_lens_for_flow = [item["mel_len"] for item in batch if item is not None] | |
| prompt_mels_lens_for_flow = torch.tensor(prompt_mels_lens_for_flow) | |
| spk_emb_for_flow = [item["spk_emb"] for item in batch if item is not None] | |
| spk_emb_for_flow = torch.tensor(spk_emb_for_flow) | |
| sampling_params = [SamplingParams(min_tokens=item["min_tokens"], max_tokens=item["max_tokens"], use_ras=True) for item in batch if item is not None] | |
| infos = [item["info"] for item in batch if item is not None] | |
| return { | |
| "prompt_mels_for_llm": prompt_mels_for_llm, | |
| "prompt_mels_lens_for_llm": prompt_mels_lens_for_llm, | |
| "prompt_text_tokens_for_llm": prompt_text_tokens_for_llm, | |
| "text_tokens_for_llm": text_tokens_for_llm, | |
| "prompt_mels_for_flow": prompt_mels_for_flow, | |
| "prompt_mels_lens_for_flow": prompt_mels_lens_for_flow, | |
| "spk_emb_for_flow": spk_emb_for_flow, | |
| "sampling_params": sampling_params, | |
| "infos": infos, | |
| } | |
| def init_distributed(): | |
| world_size = int(os.environ.get('WORLD_SIZE', 1)) | |
| local_rank = int(os.environ.get('LOCAL_RANK', 0)) | |
| rank = int(os.environ.get('RANK', 0)) | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f'[{timestamp}] - [INFO] - Inference on multiple gpus, this gpu {local_rank}, rank {rank}, world_size {world_size}') | |
| torch.cuda.set_device(local_rank) | |
| dist.init_process_group("nccl") | |
| return world_size, local_rank, rank | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description='FlashCosyVoice') | |
| parser.add_argument('--model_path', | |
| required=True, | |
| type=str, | |
| help='model path') | |
| parser.add_argument('--data_list', | |
| required=True, | |
| type=str, | |
| help='data list') | |
| parser.add_argument('--batch_size_dataloader', | |
| required=True, | |
| type=int, | |
| help='batch size (per-device) for dataloading') | |
| parser.add_argument('--batch_size_flow', | |
| required=True, | |
| type=int, | |
| help='batch size (per-device) for flow-matching') | |
| parser.add_argument('--num_workers', | |
| type=int, | |
| default=4, | |
| help='workers for dataloader') | |
| parser.add_argument('--prefetch', | |
| type=int, | |
| default=5, | |
| help='prefetch for dataloader') | |
| parser.add_argument('--enable_tn', | |
| action='store_true', | |
| help='enable text normalization') | |
| parser.add_argument('--only_llm', | |
| action='store_true', | |
| help='only generate speech tokens from llm') | |
| parser.add_argument('--fp16_flow', | |
| action='store_true', | |
| help='enable fp16 flow') | |
| parser.add_argument('--seed', | |
| type=int, | |
| default=1986, | |
| help='random seed for generation') | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = get_args() | |
| if args.enable_tn: | |
| # Check python version, if == 3.10, use ttsfrd | |
| if sys.version_info.major == 3 and sys.version_info.minor == 10: | |
| # Check if ttsfrd is installed | |
| try: | |
| import ttsfrd | |
| from cosyvoice_ttsfrd import get_resource_path | |
| except ImportError as e: | |
| raise ImportError("ttsfrd is not installed, please install it first, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for installation guide.") from e | |
| text_norm = ttsfrd.TtsFrontendEngine() | |
| text_norm.initialize(get_resource_path()) | |
| text_norm.set_lang_type('pinyinvg') | |
| else: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [WARNING] - Only python 3.10 is supported for ttsfrd, see `https://github.com/xingchensong/CosyVoice-ttsfrd` for more info. Setting enable_tn to False...") | |
| # TODO: maybe we should use wetext if python version is not 3.10? | |
| args.enable_tn = False | |
| text_norm = None | |
| else: | |
| text_norm = None | |
| assert (torch.cuda.is_available()) | |
| world_size, local_rank, rank = init_distributed() | |
| config = Config(model=args.model_path, enforce_eager=True, tensor_parallel_size=1, | |
| max_num_seqs=args.batch_size_dataloader, | |
| hf_config=CosyVoice2LLMConfig(fp16_flow=args.fp16_flow), rank=local_rank) | |
| model = CosyVoice2(config) | |
| set_all_random_seed(args.seed) | |
| dataset = AudioDataset(text_norm, model.llm.tokenizer, args.data_list, config) | |
| sampler = DistributedSampler(dataset, | |
| num_replicas=world_size, | |
| rank=rank) | |
| dataloader = DataLoader(dataset, batch_size=args.batch_size_dataloader, num_workers=args.num_workers, pin_memory=True, | |
| sampler=sampler, shuffle=False, prefetch_factor=args.prefetch, collate_fn=collate_fn) | |
| total_steps = len(dataset) | |
| if local_rank == 0: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [INFO] - {args}") | |
| progress_bar = tqdm(total=total_steps, desc="Processing samples", unit="wav", | |
| position=0, leave=True, dynamic_ncols=True) | |
| cpu_counts = os.cpu_count() | |
| executor = ThreadPoolExecutor(max_workers=min(args.batch_size_dataloader, cpu_counts // 8)) | |
| pending_futures = [] | |
| dataloader_iter = iter(dataloader) | |
| succeed_duration = 0.01 # avoid division by zero | |
| start_time = time.time() | |
| estimated_total_wavs = 0 | |
| succeed_wavs = 0 | |
| failed_wavs = 0 | |
| last_print_time = start_time | |
| while True: | |
| try: | |
| dataloader_start = time.time() | |
| batch = next(dataloader_iter) | |
| dataloader_time = time.time() - dataloader_start | |
| if len(batch['infos']) == 0: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [WARNING] - rank {rank} of {world_size}: No valid batch found, skipping this batch...") | |
| continue | |
| model_start = time.time() | |
| results_dict, timing_stats = model(**batch, batch_size_flow=args.batch_size_flow, | |
| only_llm=args.only_llm) | |
| model_time = time.time() - model_start | |
| estimated_total_wavs += len(results_dict['generated_wavs']) | |
| timing_stats['dataloader_time'] = dataloader_time | |
| timing_stats['model_inference_time'] = model_time | |
| if args.only_llm: | |
| results_dict['generated_wavs'] = [None] * len(results_dict['prompt_speech_tokens']) | |
| for i in range(len(results_dict['generated_wavs'])): | |
| future = executor.submit( | |
| save_file_async, results_dict['generated_wavs'][i], | |
| results_dict['prompt_speech_tokens'][i], | |
| results_dict['generated_speech_tokens'][i], | |
| batch['infos'][i].copy(), timing_stats.copy() | |
| ) | |
| pending_futures.append(future) | |
| completed_futures = [] | |
| for future in pending_futures: | |
| if future.done(): | |
| try: | |
| duration = future.result() | |
| succeed_duration += duration | |
| succeed_wavs += 1 | |
| except Exception as e: | |
| failed_wavs += 1 | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in async save task: {e}") | |
| completed_futures.append(future) | |
| for future in completed_futures: | |
| pending_futures.remove(future) | |
| if local_rank == 0: | |
| update_n = world_size * len(batch["prompt_text_tokens_for_llm"]) | |
| if progress_bar.n + update_n > progress_bar.total: | |
| progress_bar.update(progress_bar.total - progress_bar.n) | |
| else: | |
| progress_bar.update(update_n) | |
| current_time = time.time() | |
| if current_time - last_print_time >= 120 and not args.only_llm: | |
| elapsed_time = current_time - start_time | |
| avg_duration = succeed_duration / succeed_wavs if succeed_wavs > 0 else 0 | |
| estimated_total_duration = avg_duration * estimated_total_wavs | |
| current_rtf = elapsed_time / estimated_total_duration if estimated_total_duration > 0.01 else 0 | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Estimated total wavs: {estimated_total_wavs} ({estimated_total_wavs - succeed_wavs} pending to save), Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Estimated total duration: {estimated_total_duration:.2f}s ({estimated_total_duration / 3600:.2f} h), Estimated RTF: {current_rtf:.5f}, Elapsed time: {elapsed_time:.2f}s") # noqa | |
| last_print_time = current_time | |
| except StopIteration: | |
| break | |
| except Exception as e: | |
| failed_wavs += 1 | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in main loop: {e}") | |
| continue | |
| total_time = time.time() - start_time | |
| if local_rank == 0: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [INFO] - Waiting for {len(pending_futures)} pending save tasks to complete...") | |
| for future in pending_futures: | |
| try: | |
| duration = future.result(timeout=60) | |
| succeed_duration += duration | |
| succeed_wavs += 1 | |
| except Exception as e: | |
| failed_wavs += 1 | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [ERROR] - rank {rank} of {world_size}: Error in final async save task: {e}") | |
| executor.shutdown(wait=True) | |
| if local_rank == 0: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [INFO] - All async save tasks completed.") | |
| progress_bar.close() | |
| if not args.only_llm: | |
| timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S,%f')[:-3] | |
| tqdm.write(f"[{timestamp}] - [INFO] - rank {rank} of {world_size}: Final Report - Succeed wavs: {succeed_wavs}, Failed wavs: {failed_wavs}, Total duration: {succeed_duration:.2f}s ({succeed_duration / 3600:.2f} h), RTF: {total_time / succeed_duration:.5f}") # noqa | |
| dist.barrier() | |
| dist.destroy_process_group() | |
| if __name__ == "__main__": | |
| main() | |