Spaces:
Running
Running
| #!/usr/bin/env python | |
| """ | |
| This script transforms custom dataset, gathered from Internet into | |
| DeepSpeech-ready .csv file | |
| Use "python3 import_ukrainian.py -h" for help | |
| """ | |
| import csv | |
| import os | |
| import subprocess | |
| import unicodedata | |
| from multiprocessing import Pool | |
| import progressbar | |
| import sox | |
| from deepspeech_training.util.downloader import SIMPLE_BAR | |
| from deepspeech_training.util.importers import ( | |
| get_counter, | |
| get_imported_samples, | |
| get_importers_parser, | |
| get_validate_label, | |
| print_import_report, | |
| ) | |
| from ds_ctcdecoder import Alphabet | |
| import re | |
| FIELDNAMES = ["wav_filename", "wav_filesize", "transcript"] | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| MAX_SECS = 10 | |
| PARAMS = None | |
| FILTER_OBJ = None | |
| AUDIO_DIR = None | |
| class LabelFilter: | |
| def __init__(self, normalize, alphabet, validate_fun): | |
| self.normalize = normalize | |
| self.alphabet = alphabet | |
| self.validate_fun = validate_fun | |
| def filter(self, label): | |
| if self.normalize: | |
| label = unicodedata.normalize("NFKD", label.strip()).encode( | |
| "ascii", "ignore").decode("ascii", "ignore") | |
| label = self.validate_fun(label) | |
| if self.alphabet and label and not self.alphabet.CanEncode(label): | |
| label = None | |
| return label | |
| def init_worker(params): | |
| global FILTER_OBJ # pylint: disable=global-statement | |
| global AUDIO_DIR # pylint: disable=global-statement | |
| AUDIO_DIR = params.audio_dir if params.audio_dir else os.path.join( | |
| params.tsv_dir, "clips") | |
| validate_label = get_validate_label(params) | |
| alphabet = Alphabet( | |
| params.filter_alphabet) if params.filter_alphabet else None | |
| FILTER_OBJ = LabelFilter(params.normalize, alphabet, validate_label) | |
| def one_sample(sample): | |
| """ Take an audio file, and optionally convert it to 16kHz WAV """ | |
| global AUDIO_DIR | |
| source_filename = sample[0] | |
| if not os.path.splitext(source_filename.lower())[1] == ".wav": | |
| source_filename += ".wav" | |
| # Storing wav files next to the mp3 ones - just with a different suffix | |
| output_filename = f"{sample[2]}.wav" | |
| output_filepath = os.path.join(AUDIO_DIR, output_filename) | |
| _maybe_convert_wav(source_filename, output_filepath) | |
| file_size = -1 | |
| frames = 0 | |
| if os.path.exists(output_filepath): | |
| file_size = os.path.getsize(output_filepath) | |
| if file_size == 0: | |
| frames = 0 | |
| else: | |
| frames = int( | |
| subprocess.check_output( | |
| ["soxi", "-s", output_filepath], stderr=subprocess.STDOUT | |
| ) | |
| ) | |
| label = FILTER_OBJ.filter(sample[1]) | |
| rows = [] | |
| counter = get_counter() | |
| if file_size == -1: | |
| # Excluding samples that failed upon conversion | |
| counter["failed"] += 1 | |
| elif label is None: | |
| # Excluding samples that failed on label validation | |
| counter["invalid_label"] += 1 | |
| # + 1 added for filtering surname dataset with too short audio files | |
| elif int(frames / SAMPLE_RATE * 1000 / 10 / 2) < len(str(label)) + 1: | |
| # Excluding samples that are too short to fit the transcript | |
| counter["too_short"] += 1 | |
| elif frames / SAMPLE_RATE > MAX_SECS: | |
| # Excluding very long samples to keep a reasonable batch-size | |
| counter["too_long"] += 1 | |
| else: | |
| # This one is good - keep it for the target CSV | |
| rows.append((os.path.split(output_filename) | |
| [-1], file_size, label, sample[2])) | |
| counter["imported_time"] += frames | |
| counter["all"] += 1 | |
| counter["total_time"] += frames | |
| return (counter, rows) | |
| def convert_transcript(transcript): | |
| transcript = transcript.replace("'", "’") | |
| # transcript = transcript.replace("-", " ") | |
| return transcript.strip() | |
| def _maybe_convert_set(dataset_dir, audio_dir, filter_obj, space_after_every_character=None, rows=None): | |
| # iterate over all data lists and write converted version near them | |
| speaker_iterator = 1 | |
| samples = [] | |
| total_file_dict = dict() | |
| for subdir, dirs, files in os.walk(dataset_dir): | |
| for file in files: | |
| # Get audiofile path and transcript for each sentence in tsv | |
| if file.endswith(".data"): | |
| file_path = os.path.join(subdir, file) | |
| file = open(file_path, mode="r") | |
| data = [] | |
| file_folder = os.path.join( | |
| os.path.dirname(subdir), "wav") | |
| file_dict = dict() | |
| for row in file.readlines(): | |
| if row.isspace(): | |
| continue | |
| splitted_row = row.replace("\n", "").replace( | |
| " wav ", ".wav ").split(" ", 1) | |
| if len(splitted_row) != 2: | |
| continue | |
| file_name, transcript = splitted_row | |
| if file_name.endswith(".wav"): | |
| pass | |
| elif file_name.endswith(".mp3"): | |
| pass | |
| elif file_name.find(".") == -1: | |
| file_name += ".wav" | |
| if file_name.startswith("/"): | |
| file_name = file_name[1::] | |
| file_name = os.path.join(dataset_dir, file_name) | |
| file_dict[file_name] = convert_transcript(transcript) | |
| file.close() | |
| for wav_subdir, wav_dirs, wav_files in os.walk(file_folder): | |
| for wav_file in wav_files: | |
| wav_file_path = os.path.join(wav_subdir, wav_file) | |
| if file_dict.get(wav_file_path) is not None: | |
| total_file_dict[wav_file_path] = file_dict[wav_file_path] | |
| for key in total_file_dict.keys(): | |
| samples.append((key, total_file_dict[key], speaker_iterator)) | |
| speaker_iterator += 1 | |
| del(total_file_dict) | |
| if rows is None: | |
| rows = [] | |
| counter = get_counter() | |
| num_samples = len(samples) | |
| print("Importing dataset files...") | |
| pool = Pool(initializer=init_worker, initargs=(PARAMS,)) | |
| bar = progressbar.ProgressBar( | |
| max_value=num_samples, widgets=SIMPLE_BAR) | |
| for i, processed in enumerate(pool.imap_unordered(one_sample, samples), start=1): | |
| counter += processed[0] | |
| rows += processed[1] | |
| bar.update(i) | |
| bar.update(num_samples) | |
| pool.close() | |
| pool.join() | |
| imported_samples = get_imported_samples(counter) | |
| assert counter["all"] == num_samples | |
| assert len(rows) == imported_samples | |
| print_import_report(counter, SAMPLE_RATE, MAX_SECS) | |
| output_csv = os.path.join(os.path.abspath(audio_dir), "train.csv") | |
| print("Saving new DeepSpeech-formatted CSV file to: ", output_csv) | |
| with open(output_csv, "w", encoding="utf-8", newline="") as output_csv_file: | |
| print("Writing CSV file for DeepSpeech.py as: ", output_csv) | |
| writer = csv.DictWriter(output_csv_file, fieldnames=FIELDNAMES) | |
| writer.writeheader() | |
| bar = progressbar.ProgressBar( | |
| max_value=len(rows), widgets=SIMPLE_BAR) | |
| for filename, file_size, transcript, speaker in bar(rows): | |
| if space_after_every_character: | |
| writer.writerow( | |
| { | |
| "wav_filename": filename, | |
| "wav_filesize": file_size, | |
| "transcript": " ".join(transcript), | |
| } | |
| ) | |
| else: | |
| writer.writerow( | |
| { | |
| "wav_filename": filename, | |
| "wav_filesize": file_size, | |
| "transcript": transcript, | |
| } | |
| ) | |
| return rows | |
| def _preprocess_data(tsv_dir, audio_dir, space_after_every_character=False): | |
| set_samples = _maybe_convert_set( | |
| tsv_dir, audio_dir, space_after_every_character) | |
| def _maybe_convert_wav(mp3_filename, wav_filename): | |
| if not os.path.exists(wav_filename): | |
| transformer = sox.Transformer() | |
| transformer.convert(samplerate=SAMPLE_RATE, n_channels=CHANNELS) | |
| try: | |
| transformer.build(mp3_filename, wav_filename) | |
| except Exception as e: # TODO: improve exception handling | |
| pass | |
| def parse_args(): | |
| parser = get_importers_parser( | |
| description="Import CommonVoice v2.0 corpora") | |
| parser.add_argument("tsv_dir", help="Directory containing tsv files") | |
| parser.add_argument( | |
| "--audio_dir", | |
| help='Directory containing the audio clips - defaults to "<tsv_dir>/clips"', | |
| ) | |
| parser.add_argument( | |
| "--filter_alphabet", | |
| help="Exclude samples with characters not in provided alphabet", | |
| ) | |
| parser.add_argument( | |
| "--normalize", | |
| action="store_true", | |
| help="Converts diacritic characters to their base ones", | |
| ) | |
| parser.add_argument( | |
| "--space_after_every_character", | |
| action="store_true", | |
| help="To help transcript join by white space", | |
| ) | |
| return parser.parse_args() | |
| def main(): | |
| audio_dir = PARAMS.audio_dir if PARAMS.audio_dir else os.path.join( | |
| PARAMS.tsv_dir, "clips") | |
| _preprocess_data(PARAMS.tsv_dir, audio_dir, | |
| PARAMS.space_after_every_character) | |
| if __name__ == "__main__": | |
| PARAMS = parse_args() | |
| main() | |