Spaces:
Runtime error
Runtime error
| import itertools | |
| import os | |
| import re | |
| from collections import defaultdict | |
| from functools import partial | |
| from multiprocessing import Pool | |
| from pathlib import Path | |
| import click | |
| import numpy as np | |
| from loguru import logger | |
| from tqdm import tqdm | |
| from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData | |
| from fish_speech.datasets.protos.text_data_stream import pack_pb_stream | |
| from tools.file import load_filelist | |
| # To avoid CPU overload | |
| os.environ["MKL_NUM_THREADS"] = "1" | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| def task_generator_folder(root: Path, text_extension: str): | |
| files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) | |
| files = sorted(files) | |
| grouped_files = defaultdict(list) | |
| for file in tqdm(files, desc=f"Grouping {root}"): | |
| p = str(file.parent) | |
| speaker = file.parent.name | |
| try: | |
| if isinstance(text_extension, str): | |
| texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")] | |
| else: | |
| texts = [ | |
| file.with_suffix(ext).read_text(encoding="utf-8") | |
| for ext in text_extension | |
| ] | |
| except Exception as e: | |
| logger.error(f"Failed to read text {file}: {e}") | |
| continue | |
| grouped_files[p].append((speaker, file, texts)) | |
| logger.info( | |
| f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." | |
| ) | |
| for i in grouped_files.values(): | |
| subset = [(f, t) for _, f, t in i] | |
| yield i[0][0], subset, "folder" | |
| def task_generator_filelist(filelist): | |
| grouped_files = defaultdict(list) | |
| for filename, speaker, _, text in load_filelist(filelist): | |
| grouped_files[speaker].append((Path(filename), [text])) | |
| logger.info(f"Found {len(grouped_files)} groups in {filelist}") | |
| for speaker, values in grouped_files.items(): | |
| yield speaker, values, "filelist" | |
| def run_task(task): | |
| name, subset, source = task | |
| # Parse the files | |
| sentences = [] | |
| for file, texts in subset: | |
| np_file = file.with_suffix(".npy") | |
| if np_file.exists() is False: | |
| logger.warning(f"Can't find {np_file}") | |
| continue | |
| new_texts = [] | |
| for text in texts: | |
| # Simple cleaning: replace { xxx } and < xxx > with space | |
| text = re.sub(r"\{.*?\}", " ", text) | |
| text = re.sub(r"<.*?>", " ", text) | |
| text = re.sub(r"\s+", " ", text) | |
| new_texts.append(text) | |
| try: | |
| semantics = np.load(np_file) | |
| except Exception as e: | |
| logger.error(f"Failed to parse {file}: {e}") | |
| continue | |
| if isinstance(semantics, np.ndarray): | |
| semantics = semantics.tolist() | |
| sentences.append( | |
| Sentence( | |
| texts=new_texts, | |
| semantics=[Semantics(values=s) for s in semantics], | |
| ) | |
| ) | |
| # Pack the sentences | |
| return pack_pb_stream( | |
| TextData( | |
| source=source, | |
| name=name, | |
| sentences=sentences, | |
| ) | |
| ) | |
| def main(input, output, num_workers, text_extension, shard_size): | |
| generator_fns = [] | |
| for f in input: | |
| assert f.exists(), f"{f} not found" | |
| if f.is_dir(): | |
| generator_fn = task_generator_folder(f, text_extension) | |
| else: | |
| generator_fn = task_generator_filelist(f) | |
| generator_fns.append(generator_fn) | |
| generator_fn = itertools.chain(*generator_fns) | |
| output.mkdir(parents=True, exist_ok=True) | |
| dataset_fp = None | |
| tar_idx = 0 | |
| written_size = 0 | |
| with Pool(num_workers) as p: | |
| for result in tqdm(p.imap_unordered(run_task, generator_fn)): | |
| if dataset_fp is None: | |
| dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") | |
| dataset_fp.write(result) | |
| written_size += len(result) | |
| if written_size > shard_size * 1024 * 1024: | |
| logger.info(f"Finished writing {tar_idx} shards to {output}") | |
| dataset_fp.close() | |
| dataset_fp = None | |
| written_size = 0 | |
| tar_idx += 1 | |
| if dataset_fp is not None: | |
| dataset_fp.close() | |
| logger.info(f"Finished writing {tar_idx + 1} shards to {output}") | |
| if __name__ == "__main__": | |
| main() | |