Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	| import argparse | |
| import json | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from random import sample, shuffle | |
| from typing import Optional | |
| from tqdm import tqdm | |
| from config import get_config | |
| from style_bert_vits2.logging import logger | |
| from style_bert_vits2.nlp import clean_text | |
| from style_bert_vits2.nlp.japanese import pyopenjtalk_worker | |
| from style_bert_vits2.nlp.japanese.user_dict import update_dict | |
| from style_bert_vits2.utils.stdout_wrapper import SAFE_STDOUT | |
| # このプロセスからはワーカーを起動して辞書を使いたいので、ここで初期化 | |
| pyopenjtalk_worker.initialize_worker() | |
| # dict_data/ 以下の辞書データを pyopenjtalk に適用 | |
| update_dict() | |
| preprocess_text_config = get_config().preprocess_text_config | |
| # Count lines for tqdm | |
| def count_lines(file_path: Path): | |
| with file_path.open("r", encoding="utf-8") as file: | |
| return sum(1 for _ in file) | |
| def write_error_log(error_log_path: Path, line: str, error: Exception): | |
| with error_log_path.open("a", encoding="utf-8") as error_log: | |
| error_log.write(f"{line.strip()}\n{error}\n\n") | |
| def process_line( | |
| line: str, | |
| transcription_path: Path, | |
| correct_path: bool, | |
| use_jp_extra: bool, | |
| yomi_error: str, | |
| ): | |
| splitted_line = line.strip().split("|") | |
| if len(splitted_line) != 4: | |
| raise ValueError(f"Invalid line format: {line.strip()}") | |
| utt, spk, language, text = splitted_line | |
| norm_text, phones, tones, word2ph = clean_text( | |
| text=text, | |
| language=language, # type: ignore | |
| use_jp_extra=use_jp_extra, | |
| raise_yomi_error=(yomi_error != "use"), | |
| ) | |
| if correct_path: | |
| utt = str(transcription_path.parent / "wavs" / utt) | |
| return "{}|{}|{}|{}|{}|{}|{}\n".format( | |
| utt, | |
| spk, | |
| language, | |
| norm_text, | |
| " ".join(phones), | |
| " ".join([str(i) for i in tones]), | |
| " ".join([str(i) for i in word2ph]), | |
| ) | |
| def preprocess( | |
| transcription_path: Path, | |
| cleaned_path: Optional[Path], | |
| train_path: Path, | |
| val_path: Path, | |
| config_path: Path, | |
| val_per_lang: int, | |
| max_val_total: int, | |
| # clean: bool, | |
| use_jp_extra: bool, | |
| yomi_error: str, | |
| correct_path: bool, | |
| ): | |
| assert yomi_error in ["raise", "skip", "use"] | |
| if cleaned_path == "" or cleaned_path is None: | |
| cleaned_path = transcription_path.with_name( | |
| transcription_path.name + ".cleaned" | |
| ) | |
| error_log_path = transcription_path.parent / "text_error.log" | |
| if error_log_path.exists(): | |
| error_log_path.unlink() | |
| error_count = 0 | |
| total_lines = count_lines(transcription_path) | |
| # transcription_path から 1行ずつ読み込んで文章処理して cleaned_path に書き込む | |
| with ( | |
| transcription_path.open("r", encoding="utf-8") as trans_file, | |
| cleaned_path.open("w", encoding="utf-8") as out_file, | |
| ): | |
| for line in tqdm(trans_file, file=SAFE_STDOUT, total=total_lines): | |
| try: | |
| processed_line = process_line( | |
| line, | |
| transcription_path, | |
| correct_path, | |
| use_jp_extra, | |
| yomi_error, | |
| ) | |
| out_file.write(processed_line) | |
| except Exception as e: | |
| logger.error( | |
| f"An error occurred at line:\n{line.strip()}\n{e}", encoding="utf-8" | |
| ) | |
| write_error_log(error_log_path, line, e) | |
| error_count += 1 | |
| transcription_path = cleaned_path | |
| # 各話者ごとのlineの辞書 | |
| spk_utt_map: dict[str, list[str]] = defaultdict(list) | |
| # 話者からIDへの写像 | |
| spk_id_map: dict[str, int] = {} | |
| # 話者ID | |
| current_sid: int = 0 | |
| # 音源ファイルのチェックや、spk_id_mapの作成 | |
| with transcription_path.open("r", encoding="utf-8") as f: | |
| audio_paths: set[str] = set() | |
| count_same = 0 | |
| count_not_found = 0 | |
| for line in f.readlines(): | |
| utt, spk = line.strip().split("|")[:2] | |
| if utt in audio_paths: | |
| logger.warning(f"Same audio file appears multiple times: {utt}") | |
| count_same += 1 | |
| continue | |
| if not Path(utt).is_file(): | |
| logger.warning(f"Audio not found: {utt}") | |
| count_not_found += 1 | |
| continue | |
| audio_paths.add(utt) | |
| spk_utt_map[spk].append(line) | |
| # 新しい話者が出てきたら話者IDを割り当て、current_sidを1増やす | |
| if spk not in spk_id_map: | |
| spk_id_map[spk] = current_sid | |
| current_sid += 1 | |
| if count_same > 0 or count_not_found > 0: | |
| logger.warning( | |
| f"Total repeated audios: {count_same}, Total number of audio not found: {count_not_found}" | |
| ) | |
| train_list: list[str] = [] | |
| val_list: list[str] = [] | |
| # 各話者ごとに発話リストを処理 | |
| for spk, utts in spk_utt_map.items(): | |
| if val_per_lang == 0: | |
| train_list.extend(utts) | |
| continue | |
| # ランダムにval_per_lang個のインデックスを選択 | |
| val_indices = set(sample(range(len(utts)), val_per_lang)) | |
| # 元の順序を保ちながらリストを分割 | |
| for index, utt in enumerate(utts): | |
| if index in val_indices: | |
| val_list.append(utt) | |
| else: | |
| train_list.append(utt) | |
| # バリデーションリストのサイズ調整 | |
| if len(val_list) > max_val_total: | |
| extra_val = val_list[max_val_total:] | |
| val_list = val_list[:max_val_total] | |
| # 余剰のバリデーション発話をトレーニングリストに追加(元の順序を保持) | |
| train_list.extend(extra_val) | |
| with train_path.open("w", encoding="utf-8") as f: | |
| for line in train_list: | |
| f.write(line) | |
| with val_path.open("w", encoding="utf-8") as f: | |
| for line in val_list: | |
| f.write(line) | |
| with config_path.open("r", encoding="utf-8") as f: | |
| json_config = json.load(f) | |
| json_config["data"]["spk2id"] = spk_id_map | |
| json_config["data"]["n_speakers"] = len(spk_id_map) | |
| with config_path.open("w", encoding="utf-8") as f: | |
| json.dump(json_config, f, indent=2, ensure_ascii=False) | |
| if error_count > 0: | |
| if yomi_error == "skip": | |
| logger.warning( | |
| f"An error occurred in {error_count} lines. Proceed with lines without errors. Please check {error_log_path} for details." | |
| ) | |
| else: | |
| # yom_error == "raise"と"use"の場合。 | |
| # "use"の場合は、そもそもyomi_error = Falseで処理しているので、 | |
| # ここが実行されるのは他の例外のときなので、エラーをraiseする。 | |
| logger.error( | |
| f"An error occurred in {error_count} lines. Please check {error_log_path} for details." | |
| ) | |
| raise Exception( | |
| f"An error occurred in {error_count} lines. Please check `Data/you_model_name/text_error.log` file for details." | |
| ) | |
| # 何故か{error_log_path}をraiseすると文字コードエラーが起きるので上のように書いている | |
| else: | |
| logger.info( | |
| "Training set and validation set generation from texts is complete!" | |
| ) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--transcription-path", default=preprocess_text_config.transcription_path | |
| ) | |
| parser.add_argument("--cleaned-path", default=preprocess_text_config.cleaned_path) | |
| parser.add_argument("--train-path", default=preprocess_text_config.train_path) | |
| parser.add_argument("--val-path", default=preprocess_text_config.val_path) | |
| parser.add_argument("--config-path", default=preprocess_text_config.config_path) | |
| # 「話者ごと」のバリデーションデータ数、言語ごとではない! | |
| # 元のコードや設定ファイルでval_per_langとなっていたので名前をそのままにしている | |
| parser.add_argument( | |
| "--val-per-lang", | |
| default=preprocess_text_config.val_per_lang, | |
| help="Number of validation data per SPEAKER, not per language (due to compatibility with the original code).", | |
| ) | |
| parser.add_argument("--max-val-total", default=preprocess_text_config.max_val_total) | |
| parser.add_argument("--use_jp_extra", action="store_true") | |
| parser.add_argument("--yomi_error", default="raise") | |
| parser.add_argument("--correct_path", action="store_true") | |
| args = parser.parse_args() | |
| transcription_path = Path(args.transcription_path) | |
| cleaned_path = Path(args.cleaned_path) if args.cleaned_path else None | |
| train_path = Path(args.train_path) | |
| val_path = Path(args.val_path) | |
| config_path = Path(args.config_path) | |
| val_per_lang = int(args.val_per_lang) | |
| max_val_total = int(args.max_val_total) | |
| use_jp_extra: bool = args.use_jp_extra | |
| yomi_error: str = args.yomi_error | |
| correct_path: bool = args.correct_path | |
| preprocess( | |
| transcription_path=transcription_path, | |
| cleaned_path=cleaned_path, | |
| train_path=train_path, | |
| val_path=val_path, | |
| config_path=config_path, | |
| val_per_lang=val_per_lang, | |
| max_val_total=max_val_total, | |
| use_jp_extra=use_jp_extra, | |
| yomi_error=yomi_error, | |
| correct_path=correct_path, | |
| ) | |
