Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import copy | |
| import json | |
| import os | |
| import pathlib | |
| import re | |
| import warnings | |
| from dataclasses import dataclass | |
| import torch | |
| import torch.distributed as dist | |
| from accelerate.hooks import add_hook_to_module | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled | |
| from llava.train.sequence_parallel.globals import get_pg_manager, get_ulysses_sp_pg | |
| def rprint(*args, **kwargs): | |
| rank = int(os.environ.get("RANK", 0)) | |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
| if world_size > 1 and dist.is_initialized(): | |
| return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) | |
| else: | |
| return print(*args, **kwargs) | |
| def mprint(*args, **kwargs): | |
| rank = int(os.environ.get("RANK", 0)) | |
| world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
| if world_size > 1 and dist.is_initialized(): | |
| if rank == 0: | |
| return print(f"[dist-{rank}-of-{world_size}]", *args, **kwargs) | |
| else: | |
| return | |
| else: | |
| return print(*args, **kwargs) | |
| def is_local(model_name_or_path: str) -> bool: | |
| return os.path.isdir(model_name_or_path) | |
| def get_checkpoint_path(output_dir: str, checkpoint_prefix: str = "checkpoint") -> str | None: | |
| output_dir = os.path.abspath(output_dir) | |
| pathlib_dir = pathlib.Path(output_dir) | |
| if list(pathlib_dir.glob("config.json")): | |
| # training has been finished | |
| return output_dir, False | |
| else: | |
| try: | |
| ordering_and_checkpoint_path = [] | |
| glob_checkpoints = [ | |
| str(x) for x in pathlib.Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x) | |
| ] | |
| for path in glob_checkpoints: | |
| regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) | |
| if regex_match is not None and regex_match.groups() is not None: | |
| ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) | |
| checkpoints_sorted = sorted(ordering_and_checkpoint_path) | |
| return checkpoints_sorted[-1][1], True | |
| except: | |
| return None, True | |
| def prepare_config_for_training( | |
| config: PretrainedConfig, model_args: dataclass, training_args: dataclass, data_args: dataclass | |
| ) -> None: | |
| config.chat_template = model_args.chat_template | |
| assert model_args.vision_tower is not None, "requires vision tower" | |
| assert model_args.speech_tower is not None, "requires speech tower" | |
| assert model_args.sound_tower is not None, "requires sound tower" | |
| # set module configurations | |
| if getattr(config, "llm_cfg", None) is None: | |
| config.llm_cfg = model_args.model_name_or_path | |
| if getattr(config, "vision_tower_cfg", None) is None: | |
| config.vision_tower_cfg = model_args.vision_tower | |
| if getattr(config, "speech_tower_cfg", None) is None: | |
| config.speech_tower_cfg = model_args.speech_tower | |
| if getattr(config, "sound_tower_cfg", None) is None: | |
| config.sound_tower_cfg = model_args.sound_tower | |
| if getattr(config, "mm_projector_cfg", None) is None: | |
| config.mm_projector_cfg = model_args.mm_projector | |
| if getattr(config, "speech_mm_projector_cfg", None) is None: | |
| config.speech_mm_projector_cfg = model_args.speech_mm_projector | |
| if getattr(config, "sound_mm_projector_cfg", None) is None: | |
| config.sound_mm_projector_cfg = model_args.sound_mm_projector | |
| # set default dtype | |
| config.model_dtype = torch.bfloat16 if training_args.bf16 else torch.float16 | |
| config.model_dtype = config.model_dtype.__str__() | |
| # set tuning modules | |
| config.tune_language_model = training_args.tune_language_model | |
| config.tune_vision_tower = training_args.tune_vision_tower | |
| config.tune_speech_tower = training_args.tune_speech_tower | |
| config.tune_sound_tower = training_args.tune_sound_tower | |
| config.tune_mm_projector = training_args.tune_mm_projector | |
| config.tune_speech_mm_projector = training_args.tune_speech_mm_projector | |
| config.tune_sound_mm_projector = training_args.tune_sound_mm_projector | |
| # set data args | |
| # Get the image_aspect_ratio from the config if is defined there | |
| # (case of resuming from a checkpoint) or from the data_args | |
| # (i.e. from the command line when starting a new training). | |
| if getattr(data_args, "image_aspect_ratio", None) is not None: | |
| if getattr(config, "image_aspect_ratio", None) is None: | |
| config.image_aspect_ratio = data_args.image_aspect_ratio | |
| elif getattr(config, "image_aspect_ratio", None) is not None: | |
| data_args.image_aspect_ratio = config.image_aspect_ratio | |
| else: | |
| raise ValueError("image_aspect_ratio must be set either in data_args or in the pretrained config") | |
| if ( | |
| hasattr(training_args, "deepspeed") | |
| and training_args.deepspeed is not None | |
| and "mics" in training_args.deepspeed | |
| ): | |
| config.deepspeed = training_args.deepspeed | |
| for key, value in model_args.__dict__.items(): | |
| try: | |
| value = json.loads(value) | |
| except: | |
| pass | |
| setattr(config, key, value) | |
| def vision_resolution_elevation(model: PreTrainedModel, config: PretrainedConfig): | |
| vision_tower = model.get_vision_tower() | |
| if vision_tower is not None and "radio" not in vision_tower.__class__.__name__.lower(): | |
| vision_tower._maybe_resize_pos_embeds( | |
| model=vision_tower.vision_tower, | |
| image_processor=vision_tower.image_processor, | |
| resolution=getattr(config, "vision_resolution", -1), | |
| interpolate_mode=getattr(config, "interpolate_mode", "linear"), | |
| ) | |
| def unit_test_rope_scaling(model: PreTrainedModel, config: PretrainedConfig, training_args: dataclass): | |
| return False | |
| def calculate_loss_weight(labels, ignore_index=-100): | |
| # (Qinghao): Weighted loss based on num_active_elements | |
| # To achieve accurate sequence parallel loss calculation, we need to get | |
| # the real active_elements of each sequence partitions. | |
| # For data parallelism, the loss almost remains the same (also more accurate). | |
| shift_labels = labels[..., 1:].contiguous() | |
| shift_labels = shift_labels.view(-1) | |
| padding_mask = shift_labels.eq(ignore_index) # IGNORE_INDEX = -100 by default | |
| num_active_elements = padding_mask.numel() - padding_mask.long().sum() | |
| # global_active_sum = copy.deepcopy(num_active_elements) | |
| global_active_sum = num_active_elements.detach().clone() | |
| dist.all_reduce(global_active_sum) | |
| loss_weight = num_active_elements / global_active_sum * dist.get_world_size() | |
| return loss_weight | |
| def reshard_hiddne_states_and_labels(hidden_states, labels): | |
| PROCESS_GROUP_MANAGER = get_pg_manager() | |
| sp_degree = PROCESS_GROUP_MANAGER.sp_degree | |
| sp_rank = PROCESS_GROUP_MANAGER.sp_rank | |
| sp_group = PROCESS_GROUP_MANAGER.ulysses_pg | |
| from llava.constants import IGNORE_INDEX | |
| # Get the seq len on different sp ranks | |
| bs, shard_seqlen = labels.shape | |
| ulysses_seq_len = [torch.zeros(1, dtype=torch.int64, device=labels.device) for _ in range(sp_degree)] | |
| dist.barrier(group=sp_group) | |
| dist.all_gather(ulysses_seq_len, torch.tensor(shard_seqlen, device=labels.device), group=sp_group) | |
| dist.barrier(group=sp_group) | |
| global_seq_len = torch.cat(ulysses_seq_len, dim=0) | |
| # Gather all labels and flaten them | |
| all_labels = [ | |
| torch.zeros(bs, seq_len, dtype=labels.dtype, device=labels.device).contiguous() for seq_len in ulysses_seq_len | |
| ] | |
| dist.all_gather(all_labels, labels.contiguous(), group=sp_group) | |
| # flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].view(-1) | |
| flatten_global_labels = torch.cat(all_labels, dim=1)[:, 1:].contiguous().view(-1) | |
| # Get the label!=IGNORE_INDEX's index | |
| flatten_label_mask = flatten_global_labels.ne(IGNORE_INDEX) | |
| flatten_effective_label_index = flatten_label_mask.nonzero(as_tuple=True) | |
| # padding the effective_label_index if the length is smaller than sp_degree | |
| if flatten_effective_label_index[0].shape[0] < sp_degree: | |
| warnings.warn( | |
| f"The effective label length {flatten_effective_label_index[0].shape[0]} is smaller than sp_degree {sp_degree}, padding the index" | |
| ) | |
| repeat_num = sp_degree // flatten_effective_label_index[0].shape[0] + 1 | |
| else: | |
| repeat_num = 1 | |
| # Reconstruct the labels by selecting from the global labels | |
| effective_global_labels = flatten_global_labels[flatten_effective_label_index] | |
| if repeat_num > 1: | |
| effective_global_labels = effective_global_labels.repeat(repeat_num) | |
| # Global effective seqence length | |
| global_effective_seq_len = effective_global_labels.shape[0] | |
| reshard_size = global_effective_seq_len // sp_degree | |
| # Hyper parameters to reshard the hidden states and labels | |
| if sp_rank == 0: | |
| original_start_id = 0 | |
| original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() | |
| start_id = 0 | |
| end_id = reshard_size * (sp_rank + 1) | |
| elif sp_rank == sp_degree - 1: | |
| original_start_id = torch.sum(global_seq_len[:sp_rank]).item() | |
| original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() | |
| start_id = reshard_size * sp_rank | |
| end_id = global_effective_seq_len | |
| else: | |
| original_start_id = torch.sum(global_seq_len[:sp_rank]).item() | |
| original_end_id = torch.sum(global_seq_len[: sp_rank + 1]).item() | |
| start_id = reshard_size * sp_rank | |
| end_id = reshard_size * (sp_rank + 1) | |
| # Get the local labels | |
| effective_local_labels = torch.narrow(effective_global_labels, 0, start_id, end_id - start_id) | |
| # Gather all hidden states and flaten them | |
| # all_hidden_states = [torch.zeros(bs, seq_len, hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device, requires_grad=True).contiguous() for seq_len in ulysses_seq_len] | |
| all_hidden_states = torch.zeros( | |
| bs, torch.sum(global_seq_len), hidden_states.shape[-1], dtype=hidden_states.dtype, device=hidden_states.device | |
| ).contiguous() | |
| all_hidden_states[:, original_start_id:original_end_id, :] += hidden_states | |
| dist.barrier(group=sp_group) | |
| dist.all_reduce(all_hidden_states, group=sp_group) | |
| dist.barrier(group=sp_group) | |
| flatten_global_hidden_states = all_hidden_states[:, :-1, :].contiguous().view(-1, hidden_states.shape[-1]) | |
| # Get the local hidden states | |
| effective_flatten_global_hidden_states = flatten_global_hidden_states[flatten_effective_label_index] | |
| if repeat_num > 1: | |
| effective_flatten_global_hidden_states = effective_flatten_global_hidden_states.repeat(repeat_num, 1) | |
| effective_local_hidden_states = torch.narrow(effective_flatten_global_hidden_states, 0, start_id, end_id - start_id) | |
| return effective_local_hidden_states, effective_local_labels | |
| def sp_loss_rescale(shift_labels, loss): | |
| from llava.constants import IGNORE_INDEX | |
| PROCESS_GROUP_MANAGER = get_pg_manager() | |
| labels_mask = shift_labels.ne(IGNORE_INDEX) # IGNORE_INDEX = -100 by default | |
| num_active_elements = torch.sum(labels_mask) | |
| global_active_sum = copy.deepcopy(num_active_elements) | |
| # dist.barrier(group=get_ulysses_sp_pg()) | |
| dist.all_reduce(global_active_sum, group=get_ulysses_sp_pg()) | |
| # print(loss.shape, num_active_elements.shape, global_active_sum.shape) | |
| loss = loss * num_active_elements / global_active_sum | |
| dist.all_reduce(loss, group=get_ulysses_sp_pg()) | |
| return loss | |