Spaces:
Runtime error
Runtime error
| # Copyright 2022 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import argparse | |
| import math | |
| from abc import ABC | |
| from functools import partial | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP | |
| from ..optimizer import AcceleratedOptimizer | |
| from ..scheduler import AcceleratedScheduler | |
| from .imports import is_megatron_lm_available, is_transformers_available | |
| from .operations import recursively_apply, send_to_device | |
| if is_transformers_available(): | |
| from transformers.modeling_outputs import ( | |
| CausalLMOutputWithCrossAttentions, | |
| Seq2SeqLMOutput, | |
| SequenceClassifierOutput, | |
| ) | |
| if is_megatron_lm_available(): | |
| from megatron import ( | |
| get_args, | |
| get_num_microbatches, | |
| get_tensorboard_writer, | |
| get_timers, | |
| get_tokenizer, | |
| mpu, | |
| print_rank_0, | |
| print_rank_last, | |
| ) | |
| from megatron.arguments import _add_data_args, _add_validation_args, parse_args, validate_args | |
| from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint | |
| from megatron.data.data_samplers import MegatronPretrainingRandomSampler, MegatronPretrainingSampler | |
| from megatron.global_vars import set_global_variables | |
| from megatron.initialize import ( | |
| _compile_dependencies, | |
| _init_autoresume, | |
| _set_random_seed, | |
| set_jit_fusion_options, | |
| write_args_to_tensorboard, | |
| ) | |
| from megatron.model import BertModel, Float16Module, GPTModel, ModelType, T5Model | |
| from megatron.model import DistributedDataParallel as LocalDDP | |
| from megatron.model.classification import Classification | |
| from megatron.optimizer import get_megatron_optimizer | |
| from megatron.schedules import get_forward_backward_func | |
| from megatron.text_generation.communication import broadcast_int_list, broadcast_tensor | |
| from megatron.text_generation.generation import ( | |
| beam_search_and_return_on_first_stage, | |
| generate_tokens_probs_and_return_on_first_stage, | |
| ) | |
| from megatron.tokenizer.tokenizer import _vocab_size_with_padding | |
| from megatron.training import get_model, get_optimizer_param_scheduler, training_log | |
| from megatron.utils import ( | |
| average_losses_across_data_parallel_group, | |
| calc_params_l2_norm, | |
| get_ltor_masks_and_position_ids, | |
| unwrap_model, | |
| ) | |
| # model utilities | |
| def model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True): | |
| """Build the model.""" | |
| args = get_args() | |
| mode = "pre-training" if args.pretraining_flag else "fine-tuning" | |
| if args.rank == 0: | |
| print(f"Building {args.model_type_name} model in the {mode} mode.") | |
| print( | |
| "The Megatron LM model weights are initialized at random in `accelerator.prepare`. " | |
| "Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup." | |
| ) | |
| if args.model_type_name == "bert": | |
| if args.pretraining_flag: | |
| num_tokentypes = 2 if args.bert_binary_head else 0 | |
| model = BertModel( | |
| num_tokentypes=num_tokentypes, | |
| add_binary_head=args.bert_binary_head, | |
| parallel_output=True, | |
| pre_process=pre_process, | |
| post_process=post_process, | |
| ) | |
| else: | |
| model = Classification( | |
| num_classes=args.num_labels, num_tokentypes=2, pre_process=pre_process, post_process=post_process | |
| ) | |
| elif args.model_type_name == "gpt": | |
| model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process) | |
| elif args.model_type_name == "t5": | |
| model = T5Model( | |
| num_tokentypes=0, | |
| parallel_output=True, | |
| pre_process=pre_process, | |
| post_process=post_process, | |
| add_encoder=add_encoder, | |
| add_decoder=add_decoder, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported model type: {args.model_type_name}") | |
| return model | |
| def prepare_model(accelerator): | |
| accelerator.print("Preparing model") | |
| args = get_args() | |
| if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None: | |
| if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None: | |
| raise ValueError( | |
| "You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`." | |
| ) | |
| custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function | |
| model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func) | |
| else: | |
| if args.model_type_name in ("bert", "gpt"): | |
| model_type = ModelType.encoder_or_decoder | |
| elif args.model_type_name == "t5": | |
| model_type = ModelType.encoder_and_decoder | |
| if args.pipeline_model_parallel_split_rank is None and args.pipeline_model_parallel_size > 1: | |
| args.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_size // 2 | |
| model = get_model(model_provider_func, model_type) | |
| return model | |
| # dataloader utilities | |
| class MegatronLMDummyDataLoader: | |
| """ | |
| Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training | |
| Args: | |
| **dataset_kwargs: Megatron data arguments. | |
| """ | |
| def __init__(self, **dataset_kwargs): | |
| parser = argparse.ArgumentParser() | |
| parser = _add_data_args(parser) | |
| parser = _add_validation_args(parser) | |
| data_args = parser.parse_known_args() | |
| self.dataset_args = vars(data_args[0]) | |
| self.dataset_args.update(dataset_kwargs) | |
| self.dataset_args["megatron_dataset_flag"] = True | |
| def set_megatron_data_args(self): | |
| args = get_args() | |
| for key, value in self.dataset_args.items(): | |
| setattr(args, key, value) | |
| def get_train_valid_test_datasets_provider(self): | |
| def train_valid_test_datasets_provider(train_val_test_num_samples): | |
| """Build train, valid, and test datasets.""" | |
| args = get_args() | |
| dataset_args = { | |
| "data_prefix": args.data_path, | |
| "data_impl": args.data_impl, | |
| "splits_string": args.split, | |
| "train_valid_test_num_samples": train_val_test_num_samples, | |
| "skip_warmup": (not args.mmap_warmup), | |
| "seed": args.seed, | |
| } | |
| if args.model_type_name == "bert": | |
| dataset_args.update( | |
| { | |
| "max_seq_length": args.seq_length, | |
| "masked_lm_prob": args.mask_prob, | |
| "short_seq_prob": args.short_seq_prob, | |
| "binary_head": args.bert_binary_head, | |
| } | |
| ) | |
| elif args.model_type_name == "gpt": | |
| dataset_args.update( | |
| { | |
| "seq_length": args.seq_length, | |
| } | |
| ) | |
| elif args.model_type_name == "t5": | |
| dataset_args.update( | |
| { | |
| "max_seq_length": args.encoder_seq_length, | |
| "max_seq_length_dec": args.decoder_seq_length, | |
| "masked_lm_prob": args.mask_prob, | |
| "short_seq_prob": args.short_seq_prob, | |
| "dataset_type": "t5", | |
| } | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported model type: {args.model_type_name}") | |
| if args.model_type_name == "gpt": | |
| from megatron.data.gpt_dataset import build_train_valid_test_datasets | |
| else: | |
| from megatron.data.dataset_utils import build_train_valid_test_datasets | |
| train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args) | |
| return train_ds, valid_ds, test_ds | |
| return train_valid_test_datasets_provider | |
| def build_pretraining_data_loader(self, dataset, consumed_samples): | |
| if dataset is None: | |
| return None | |
| args = get_args() | |
| micro_batch_size = args.micro_batch_size * args.num_micro_batches | |
| # Megatron sampler | |
| if args.dataloader_type == "single": | |
| batch_sampler = MegatronPretrainingSampler( | |
| total_samples=len(dataset), | |
| consumed_samples=consumed_samples, | |
| micro_batch_size=micro_batch_size, | |
| data_parallel_rank=mpu.get_data_parallel_rank(), | |
| data_parallel_size=mpu.get_data_parallel_world_size(), | |
| ) | |
| elif args.dataloader_type == "cyclic": | |
| batch_sampler = MegatronPretrainingRandomSampler( | |
| dataset, | |
| total_samples=len(dataset), | |
| consumed_samples=consumed_samples, | |
| micro_batch_size=micro_batch_size, | |
| data_parallel_rank=mpu.get_data_parallel_rank(), | |
| data_parallel_size=mpu.get_data_parallel_world_size(), | |
| data_sharding=args.data_sharding, | |
| ) | |
| else: | |
| raise Exception(f"{args.dataloader_type} dataloader type is not supported.") | |
| # Torch dataloader. | |
| return torch.utils.data.DataLoader( | |
| dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True | |
| ) | |
| def build_train_valid_test_data_iterators(self): | |
| def cyclic_iter(iter): | |
| while True: | |
| yield from iter | |
| args = get_args() | |
| (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) | |
| print_rank_0("> building train, validation, and test datasets ...") | |
| # Backward compatibility, assume fixed batch size. | |
| if args.iteration > 0 and args.consumed_train_samples == 0: | |
| assert args.train_samples is None, "only backward compatiblity support for iteration-based training" | |
| args.consumed_train_samples = args.iteration * args.global_batch_size | |
| if args.iteration > 0 and args.consumed_valid_samples == 0: | |
| if args.train_samples is None: | |
| args.consumed_valid_samples = ( | |
| (args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size | |
| ) | |
| # Data loader only on rank 0 of each model parallel group. | |
| if mpu.get_tensor_model_parallel_rank() == 0: | |
| # Number of train/valid/test samples. | |
| if args.train_samples: | |
| train_samples = args.train_samples | |
| else: | |
| train_samples = args.train_iters * args.global_batch_size | |
| eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters | |
| test_iters = args.eval_iters | |
| train_val_test_num_samples = [ | |
| train_samples, | |
| eval_iters * args.global_batch_size, | |
| test_iters * args.global_batch_size, | |
| ] | |
| print_rank_0(" > datasets target sizes (minimum size):") | |
| print_rank_0(f" train: {train_val_test_num_samples[0]}") | |
| print_rank_0(f" validation: {train_val_test_num_samples[1]}") | |
| print_rank_0(f" test: {train_val_test_num_samples[2]}") | |
| # Build the datasets. | |
| train_valid_test_datasets_provider = self.get_train_valid_test_datasets_provider() | |
| train_ds, valid_ds, test_ds = train_valid_test_datasets_provider(train_val_test_num_samples) | |
| # Build dataloders. | |
| train_dataloader = self.build_pretraining_data_loader(train_ds, args.consumed_train_samples) | |
| valid_dataloader = self.build_pretraining_data_loader(valid_ds, args.consumed_valid_samples) | |
| test_dataloader = self.build_pretraining_data_loader(test_ds, 0) | |
| # Flags to know if we need to do training/validation/testing. | |
| do_train = train_dataloader is not None and args.train_iters > 0 | |
| do_valid = valid_dataloader is not None and args.eval_iters > 0 | |
| do_test = test_dataloader is not None and args.eval_iters > 0 | |
| # Need to broadcast num_tokens and num_type_tokens. | |
| flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) | |
| else: | |
| flags = torch.cuda.LongTensor([0, 0, 0]) | |
| # Broadcast num tokens. | |
| torch.distributed.broadcast( | |
| flags, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() | |
| ) | |
| args.do_train = flags[0].item() | |
| args.do_valid = flags[1].item() | |
| args.do_test = flags[2].item() | |
| # Build iterators. | |
| dl_type = args.dataloader_type | |
| assert dl_type in ["single", "cyclic"] | |
| if train_dataloader is not None: | |
| train_data_iterator = ( | |
| iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader)) | |
| ) | |
| else: | |
| train_data_iterator = None | |
| if valid_dataloader is not None: | |
| valid_data_iterator = ( | |
| iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader)) | |
| ) | |
| else: | |
| valid_data_iterator = None | |
| if test_dataloader is not None: | |
| test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader)) | |
| else: | |
| test_data_iterator = None | |
| return train_data_iterator, valid_data_iterator, test_data_iterator | |
| def prepare_data_loader(accelerator, dataloader): | |
| accelerator.print("Preparing dataloader") | |
| args = get_args() | |
| if not args.megatron_dataset_flag: | |
| from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader | |
| args = get_args() | |
| micro_batch_size = args.micro_batch_size * args.num_micro_batches | |
| kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS} | |
| if kwargs["batch_size"] is None: | |
| if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler): | |
| kwargs["sampler"].batch_size = micro_batch_size | |
| else: | |
| del kwargs["sampler"] | |
| del kwargs["shuffle"] | |
| del kwargs["batch_size"] | |
| kwargs["batch_sampler"].batch_size = micro_batch_size | |
| else: | |
| del kwargs["batch_sampler"] | |
| kwargs["batch_size"] = micro_batch_size | |
| dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs) | |
| return prepare_data_loader( | |
| dataloader, | |
| accelerator.device, | |
| num_processes=mpu.get_data_parallel_world_size(), | |
| process_index=mpu.get_data_parallel_rank(), | |
| split_batches=accelerator.split_batches, | |
| put_on_device=True, | |
| rng_types=accelerator.rng_types.copy(), | |
| dispatch_batches=accelerator.dispatch_batches, | |
| ) | |
| else: | |
| if args.consumed_samples is not None: | |
| ( | |
| args.consumed_train_samples, | |
| args.consumed_valid_samples, | |
| args.consumed_test_samples, | |
| ) = args.consumed_samples | |
| else: | |
| args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0 | |
| ( | |
| train_data_iterator, | |
| valid_data_iterator, | |
| test_data_iterator, | |
| ) = dataloader.build_train_valid_test_data_iterators() | |
| return train_data_iterator, valid_data_iterator, test_data_iterator | |
| # optimizer utilities | |
| class MegatronLMOptimizerWrapper(AcceleratedOptimizer): | |
| def __init__(self, optimizer): | |
| super().__init__(optimizer, device_placement=False, scaler=None) | |
| def zero_grad(self, set_to_none=None): | |
| pass # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed | |
| def step(self): | |
| pass # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed | |
| def step_was_skipped(self): | |
| """Whether or not the optimizer step was done, or skipped because of gradient overflow.""" | |
| return self.optimizer.skipped_iter | |
| def prepare_optimizer(accelerator, model): | |
| accelerator.print("Preparing optimizer") | |
| args = get_args() | |
| optimizer = get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult) | |
| return optimizer | |
| # scheduler utilities | |
| class MegatronLMDummyScheduler: | |
| """ | |
| Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training | |
| loop when scheduler config is specified in the deepspeed config file. | |
| Args: | |
| optimizer (`torch.optim.optimizer.Optimizer`): | |
| The optimizer to wrap. | |
| total_num_steps (int): | |
| Total number of steps. | |
| warmup_num_steps (int): | |
| Number of steps for warmup. | |
| **kwargs (additional keyword arguments, *optional*): | |
| Other arguments. | |
| """ | |
| def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs): | |
| self.optimizer = optimizer | |
| self.total_num_steps = total_num_steps | |
| self.warmup_num_steps = warmup_num_steps | |
| self.kwargs = kwargs | |
| class MegatronLMSchedulerWrapper(AcceleratedScheduler): | |
| def __init__(self, scheduler, optimizers): | |
| super().__init__(scheduler, optimizers) | |
| def step(self, *args, **kwargs): | |
| return # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed | |
| def prepare_scheduler(accelerator, optimizer, scheduler): | |
| accelerator.print("Preparing scheduler") | |
| scheduler = get_optimizer_param_scheduler(optimizer) | |
| return scheduler | |
| class AbstractTrainStep(ABC): | |
| """Abstract class for batching, forward pass and loss handler.""" | |
| def __init__(self, name): | |
| super().__init__() | |
| self.name = name | |
| def get_batch_func(self): | |
| pass | |
| def get_forward_step_func(self): | |
| pass | |
| def get_loss_func(self): | |
| pass | |
| class BertTrainStep(AbstractTrainStep): | |
| """ | |
| Bert train step class. | |
| Args: | |
| args (`argparse.Namespace`): Megatron-LM arguments. | |
| """ | |
| def __init__(self, args): | |
| super().__init__("BertTrainStep") | |
| self.get_batch = self.get_batch_func(args.megatron_dataset_flag) | |
| self.loss_func = self.get_loss_func(args.pretraining_flag, args.num_labels) | |
| self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head) | |
| if not args.model_return_dict: | |
| self.model_output_class = None | |
| else: | |
| self.model_output_class = SequenceClassifierOutput | |
| def get_batch_func(self, megatron_dataset_flag): | |
| def get_batch_megatron(data_iterator): | |
| """Build the batch.""" | |
| # Items and their type. | |
| keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] | |
| datatype = torch.int64 | |
| # Broadcast data. | |
| if data_iterator is not None: | |
| data = next(data_iterator) | |
| else: | |
| data = None | |
| data_b = mpu.broadcast_data(keys, data, datatype) | |
| # Unpack. | |
| tokens = data_b["text"].long() | |
| types = data_b["types"].long() | |
| sentence_order = data_b["is_random"].long() | |
| loss_mask = data_b["loss_mask"].float() | |
| lm_labels = data_b["labels"].long() | |
| padding_mask = data_b["padding_mask"].long() | |
| return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | |
| def get_batch_transformer(data_iterator): | |
| """Build the batch.""" | |
| data = next(data_iterator) | |
| data = send_to_device(data, torch.cuda.current_device()) | |
| # Unpack. | |
| tokens = data["input_ids"].long() | |
| padding_mask = data["attention_mask"].long() | |
| if "token_type_ids" in data: | |
| types = data["token_type_ids"].long() | |
| else: | |
| types = None | |
| if "labels" in data: | |
| lm_labels = data["labels"].long() | |
| loss_mask = (data["labels"] != -100).to(torch.float) | |
| else: | |
| lm_labels = None | |
| loss_mask = None | |
| if "next_sentence_label" in data: | |
| sentence_order = data["next_sentence_label"].long() | |
| else: | |
| sentence_order = None | |
| return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask | |
| if megatron_dataset_flag: | |
| return get_batch_megatron | |
| else: | |
| return get_batch_transformer | |
| def get_loss_func(self, pretraining_flag, num_labels): | |
| def loss_func_pretrain(loss_mask, sentence_order, output_tensor): | |
| lm_loss_, sop_logits = output_tensor | |
| lm_loss_ = lm_loss_.float() | |
| loss_mask = loss_mask.float() | |
| lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() | |
| if sop_logits is not None: | |
| sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) | |
| sop_loss = sop_loss.float() | |
| loss = lm_loss + sop_loss | |
| averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss]) | |
| return loss, {"lm loss": averaged_losses[0], "sop loss": averaged_losses[1]} | |
| else: | |
| loss = lm_loss | |
| averaged_losses = average_losses_across_data_parallel_group([lm_loss]) | |
| return loss, {"lm loss": averaged_losses[0]} | |
| def loss_func_finetune(labels, logits): | |
| if num_labels == 1: | |
| # We are doing regression | |
| loss_fct = MSELoss() | |
| loss = loss_fct(logits.view(-1), labels.view(-1)) | |
| elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)): | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, num_labels), labels.view(-1)) | |
| else: | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| averaged_losses = average_losses_across_data_parallel_group([loss]) | |
| return loss, {"loss": averaged_losses[0]} | |
| if pretraining_flag: | |
| return loss_func_pretrain | |
| else: | |
| return loss_func_finetune | |
| def get_forward_step_func(self, pretraining_flag, bert_binary_head): | |
| def forward_step(data_iterator, model): | |
| """Forward step.""" | |
| tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator) | |
| if not bert_binary_head: | |
| types = None | |
| # Forward pass through the model. | |
| if pretraining_flag: | |
| output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels) | |
| return output_tensor, partial(self.loss_func, loss_mask, sentence_order) | |
| else: | |
| logits = model(tokens, padding_mask, tokentype_ids=types) | |
| return logits, partial(self.loss_func, labels) | |
| return forward_step | |
| class GPTTrainStep(AbstractTrainStep): | |
| """ | |
| GPT train step class. | |
| Args: | |
| args (`argparse.Namespace`): Megatron-LM arguments. | |
| """ | |
| def __init__(self, args): | |
| super().__init__("GPTTrainStep") | |
| self.get_batch = self.get_batch_func(args.megatron_dataset_flag) | |
| self.loss_func = self.get_loss_func() | |
| self.forward_step = self.get_forward_step_func() | |
| self.eod_token = args.padded_vocab_size - 1 | |
| if args.vocab_file is not None: | |
| tokenizer = get_tokenizer() | |
| self.eod_token = tokenizer.eod | |
| self.reset_position_ids = args.reset_position_ids | |
| self.reset_attention_mask = args.reset_attention_mask | |
| self.eod_mask_loss = args.eod_mask_loss | |
| if not args.model_return_dict: | |
| self.model_output_class = None | |
| else: | |
| self.model_output_class = CausalLMOutputWithCrossAttentions | |
| def get_batch_func(self, megatron_dataset_flag): | |
| def get_batch_megatron(data_iterator): | |
| """Generate a batch""" | |
| # Items and their type. | |
| keys = ["text"] | |
| datatype = torch.int64 | |
| # Broadcast data. | |
| if data_iterator is not None: | |
| data = next(data_iterator) | |
| else: | |
| data = None | |
| data_b = mpu.broadcast_data(keys, data, datatype) | |
| # Unpack. | |
| tokens_ = data_b["text"].long() | |
| labels = tokens_[:, 1:].contiguous() | |
| tokens = tokens_[:, :-1].contiguous() | |
| # Get the masks and postition ids. | |
| attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( | |
| tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss | |
| ) | |
| return tokens, labels, loss_mask, attention_mask, position_ids | |
| def get_batch_transformer(data_iterator): | |
| data = next(data_iterator) | |
| data = {"input_ids": data["input_ids"]} | |
| data = send_to_device(data, torch.cuda.current_device()) | |
| tokens_ = data["input_ids"].long() | |
| padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token | |
| tokens_ = torch.concat([tokens_, padding], dim=1) | |
| labels = tokens_[:, 1:].contiguous() | |
| tokens = tokens_[:, :-1].contiguous() | |
| # Get the masks and postition ids. | |
| attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( | |
| tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, True | |
| ) | |
| return tokens, labels, loss_mask, attention_mask, position_ids | |
| if megatron_dataset_flag: | |
| return get_batch_megatron | |
| else: | |
| return get_batch_transformer | |
| def get_loss_func(self): | |
| args = get_args() | |
| def loss_func(loss_mask, output_tensor): | |
| if args.return_logits: | |
| losses, logits = output_tensor | |
| else: | |
| losses = output_tensor | |
| losses = losses.float() | |
| loss_mask = loss_mask.view(-1).float() | |
| loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() | |
| # Reduce loss for logging. | |
| averaged_loss = average_losses_across_data_parallel_group([loss]) | |
| output_dict = {"lm loss": averaged_loss[0]} | |
| if args.return_logits: | |
| output_dict.update({"logits": logits}) | |
| return loss, output_dict | |
| return loss_func | |
| def get_forward_step_func(self): | |
| def forward_step(data_iterator, model): | |
| """Forward step.""" | |
| # Get the batch. | |
| tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator) | |
| output_tensor = model(tokens, position_ids, attention_mask, labels=labels) | |
| return output_tensor, partial(self.loss_func, loss_mask) | |
| return forward_step | |
| class T5TrainStep(AbstractTrainStep): | |
| """ | |
| T5 train step class. | |
| Args: | |
| args (`argparse.Namespace`): Megatron-LM arguments. | |
| """ | |
| def __init__(self, args): | |
| super().__init__("T5TrainStep") | |
| self.get_batch = self.get_batch_func(args.megatron_dataset_flag) | |
| self.loss_func = self.get_loss_func() | |
| self.forward_step = self.get_forward_step_func() | |
| if not args.model_return_dict: | |
| self.model_output_class = None | |
| else: | |
| self.model_output_class = Seq2SeqLMOutput | |
| def attn_mask_postprocess(attention_mask): | |
| # We create a 3D attention mask from a 2D tensor mask. | |
| # [b, 1, s] | |
| attention_mask_b1s = attention_mask.unsqueeze(1) | |
| # [b, s, 1] | |
| attention_mask_bs1 = attention_mask.unsqueeze(2) | |
| # [b, s, s] | |
| attention_mask_bss = attention_mask_b1s * attention_mask_bs1 | |
| # Convert attention mask to binary: | |
| extended_attention_mask = attention_mask_bss < 0.5 | |
| return extended_attention_mask | |
| def get_decoder_mask(seq_length, device): | |
| attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device)) | |
| attention_mask = attention_mask < 0.5 | |
| return attention_mask | |
| def get_enc_dec_mask(attention_mask, dec_seq_length, device): | |
| batch_size, _ = attention_mask.shape | |
| # We create a 3D attention mask from a 2D tensor mask. | |
| # [b, 1, s] | |
| attention_mask_b1s = attention_mask.unsqueeze(1) | |
| # [b, s, 1] | |
| attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device) | |
| attention_mask_bss = attention_mask_bs1 * attention_mask_b1s | |
| extended_attention_mask = attention_mask_bss < 0.5 | |
| return extended_attention_mask | |
| def get_batch_func(self, megatron_dataset_flag): | |
| def get_batch_megatron(data_iterator): | |
| """Build the batch.""" | |
| keys = ["text_enc", "text_dec", "labels", "loss_mask", "enc_mask", "dec_mask", "enc_dec_mask"] | |
| datatype = torch.int64 | |
| # Broadcast data. | |
| if data_iterator is not None: | |
| data = next(data_iterator) | |
| else: | |
| data = None | |
| data_b = mpu.broadcast_data(keys, data, datatype) | |
| # Unpack. | |
| tokens_enc = data_b["text_enc"].long() | |
| tokens_dec = data_b["text_dec"].long() | |
| labels = data_b["labels"].long() | |
| loss_mask = data_b["loss_mask"].float() | |
| enc_mask = data_b["enc_mask"] < 0.5 | |
| dec_mask = data_b["dec_mask"] < 0.5 | |
| enc_dec_mask = data_b["enc_dec_mask"] < 0.5 | |
| return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask | |
| def get_batch_transformer(data_iterator): | |
| """Build the batch.""" | |
| data = next(data_iterator) | |
| data = send_to_device(data, torch.cuda.current_device()) | |
| tokens_enc = data["input_ids"].long() | |
| labels = data["labels"].long() | |
| loss_mask = (labels != -100).to(torch.float) | |
| if "decoder_input_ids" in data: | |
| tokens_dec = data["decoder_input_ids"].long() | |
| else: | |
| tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long) | |
| tokens_dec[..., 1:] = labels[..., :-1].clone() | |
| tokens_dec[..., 0] = 0 | |
| tokens_dec.masked_fill_(tokens_dec == -100, 0) | |
| enc_mask = T5TrainStep.attn_mask_postprocess(data["attention_mask"].long()) | |
| dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device) | |
| enc_dec_mask = T5TrainStep.get_enc_dec_mask( | |
| data["attention_mask"].long(), tokens_dec.shape[1], tokens_dec.device | |
| ) | |
| return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask | |
| if megatron_dataset_flag: | |
| return get_batch_megatron | |
| else: | |
| return get_batch_transformer | |
| def get_loss_func(self): | |
| def loss_func(loss_mask, output_tensor): | |
| lm_loss_ = output_tensor.float() | |
| lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() | |
| loss = lm_loss | |
| averaged_losses = average_losses_across_data_parallel_group([lm_loss]) | |
| return loss, {"lm loss": averaged_losses[0]} | |
| return loss_func | |
| def get_forward_step_func(self): | |
| def forward_step(data_iterator, model): | |
| """Forward step.""" | |
| # Get the batch. | |
| tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch( | |
| data_iterator | |
| ) | |
| # Forward model lm_labels | |
| output_tensor = model( | |
| tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels | |
| ) | |
| return output_tensor, partial(self.loss_func, loss_mask) | |
| return forward_step | |
| # intialize megatron setup | |
| def initialize(accelerator, extra_args_provider=None, args_defaults={}): | |
| accelerator.print("Initializing Megatron-LM") | |
| assert torch.cuda.is_available(), "Megatron requires CUDA." | |
| # Parse arguments | |
| args = parse_args(extra_args_provider, ignore_unknown_args=True) | |
| # Set defaults | |
| for key, value in args_defaults.items(): | |
| if getattr(args, key, None) is not None: | |
| if args.rank == 0: | |
| print( | |
| f"WARNING: overriding default arguments for " f"{key}:{getattr(args, key)} with {key}:{value}", | |
| flush=True, | |
| ) | |
| setattr(args, key, value) | |
| if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): | |
| assert args.load is not None, "--use-checkpoints-args requires --load argument" | |
| load_args_from_checkpoint(args) | |
| validate_args(args) | |
| # set global args, build tokenizer, and set adlr-autoresume, | |
| # tensorboard-writer, and timers. | |
| set_global_variables(args) | |
| # torch.distributed initialization | |
| def finish_mpu_init(): | |
| args = get_args() | |
| # Pytorch distributed. | |
| device_count = torch.cuda.device_count() | |
| args.rank = torch.distributed.get_rank() | |
| args.world_size = torch.distributed.get_world_size() | |
| if device_count > 0: | |
| device = args.rank % device_count | |
| if args.local_rank is not None: | |
| assert args.local_rank == device, "expected local-rank to be the same as rank % device-count." | |
| else: | |
| args.local_rank = device | |
| # Set the tensor model-parallel, pipeline model-parallel, and | |
| # data-parallel communicators. | |
| if mpu.model_parallel_is_initialized(): | |
| print("model parallel is already initialized") | |
| else: | |
| mpu.initialize_model_parallel( | |
| args.tensor_model_parallel_size, | |
| args.pipeline_model_parallel_size, | |
| args.virtual_pipeline_model_parallel_size, | |
| args.pipeline_model_parallel_split_rank, | |
| ) | |
| # Random seeds for reproducibility. | |
| if args.rank == 0: | |
| print(f"> setting random seeds to {args.seed} ...") | |
| _set_random_seed(args.seed, args.data_parallel_random_init) | |
| args = get_args() | |
| # Megatron's MPU is the master. Complete initialization right away. | |
| finish_mpu_init() | |
| # Autoresume. | |
| _init_autoresume() | |
| # Compile dependencies. | |
| _compile_dependencies() | |
| # Set pytorch JIT layer fusion options and warmup JIT functions. | |
| set_jit_fusion_options() | |
| args = get_args() | |
| args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args) | |
| if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2: | |
| args.bert_binary_head = True | |
| else: | |
| args.bert_binary_head = False | |
| args.iteration = 0 | |
| class MegatronEngine(torch.nn.Module): | |
| """ | |
| Megatron-LM model wrapper | |
| Args: | |
| accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use. | |
| model: Megatron-LM model | |
| optimizer: Megatron-LM optimizer | |
| lr_scheduler: Megatron-LM lr scheduler | |
| """ | |
| def __init__(self, accelerator, model, optimizer, scheduler): | |
| super().__init__() | |
| self.module = model | |
| self.base_model = model[0] | |
| self.optimizer = optimizer | |
| self.scheduler = scheduler | |
| args = get_args() | |
| if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None: | |
| self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class( | |
| args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs | |
| ) | |
| elif args.model_type_name == "bert": | |
| self.train_step_handler = BertTrainStep(args) | |
| elif args.model_type_name == "gpt": | |
| self.train_step_handler = GPTTrainStep(args) | |
| elif args.model_type_name == "t5": | |
| self.train_step_handler = T5TrainStep(args) | |
| else: | |
| raise ValueError(f"Unsupported model type: {args.model_type_name}") | |
| self.optimizer.skipped_iter = False | |
| # Tracking loss. | |
| self.total_loss_dict = {} | |
| self.eval_total_loss_dict = {} | |
| self.iteration = 0 | |
| self.report_memory_flag = True | |
| if args.tensorboard_dir is not None: | |
| write_args_to_tensorboard() | |
| def train(self): | |
| for model_module in self.module: | |
| model_module.train() | |
| self.log_eval_results() | |
| def eval(self): | |
| for model_module in self.module: | |
| model_module.eval() | |
| def train_step(self, **batch_data): | |
| """ | |
| Training step for Megatron-LM | |
| Args: | |
| batch_data (:obj:`dict`): The batch data to train on. | |
| """ | |
| args = get_args() | |
| timers = get_timers() | |
| if len(batch_data) > 0: | |
| data_chunks = [] | |
| if args.num_micro_batches > 1: | |
| for i in range(0, args.num_micro_batches): | |
| data_chunks.append( | |
| { | |
| k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size] | |
| for k, v in batch_data.items() | |
| } | |
| ) | |
| else: | |
| data_chunks = [batch_data] | |
| if len(self.module) > 1: | |
| batch_data_iterator = ( | |
| [iter(data_chunks) for _ in range(len(self.module))] | |
| if len(batch_data) > 0 | |
| else [None] * len(self.module) | |
| ) | |
| else: | |
| batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None | |
| # Set grad to zero. | |
| if args.DDP_impl == "local" and args.use_contiguous_buffers_in_local_ddp: | |
| for partition in self.module: | |
| partition.zero_grad_buffer() | |
| self.optimizer.zero_grad() | |
| # Forward pass. | |
| forward_backward_func = get_forward_backward_func() | |
| losses_reduced = forward_backward_func( | |
| self.train_step_handler.forward_step, | |
| batch_data_iterator, | |
| self.module, | |
| self.optimizer, | |
| None, | |
| forward_only=False, | |
| ) | |
| # Empty unused memory. | |
| if args.empty_unused_memory_level >= 1: | |
| torch.cuda.empty_cache() | |
| # Reduce gradients. | |
| timers("backward-reduce-model-grads").start() | |
| self.optimizer.reduce_model_grads(args, timers) | |
| timers("backward-reduce-model-grads").stop() | |
| # Update parameters. | |
| timers("optimizer").start() | |
| update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(args, timers) | |
| timers("optimizer").stop() | |
| # Gather params. | |
| if update_successful: | |
| timers("backward-gather-model-params").start() | |
| self.optimizer.gather_model_params(args, timers) | |
| timers("backward-gather-model-params").stop() | |
| # Update learning rate. | |
| if update_successful: | |
| if self.scheduler is not None: | |
| increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size | |
| self.scheduler.step(increment=increment) | |
| skipped_iter = 0 | |
| else: | |
| skipped_iter = 1 | |
| self.optimizer.skipped_iter = not update_successful | |
| # Empty unused memory. | |
| if args.empty_unused_memory_level >= 2: | |
| torch.cuda.empty_cache() | |
| args.consumed_train_samples += ( | |
| mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() | |
| ) | |
| if mpu.is_pipeline_last_stage(ignore_virtual=True): | |
| # Average loss across microbatches. | |
| loss_reduced = {} | |
| for key in losses_reduced[0]: | |
| losses_reduced_for_key = [x[key] for x in losses_reduced] | |
| if len(losses_reduced_for_key[0].shape) == 0: | |
| loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) | |
| else: | |
| loss_reduced[key] = torch.concat(losses_reduced_for_key) | |
| return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad | |
| return {}, skipped_iter, grad_norm, num_zeros_in_grad | |
| def eval_step(self, **batch_data): | |
| """ | |
| Evaluation step for Megatron-LM | |
| Args: | |
| batch_data (:obj:`dict`): The batch data to evaluate on. | |
| """ | |
| args = get_args() | |
| data_chunks = [] | |
| if args.num_micro_batches > 1: | |
| for i in range(0, args.num_micro_batches): | |
| data_chunks.append( | |
| {k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size] for k, v in batch_data.items()} | |
| ) | |
| else: | |
| data_chunks = [batch_data] | |
| if len(self.module) > 1: | |
| batch_data_iterator = [iter(data_chunks) for _ in range(len(self.module))] | |
| else: | |
| batch_data_iterator = iter(data_chunks) | |
| forward_backward_func = get_forward_backward_func() | |
| loss_dicts = forward_backward_func( | |
| self.train_step_handler.forward_step, | |
| batch_data_iterator, | |
| self.module, | |
| optimizer=None, | |
| timers=None, | |
| forward_only=True, | |
| ) | |
| # Empty unused memory | |
| if args.empty_unused_memory_level >= 1: | |
| torch.cuda.empty_cache() | |
| args.consumed_valid_samples += ( | |
| mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() | |
| ) | |
| if mpu.is_pipeline_last_stage(ignore_virtual=True): | |
| # Average loss across microbatches. | |
| loss_reduced = {} | |
| for key in loss_dicts[0]: | |
| losses_reduced_for_key = [x[key] for x in loss_dicts] | |
| if len(losses_reduced_for_key[0].shape) == 0: | |
| loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) | |
| else: | |
| loss_reduced[key] = torch.concat(losses_reduced_for_key) | |
| return loss_reduced | |
| else: | |
| return {} | |
| def forward(self, **batch_data): | |
| # During training, we use train_step() | |
| # model(**batch_data) performs following operations by delegating it to `self.train_step`: | |
| # 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism | |
| # 2. Set grad to zero. | |
| # 3. forward pass and backward pass using Pipeline Parallelism | |
| # 4. Empty unused memory. | |
| # 5. Reduce gradients. | |
| # 6. Update parameters. | |
| # 7. Gather params when using Distributed Optimizer (Data Parallelism). | |
| # 8. Update learning rate if scheduler is specified. | |
| # 9. Empty unused memory. | |
| # 10. Average loss across microbatches and across DP ranks. | |
| # | |
| # During evaluation, we use eval_step() | |
| args = get_args() | |
| if self.module[0].training: | |
| loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data) | |
| self.iteration += 1 | |
| if args.tensorboard_dir is not None: | |
| # Logging. | |
| loss_scale = self.optimizer.get_loss_scale().item() | |
| params_norm = None | |
| if args.log_params_norm: | |
| params_norm = calc_params_l2_norm(self.model) | |
| self.report_memory_flag = training_log( | |
| loss_dict, | |
| self.total_loss_dict, | |
| self.optimizer.param_groups[0]["lr"], | |
| self.iteration, | |
| loss_scale, | |
| self.report_memory_flag, | |
| skipped_iter, | |
| grad_norm, | |
| params_norm, | |
| num_zeros_in_grad, | |
| ) | |
| else: | |
| loss_dict = self.eval_step(**batch_data) | |
| if args.tensorboard_dir is not None: | |
| for key in loss_dict: | |
| self.eval_total_loss_dict[key] = ( | |
| self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] | |
| ) | |
| self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get( | |
| key + "_num_iters", torch.cuda.FloatTensor([0.0]) | |
| ) + torch.cuda.FloatTensor([1.0]) | |
| loss = torch.tensor(0.0, device=args.local_rank) | |
| for key in loss_dict: | |
| if len(loss_dict[key].shape) == 0: | |
| loss += loss_dict[key] | |
| logits = None | |
| if "logits" in loss_dict: | |
| logits = loss_dict["logits"] | |
| # loss = reduce(loss) | |
| if self.train_step_handler.model_output_class is not None: | |
| return self.train_step_handler.model_output_class(loss=loss, logits=logits) | |
| return loss | |
| def log_eval_results(self): | |
| args = get_args() | |
| if args.tensorboard_dir is None or self.iteration == 0: | |
| return | |
| args = get_args() | |
| writer = get_tensorboard_writer() | |
| string = f"validation loss at iteration {self.iteration} | " | |
| for key in self.eval_total_loss_dict: | |
| if key.endswith("_num_iters"): | |
| continue | |
| value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + "_num_iters"] | |
| string += f"{key} value: {value} | " | |
| ppl = math.exp(min(20, value.item())) | |
| if args.pretraining_flag: | |
| string += f"{key} PPL: {ppl} | " | |
| if writer: | |
| writer.add_scalar(f"{key} validation", value.item(), self.iteration) | |
| if args.pretraining_flag: | |
| writer.add_scalar(f"{key} validation ppl", ppl, self.iteration) | |
| length = len(string) + 1 | |
| print_rank_last("-" * length) | |
| print_rank_last(string) | |
| print_rank_last("-" * length) | |
| self.eval_total_loss_dict = {} | |
| def save_checkpoint(self, output_dir): | |
| self.log_eval_results() | |
| args = get_args() | |
| args.save = output_dir | |
| torch.distributed.barrier() | |
| save_checkpoint(self.iteration, self.module, self.optimizer, self.scheduler) | |
| torch.distributed.barrier() | |
| def load_checkpoint(self, input_dir): | |
| args = get_args() | |
| args.load = input_dir | |
| args.consumed_train_samples = 0 | |
| args.consumed_valid_samples = 0 | |
| torch.distributed.barrier() | |
| iteration = load_checkpoint(self.module, self.optimizer, self.scheduler) | |
| torch.distributed.barrier() | |
| self.iteration = iteration | |
| if args.fp16 and self.iteration == 0: | |
| self.optimizer.reload_model_params() | |
| def megatron_generate( | |
| self, | |
| inputs, | |
| attention_mask=None, | |
| max_length=None, | |
| max_new_tokens=None, | |
| num_beams=None, | |
| temperature=None, | |
| top_k=None, | |
| top_p=None, | |
| length_penalty=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Generate method for GPT2 model. This method is used for inference. Supports both greedy and beam search along | |
| with sampling. Refer the Megatron-LM repo for more details | |
| Args: | |
| inputs (torch.Tensor): input ids | |
| attention_mask (torch.Tensor, optional): attention mask. Defaults to None. | |
| max_length (int, optional): max length of the generated sequence. Defaults to None. | |
| Either this or max_new_tokens should be provided. | |
| max_new_tokens (int, optional): max number of tokens to be generated. Defaults to None. | |
| Either this or max_length should be provided. | |
| num_beams (int, optional): number of beams to use for beam search. Defaults to None. | |
| temperature (float, optional): temperature for sampling. Defaults to 1.0. | |
| top_k (int, optional): top k tokens to consider for sampling. Defaults to 0.0. | |
| top_p (float, optional): tokens in top p probability are considered for sampling. Defaults to 0.0. | |
| length_penalty (float, optional): length penalty for beam search. Defaults to None. | |
| kwargs: additional key-value arguments | |
| """ | |
| # checking if required arguments are passed | |
| args = get_args() | |
| if args.model_type_name != "gpt": | |
| raise NotImplementedError("Generate method is not implemented for this model") | |
| if args.data_parallel_size > 1: | |
| raise ValueError("Generate method requires data parallelism to be 1") | |
| if args.sequence_parallel: | |
| raise ValueError("Generate method requires sequence parallelism to be False") | |
| if args.recompute_granularity is not None: | |
| raise ValueError("Checkpoint activations cannot be set for inference") | |
| if args.vocab_file is None: | |
| raise ValueError("Vocab file is required for inference") | |
| # Prepare inputs | |
| if max_length is None and max_new_tokens is None: | |
| raise ValueError("`max_length` or `max_new_tokens` are required for inference") | |
| if temperature is None: | |
| temperature = 1.0 | |
| elif not (0.0 < temperature <= 100.0): | |
| raise ValueError("temperature must be a positive number less than or equal to 100.0") | |
| if top_k is None: | |
| top_k = 0 | |
| elif not (0 <= top_k <= 1000): | |
| raise ValueError("top_k must be a positive number less than or equal to 1000") | |
| if top_p is None: | |
| top_p = 0.0 | |
| elif top_p > 0.0 and top_k > 0.0: | |
| raise ValueError("top_p and top_k sampling cannot be set together") | |
| else: | |
| if not (0.0 <= top_p <= 1.0): | |
| raise ValueError("top_p must be less than or equal to 1.0") | |
| top_p_decay = kwargs.get("top_p_decay", 0.0) | |
| if not (0.0 <= top_p_decay <= 1.0): | |
| raise ValueError("top_p_decay must be less than or equal to 1.0") | |
| top_p_bound = kwargs.get("top_p_bound", 0.0) | |
| if not (0.0 <= top_p_bound <= 1.0): | |
| raise ValueError("top_p_bound must be less than or equal to 1.0") | |
| add_BOS = kwargs.get("add_BOS", False) | |
| if not (isinstance(add_BOS, bool)): | |
| raise ValueError("add_BOS must be a boolean") | |
| beam_width = num_beams | |
| if beam_width is not None: | |
| if not isinstance(beam_width, int): | |
| raise ValueError("beam_width must be an integer") | |
| if beam_width < 1: | |
| raise ValueError("beam_width must be greater than 0") | |
| if inputs.shape[0] > 1: | |
| return "When doing beam_search, batch size must be 1" | |
| tokenizer = get_tokenizer() | |
| stop_token = kwargs.get("stop_token", tokenizer.eod) | |
| if stop_token is not None: | |
| if not isinstance(stop_token, int): | |
| raise ValueError("stop_token must be an integer") | |
| if length_penalty is None: | |
| length_penalty = 1.0 | |
| sizes_list = None | |
| prompts_tokens_tensor = None | |
| prompts_length_tensor = None | |
| if torch.distributed.get_rank() == 0: | |
| # Get the prompts length. | |
| if attention_mask is None: | |
| prompts_length_tensor = torch.cuda.LongTensor([inputs.shape[1]] * inputs.shape[0]) | |
| else: | |
| prompts_length_tensor = attention_mask.sum(axis=-1).cuda() | |
| if max_new_tokens is None: | |
| max_new_tokens = max_length - inputs.shape[1] | |
| if max_new_tokens <= 0: | |
| raise ValueError("max_new_tokens must be greater than 0") | |
| if add_BOS: | |
| max_length = max_new_tokens + inputs.shape[1] + 1 | |
| # making sure that `max_length` is a multiple of 4 to leverage fused kernels | |
| max_length = 4 * math.ceil(max_length / 4) | |
| max_new_tokens = max_length - (inputs.shape[1] + 1) | |
| padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0]) | |
| prompts_tokens_tensor = torch.concat( | |
| [torch.unsqueeze(padding[:, 0], axis=-1), inputs.cuda(), padding], axis=-1 | |
| ) | |
| else: | |
| # making sure that `max_length` is a multiple of 4 to leverage fused kernels | |
| max_length = max_new_tokens + inputs.shape[1] | |
| max_length = 4 * math.ceil(max_length / 4) | |
| max_new_tokens = max_length - inputs.shape[1] | |
| padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0]) | |
| prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1) | |
| # We need the sizes of these tensors for the boradcast | |
| sizes_list = [ | |
| prompts_tokens_tensor.size(0), # Batch size | |
| prompts_tokens_tensor.size(1), | |
| ] # Sequence lenght | |
| # First, broadcast the sizes. | |
| sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0) | |
| # Now that we have the sizes, we can boradcast the tokens | |
| # and length tensors. | |
| sizes = sizes_tensor.tolist() | |
| context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0) | |
| context_length_tensor = broadcast_tensor(sizes[0], torch.int64, tensor=prompts_length_tensor, rank=0) | |
| # Run the inference | |
| random_seed = kwargs.get("random_seed", 0) | |
| torch.random.manual_seed(random_seed) | |
| unwrapped_model = unwrap_model(self.base_model, (torchDDP, LocalDDP, Float16Module)) | |
| if beam_width is not None: | |
| tokens, _ = beam_search_and_return_on_first_stage( | |
| unwrapped_model, | |
| context_tokens_tensor, | |
| context_length_tensor, | |
| beam_width, | |
| stop_token=stop_token, | |
| num_return_gen=1, | |
| length_penalty=length_penalty, | |
| ) | |
| else: | |
| tokens, _, _ = generate_tokens_probs_and_return_on_first_stage( | |
| unwrapped_model, | |
| context_tokens_tensor, | |
| context_length_tensor, | |
| return_output_log_probs=False, | |
| top_k=top_k, | |
| top_p=top_p, | |
| top_p_decay=top_p_decay, | |
| top_p_bound=top_p_bound, | |
| temperature=temperature, | |
| use_eod_token_for_early_termination=True, | |
| ) | |
| return tokens | |
| # other utilities | |
| def avg_losses_across_data_parallel_group(losses): | |
| """ | |
| Average losses across data parallel group. | |
| Args: | |
| losses (List[Tensor]): List of losses to average across data parallel group. | |
| """ | |
| return average_losses_across_data_parallel_group(losses) | |
| def gather_across_data_parallel_groups(tensor): | |
| """ | |
| Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks. | |
| Args: | |
| tensor (nested list/tuple/dictionary of `torch.Tensor`): | |
| The data to gather across data parallel ranks. | |
| """ | |
| def _gpu_gather_one(tensor): | |
| if tensor.ndim == 0: | |
| tensor = tensor.clone()[None] | |
| output_tensors = [ | |
| torch.empty_like(tensor) | |
| for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group())) | |
| ] | |
| torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group()) | |
| return torch.cat(output_tensors, dim=0) | |
| return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) | |