Spaces:
Build error
Build error
| # Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import functools | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| import torch.distributed.fsdp._traversal_utils as traversal_utils | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.fsdp import ( | |
| CPUOffload, | |
| FullyShardedDataParallel as FSDP, | |
| MixedPrecision, | |
| BackwardPrefetch, | |
| ShardingStrategy, | |
| FullStateDictConfig, | |
| StateDictType, | |
| ) | |
| from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy | |
| from safetensors.torch import load_file, save_file | |
| from modeling.bagel.modeling_utils import MLPconnector, TimestepEmbedder, PositionEmbedding | |
| from modeling.bagel.qwen2_navit import ( | |
| Qwen2DecoderLayer, | |
| Qwen2MoEDecoderLayer, | |
| Qwen2MoTDecoderLayer, | |
| ) | |
| from modeling.bagel.siglip_navit import SiglipEncoderLayer, SiglipVisionTransformer | |
| class FSDPConfig: | |
| def __init__( | |
| self, | |
| sharding_strategy, | |
| backward_prefetch, | |
| cpu_offload, | |
| num_replicate, | |
| num_shard=8, | |
| ): | |
| self.sharding_strategy = sharding_strategy | |
| self.backward_prefetch = backward_prefetch | |
| self.cpu_offload = cpu_offload | |
| self.num_replicate = num_replicate | |
| self.num_shard = num_shard | |
| def fsdp_wrapper(original_model, fsdp_config, ignored_modules=[]): | |
| if fsdp_config.sharding_strategy == 'HYBRID_SHARD': | |
| device_mesh = init_device_mesh( | |
| "cuda", | |
| mesh_shape=(fsdp_config.num_replicate, fsdp_config.num_shard), | |
| mesh_dim_names=("replicate", "shard") | |
| ) | |
| else: | |
| device_mesh = None | |
| return FSDP( | |
| original_model, | |
| auto_wrap_policy=functools.partial( | |
| transformer_auto_wrap_policy, | |
| transformer_layer_cls={ | |
| Qwen2DecoderLayer, | |
| Qwen2MoEDecoderLayer, | |
| Qwen2MoTDecoderLayer, | |
| SiglipEncoderLayer, | |
| SiglipVisionTransformer, | |
| MLPconnector, | |
| TimestepEmbedder, | |
| PositionEmbedding, | |
| }, | |
| ), | |
| ignored_modules=ignored_modules, | |
| mixed_precision=MixedPrecision( | |
| param_dtype=torch.bfloat16, | |
| reduce_dtype=torch.bfloat16, | |
| buffer_dtype=torch.bfloat16, | |
| ), | |
| device_id=dist.get_rank() % torch.cuda.device_count(), | |
| sharding_strategy=ShardingStrategy[fsdp_config.sharding_strategy], | |
| backward_prefetch=BackwardPrefetch[fsdp_config.backward_prefetch], | |
| cpu_offload=CPUOffload(offload_params=fsdp_config.cpu_offload), | |
| device_mesh=device_mesh, | |
| ) | |
| class FSDPCheckpoint: | |
| def fsdp_save_ckpt( | |
| ckpt_dir, | |
| train_steps, | |
| model, | |
| ema_model, | |
| optimizer, | |
| scheduler, | |
| data_status, | |
| logger, | |
| fsdp_config, | |
| ): | |
| save_path = os.path.join(ckpt_dir, f"{train_steps:07d}") | |
| os.makedirs(save_path, exist_ok=True) | |
| logger.info(f"Saving checkpoint to {save_path}.") | |
| if ema_model is not None: | |
| with FSDP.state_dict_type( | |
| ema_model, | |
| StateDictType.FULL_STATE_DICT, | |
| FullStateDictConfig(rank0_only=True, offload_to_cpu=True), | |
| ): | |
| ema_state_dict = ema_model.state_dict() | |
| if dist.get_rank() == 0: | |
| save_file(ema_state_dict, os.path.join(save_path, "ema.safetensors")) | |
| with FSDP.state_dict_type( | |
| model, | |
| StateDictType.FULL_STATE_DICT, | |
| FullStateDictConfig(rank0_only=True, offload_to_cpu=True), | |
| ): | |
| model_state_dict = model.state_dict() | |
| if dist.get_rank() == 0: | |
| save_file(model_state_dict, os.path.join(save_path, "model.safetensors")) | |
| with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT): | |
| if fsdp_config.sharding_strategy == "FULL_SHARD": | |
| shard_index = dist.get_rank() | |
| total_shards = dist.get_world_size() | |
| elif fsdp_config.sharding_strategy == "HYBRID_SHARD": | |
| shard_index = dist.get_rank() % fsdp_config.num_shard | |
| total_shards = fsdp_config.num_shard | |
| else: | |
| raise NotImplementedError | |
| optimizer_save_path = os.path.join( | |
| save_path, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" | |
| ) | |
| if fsdp_config.sharding_strategy == "FULL_SHARD": | |
| torch.save(optimizer.state_dict(), optimizer_save_path) | |
| elif fsdp_config.sharding_strategy == "HYBRID_SHARD": | |
| if dist.get_rank() < fsdp_config.num_shard: | |
| torch.save(optimizer.state_dict(), optimizer_save_path) | |
| else: | |
| raise NotImplementedError | |
| if dist.get_rank() == 0 and scheduler is not None: | |
| torch.save(scheduler.state_dict(), os.path.join(save_path, "scheduler.pt")) | |
| if dist.get_rank() == 0 and data_status is not None: | |
| torch.save(data_status, os.path.join(save_path, "data_status.pt")) | |
| dist.barrier() | |
| return | |
| def try_load_ckpt(resume_from, logger, model, ema_model=None, resume_from_ema=False): | |
| if resume_from is not None and os.path.exists(resume_from): | |
| logger.info(f"Loading checkpoint from {resume_from}.") | |
| if resume_from_ema: | |
| model_state_dict_path = os.path.join(resume_from, f"ema.safetensors") | |
| else: | |
| model_state_dict_path = os.path.join(resume_from, f"model.safetensors") | |
| model_state_dict = load_file(model_state_dict_path, device="cpu") | |
| # NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off, | |
| # which makes it easier to adapt to different resolutions. | |
| model_state_dict.pop('latent_pos_embed.pos_embed') | |
| model_state_dict.pop('vit_pos_embed.pos_embed') | |
| msg = model.load_state_dict(model_state_dict, strict=False) | |
| logger.info(msg) | |
| del model_state_dict | |
| if ema_model is not None: | |
| ema_state_dict_path = os.path.join(resume_from, f"ema.safetensors") | |
| if not os.path.exists(ema_state_dict_path): | |
| logger.info(f"replicaing ema model from {model_state_dict_path}.") | |
| ema_state_dict_path = model_state_dict_path | |
| ema_state_dict = load_file(ema_state_dict_path, device="cpu") | |
| # NOTE position embeds are fixed sinusoidal embeddings, so we can just pop it off, | |
| # which makes it easier to adapt to different resolutions. | |
| ema_state_dict.pop('latent_pos_embed.pos_embed') | |
| ema_state_dict.pop('vit_pos_embed.pos_embed') | |
| msg = ema_model.load_state_dict(ema_state_dict, strict=False) | |
| logger.info(msg) | |
| del ema_state_dict | |
| else: | |
| logger.info(f"Training from scratch.") | |
| return model, ema_model | |
| def try_load_train_state(resume_from, optimizer, scheduler, fsdp_config): | |
| if resume_from is not None and os.path.exists(resume_from): | |
| if fsdp_config.sharding_strategy == "FULL_SHARD": | |
| shard_index = dist.get_rank() | |
| total_shards = dist.get_world_size() | |
| elif fsdp_config.sharding_strategy == "HYBRID_SHARD": | |
| shard_index = dist.get_rank() % fsdp_config.num_shard | |
| total_shards = fsdp_config.num_shard | |
| else: | |
| raise NotImplementedError | |
| optimizer_state_dict_path = os.path.join( | |
| resume_from, f"optimizer.{shard_index:05d}-of-{total_shards:05d}.pt" | |
| ) | |
| optimizer_state_dict = torch.load(optimizer_state_dict_path, map_location="cpu", weights_only=True) | |
| optimizer.load_state_dict(optimizer_state_dict) | |
| del optimizer_state_dict | |
| scheduler_state_dict_path = os.path.join(resume_from, "scheduler.pt") | |
| scheduler_state_dict = torch.load(scheduler_state_dict_path, weights_only=True, map_location="cpu") | |
| scheduler.load_state_dict(scheduler_state_dict) | |
| del scheduler_state_dict | |
| train_steps = int(os.path.basename(os.path.normpath(resume_from))) + 1 | |
| """ | |
| data_status = [ | |
| { | |
| dataset_name: { | |
| worker_id: [parquet_idx, row_group_id, row_idx], | |
| }, | |
| }, | |
| ] | |
| """ | |
| data_status_path = os.path.join(resume_from, "data_status.pt") | |
| if os.path.exists(data_status_path): | |
| data_status = torch.load(data_status_path, weights_only=True, map_location="cpu") | |
| local_rank = dist.get_rank() | |
| if local_rank < len(data_status): | |
| data_status = data_status[local_rank] | |
| else: | |
| data_status = None | |
| else: | |
| data_status = None | |
| else: | |
| train_steps = 0 | |
| data_status = None | |
| return optimizer, scheduler, train_steps, data_status | |
| def grad_checkpoint_check_fn(module): | |
| module_options = ( | |
| Qwen2DecoderLayer, | |
| SiglipEncoderLayer, | |
| MLPconnector, | |
| Qwen2MoEDecoderLayer, | |
| Qwen2MoTDecoderLayer | |
| ) | |
| return isinstance(module, module_options) | |
| def fsdp_ema_setup(ema_model, fsdp_config, ignored_modules=[]): | |
| for param in ema_model.parameters(): | |
| param.requires_grad = False | |
| ema_model = fsdp_wrapper(ema_model, fsdp_config, ignored_modules=ignored_modules) | |
| return ema_model | |
| def fsdp_ema_update(ema_model, model, decay=0.9999): | |
| ema_handles = traversal_utils._get_fsdp_handles(ema_model) | |
| new_handles = traversal_utils._get_fsdp_handles(model) | |
| assert len(ema_handles) == len(new_handles) | |
| ema_params = [] | |
| new_params = [] | |
| for ema_handle, new_handle in zip(ema_handles, new_handles): | |
| if ema_handle.flat_param is not None and new_handle.flat_param.requires_grad: | |
| ema_params.append(ema_handle.flat_param.data) | |
| new_params.append(new_handle.flat_param.data.to(dtype=ema_handle.flat_param.dtype)) | |
| torch._foreach_mul_(ema_params, decay) | |
| torch._foreach_add_(ema_params, new_params, alpha=1 - decay) | |