# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. import os import time import logging import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, HfArgumentParser, default_data_collator, ) import wandb from peft import LoraConfig, get_peft_model from core.arguments import parse_args from core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from core.datasets.gpt_dataset import GPTDatasetConfig, GPTDataset logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) _GLOBAL_TOKENIZER = None def is_dataset_built_on_rank(): # return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0 return True def core_gpt_dataset_config_from_args(args): return GPTDatasetConfig( is_built_on_rank=is_dataset_built_on_rank, random_seed=args.seed, sequence_length=args.seq_length, blend=args.data_path, blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path], split=args.split, path_to_cache=args.data_cache_path, return_document_ids=args.retro_return_doc_ids, reset_position_ids=args.reset_position_ids, reset_attention_mask=args.reset_attention_mask, eod_mask_loss=args.eod_mask_loss, eod_id=_GLOBAL_TOKENIZER.vocab[''], enable_shuffle=args.enable_shuffle, ) def _build_tokenizer(args): """Initialize tokenizer.""" global _GLOBAL_TOKENIZER logger.info(f"Loading tokenizer from {args.model_name_or_path}") _GLOBAL_TOKENIZER = AutoTokenizer.from_pretrained( args.model_name_or_path, model_max_length=args.model_max_length, padding_side="right") return _GLOBAL_TOKENIZER def build_train_valid_test_datasets(args): """Build the train, validation, and test datasets.""" # Number of train/valid/test samples. if args.train_samples: train_samples = args.train_samples else: train_samples = args.train_iters * args.global_batch_size eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters test_iters = args.eval_iters train_val_test_num_samples = [train_samples, eval_iters * args.global_batch_size, test_iters * args.global_batch_size] logger.info("> Building train, validation, and test datasets...") try: train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( GPTDataset, train_val_test_num_samples, core_gpt_dataset_config_from_args(args) ).build() logger.info("> Finished creating datasets") return train_ds, valid_ds, test_ds except Exception as e: logger.error(f"Failed to build datasets: {e}") raise def _compile_dependencies(): """Compile dataset C++ code.""" if torch.distributed.get_rank() == 0: start_time = time.time() logger.info("> Compiling dataset index builder...") try: from core.datasets.utils import compile_helpers compile_helpers() logger.info( f">>> Done with dataset index builder. Compilation time: {time.time() - start_time:.3f} seconds" ) except Exception as e: logger.error(f"Failed to compile helpers: {e}") raise def setup_distributed_training(): """Setup distributed training environment.""" try: # Initialize process group for distributed training local_rank = int(os.environ.get("LOCAL_RANK", "0")) world_size = int(os.environ.get("WORLD_SIZE", "1")) if world_size > 1: # Multi-GPU setup torch.cuda.set_device(local_rank) if not torch.distributed.is_initialized(): torch.distributed.init_process_group(backend="nccl") logger.info(f"Distributed training initialized with world size: {world_size}, local rank: {local_rank}") else: # Single GPU setup logger.info(f"Running on a single GPU (device {local_rank})") torch.cuda.set_device(local_rank) return local_rank except Exception as e: logger.error(f"Failed to setup distributed training: {e}") raise def create_and_configure_model(args): """Create and configure the model with LoRA.""" try: if args.fp16: assert not args.bf16 args.params_dtype = torch.half if args.bf16: assert not args.fp16 args.params_dtype = torch.bfloat16 logger.info(f"Loading base model from {args.model_name_or_path}") model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, torch_dtype=args.params_dtype, cache_dir=args.cache_dir ) logger.info(f"Configuring LoRA with r={args.lora_r}, alpha={args.lora_alpha}") lora_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, target_modules=args.lora_target_modules, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM" ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Number of trainable parameters: {trainable_params:,}") return model except Exception as e: logger.error(f"Failed to create and configure model: {e}") raise def main(): # Setup distributed training local_rank = setup_distributed_training() # Compile dependencies after initializing distributed group _compile_dependencies() # Parse arguments args = parse_args() # Build tokenizer _build_tokenizer(args) # Build datasets train_ds, valid_ds, test_ds = build_train_valid_test_datasets(args) # Create and configure model model = create_and_configure_model(args) # Setup training arguments parser = HfArgumentParser(TrainingArguments) training_args = parser.parse_dict(args.__dict__, allow_extra_keys=True)[0] # Initialize wandb if specified is_main_process = torch.distributed.get_rank() == 0 if args.report_to == "wandb" and is_main_process: try: wandb.init( project=args.wandb_project or "YuE-finetune", config=vars(args), name=args.run_name ) except Exception as e: logger.warning(f"Failed to initialize wandb: {e}. Continuing without wandb.") # Create trainer trainer = Trainer( model=model, tokenizer=_GLOBAL_TOKENIZER, args=training_args, train_dataset=train_ds, eval_dataset=valid_ds, data_collator=default_data_collator, ) # Start training logger.info("Starting training...") trainer.train() # Save model and tokenizer if is_main_process: logger.info(f"Saving model to {training_args.output_dir}") trainer.save_model(training_args.output_dir) _GLOBAL_TOKENIZER.save_pretrained(training_args.output_dir) logger.info("Training completed successfully") if __name__ == "__main__": try: main() except Exception as e: logger.error(f"Training failed with error: {e}") raise