Upload 2 files (#2)
Browse files- Upload 2 files (b2049a6a169ec39e4d5e8095d4dbcb51d2913040)
Co-authored-by: Elodie Gauthier <gauthelo@users.noreply.huggingface.co>
- ASR_FLEURS-swahili_hf.yaml +190 -0
- SB_ASR_FLEURS_finetuning.ipynb +689 -0
    	
        ASR_FLEURS-swahili_hf.yaml
    ADDED
    
    | @@ -0,0 +1,190 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Seed needs to be set at top of yaml, before objects with parameters are made
         | 
| 2 | 
            +
            seed: 1987
         | 
| 3 | 
            +
            __set_seed: !apply:torch.manual_seed [!ref <seed>]
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            lang_csv: Swahili
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            output_folder: !ref results/finetune_hubert_ASR_char/<seed>/<lang_csv>
         | 
| 8 | 
            +
            output_wer_folder: !ref <output_folder>/
         | 
| 9 | 
            +
            save_folder: !ref <output_folder>/save
         | 
| 10 | 
            +
            train_log: !ref <output_folder>/train_log.txt
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            # huggingface format 
         | 
| 13 | 
            +
            hubert_hub: Orange/SSA-HuBERT-base-5k
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            hubert_folder: !ref <save_folder>/hubert_checkpoint
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # Data files
         | 
| 18 | 
            +
            data_folder: !ref PATH_TO_YOUR_FOLDER/data_speechbrain/<lang_csv>
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ckpt_interval_minutes: 10 # save checkpoint every N min
         | 
| 21 | 
            +
            train_csv: !ref <data_folder>/train.csv
         | 
| 22 | 
            +
            valid_csv: !ref <data_folder>/validation.csv 
         | 
| 23 | 
            +
            test_csv:
         | 
| 24 | 
            +
                    - !ref <data_folder>/test.csv
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ####################### Training Parameters ####################################
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            number_of_epochs: 10
         | 
| 29 | 
            +
            lr: 0.1
         | 
| 30 | 
            +
            lr_hubert: 0.000005
         | 
| 31 | 
            +
            sorting: ascending
         | 
| 32 | 
            +
            precision: fp32 # bf16, fp16 or fp32
         | 
| 33 | 
            +
            sample_rate: 16000
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            # skip audio file longer than
         | 
| 36 | 
            +
            avoid_if_longer_than: 60
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            batch_size: 2
         | 
| 39 | 
            +
            test_batch_size: 2
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            # Dataloader options
         | 
| 42 | 
            +
            train_dataloader_opts:
         | 
| 43 | 
            +
               batch_size: !ref <batch_size>
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            valid_dataloader_opts:
         | 
| 46 | 
            +
               batch_size: !ref <batch_size>
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            test_dataloader_opts:
         | 
| 49 | 
            +
               batch_size: !ref <test_batch_size>
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            ####################### Model Parameters #######################################
         | 
| 52 | 
            +
            activation: !name:torch.nn.LeakyReLU
         | 
| 53 | 
            +
            dnn_layers: 2
         | 
| 54 | 
            +
            dnn_neurons: 1024
         | 
| 55 | 
            +
            freeze_hubert: False
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            # Outputs
         | 
| 58 | 
            +
            output_neurons: 66  # BPE size, index(blank/eos/bos) = 0
         | 
| 59 | 
            +
            blank_index: 0
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            #
         | 
| 62 | 
            +
            # Functions and classes
         | 
| 63 | 
            +
            #
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            label_encoder: !new:speechbrain.dataio.encoder.CTCTextEncoder
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter
         | 
| 68 | 
            +
               limit: !ref <number_of_epochs>
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            hubert: !new:speechbrain.lobes.models.huggingface_transformers.hubert.HuBERT
         | 
| 71 | 
            +
               source: !ref <hubert_hub>
         | 
| 72 | 
            +
               output_norm: True
         | 
| 73 | 
            +
               freeze: !ref <freeze_hubert>
         | 
| 74 | 
            +
               save_path: !ref <hubert_folder>
         | 
| 75 | 
            +
               
         | 
| 76 | 
            +
            top_lin: !new:speechbrain.lobes.models.VanillaNN.VanillaNN
         | 
| 77 | 
            +
               input_shape: [null, null, 768] # 768 == output of hubert base model
         | 
| 78 | 
            +
               activation: !ref <activation>
         | 
| 79 | 
            +
               dnn_blocks: !ref <dnn_layers>
         | 
| 80 | 
            +
               dnn_neurons: !ref <dnn_neurons>
         | 
| 81 | 
            +
             | 
| 82 | 
            +
            ctc_lin: !new:speechbrain.nnet.linear.Linear
         | 
| 83 | 
            +
               input_size: !ref <dnn_neurons>
         | 
| 84 | 
            +
               n_neurons: !ref <output_neurons>
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            log_softmax: !new:speechbrain.nnet.activations.Softmax
         | 
| 87 | 
            +
               apply_log: True
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
         | 
| 90 | 
            +
               blank_index: !ref <blank_index>
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            modules:
         | 
| 93 | 
            +
               hubert: !ref <hubert>
         | 
| 94 | 
            +
               top_lin: !ref <top_lin>
         | 
| 95 | 
            +
               ctc_lin: !ref <ctc_lin>
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            model: !new:torch.nn.ModuleList
         | 
| 98 | 
            +
               - [!ref <top_lin>, !ref <ctc_lin>]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            model_opt_class: !name:torch.optim.Adadelta
         | 
| 101 | 
            +
               lr: !ref <lr>
         | 
| 102 | 
            +
               rho: 0.95
         | 
| 103 | 
            +
               eps: 1.e-8
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            hubert_opt_class: !name:torch.optim.Adam
         | 
| 106 | 
            +
               lr: !ref <lr_hubert>
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            lr_annealing_model: !new:speechbrain.nnet.schedulers.NewBobScheduler
         | 
| 109 | 
            +
               initial_value: !ref <lr>
         | 
| 110 | 
            +
               improvement_threshold: 0.0025
         | 
| 111 | 
            +
               annealing_factor: 0.8
         | 
| 112 | 
            +
               patient: 0
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            lr_annealing_hubert: !new:speechbrain.nnet.schedulers.NewBobScheduler
         | 
| 115 | 
            +
               initial_value: !ref <lr_hubert>
         | 
| 116 | 
            +
               improvement_threshold: 0.0025
         | 
| 117 | 
            +
               annealing_factor: 0.9
         | 
| 118 | 
            +
               patient: 0
         | 
| 119 | 
            +
             | 
| 120 | 
            +
            ############################## Augmentations ###################################
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            # Speed perturbation
         | 
| 123 | 
            +
            speed_perturb: !new:speechbrain.augment.time_domain.SpeedPerturb
         | 
| 124 | 
            +
               orig_freq: !ref <sample_rate>
         | 
| 125 | 
            +
               speeds: [95, 100, 105]
         | 
| 126 | 
            +
             | 
| 127 | 
            +
            # Frequency drop: randomly drops a number of frequency bands to zero.
         | 
| 128 | 
            +
            drop_freq: !new:speechbrain.augment.time_domain.DropFreq
         | 
| 129 | 
            +
               drop_freq_low: 0
         | 
| 130 | 
            +
               drop_freq_high: 1
         | 
| 131 | 
            +
               drop_freq_count_low: 1
         | 
| 132 | 
            +
               drop_freq_count_high: 3
         | 
| 133 | 
            +
               drop_freq_width: 0.05
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            # Time drop: randomly drops a number of temporal chunks.
         | 
| 136 | 
            +
            drop_chunk: !new:speechbrain.augment.time_domain.DropChunk
         | 
| 137 | 
            +
               drop_length_low: 1000
         | 
| 138 | 
            +
               drop_length_high: 2000
         | 
| 139 | 
            +
               drop_count_low: 1
         | 
| 140 | 
            +
               drop_count_high: 5
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            # Augmenter: Combines previously defined augmentations to perform data augmentation
         | 
| 143 | 
            +
            wav_augment: !new:speechbrain.augment.augmenter.Augmenter
         | 
| 144 | 
            +
               concat_original: True
         | 
| 145 | 
            +
               min_augmentations: 4
         | 
| 146 | 
            +
               max_augmentations: 4
         | 
| 147 | 
            +
               augment_prob: 1.0
         | 
| 148 | 
            +
               augmentations: [
         | 
| 149 | 
            +
                  !ref <speed_perturb>,
         | 
| 150 | 
            +
                  !ref <drop_freq>,
         | 
| 151 | 
            +
                  !ref <drop_chunk>]
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            ############################## Decoding ########################################
         | 
| 154 | 
            +
             | 
| 155 | 
            +
            # Decoding parameters
         | 
| 156 | 
            +
            test_beam_search:
         | 
| 157 | 
            +
               beam_size: 143
         | 
| 158 | 
            +
               topk: 1
         | 
| 159 | 
            +
               blank_index: !ref <blank_index>
         | 
| 160 | 
            +
               space_token: ' ' # make sure this is the same as the one used in the tokenizer
         | 
| 161 | 
            +
               beam_prune_logp: -12.0
         | 
| 162 | 
            +
               token_prune_min_logp: -1.20
         | 
| 163 | 
            +
               prune_history: True
         | 
| 164 | 
            +
               alpha: 0.8
         | 
| 165 | 
            +
               beta: 1.2
         | 
| 166 | 
            +
               # can be downloaded from here https://www.openslr.org/11/ or trained with kenLM
         | 
| 167 | 
            +
               # It can either be a .bin or .arpa ; note: .arpa is much slower at loading
         | 
| 168 | 
            +
               # If you don't want to use an LM, comment it out or set it to null
         | 
| 169 | 
            +
               kenlm_model_path: null
         | 
| 170 | 
            +
             | 
| 171 | 
            +
            ############################## Logging and Pretrainer ##########################
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
         | 
| 174 | 
            +
               checkpoints_dir: !ref <save_folder>
         | 
| 175 | 
            +
               recoverables:
         | 
| 176 | 
            +
                  hubert: !ref <hubert>
         | 
| 177 | 
            +
                  model: !ref <model>
         | 
| 178 | 
            +
                  scheduler_model: !ref <lr_annealing_model>
         | 
| 179 | 
            +
                  scheduler_hubert: !ref <lr_annealing_hubert>
         | 
| 180 | 
            +
                  counter: !ref <epoch_counter>
         | 
| 181 | 
            +
                  tokenizer: !ref <label_encoder>
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
         | 
| 184 | 
            +
               save_file: !ref <train_log>
         | 
| 185 | 
            +
             | 
| 186 | 
            +
            error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
         | 
| 187 | 
            +
             | 
| 188 | 
            +
            cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
         | 
| 189 | 
            +
               split_tokens: True
         | 
| 190 | 
            +
             | 
    	
        SB_ASR_FLEURS_finetuning.ipynb
    ADDED
    
    | @@ -0,0 +1,689 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "markdown",
         | 
| 5 | 
            +
               "id": "49b85514-0fb6-49c6-be76-259bfeb638c6",
         | 
| 6 | 
            +
               "metadata": {},
         | 
| 7 | 
            +
               "source": [
         | 
| 8 | 
            +
                "# Introduction\n",
         | 
| 9 | 
            +
                "N'hésitez pas à nous contacter en cas de questions : antoine.caubriere@orange.com & elodie.gauthier@orange.com\n",
         | 
| 10 | 
            +
                "\n",
         | 
| 11 | 
            +
                "Pensez à modifier l'ensemble des PATH dans le fichier de configuration ASR_FLEURSswahili_hf.yaml et dans le code python ci-dessous (PATH_TO_YOUR_FOLDER).\n",
         | 
| 12 | 
            +
                "\n",
         | 
| 13 | 
            +
                "Dans le cas d'un changement de corpus (autre sous partie de FLEURS / vos propres jeux de données), pensez à modifier la taille de la couche de sortie du modèle : ASR_swahili_hf.yaml/output_neurons\n"
         | 
| 14 | 
            +
               ]
         | 
| 15 | 
            +
              },
         | 
| 16 | 
            +
              {
         | 
| 17 | 
            +
               "cell_type": "markdown",
         | 
| 18 | 
            +
               "id": "e62faa86-911a-48ce-82bc-8a34e13ffbc4",
         | 
| 19 | 
            +
               "metadata": {},
         | 
| 20 | 
            +
               "source": [
         | 
| 21 | 
            +
                "# Préparation des données FLEURS"
         | 
| 22 | 
            +
               ]
         | 
| 23 | 
            +
              },
         | 
| 24 | 
            +
              {
         | 
| 25 | 
            +
               "cell_type": "markdown",
         | 
| 26 | 
            +
               "id": "c6ccf4a5-cad1-4632-8954-f4e454ff3540",
         | 
| 27 | 
            +
               "metadata": {},
         | 
| 28 | 
            +
               "source": [
         | 
| 29 | 
            +
                "### 1. Installation des dépendances"
         | 
| 30 | 
            +
               ]
         | 
| 31 | 
            +
              },
         | 
| 32 | 
            +
              {
         | 
| 33 | 
            +
               "cell_type": "code",
         | 
| 34 | 
            +
               "execution_count": null,
         | 
| 35 | 
            +
               "id": "7bb8b44e-826f-4f13-b128-eebbd18dedc5",
         | 
| 36 | 
            +
               "metadata": {
         | 
| 37 | 
            +
                "jupyter": {
         | 
| 38 | 
            +
                 "source_hidden": true
         | 
| 39 | 
            +
                }
         | 
| 40 | 
            +
               },
         | 
| 41 | 
            +
               "outputs": [],
         | 
| 42 | 
            +
               "source": [
         | 
| 43 | 
            +
                "pip install datasets librosa soundfile"
         | 
| 44 | 
            +
               ]
         | 
| 45 | 
            +
              },
         | 
| 46 | 
            +
              {
         | 
| 47 | 
            +
               "cell_type": "markdown",
         | 
| 48 | 
            +
               "id": "016d7646-bcca-4422-8b28-9d12d4b86c8f",
         | 
| 49 | 
            +
               "metadata": {},
         | 
| 50 | 
            +
               "source": [
         | 
| 51 | 
            +
                "### 2. Téléchargement et formatage du dataset"
         | 
| 52 | 
            +
               ]
         | 
| 53 | 
            +
              },
         | 
| 54 | 
            +
              {
         | 
| 55 | 
            +
               "cell_type": "code",
         | 
| 56 | 
            +
               "execution_count": null,
         | 
| 57 | 
            +
               "id": "da273973-05ee-4de5-830e-34d7f2220353",
         | 
| 58 | 
            +
               "metadata": {},
         | 
| 59 | 
            +
               "outputs": [],
         | 
| 60 | 
            +
               "source": [
         | 
| 61 | 
            +
                "from datasets import load_dataset\n",
         | 
| 62 | 
            +
                "from pathlib import Path\n",
         | 
| 63 | 
            +
                "from collections import OrderedDict\n",
         | 
| 64 | 
            +
                "from tqdm import tqdm\n",
         | 
| 65 | 
            +
                "import shutil\n",
         | 
| 66 | 
            +
                "import os\n",
         | 
| 67 | 
            +
                "\n",
         | 
| 68 | 
            +
                "dataset_write_base = \"PATH_TO_YOUR_FOLDER/data_speechbrain/\"\n",
         | 
| 69 | 
            +
                "cache_dir = \"PATH_TO_YOUR_FOLDER/data_huggingface/\"\n",
         | 
| 70 | 
            +
                "\n",
         | 
| 71 | 
            +
                "if os.path.isdir(cache_dir):\n",
         | 
| 72 | 
            +
                "    print(\"rm -rf \"+cache_dir)\n",
         | 
| 73 | 
            +
                "    os.system(\"rm -rf \"+cache_dir)\n",
         | 
| 74 | 
            +
                "\n",
         | 
| 75 | 
            +
                "if os.path.isdir(dataset_write_base):\n",
         | 
| 76 | 
            +
                "    print(\"rm -rf \"+dataset_write_base)\n",
         | 
| 77 | 
            +
                "    os.system(\"rm -rf \"+dataset_write_base)\n",
         | 
| 78 | 
            +
                "\n",
         | 
| 79 | 
            +
                "# **************************************\n",
         | 
| 80 | 
            +
                "# choix des langues à extraire de FLEURS\n",
         | 
| 81 | 
            +
                "# **************************************\n",
         | 
| 82 | 
            +
                "lang_dict = OrderedDict([\n",
         | 
| 83 | 
            +
                "    #(\"Afrikaans\",\"af_za\"),\n",
         | 
| 84 | 
            +
                "    #(\"Amharic\", \"am_et\"),\n",
         | 
| 85 | 
            +
                "    #(\"Fula\", \"ff_sn\"),\n",
         | 
| 86 | 
            +
                "    #(\"Ganda\", \"lg_ug\"),\n",
         | 
| 87 | 
            +
                "    #(\"Hausa\", \"ha_ng\"),\n",
         | 
| 88 | 
            +
                "    #(\"Igbo\", \"ig_ng\"),\n",
         | 
| 89 | 
            +
                "    #(\"Kamba\", \"kam_ke\"),\n",
         | 
| 90 | 
            +
                "    #(\"Lingala\", \"ln_cd\"),\n",
         | 
| 91 | 
            +
                "    #(\"Luo\", \"luo_ke\"),\n",
         | 
| 92 | 
            +
                "    #(\"Northern-Sotho\", \"nso_za\"),\n",
         | 
| 93 | 
            +
                "    #(\"Nyanja\", \"ny_mw\"),\n",
         | 
| 94 | 
            +
                "    #(\"Oromo\", \"om_et\"),\n",
         | 
| 95 | 
            +
                "    #(\"Shona\", \"sn_zw\"),\n",
         | 
| 96 | 
            +
                "    #(\"Somali\", \"so_so\"),\n",
         | 
| 97 | 
            +
                "    (\"Swahili\", \"sw_ke\"),\n",
         | 
| 98 | 
            +
                "    #(\"Umbundu\", \"umb_ao\"),\n",
         | 
| 99 | 
            +
                "    #(\"Wolof\", \"wo_sn\"), \n",
         | 
| 100 | 
            +
                "    #(\"Xhosa\", \"xh_za\"), \n",
         | 
| 101 | 
            +
                "    #(\"Yoruba\", \"yo_ng\"), \n",
         | 
| 102 | 
            +
                "    #(\"Zulu\", \"zu_za\")\n",
         | 
| 103 | 
            +
                "    ])\n",
         | 
| 104 | 
            +
                "\n",
         | 
| 105 | 
            +
                "# ********************************\n",
         | 
| 106 | 
            +
                "# choix des sous-parties à traiter\n",
         | 
| 107 | 
            +
                "# ********************************\n",
         | 
| 108 | 
            +
                "datasets = [\"train\",\"test\",\"validation\"]\n",
         | 
| 109 | 
            +
                "\n",
         | 
| 110 | 
            +
                "for lang in lang_dict:\n",
         | 
| 111 | 
            +
                "    print(\"Prepare --->\", lang)\n",
         | 
| 112 | 
            +
                "    \n",
         | 
| 113 | 
            +
                "    # ********************************\n",
         | 
| 114 | 
            +
                "    # Download FLEURS from huggingface\n",
         | 
| 115 | 
            +
                "    # ********************************\n",
         | 
| 116 | 
            +
                "    fleurs_asr = load_dataset(\"google/fleurs\", lang_dict[lang],cache_dir=cache_dir, trust_remote_code=True)\n",
         | 
| 117 | 
            +
                "\n",
         | 
| 118 | 
            +
                "    for subparts in datasets:\n",
         | 
| 119 | 
            +
                "        \n",
         | 
| 120 | 
            +
                "        used_ID = []\n",
         | 
| 121 | 
            +
                "        Path(dataset_write_base+\"/\"+lang+\"/wavs/\"+subparts).mkdir(parents=True, exist_ok=True)\n",
         | 
| 122 | 
            +
                "        \n",
         | 
| 123 | 
            +
                "        # csv header\n",
         | 
| 124 | 
            +
                "        f = open(dataset_write_base+\"/\"+lang+\"/\"+subparts+\".csv\", \"w\")\n",
         | 
| 125 | 
            +
                "        f.write(\"ID,duration,wav,spk_id,wrd\\n\")\n",
         | 
| 126 | 
            +
                "\n",
         | 
| 127 | 
            +
                "        for uid in tqdm(range(len(fleurs_asr[subparts]))):\n",
         | 
| 128 | 
            +
                "\n",
         | 
| 129 | 
            +
                "            # ***************\n",
         | 
| 130 | 
            +
                "            # format CSV line\n",
         | 
| 131 | 
            +
                "            # ***************\n",
         | 
| 132 | 
            +
                "            text_id = lang+\"_\"+str(fleurs_asr[subparts][uid][\"id\"])\n",
         | 
| 133 | 
            +
                "            \n",
         | 
| 134 | 
            +
                "            # some ID are duplicated (same speaker, same transcription BUT different recording)\n",
         | 
| 135 | 
            +
                "            while(text_id in used_ID):\n",
         | 
| 136 | 
            +
                "                text_id += \"_bis\"\n",
         | 
| 137 | 
            +
                "            used_ID.append(text_id)\n",
         | 
| 138 | 
            +
                "\n",
         | 
| 139 | 
            +
                "            duration = \"{:.3f}\".format(round(float(fleurs_asr[subparts][uid][\"num_samples\"])/float(fleurs_asr[subparts][uid][\"audio\"][\"sampling_rate\"]),3))\n",
         | 
| 140 | 
            +
                "            wav_path = \"/\".join([dataset_write_base, lang, \"wavs\",subparts, fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
         | 
| 141 | 
            +
                "            spk_id = \"spk_\" + text_id\n",
         | 
| 142 | 
            +
                "            # AC : \"pseudo-normalisation\" de cas marginaux -- TODO mieux\n",
         | 
| 143 | 
            +
                "            wrd = fleurs_asr[subparts][uid][\"transcription\"].replace(',','').replace('$',' $ ').replace('\"','').replace('”','').replace('  ',' ')\n",
         | 
| 144 | 
            +
                "\n",
         | 
| 145 | 
            +
                "            # **************\n",
         | 
| 146 | 
            +
                "            # write CSV line\n",
         | 
| 147 | 
            +
                "            # **************\n",
         | 
| 148 | 
            +
                "            f.write(text_id+\",\"+duration+\",\"+wav_path+\",\"+spk_id+\",\"+wrd+\"\\n\") \n",
         | 
| 149 | 
            +
                "\n",
         | 
| 150 | 
            +
                "            # *******************\n",
         | 
| 151 | 
            +
                "            # Move wav from cache\n",
         | 
| 152 | 
            +
                "            # *******************\n",
         | 
| 153 | 
            +
                "            previous_path = \"/\".join(fleurs_asr[subparts][uid][\"path\"].split('/')[:-1]) + \"/\" + fleurs_asr[subparts][uid][\"audio\"][\"path\"]\n",
         | 
| 154 | 
            +
                "            new_path = \"/\".join([dataset_write_base,lang,\"wavs\",subparts,fleurs_asr[subparts][uid][\"audio\"][\"path\"].split('/')[-1]])\n",
         | 
| 155 | 
            +
                "            shutil.move(previous_path,new_path)\n",
         | 
| 156 | 
            +
                "        \n",
         | 
| 157 | 
            +
                "        f.close()\n",
         | 
| 158 | 
            +
                "    print(\"--->\", lang, \"done\")"
         | 
| 159 | 
            +
               ]
         | 
| 160 | 
            +
              },
         | 
| 161 | 
            +
              {
         | 
| 162 | 
            +
               "cell_type": "markdown",
         | 
| 163 | 
            +
               "id": "4c32e369-f0f9-4695-8c9a-aa3a9de7bf7b",
         | 
| 164 | 
            +
               "metadata": {},
         | 
| 165 | 
            +
               "source": [
         | 
| 166 | 
            +
                "# Recette ASR"
         | 
| 167 | 
            +
               ]
         | 
| 168 | 
            +
              },
         | 
| 169 | 
            +
              {
         | 
| 170 | 
            +
               "cell_type": "markdown",
         | 
| 171 | 
            +
               "id": "77fb2c55-3f8c-4f34-81f0-ad48a632e010",
         | 
| 172 | 
            +
               "metadata": {
         | 
| 173 | 
            +
                "jp-MarkdownHeadingCollapsed": true
         | 
| 174 | 
            +
               },
         | 
| 175 | 
            +
               "source": [
         | 
| 176 | 
            +
                "## 1. Installation des dépendances"
         | 
| 177 | 
            +
               ]
         | 
| 178 | 
            +
              },
         | 
| 179 | 
            +
              {
         | 
| 180 | 
            +
               "cell_type": "code",
         | 
| 181 | 
            +
               "execution_count": null,
         | 
| 182 | 
            +
               "id": "fbe25635-e765-480c-8416-c48a31ee6140",
         | 
| 183 | 
            +
               "metadata": {},
         | 
| 184 | 
            +
               "outputs": [],
         | 
| 185 | 
            +
               "source": [
         | 
| 186 | 
            +
                "pip install torch==2.2.2 torchaudio==2.2.2 torchvision==0.17.2 speechbrain transformers jdc"
         | 
| 187 | 
            +
               ]
         | 
| 188 | 
            +
              },
         | 
| 189 | 
            +
              {
         | 
| 190 | 
            +
               "cell_type": "markdown",
         | 
| 191 | 
            +
               "id": "6acf1f8c-2cf3-4c9c-8a45-e2580ecbee27",
         | 
| 192 | 
            +
               "metadata": {},
         | 
| 193 | 
            +
               "source": [
         | 
| 194 | 
            +
                "## 2. Mise en place de la recette Speechbrain -- class Brain"
         | 
| 195 | 
            +
               ]
         | 
| 196 | 
            +
              },
         | 
| 197 | 
            +
              {
         | 
| 198 | 
            +
               "cell_type": "markdown",
         | 
| 199 | 
            +
               "id": "d5e8884d-3542-40ff-a454-597078fcf97c",
         | 
| 200 | 
            +
               "metadata": {},
         | 
| 201 | 
            +
               "source": [
         | 
| 202 | 
            +
                "### 2.1 Imports & logger"
         | 
| 203 | 
            +
               ]
         | 
| 204 | 
            +
              },
         | 
| 205 | 
            +
              {
         | 
| 206 | 
            +
               "cell_type": "code",
         | 
| 207 | 
            +
               "execution_count": null,
         | 
| 208 | 
            +
               "id": "6c677f9f-6abe-423f-b4dd-fdf5ded357cd",
         | 
| 209 | 
            +
               "metadata": {},
         | 
| 210 | 
            +
               "outputs": [],
         | 
| 211 | 
            +
               "source": [
         | 
| 212 | 
            +
                "import logging\n",
         | 
| 213 | 
            +
                "import os\n",
         | 
| 214 | 
            +
                "import sys\n",
         | 
| 215 | 
            +
                "from pathlib import Path\n",
         | 
| 216 | 
            +
                "\n",
         | 
| 217 | 
            +
                "import torch\n",
         | 
| 218 | 
            +
                "from hyperpyyaml import load_hyperpyyaml\n",
         | 
| 219 | 
            +
                "\n",
         | 
| 220 | 
            +
                "import speechbrain as sb\n",
         | 
| 221 | 
            +
                "from speechbrain.utils.distributed import if_main_process, run_on_main\n",
         | 
| 222 | 
            +
                "\n",
         | 
| 223 | 
            +
                "import jdc\n",
         | 
| 224 | 
            +
                "\n",
         | 
| 225 | 
            +
                "logger = logging.getLogger(__name__)"
         | 
| 226 | 
            +
               ]
         | 
| 227 | 
            +
              },
         | 
| 228 | 
            +
              {
         | 
| 229 | 
            +
               "cell_type": "markdown",
         | 
| 230 | 
            +
               "id": "9698bb92-16ad-4b61-8938-c74b62ee93b2",
         | 
| 231 | 
            +
               "metadata": {},
         | 
| 232 | 
            +
               "source": [
         | 
| 233 | 
            +
                "### 2.2 Création de notre classe héritant de la classe brain"
         | 
| 234 | 
            +
               ]
         | 
| 235 | 
            +
              },
         | 
| 236 | 
            +
              {
         | 
| 237 | 
            +
               "cell_type": "code",
         | 
| 238 | 
            +
               "execution_count": null,
         | 
| 239 | 
            +
               "id": "7c7cd624-6249-449b-8ee9-d4a73b7b3301",
         | 
| 240 | 
            +
               "metadata": {},
         | 
| 241 | 
            +
               "outputs": [],
         | 
| 242 | 
            +
               "source": [
         | 
| 243 | 
            +
                "# Define training procedure\n",
         | 
| 244 | 
            +
                "class MY_SSA_ASR(sb.Brain):\n",
         | 
| 245 | 
            +
                "    print(\"\")\n",
         | 
| 246 | 
            +
                "    # define here"
         | 
| 247 | 
            +
               ]
         | 
| 248 | 
            +
              },
         | 
| 249 | 
            +
              {
         | 
| 250 | 
            +
               "cell_type": "markdown",
         | 
| 251 | 
            +
               "id": "ecf31c9c-15dd-4428-aa10-b3cc5e127f0d",
         | 
| 252 | 
            +
               "metadata": {},
         | 
| 253 | 
            +
               "source": [
         | 
| 254 | 
            +
                "### 2.3 Définition de la fonction forward "
         | 
| 255 | 
            +
               ]
         | 
| 256 | 
            +
              },
         | 
| 257 | 
            +
              {
         | 
| 258 | 
            +
               "cell_type": "code",
         | 
| 259 | 
            +
               "execution_count": null,
         | 
| 260 | 
            +
               "id": "4368b488-b9d8-49ff-8ce3-78a12d46be83",
         | 
| 261 | 
            +
               "metadata": {},
         | 
| 262 | 
            +
               "outputs": [],
         | 
| 263 | 
            +
               "source": [
         | 
| 264 | 
            +
                "%%add_to MY_SSA_ASR\n",
         | 
| 265 | 
            +
                "def compute_forward(self, batch, stage):\n",
         | 
| 266 | 
            +
                "    \"\"\"Forward computations from the waveform batches to the output probabilities.\"\"\"\n",
         | 
| 267 | 
            +
                "    batch = batch.to(self.device)\n",
         | 
| 268 | 
            +
                "    wavs, wav_lens = batch.sig\n",
         | 
| 269 | 
            +
                "    wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)\n",
         | 
| 270 | 
            +
                "\n",
         | 
| 271 | 
            +
                "    # Downsample the inputs if specified\n",
         | 
| 272 | 
            +
                "    if hasattr(self.modules, \"downsampler\"):\n",
         | 
| 273 | 
            +
                "        wavs = self.modules.downsampler(wavs)\n",
         | 
| 274 | 
            +
                "\n",
         | 
| 275 | 
            +
                "    # Add waveform augmentation if specified.\n",
         | 
| 276 | 
            +
                "    if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
         | 
| 277 | 
            +
                "        wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)\n",
         | 
| 278 | 
            +
                "\n",
         | 
| 279 | 
            +
                "    # Forward pass\n",
         | 
| 280 | 
            +
                "    feats = self.modules.hubert(wavs, wav_lens)\n",
         | 
| 281 | 
            +
                "    x = self.modules.top_lin(feats)\n",
         | 
| 282 | 
            +
                "\n",
         | 
| 283 | 
            +
                "    # Compute outputs\n",
         | 
| 284 | 
            +
                "    logits = self.modules.ctc_lin(x)\n",
         | 
| 285 | 
            +
                "    p_ctc = self.hparams.log_softmax(logits)\n",
         | 
| 286 | 
            +
                "\n",
         | 
| 287 | 
            +
                "\n",
         | 
| 288 | 
            +
                "    p_tokens = None\n",
         | 
| 289 | 
            +
                "    if stage == sb.Stage.VALID:\n",
         | 
| 290 | 
            +
                "        p_tokens = sb.decoders.ctc_greedy_decode(p_ctc, wav_lens, blank_id=self.hparams.blank_index)\n",
         | 
| 291 | 
            +
                "\n",
         | 
| 292 | 
            +
                "    elif stage == sb.Stage.TEST:\n",
         | 
| 293 | 
            +
                "        p_tokens = test_searcher(p_ctc, wav_lens)\n",
         | 
| 294 | 
            +
                "\n",
         | 
| 295 | 
            +
                "        candidates = []\n",
         | 
| 296 | 
            +
                "        scores = []\n",
         | 
| 297 | 
            +
                "\n",
         | 
| 298 | 
            +
                "        for batch in p_tokens:\n",
         | 
| 299 | 
            +
                "            candidates.append([hyp.text for hyp in batch])\n",
         | 
| 300 | 
            +
                "            scores.append([hyp.score for hyp in batch])\n",
         | 
| 301 | 
            +
                "\n",
         | 
| 302 | 
            +
                "        if hasattr(self.hparams, \"rescorer\"):\n",
         | 
| 303 | 
            +
                "            p_tokens, _ = self.hparams.rescorer.rescore(candidates, scores)\n",
         | 
| 304 | 
            +
                "\n",
         | 
| 305 | 
            +
                "    return p_ctc, wav_lens, p_tokens\n"
         | 
| 306 | 
            +
               ]
         | 
| 307 | 
            +
              },
         | 
| 308 | 
            +
              {
         | 
| 309 | 
            +
               "cell_type": "markdown",
         | 
| 310 | 
            +
               "id": "f0052b79-5a27-4c4c-8601-7ab064e8c951",
         | 
| 311 | 
            +
               "metadata": {},
         | 
| 312 | 
            +
               "source": [
         | 
| 313 | 
            +
                "### 2.4 Définition de la fonction objectives"
         | 
| 314 | 
            +
               ]
         | 
| 315 | 
            +
              },
         | 
| 316 | 
            +
              {
         | 
| 317 | 
            +
               "cell_type": "code",
         | 
| 318 | 
            +
               "execution_count": null,
         | 
| 319 | 
            +
               "id": "3608aee8-c9c3-4e34-98bc-667513fa7f7b",
         | 
| 320 | 
            +
               "metadata": {},
         | 
| 321 | 
            +
               "outputs": [],
         | 
| 322 | 
            +
               "source": [
         | 
| 323 | 
            +
                "%%add_to MY_SSA_ASR\n",
         | 
| 324 | 
            +
                "def compute_objectives(self, predictions, batch, stage):\n",
         | 
| 325 | 
            +
                "    \"\"\"Computes the loss (CTC+NLL) given predictions and targets.\"\"\"\n",
         | 
| 326 | 
            +
                "\n",
         | 
| 327 | 
            +
                "    p_ctc, wav_lens, predicted_tokens = predictions\n",
         | 
| 328 | 
            +
                "\n",
         | 
| 329 | 
            +
                "    ids = batch.id\n",
         | 
| 330 | 
            +
                "    tokens, tokens_lens = batch.tokens\n",
         | 
| 331 | 
            +
                "\n",
         | 
| 332 | 
            +
                "    # Labels must be extended if parallel augmentation or concatenated\n",
         | 
| 333 | 
            +
                "    # augmentation was performed on the input (increasing the time dimension)\n",
         | 
| 334 | 
            +
                "    if stage == sb.Stage.TRAIN and hasattr(self.hparams, \"wav_augment\"):\n",
         | 
| 335 | 
            +
                "        (tokens, tokens_lens) = self.hparams.wav_augment.replicate_multiple_labels(tokens, tokens_lens)\n",
         | 
| 336 | 
            +
                "\n",
         | 
| 337 | 
            +
                "\n",
         | 
| 338 | 
            +
                "\n",
         | 
| 339 | 
            +
                "    # Compute loss\n",
         | 
| 340 | 
            +
                "    loss = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)\n",
         | 
| 341 | 
            +
                "\n",
         | 
| 342 | 
            +
                "    if stage == sb.Stage.VALID:\n",
         | 
| 343 | 
            +
                "        # Decode token terms to words\n",
         | 
| 344 | 
            +
                "        predicted_words = [\"\".join(self.tokenizer.decode_ndim(utt_seq)).split(\" \") for utt_seq in predicted_tokens]\n",
         | 
| 345 | 
            +
                "        \n",
         | 
| 346 | 
            +
                "    elif stage == sb.Stage.TEST:\n",
         | 
| 347 | 
            +
                "        predicted_words = [hyp[0].text.split(\" \") for hyp in predicted_tokens]\n",
         | 
| 348 | 
            +
                "\n",
         | 
| 349 | 
            +
                "    if stage != sb.Stage.TRAIN:\n",
         | 
| 350 | 
            +
                "        target_words = [wrd.split(\" \") for wrd in batch.wrd]\n",
         | 
| 351 | 
            +
                "        self.wer_metric.append(ids, predicted_words, target_words)\n",
         | 
| 352 | 
            +
                "        self.cer_metric.append(ids, predicted_words, target_words)\n",
         | 
| 353 | 
            +
                "\n",
         | 
| 354 | 
            +
                "    return loss\n"
         | 
| 355 | 
            +
               ]
         | 
| 356 | 
            +
              },
         | 
| 357 | 
            +
              {
         | 
| 358 | 
            +
               "cell_type": "markdown",
         | 
| 359 | 
            +
               "id": "9a514c50-89ad-41cb-882a-23daf829a538",
         | 
| 360 | 
            +
               "metadata": {},
         | 
| 361 | 
            +
               "source": [
         | 
| 362 | 
            +
                "### 2.5 définition du comportement au début d'un \"stage\""
         | 
| 363 | 
            +
               ]
         | 
| 364 | 
            +
              },
         | 
| 365 | 
            +
              {
         | 
| 366 | 
            +
               "cell_type": "code",
         | 
| 367 | 
            +
               "execution_count": null,
         | 
| 368 | 
            +
               "id": "609814ce-3ef0-4818-a70f-cadc293c9dd2",
         | 
| 369 | 
            +
               "metadata": {},
         | 
| 370 | 
            +
               "outputs": [],
         | 
| 371 | 
            +
               "source": [
         | 
| 372 | 
            +
                "%%add_to MY_SSA_ASR\n",
         | 
| 373 | 
            +
                "# stage gestion\n",
         | 
| 374 | 
            +
                "def on_stage_start(self, stage, epoch):\n",
         | 
| 375 | 
            +
                "    \"\"\"Gets called at the beginning of each epoch\"\"\"\n",
         | 
| 376 | 
            +
                "    if stage != sb.Stage.TRAIN:\n",
         | 
| 377 | 
            +
                "        self.cer_metric = self.hparams.cer_computer()\n",
         | 
| 378 | 
            +
                "        self.wer_metric = self.hparams.error_rate_computer()\n",
         | 
| 379 | 
            +
                "\n",
         | 
| 380 | 
            +
                "    if stage == sb.Stage.TEST:\n",
         | 
| 381 | 
            +
                "        if hasattr(self.hparams, \"rescorer\"):\n",
         | 
| 382 | 
            +
                "            self.hparams.rescorer.move_rescorers_to_device()\n",
         | 
| 383 | 
            +
                "\n"
         | 
| 384 | 
            +
               ]
         | 
| 385 | 
            +
              },
         | 
| 386 | 
            +
              {
         | 
| 387 | 
            +
               "cell_type": "markdown",
         | 
| 388 | 
            +
               "id": "55929209-c94a-4f8b-8f2e-9dd5d9de8be9",
         | 
| 389 | 
            +
               "metadata": {},
         | 
| 390 | 
            +
               "source": [
         | 
| 391 | 
            +
                "### 2.6 définition du comportement à la fin d'un \"stage\""
         | 
| 392 | 
            +
               ]
         | 
| 393 | 
            +
              },
         | 
| 394 | 
            +
              {
         | 
| 395 | 
            +
               "cell_type": "code",
         | 
| 396 | 
            +
               "execution_count": null,
         | 
| 397 | 
            +
               "id": "8f297542-10d5-47bf-9938-c141f5a99ab8",
         | 
| 398 | 
            +
               "metadata": {},
         | 
| 399 | 
            +
               "outputs": [],
         | 
| 400 | 
            +
               "source": [
         | 
| 401 | 
            +
                "%%add_to MY_SSA_ASR\n",
         | 
| 402 | 
            +
                "def on_stage_end(self, stage, stage_loss, epoch):\n",
         | 
| 403 | 
            +
                "    \"\"\"Gets called at the end of an epoch.\"\"\"\n",
         | 
| 404 | 
            +
                "    # Compute/store important stats\n",
         | 
| 405 | 
            +
                "    stage_stats = {\"loss\": stage_loss}\n",
         | 
| 406 | 
            +
                "    if stage == sb.Stage.TRAIN:\n",
         | 
| 407 | 
            +
                "        self.train_stats = stage_stats\n",
         | 
| 408 | 
            +
                "    else:\n",
         | 
| 409 | 
            +
                "        stage_stats[\"CER\"] = self.cer_metric.summarize(\"error_rate\")\n",
         | 
| 410 | 
            +
                "        stage_stats[\"WER\"] = self.wer_metric.summarize(\"error_rate\")\n",
         | 
| 411 | 
            +
                "\n",
         | 
| 412 | 
            +
                "    # Perform end-of-iteration things, like annealing, logging, etc.\n",
         | 
| 413 | 
            +
                "    if stage == sb.Stage.VALID:\n",
         | 
| 414 | 
            +
                "        # *******************************\n",
         | 
| 415 | 
            +
                "        # Anneal and update Learning Rate\n",
         | 
| 416 | 
            +
                "        # *******************************\n",
         | 
| 417 | 
            +
                "        old_lr_model, new_lr_model = self.hparams.lr_annealing_model(stage_stats[\"loss\"])\n",
         | 
| 418 | 
            +
                "        old_lr_hubert, new_lr_hubert = self.hparams.lr_annealing_hubert(stage_stats[\"loss\"])\n",
         | 
| 419 | 
            +
                "        sb.nnet.schedulers.update_learning_rate(self.model_optimizer, new_lr_model)\n",
         | 
| 420 | 
            +
                "        sb.nnet.schedulers.update_learning_rate(self.hubert_optimizer, new_lr_hubert)\n",
         | 
| 421 | 
            +
                "\n",
         | 
| 422 | 
            +
                "        # *****************\n",
         | 
| 423 | 
            +
                "        # Logs informations\n",
         | 
| 424 | 
            +
                "        # *****************\n",
         | 
| 425 | 
            +
                "        self.hparams.train_logger.log_stats(stats_meta={\"epoch\": epoch, \"lr_model\": old_lr_model, \"lr_hubert\": old_lr_hubert}, train_stats=self.train_stats, valid_stats=stage_stats)\n",
         | 
| 426 | 
            +
                "\n",
         | 
| 427 | 
            +
                "        # ***************\n",
         | 
| 428 | 
            +
                "        # Save checkpoint\n",
         | 
| 429 | 
            +
                "        # ***************\n",
         | 
| 430 | 
            +
                "        self.checkpointer.save_and_keep_only(meta={\"WER\": stage_stats[\"WER\"]},min_keys=[\"WER\"])\n",
         | 
| 431 | 
            +
                "\n",
         | 
| 432 | 
            +
                "    elif stage == sb.Stage.TEST:\n",
         | 
| 433 | 
            +
                "        self.hparams.train_logger.log_stats(stats_meta={\"Epoch loaded\": self.hparams.epoch_counter.current},test_stats=stage_stats)\n",
         | 
| 434 | 
            +
                "        if if_main_process():\n",
         | 
| 435 | 
            +
                "            with open(self.hparams.test_wer_file, \"w\") as w:\n",
         | 
| 436 | 
            +
                "                self.wer_metric.write_stats(w)\n"
         | 
| 437 | 
            +
               ]
         | 
| 438 | 
            +
              },
         | 
| 439 | 
            +
              {
         | 
| 440 | 
            +
               "cell_type": "markdown",
         | 
| 441 | 
            +
               "id": "0c656457-6b61-4316-8199-70021f92babf",
         | 
| 442 | 
            +
               "metadata": {},
         | 
| 443 | 
            +
               "source": [
         | 
| 444 | 
            +
                "### 2.7 définition de l'initialisation des optimizers"
         | 
| 445 | 
            +
               ]
         | 
| 446 | 
            +
              },
         | 
| 447 | 
            +
              {
         | 
| 448 | 
            +
               "cell_type": "code",
         | 
| 449 | 
            +
               "execution_count": null,
         | 
| 450 | 
            +
               "id": "da8d9cb5-c5ad-4e78-83d3-e129e138a741",
         | 
| 451 | 
            +
               "metadata": {},
         | 
| 452 | 
            +
               "outputs": [],
         | 
| 453 | 
            +
               "source": [
         | 
| 454 | 
            +
                "%%add_to MY_SSA_ASR\n",
         | 
| 455 | 
            +
                "def init_optimizers(self):\n",
         | 
| 456 | 
            +
                "    \"Initializes the hubert optimizer and model optimizer\"\n",
         | 
| 457 | 
            +
                "    self.hubert_optimizer = self.hparams.hubert_opt_class(self.modules.hubert.parameters())\n",
         | 
| 458 | 
            +
                "    self.model_optimizer = self.hparams.model_opt_class(self.hparams.model.parameters())\n",
         | 
| 459 | 
            +
                "\n",
         | 
| 460 | 
            +
                "    # save the optimizers in a dictionary\n",
         | 
| 461 | 
            +
                "    # the key will be used in `freeze_optimizers()`\n",
         | 
| 462 | 
            +
                "    self.optimizers_dict = {\"model_optimizer\": self.model_optimizer}\n",
         | 
| 463 | 
            +
                "    if not self.hparams.freeze_hubert:\n",
         | 
| 464 | 
            +
                "        self.optimizers_dict[\"hubert_optimizer\"] = self.hubert_optimizer\n",
         | 
| 465 | 
            +
                "\n",
         | 
| 466 | 
            +
                "    if self.checkpointer is not None:\n",
         | 
| 467 | 
            +
                "        self.checkpointer.add_recoverable(\"hubert_opt\", self.hubert_optimizer)\n",
         | 
| 468 | 
            +
                "        self.checkpointer.add_recoverable(\"model_opt\", self.model_optimizer)\n"
         | 
| 469 | 
            +
               ]
         | 
| 470 | 
            +
              },
         | 
| 471 | 
            +
              {
         | 
| 472 | 
            +
               "cell_type": "markdown",
         | 
| 473 | 
            +
               "id": "cf2e730c-2faa-41f2-b98d-e5fbb2305cc2",
         | 
| 474 | 
            +
               "metadata": {},
         | 
| 475 | 
            +
               "source": [
         | 
| 476 | 
            +
                "## 3 Définition de la lecture des datasets"
         | 
| 477 | 
            +
               ]
         | 
| 478 | 
            +
              },
         | 
| 479 | 
            +
              {
         | 
| 480 | 
            +
               "cell_type": "code",
         | 
| 481 | 
            +
               "execution_count": null,
         | 
| 482 | 
            +
               "id": "c5e667f7-6269-4b49-88bb-5e431762c8fe",
         | 
| 483 | 
            +
               "metadata": {},
         | 
| 484 | 
            +
               "outputs": [],
         | 
| 485 | 
            +
               "source": [
         | 
| 486 | 
            +
                "def dataio_prepare(hparams):\n",
         | 
| 487 | 
            +
                "    \"\"\"This function prepares the datasets to be used in the brain class.\n",
         | 
| 488 | 
            +
                "    It also defines the data processing pipeline through user-defined functions.\n",
         | 
| 489 | 
            +
                "    \"\"\"\n",
         | 
| 490 | 
            +
                "\n",
         | 
| 491 | 
            +
                "    # **************\n",
         | 
| 492 | 
            +
                "    # Load CSV files\n",
         | 
| 493 | 
            +
                "    # **************\n",
         | 
| 494 | 
            +
                "    data_folder = hparams[\"data_folder\"]\n",
         | 
| 495 | 
            +
                "\n",
         | 
| 496 | 
            +
                "    train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"train_csv\"],replacements={\"data_root\": data_folder})\n",
         | 
| 497 | 
            +
                "    # we sort training data to speed up training and get better results.\n",
         | 
| 498 | 
            +
                "    train_data = train_data.filtered_sorted(sort_key=\"duration\")\n",
         | 
| 499 | 
            +
                "    hparams[\"train_dataloader_opts\"][\"shuffle\"] = False # when sorting do not shuffle in dataloader ! otherwise is pointless\n",
         | 
| 500 | 
            +
                "\n",
         | 
| 501 | 
            +
                "    valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=hparams[\"valid_csv\"],replacements={\"data_root\": data_folder})\n",
         | 
| 502 | 
            +
                "    valid_data = valid_data.filtered_sorted(sort_key=\"duration\")\n",
         | 
| 503 | 
            +
                "\n",
         | 
| 504 | 
            +
                "    # test is separate\n",
         | 
| 505 | 
            +
                "    test_datasets = {}\n",
         | 
| 506 | 
            +
                "    for csv_file in hparams[\"test_csv\"]:\n",
         | 
| 507 | 
            +
                "        name = Path(csv_file).stem\n",
         | 
| 508 | 
            +
                "        test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(csv_path=csv_file, replacements={\"data_root\": data_folder})\n",
         | 
| 509 | 
            +
                "        test_datasets[name] = test_datasets[name].filtered_sorted(sort_key=\"duration\")\n",
         | 
| 510 | 
            +
                "\n",
         | 
| 511 | 
            +
                "    datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]\n",
         | 
| 512 | 
            +
                "\n",
         | 
| 513 | 
            +
                "    # *************************\n",
         | 
| 514 | 
            +
                "    # 2. Define audio pipeline:\n",
         | 
| 515 | 
            +
                "    # *************************\n",
         | 
| 516 | 
            +
                "    @sb.utils.data_pipeline.takes(\"wav\")\n",
         | 
| 517 | 
            +
                "    @sb.utils.data_pipeline.provides(\"sig\")\n",
         | 
| 518 | 
            +
                "    def audio_pipeline(wav):\n",
         | 
| 519 | 
            +
                "        sig = sb.dataio.dataio.read_audio(wav)\n",
         | 
| 520 | 
            +
                "        return sig\n",
         | 
| 521 | 
            +
                "\n",
         | 
| 522 | 
            +
                "    sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)\n",
         | 
| 523 | 
            +
                "\n",
         | 
| 524 | 
            +
                "    # ************************\n",
         | 
| 525 | 
            +
                "    # 3. Define text pipeline:\n",
         | 
| 526 | 
            +
                "    # ************************\n",
         | 
| 527 | 
            +
                "    label_encoder = sb.dataio.encoder.CTCTextEncoder()\n",
         | 
| 528 | 
            +
                "    \n",
         | 
| 529 | 
            +
                "    @sb.utils.data_pipeline.takes(\"wrd\")\n",
         | 
| 530 | 
            +
                "    @sb.utils.data_pipeline.provides(\"wrd\", \"char_list\", \"tokens_list\", \"tokens\")\n",
         | 
| 531 | 
            +
                "    def text_pipeline(wrd):\n",
         | 
| 532 | 
            +
                "        yield wrd\n",
         | 
| 533 | 
            +
                "        char_list = list(wrd)\n",
         | 
| 534 | 
            +
                "        yield char_list\n",
         | 
| 535 | 
            +
                "        tokens_list = label_encoder.encode_sequence(char_list)\n",
         | 
| 536 | 
            +
                "        yield tokens_list\n",
         | 
| 537 | 
            +
                "        tokens = torch.LongTensor(tokens_list)\n",
         | 
| 538 | 
            +
                "        yield tokens\n",
         | 
| 539 | 
            +
                "\n",
         | 
| 540 | 
            +
                "    sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)\n",
         | 
| 541 | 
            +
                "\n",
         | 
| 542 | 
            +
                "\n",
         | 
| 543 | 
            +
                "    # *******************************\n",
         | 
| 544 | 
            +
                "    # 4. Create or load label encoder\n",
         | 
| 545 | 
            +
                "    # *******************************\n",
         | 
| 546 | 
            +
                "    lab_enc_file = os.path.join(hparams[\"save_folder\"], \"label_encoder.txt\")\n",
         | 
| 547 | 
            +
                "    special_labels = {\"blank_label\": hparams[\"blank_index\"]}\n",
         | 
| 548 | 
            +
                "    label_encoder.add_unk()\n",
         | 
| 549 | 
            +
                "    label_encoder.load_or_create(path=lab_enc_file, from_didatasets=[train_data], output_key=\"char_list\", special_labels=special_labels, sequence_input=True)\n",
         | 
| 550 | 
            +
                "\n",
         | 
| 551 | 
            +
                "    # **************\n",
         | 
| 552 | 
            +
                "    # 5. Set output:\n",
         | 
| 553 | 
            +
                "    # **************\n",
         | 
| 554 | 
            +
                "    sb.dataio.dataset.set_output_keys(datasets,[\"id\", \"sig\", \"wrd\", \"char_list\", \"tokens\"],)\n",
         | 
| 555 | 
            +
                "\n",
         | 
| 556 | 
            +
                "    return train_data, valid_data, test_datasets, label_encoder\n"
         | 
| 557 | 
            +
               ]
         | 
| 558 | 
            +
              },
         | 
| 559 | 
            +
              {
         | 
| 560 | 
            +
               "cell_type": "markdown",
         | 
| 561 | 
            +
               "id": "e97c4f20-6951-4d12-8e17-9eb818a52bb1",
         | 
| 562 | 
            +
               "metadata": {},
         | 
| 563 | 
            +
               "source": [
         | 
| 564 | 
            +
                "## 4. Utilisation de la recette Créée"
         | 
| 565 | 
            +
               ]
         | 
| 566 | 
            +
              },
         | 
| 567 | 
            +
              {
         | 
| 568 | 
            +
               "cell_type": "markdown",
         | 
| 569 | 
            +
               "id": "76b72148-6bd0-48bd-ad40-cb6f8bfd34c0",
         | 
| 570 | 
            +
               "metadata": {},
         | 
| 571 | 
            +
               "source": [
         | 
| 572 | 
            +
                "### 4.1 Préparation au lancement"
         | 
| 573 | 
            +
               ]
         | 
| 574 | 
            +
              },
         | 
| 575 | 
            +
              {
         | 
| 576 | 
            +
               "cell_type": "code",
         | 
| 577 | 
            +
               "execution_count": null,
         | 
| 578 | 
            +
               "id": "d47ec39a-5562-4a63-8243-656c9235b7a2",
         | 
| 579 | 
            +
               "metadata": {},
         | 
| 580 | 
            +
               "outputs": [],
         | 
| 581 | 
            +
               "source": [
         | 
| 582 | 
            +
                "hparams_file, run_opts, overrides = sb.parse_arguments([\"PATH_TO_YOUR_FOLDER/ASR_FLEURS-swahili_hf.yaml\"])\n",
         | 
| 583 | 
            +
                "# create ddp_group with the right communication protocol\n",
         | 
| 584 | 
            +
                "sb.utils.distributed.ddp_init_group(run_opts)\n",
         | 
| 585 | 
            +
                "\n",
         | 
| 586 | 
            +
                "# ***********************************\n",
         | 
| 587 | 
            +
                "# Chargement du fichier de paramètres\n",
         | 
| 588 | 
            +
                "# ***********************************\n",
         | 
| 589 | 
            +
                "with open(hparams_file) as fin:\n",
         | 
| 590 | 
            +
                "    hparams = load_hyperpyyaml(fin, overrides)\n",
         | 
| 591 | 
            +
                "\n",
         | 
| 592 | 
            +
                "# ***************************\n",
         | 
| 593 | 
            +
                "# Create experiment directory\n",
         | 
| 594 | 
            +
                "# ***************************\n",
         | 
| 595 | 
            +
                "sb.create_experiment_directory(experiment_directory=hparams[\"output_folder\"], hyperparams_to_save=hparams_file, overrides=overrides)\n",
         | 
| 596 | 
            +
                "\n",
         | 
| 597 | 
            +
                "# ***************************\n",
         | 
| 598 | 
            +
                "# Create the datasets objects\n",
         | 
| 599 | 
            +
                "# ***************************\n",
         | 
| 600 | 
            +
                "train_data, valid_data, test_datasets, label_encoder = dataio_prepare(hparams)\n",
         | 
| 601 | 
            +
                "\n",
         | 
| 602 | 
            +
                "# **********************\n",
         | 
| 603 | 
            +
                "# Trainer initialization\n",
         | 
| 604 | 
            +
                "# **********************\n",
         | 
| 605 | 
            +
                "asr_brain = MY_SSA_ASR(modules=hparams[\"modules\"], hparams=hparams, run_opts=run_opts, checkpointer=hparams[\"checkpointer\"])\n",
         | 
| 606 | 
            +
                "asr_brain.tokenizer = label_encoder"
         | 
| 607 | 
            +
               ]
         | 
| 608 | 
            +
              },
         | 
| 609 | 
            +
              {
         | 
| 610 | 
            +
               "cell_type": "markdown",
         | 
| 611 | 
            +
               "id": "62ae72eb-416c-4ef0-9348-d02bbc268fbd",
         | 
| 612 | 
            +
               "metadata": {},
         | 
| 613 | 
            +
               "source": [
         | 
| 614 | 
            +
                "### 4.2 Apprentissage du modèle"
         | 
| 615 | 
            +
               ]
         | 
| 616 | 
            +
              },
         | 
| 617 | 
            +
              {
         | 
| 618 | 
            +
               "cell_type": "code",
         | 
| 619 | 
            +
               "execution_count": null,
         | 
| 620 | 
            +
               "id": "d3dd30ee-89c0-40ea-a9d2-0e2b9d8c8686",
         | 
| 621 | 
            +
               "metadata": {},
         | 
| 622 | 
            +
               "outputs": [],
         | 
| 623 | 
            +
               "source": [
         | 
| 624 | 
            +
                "# ********\n",
         | 
| 625 | 
            +
                "# Training\n",
         | 
| 626 | 
            +
                "# ********\n",
         | 
| 627 | 
            +
                "asr_brain.fit(asr_brain.hparams.epoch_counter, \n",
         | 
| 628 | 
            +
                "              train_data, valid_data, \n",
         | 
| 629 | 
            +
                "              train_loader_kwargs=hparams[\"train_dataloader_opts\"], \n",
         | 
| 630 | 
            +
                "              valid_loader_kwargs=hparams[\"valid_dataloader_opts\"],\n",
         | 
| 631 | 
            +
                "             )\n",
         | 
| 632 | 
            +
                "\n"
         | 
| 633 | 
            +
               ]
         | 
| 634 | 
            +
              },
         | 
| 635 | 
            +
              {
         | 
| 636 | 
            +
               "cell_type": "markdown",
         | 
| 637 | 
            +
               "id": "1b55af4c-c544-45ff-8435-58226218328f",
         | 
| 638 | 
            +
               "metadata": {},
         | 
| 639 | 
            +
               "source": [
         | 
| 640 | 
            +
                "### 4.3 Test du Modèle"
         | 
| 641 | 
            +
               ]
         | 
| 642 | 
            +
              },
         | 
| 643 | 
            +
              {
         | 
| 644 | 
            +
               "cell_type": "code",
         | 
| 645 | 
            +
               "execution_count": null,
         | 
| 646 | 
            +
               "id": "9cef9011-1a3e-43a4-ab16-8cfb2b57dbd9",
         | 
| 647 | 
            +
               "metadata": {},
         | 
| 648 | 
            +
               "outputs": [],
         | 
| 649 | 
            +
               "source": [
         | 
| 650 | 
            +
                "# *******\n",
         | 
| 651 | 
            +
                "# Testing\n",
         | 
| 652 | 
            +
                "# *******\n",
         | 
| 653 | 
            +
                "if not os.path.exists(hparams[\"output_wer_folder\"]):\n",
         | 
| 654 | 
            +
                "    os.makedirs(hparams[\"output_wer_folder\"])\n",
         | 
| 655 | 
            +
                "\n",
         | 
| 656 | 
            +
                "from speechbrain.decoders.ctc import CTCBeamSearcher\n",
         | 
| 657 | 
            +
                "\n",
         | 
| 658 | 
            +
                "ind2lab = label_encoder.ind2lab\n",
         | 
| 659 | 
            +
                "vocab_list = [ind2lab[x] for x in range(len(ind2lab))]\n",
         | 
| 660 | 
            +
                "test_searcher = CTCBeamSearcher(**hparams[\"test_beam_search\"], vocab_list=vocab_list)\n",
         | 
| 661 | 
            +
                "\n",
         | 
| 662 | 
            +
                "for k in test_datasets.keys():  # Allow multiple evaluation throught list of test sets\n",
         | 
| 663 | 
            +
                "    asr_brain.hparams.test_wer_file = os.path.join(hparams[\"output_wer_folder\"], f\"wer_{k}.txt\")\n",
         | 
| 664 | 
            +
                "    asr_brain.evaluate(test_datasets[k], test_loader_kwargs=hparams[\"test_dataloader_opts\"], min_key=\"WER\")\n"
         | 
| 665 | 
            +
               ]
         | 
| 666 | 
            +
              }
         | 
| 667 | 
            +
             ],
         | 
| 668 | 
            +
             "metadata": {
         | 
| 669 | 
            +
              "kernelspec": {
         | 
| 670 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 671 | 
            +
               "language": "python",
         | 
| 672 | 
            +
               "name": "python3"
         | 
| 673 | 
            +
              },
         | 
| 674 | 
            +
              "language_info": {
         | 
| 675 | 
            +
               "codemirror_mode": {
         | 
| 676 | 
            +
                "name": "ipython",
         | 
| 677 | 
            +
                "version": 3
         | 
| 678 | 
            +
               },
         | 
| 679 | 
            +
               "file_extension": ".py",
         | 
| 680 | 
            +
               "mimetype": "text/x-python",
         | 
| 681 | 
            +
               "name": "python",
         | 
| 682 | 
            +
               "nbconvert_exporter": "python",
         | 
| 683 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 684 | 
            +
               "version": "3.10.14"
         | 
| 685 | 
            +
              }
         | 
| 686 | 
            +
             },
         | 
| 687 | 
            +
             "nbformat": 4,
         | 
| 688 | 
            +
             "nbformat_minor": 5
         | 
| 689 | 
            +
            }
         | 

