Spaces:
Runtime error
Runtime error
| import time | |
| from contextlib import suppress | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import datetime | |
| import os | |
| import gc | |
| from torch.distributed.fsdp import ( | |
| FullyShardedDataParallel as FSDP, | |
| MixedPrecision, | |
| BackwardPrefetch, | |
| ShardingStrategy, | |
| FullStateDictConfig, | |
| StateDictType, | |
| ) | |
| from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler | |
| from torch.distributed.fsdp.wrap import ( | |
| transformer_auto_wrap_policy, | |
| enable_wrap, | |
| wrap, | |
| ) | |
| from torch.utils.tensorboard import SummaryWriter | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s %(message)s', | |
| datefmt='%m/%d %I:%M:%S', | |
| ) | |
| def get_cast_dtype(precision: str): | |
| cast_dtype = None | |
| if precision == "bf16": | |
| cast_dtype = torch.bfloat16 | |
| elif precision == "fp16": | |
| cast_dtype = torch.float16 | |
| return cast_dtype | |
| def get_autocast(precision): | |
| if precision == "amp_fp16": | |
| return lambda: torch.cuda.amp.autocast(dtype=torch.float16) | |
| elif precision == "amp_bfloat16" or precision == "amp_bf16": | |
| # amp_bfloat16 is more stable than amp float16 for clip training | |
| return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) | |
| else: | |
| return suppress | |
| def get_sync(model, flag): | |
| if flag: | |
| return suppress | |
| else: | |
| return lambda: model.no_sync() | |
| def train_one_epoch( | |
| args, | |
| model, | |
| laion_loader, | |
| pile_loader, | |
| tokenizer, | |
| optimizer, | |
| lr_scheduler, | |
| device_id, | |
| writer: SummaryWriter, | |
| optim_groups, | |
| scaler, | |
| total_laion_token: int, | |
| total_pile_token: int, | |
| total_laion_sample: int, | |
| total_step: int, | |
| ): | |
| world_size = torch.distributed.get_world_size() | |
| autocast = get_autocast(args.precision) | |
| cast_dtype = get_cast_dtype(args.precision) | |
| media_token_id = tokenizer("<|#image#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofmedia_token_id = tokenizer("<|#endofimage#|>", add_special_tokens=False)["input_ids"][-1] | |
| visual_token_id = tokenizer("<|#visual#|>", add_special_tokens=False)["input_ids"][-1] | |
| if args.add_box: | |
| box_token_id = tokenizer("<|#box#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofobject_token_id = tokenizer("<|#endofobject#|>", add_special_tokens=False)["input_ids"][-1] | |
| endofattr_token_id = tokenizer("<|#endofattr#|>", add_special_tokens=False)["input_ids"][-1] | |
| if args.use_format_v2: | |
| prebox_token_id = tokenizer("<|#prebox#|>", add_special_tokens=False)["input_ids"][-1] | |
| previsual_token_id = tokenizer("<|#previsual#|>", add_special_tokens=False)["input_ids"][-1] | |
| if args.rank == 0: | |
| logging.info(f"train from: {total_step} step") | |
| model.train() | |
| # loop through dataloader | |
| last_logging_step = total_step | |
| last_save_step = total_step | |
| for num_steps, (batch_laion, batch_pile) in tqdm( | |
| enumerate(zip(laion_loader, pile_loader)), | |
| disable=args.rank != 0 or "SLURM_PROCID" in os.environ, | |
| total=args.num_steps * args.gradient_accumulation_steps, | |
| initial=total_step * args.gradient_accumulation_steps, | |
| ): | |
| #### LAION FORWARD PASS #### | |
| images = ( | |
| batch_laion[0] | |
| .to(device_id, dtype=cast_dtype, non_blocking=True) | |
| .unsqueeze(1) | |
| .unsqueeze(1) | |
| ) | |
| image_nums = batch_laion[1] | |
| image_start_index_list = batch_laion[2] | |
| # TODO: OPT model: input_ids is not started with </s> while input_ids2 is? | |
| input_ids = batch_laion[3].to(device_id, non_blocking=True).long() | |
| attention_mask = batch_laion[4].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| added_bbox_list = [x.to(device_id) for x in batch_laion[5]] # list object | |
| total_laion_token += int(attention_mask.sum().long()) * world_size | |
| total_laion_sample += sum(image_nums) * world_size | |
| labels = input_ids.clone() | |
| if args.add_box: | |
| labels[input_ids == visual_token_id] = -100 | |
| labels[input_ids == box_token_id] = -100 | |
| labels[input_ids == endofattr_token_id] = -100 | |
| if args.use_format_v2: | |
| labels[input_ids == previsual_token_id] = -100 | |
| labels[input_ids == prebox_token_id] = -100 | |
| labels[torch.roll(input_ids == prebox_token_id, 1)] = -100 | |
| labels[torch.roll(input_ids == box_token_id, 1)] = -100 | |
| labels[:, 0] = -100 | |
| labels[input_ids == tokenizer.pad_token_id] = -100 | |
| labels[input_ids == media_token_id] = -100 | |
| labels[input_ids == endofmedia_token_id] = -100 | |
| labels.to(device_id) | |
| current_laion_num = input_ids.shape[0] | |
| #### PILE FORWARD PASS #### | |
| if batch_pile is not None and batch_pile[0] is not None and batch_pile[1] is not None: | |
| input_ids2 = batch_pile[0].to(device_id, non_blocking=True).long() | |
| attention_mask2 = batch_pile[1].to(device_id, dtype=cast_dtype, non_blocking=True) | |
| input_length = input_ids.shape[-1] | |
| input_ids2 = torch.cat([input_ids2, torch.ones((input_ids2.shape[0], input_length - input_ids2.shape[1]), device=input_ids2.device, dtype=input_ids2.dtype) * tokenizer.pad_token_id], dim=-1) | |
| attention_mask2 = torch.cat([attention_mask2, torch.zeros((attention_mask2.shape[0], input_length - attention_mask2.shape[1]), device=attention_mask2.device, dtype=attention_mask2.dtype)], dim=-1) | |
| labels2 = input_ids2.clone() | |
| labels2[labels2 == tokenizer.pad_token_id] = -100 | |
| labels2[:, 0] = -100 | |
| labels2.to(device_id) | |
| if (num_steps != 0 and num_steps % args.pile_freq == 0) or args.pile_freq == 1: | |
| image_nums = image_nums + [0] * len(input_ids2) | |
| image_start_index_list = image_start_index_list + [[]] * len(input_ids2) | |
| input_ids = torch.cat([input_ids, input_ids2], dim=0) | |
| attention_mask = torch.cat([attention_mask, attention_mask2], dim=0) | |
| labels = torch.cat([labels, labels2], dim=0) | |
| total_pile_token += int(attention_mask2.sum().long()) * world_size | |
| else: | |
| del input_ids2 | |
| del attention_mask2 | |
| del labels2 | |
| if args.instruct: | |
| answer_token_id = tokenizer(" Answer").input_ids[0] | |
| answer_token_loc = (input_ids == answer_token_id).nonzero() | |
| for batch_idx, idx in answer_token_loc: | |
| labels[batch_idx][:idx+2] = -100 | |
| if args.relation and not args.instruct: | |
| relations = batch_laion[6] | |
| else: | |
| relations = None | |
| if len(added_bbox_list) == 0: | |
| added_bbox_list = None | |
| update_flag = (num_steps != 0 and num_steps % args.gradient_accumulation_steps == 0) or args.gradient_accumulation_steps == 1 | |
| # do_sync = get_sync(model, update_flag) | |
| with autocast(): | |
| # modify: | |
| # /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/codegen/modeling_codegen.py | |
| # /gpfs/u/home/LMCG/LMCGljnn/scratch/miniconda3-ppc64le/envs/unified/lib/python3.9/site-packages/transformers/models/opt/modeling_opt.py | |
| # CrossEntropyLoss(reduction="none") | |
| outputs = model( | |
| vision_x=images, | |
| lang_x=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| image_nums=image_nums, | |
| image_start_index_list=image_start_index_list, | |
| added_bbox_list=added_bbox_list, | |
| add_box=args.add_box, | |
| relations=relations, | |
| ) | |
| loss_total = outputs.loss.reshape(labels.shape[0], -1) | |
| loss_sample = loss_total.sum(-1) / (loss_total != 0).sum(-1) | |
| loss_sample_for_laion = loss_sample[:current_laion_num] | |
| nan_mask = torch.isnan(loss_sample_for_laion) | |
| if nan_mask.sum() > 0: | |
| logging.warning(f"caption NaN: {nan_mask}") | |
| if nan_mask.sum() == len(loss_sample_for_laion) or not model.valid: | |
| logging.info("WARNING: skip this caption loss due to some error") | |
| loss_laion = torch.tensor(0.0).cuda() | |
| else: | |
| loss_laion = loss_sample_for_laion[~nan_mask].mean() | |
| loss_caption = loss_laion | |
| divided_loss_laion = loss_laion / args.gradient_accumulation_steps | |
| if current_laion_num != loss_sample.shape[0]: | |
| loss_pile = loss_sample[current_laion_num:].mean() | |
| else: | |
| loss_pile = torch.tensor(0.0).cuda() | |
| divided_loss_pile = loss_pile / args.gradient_accumulation_steps | |
| if "detection_losses" in outputs: | |
| loss_det = outputs["detection_losses"]["loss"] | |
| loss_iou = outputs["detection_losses"]["loss_iou"] | |
| loss_obj = outputs["detection_losses"]["loss_obj"] | |
| loss_cls = outputs["detection_losses"]["loss_cls"] | |
| else: | |
| loss_det = torch.tensor(0.0).cuda() | |
| loss_iou = torch.tensor(0.0).cuda() | |
| loss_obj = torch.tensor(0.0).cuda() | |
| loss_cls = torch.tensor(0.0).cuda() | |
| if "loss_dict" in outputs: | |
| visual_loss_iou = outputs["loss_dict"][0]["loss_iou"] | |
| previsual_loss_iou = outputs["loss_dict"][1]["loss_iou"] | |
| visual_loss_obj = outputs["loss_dict"][0]["loss_obj"] | |
| previsual_loss_obj = outputs["loss_dict"][1]["loss_obj"] | |
| else: | |
| visual_loss_iou = torch.tensor(0.0).cuda() | |
| previsual_loss_iou = torch.tensor(0.0).cuda() | |
| visual_loss_obj = torch.tensor(0.0).cuda() | |
| previsual_loss_obj = torch.tensor(0.0).cuda() | |
| divided_loss_det = loss_det / args.gradient_accumulation_steps | |
| loss_rel = outputs.get("rel_loss", torch.tensor(0.0).cuda()) | |
| divided_loss_rel = loss_rel / args.gradient_accumulation_steps | |
| loss = ( | |
| divided_loss_laion * args.loss_multiplier_laion + | |
| divided_loss_pile * args.loss_multiplier_pile + | |
| divided_loss_det * args.loss_multiplier_det + | |
| divided_loss_rel * args.loss_multiplier_rel | |
| ) | |
| scaler.scale(loss).backward() | |
| # for logging only | |
| loss = ( | |
| loss_laion * args.loss_multiplier_laion | |
| + loss_pile * args.loss_multiplier_pile | |
| + loss_det * args.loss_multiplier_det | |
| + loss_rel * args.loss_multiplier_rel | |
| ).detach() | |
| # step optimizer and log | |
| if update_flag: | |
| #### MASK GRADIENTS FOR EMBEDDINGS #### | |
| # Note (anas): Do not apply weight decay to embeddings as it will break this function. | |
| # ! not an important point | |
| # if args.ddp: | |
| # def mask_embedding(m): | |
| # if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: | |
| # zero_mask = torch.zeros_like(m.weight.grad) | |
| # zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) | |
| # zero_mask[endofmedia_token_id] = torch.ones_like(zero_mask[endofmedia_token_id]) | |
| # m.weight.grad = m.weight.grad * zero_mask | |
| # model.apply(mask_embedding) | |
| total_step += 1 | |
| scaler.unscale_(optimizer) | |
| if args.ddp: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| else: | |
| model.clip_grad_norm_(1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # https://github.com/facebookresearch/fairscale/issues/627 | |
| model.zero_grad(set_to_none=True) | |
| if args.rank == 0 and total_step % args.logging_steps == 0 and total_step != last_logging_step: | |
| last_logging_step = total_step | |
| global_step = total_step | |
| lr = optimizer.param_groups[0]["lr"] | |
| writer.add_scalar("lr", lr, global_step) | |
| writer.add_scalar("scale", scaler.get_scale(), global_step) | |
| writer.add_scalar("loss_groundcaption", loss_laion.item(), global_step) | |
| writer.add_scalar("loss_laion", loss_caption.item(), global_step) | |
| writer.add_scalar("loss_pile", loss_pile.item(), global_step) | |
| writer.add_scalar("loss", loss.item(), global_step) | |
| writer.add_scalar("loss_det", loss_det.item(), global_step) | |
| writer.add_scalar("loss_iou", loss_iou.item(), global_step) | |
| writer.add_scalar("loss_obj", loss_obj.item(), global_step) | |
| writer.add_scalar("loss_cls", loss_cls.item(), global_step) | |
| if loss_rel.item() != 0: | |
| writer.add_scalar("loss_rel", loss_rel.item(), global_step) | |
| if args.use_format_v2: | |
| writer.add_scalar("loss_iou_visual", visual_loss_iou.item(), global_step) | |
| writer.add_scalar("loss_obj_visual", visual_loss_obj.item(), global_step) | |
| writer.add_scalar("loss_iou_previsual", previsual_loss_iou.item(), global_step) | |
| writer.add_scalar("loss_obj_previsual", previsual_loss_obj.item(), global_step) | |
| global_sample_num = total_laion_sample | |
| writer.add_scalar("loss_groundcaption_vs_sample_num", loss_laion.item(), global_sample_num) | |
| writer.add_scalar("loss_laion_vs_sample_num", loss_caption.item(), global_sample_num) | |
| writer.add_scalar("loss_pile_vs_sample_num", loss_pile.item(), global_sample_num) | |
| writer.add_scalar("loss_vs_sample_num", loss.item(), global_sample_num) | |
| writer.add_scalar("loss_det_vs_sample_num", loss_det.item(), global_sample_num) | |
| writer.add_scalar("loss_iou_vs_sample_num", loss_iou.item(), global_sample_num) | |
| writer.add_scalar("loss_obj_vs_sample_num", loss_obj.item(), global_sample_num) | |
| if loss_rel.item() != 0: | |
| writer.add_scalar("loss_rel_vs_sample_num", loss_rel.item(), global_sample_num) | |
| writer.add_scalar("lr_vs_sample_num", optimizer.param_groups[0]["lr"], global_sample_num) | |
| writer.add_scalar("loss_groundcaption_vs_token", loss_laion.item(), total_laion_token) | |
| writer.add_scalar("loss_laion_vs_token", loss_caption.item(), total_laion_token) | |
| writer.add_scalar("loss_pile_vs_token", loss_pile.item(), total_pile_token) | |
| writer.add_scalar("loss_det_vs_token", loss_det.item(), total_laion_token) | |
| writer.add_scalar("loss_iou_vs_token", loss_iou.item(), total_laion_token) | |
| writer.add_scalar("loss_obj_vs_token", loss_obj.item(), total_laion_token) | |
| writer.add_scalar("loss_cls_vs_token", loss_cls.item(), total_laion_token) | |
| if loss_rel.item() != 0: | |
| writer.add_scalar("loss_rel_vs_token", loss_rel.item(), total_laion_token) | |
| total_token = total_laion_token + total_pile_token | |
| writer.add_scalar("sample_num", global_sample_num, global_step) | |
| writer.add_scalar("total_laion_token", total_laion_token, global_step) | |
| writer.add_scalar("total_pile_token", total_pile_token, global_step) | |
| writer.add_scalar("total_token", total_token, global_step) | |
| logging.info( | |
| f"[{global_step}][{total_laion_sample}][{total_token}]. total: {loss.item():.3f} // laion: {loss_caption.item():.3f} // pile: {loss_pile.item():.3f} // iou: {loss_iou.item():.4f} // obj: {loss_obj.item():.4f} // previsual_obj: {previsual_loss_obj.item():.4f} // visual_obj: {visual_loss_obj.item():.4f} // previsual_iou: {previsual_loss_iou.item():.4f} // visual_iou: {visual_loss_iou.item():.4f} // lr: {lr:.2e} // scale: {scaler.get_scale()}" | |
| ) | |
| if total_step % args.save_interval == 0 and total_step != last_save_step: | |
| last_save_step = total_step | |
| torch.distributed.barrier() | |
| if args.ddp: | |
| cpu_state = model.state_dict() | |
| # if args.rank == 0: | |
| # optimizer_state = optimizer.state_dict() | |
| else: | |
| save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) | |
| with FSDP.state_dict_type( | |
| model, StateDictType.FULL_STATE_DICT, save_policy | |
| ): | |
| cpu_state = model.state_dict() | |
| torch.distributed.barrier() | |
| # https://pytorch.org/docs/1.12/fsdp.html | |
| # need to pass optim_groups as optim_input | |
| # optimizer_state = FSDP.full_optim_state_dict(model, optimizer, optim_input=optim_groups) | |
| if args.rank == 0: | |
| checkpoint_dict = { | |
| "model_state_dict": cpu_state, | |
| # "optimizer_state_dict": optimizer_state, | |
| "lr_scheduler_state_dict": lr_scheduler.state_dict(), | |
| "scaler_state_dict": scaler.state_dict(), | |
| "total_pile_token": total_pile_token, | |
| "total_laion_token": total_laion_token, | |
| "total_laion_sample": total_laion_sample, | |
| "total_step": total_step, | |
| } | |
| logging.info(f"Saving checkpoint to {args.run_name}/checkpoint_{total_step}.pt") | |
| torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{total_step}.pt") | |
| del checkpoint_dict | |
| if args.delete_previous_checkpoint and total_step-args.save_interval > 0 and (total_step-args.save_interval) % args.skip_delete_pattern != 0: | |
| try: | |
| os.remove(f"{args.run_name}/checkpoint_{total_step-args.save_interval}.pt") | |
| except: | |
| pass | |
| torch.distributed.barrier() | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |