Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import copy | |
| import datetime | |
| import logging | |
| import sys | |
| from contextlib import nullcontext | |
| # if your python version < 3.7 use the below one | |
| # from contextlib import suppress as nullcontext | |
| import torch | |
| from wenet.utils.common import StepTimer | |
| from wenet.utils.train_utils import (wenet_join, batch_forward, batch_backward, | |
| update_parameter_and_lr, log_per_step, | |
| save_model) | |
| class Executor: | |
| def __init__(self, | |
| global_step: int = 0, | |
| device: torch.device = torch.device("cpu")): | |
| self.step = global_step + 1 | |
| self.train_step_timer = None | |
| self.cv_step_timer = None | |
| self.device = device | |
| def train(self, model, optimizer, scheduler, train_data_loader, | |
| cv_data_loader, writer, configs, scaler, group_join): | |
| ''' Train one epoch | |
| ''' | |
| if self.train_step_timer is None: | |
| self.train_step_timer = StepTimer(self.step) | |
| model.train() | |
| info_dict = copy.deepcopy(configs) | |
| logging.info('using accumulate grad, new batch size is {} times' | |
| ' larger than before'.format(info_dict['accum_grad'])) | |
| # A context manager to be used in conjunction with an instance of | |
| # torch.nn.parallel.DistributedDataParallel to be able to train | |
| # with uneven inputs across participating processes. | |
| if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
| model_context = model.join | |
| else: | |
| model_context = nullcontext | |
| with model_context(): | |
| for batch_idx, batch_dict in enumerate(train_data_loader): | |
| info_dict["tag"] = "TRAIN" | |
| info_dict["step"] = self.step | |
| info_dict["batch_idx"] = batch_idx | |
| if wenet_join(group_join, info_dict): | |
| break # fix by zhaoyi ,促进多机训练 | |
| if batch_dict["target_lengths"].size(0) == 0: | |
| continue | |
| context = None | |
| # Disable gradient synchronizations across DDP processes. | |
| # Within this context, gradients will be accumulated on module | |
| # variables, which will later be synchronized. | |
| if info_dict.get("train_engine", "torch_ddp") in [ | |
| "torch_ddp", "torch_fsdp" | |
| ] and (batch_idx + 1) % info_dict["accum_grad"] != 0: | |
| context = model.no_sync | |
| # Used for single gpu training and DDP gradient synchronization | |
| # processes. | |
| else: | |
| context = nullcontext | |
| with context(): | |
| info_dict = batch_forward(model, batch_dict, scaler, | |
| info_dict, self.device) | |
| info_dict = batch_backward(model, scaler, info_dict) | |
| info_dict = update_parameter_and_lr(model, optimizer, | |
| scheduler, scaler, | |
| info_dict) | |
| # write training: tensorboard && log | |
| log_per_step(writer, info_dict, timer=self.train_step_timer) | |
| # save_interval = info_dict.get('save_interval', sys.maxsize) | |
| # if (self.step + | |
| # 1) % save_interval == 0 and self.step != 0 and ( | |
| # batch_idx + 1) % info_dict["accum_grad"] == 0: | |
| # import torch.distributed as dist | |
| # # Ensure all ranks start CV at the same time in step mode | |
| # dist.barrier() | |
| # # loss_dict = self.cv(model, cv_data_loader, configs) | |
| # model.train() | |
| # info_dict.update({ | |
| # "tag": | |
| # "step_{}".format(self.step), | |
| # "loss_dict": {'loss':999,'acc':999}, | |
| # "save_time": | |
| # datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), | |
| # "lrs": | |
| # [group['lr'] for group in optimizer.param_groups] | |
| # }) | |
| # save_model(model, info_dict) | |
| # # write final cv: tensorboard | |
| # log_per_step(writer, info_dict) | |
| # # Ensure all ranks start Train at the same time in step mode | |
| # dist.barrier() | |
| self.step += 1 if (batch_idx + | |
| 1) % info_dict["accum_grad"] == 0 else 0 | |
| def cv(self, model, cv_data_loader, configs): | |
| ''' Cross validation on | |
| ''' | |
| if self.cv_step_timer is None: | |
| self.cv_step_timer = StepTimer(0.0) | |
| else: | |
| self.cv_step_timer.last_iteration = 0.0 | |
| model.eval() | |
| info_dict = copy.deepcopy(configs) | |
| num_seen_utts, loss_dict, total_acc = 1, {}, [] # avoid division by 0 | |
| with torch.no_grad(): | |
| for batch_idx, batch_dict in enumerate(cv_data_loader): | |
| info_dict["tag"] = "CV" | |
| info_dict["step"] = self.step | |
| info_dict["batch_idx"] = batch_idx | |
| info_dict["cv_step"] = batch_idx | |
| num_utts = batch_dict["target_lengths"].size(0) | |
| if num_utts == 0: | |
| continue | |
| info_dict = batch_forward(model, batch_dict, None, info_dict, | |
| self.device) | |
| _dict = info_dict["loss_dict"] | |
| num_seen_utts += num_utts | |
| total_acc.append(_dict['th_accuracy'].item( | |
| ) if _dict.get('th_accuracy', None) is not None else 0.0) | |
| for loss_name, loss_value in _dict.items(): | |
| if loss_value is not None and "loss" in loss_name \ | |
| and torch.isfinite(loss_value): | |
| loss_value = loss_value.item() | |
| loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \ | |
| loss_value * num_utts | |
| # write cv: log | |
| log_per_step(writer=None, | |
| info_dict=info_dict, | |
| timer=self.cv_step_timer) | |
| for loss_name, loss_value in loss_dict.items(): | |
| loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts | |
| loss_dict["acc"] = sum(total_acc) / len(total_acc) | |
| return loss_dict |