Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # InternVL | |
| # Copyright (c) 2024 OpenGVLab | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # -------------------------------------------------------- | |
| import bisect | |
| import copy | |
| import logging | |
| from collections import defaultdict | |
| from typing import List, Union | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from torch.utils.data import IterableDataset, get_worker_info | |
| from transformers.trainer_pt_utils import LabelSmoother | |
| from .constants import IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN | |
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def is_dist_avail_and_initialized(): | |
| if not dist.is_available(): | |
| return False | |
| if not dist.is_initialized(): | |
| return False | |
| return True | |
| def get_world_size(): | |
| if not is_dist_avail_and_initialized(): | |
| return 1 | |
| return dist.get_world_size() | |
| def get_rank(): | |
| if not is_dist_avail_and_initialized(): | |
| return 0 | |
| return dist.get_rank() | |
| class PackedDataset(IterableDataset): | |
| def __init__( | |
| self, | |
| tokenizer, | |
| data_rank, | |
| data_world_size, | |
| datasets: List, | |
| dataset_weight: List[int] = None, | |
| num_images_expected: int = 6, | |
| max_packed_tokens: int = 32768, | |
| max_buffer_size: int = 100, | |
| log_freq: int = 1000000, | |
| strict_mode: bool = False, | |
| debug_mode: bool = False, | |
| replacement: bool = True, | |
| allow_overflow: bool = True, | |
| allow_empty_data: bool = False, | |
| allow_deduplicated_ds_name: bool = False, | |
| ): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.data_rank = data_rank | |
| self.data_world_size = data_world_size | |
| self.datasets = datasets | |
| self.num_images_expected = num_images_expected | |
| self.max_buffer_size = max_buffer_size | |
| self.log_freq = log_freq | |
| self.strict_mode = strict_mode | |
| self.debug_mode = debug_mode | |
| self.replacement = replacement | |
| self.allow_overflow = allow_overflow | |
| self.allow_empty_data = allow_empty_data | |
| self.max_packed_tokens = max_packed_tokens | |
| self.img_start_token_id = self.tokenizer.convert_tokens_to_ids(IMG_START_TOKEN) | |
| self.img_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) | |
| self.img_end_token_id = self.tokenizer.convert_tokens_to_ids(IMG_END_TOKEN) | |
| assert self.img_start_token_id != self.tokenizer.unk_token_id | |
| assert self.img_token_id != self.tokenizer.unk_token_id | |
| assert self.img_end_token_id != self.tokenizer.unk_token_id | |
| if dataset_weight is None: | |
| dataset_weight = [1] * len(datasets) | |
| self.dataset_type = [d.dataset_type for d in self.datasets] | |
| self.datasets_orig = datasets | |
| self.dataset_weight_orig = [w / sum(dataset_weight) for w in dataset_weight] | |
| self.datasets = [ds for ds in self.datasets_orig] | |
| self.dataset_weight = [w for w in self.dataset_weight_orig] | |
| # lazy init | |
| self.worker_id = None | |
| self.worker_state_key = None | |
| self.dataset_iter_list = None | |
| self._state_dict = { | |
| 'sample_info': {d.ds_name:0 for d in self.datasets}, | |
| } | |
| self.worker_custom_infos = None | |
| ds_name_list = [d.ds_name for d in self.datasets] | |
| if not allow_deduplicated_ds_name: | |
| assert len(ds_name_list) == len(set(ds_name_list)), f'deduplicated ds_name: {ds_name_list}' | |
| for ds in self.datasets: | |
| if ds.max_num_images > self.num_images_expected: | |
| logger.warning(f'{ds.max_num_images=} of {ds.ds_name} is larger than {self.num_images_expected=}') | |
| ds.max_num_images = num_images_expected | |
| if ds.max_tokens > self.max_packed_tokens: | |
| logger.warning(f'{ds.max_tokens=} of {ds.ds_name} is larger than {self.max_packed_tokens=}') | |
| ds.max_tokens = self.max_packed_tokens | |
| self._state_dict[ds.ds_name] = {} | |
| if get_rank() == 0: | |
| logger.info( | |
| f'Loaded dataset to pack: {ds_name_list}, ' | |
| f'{self.num_images_expected=}, {self.max_packed_tokens=}, ' | |
| f'{self.replacement=}, {self.allow_overflow=}', | |
| ) | |
| temp = [] | |
| for ds, ds_w in zip(self.datasets, self.dataset_weight): | |
| temp.append(f'{ds.ds_name:<25}: {ds_w*100:.2f}%') | |
| temp = '\n'.join(temp) | |
| logger.info( | |
| f'Sampling prob for each dataset:\n{temp}' | |
| ) | |
| if self.allow_empty_data: | |
| logger.warning('allow_empty_data is enabled, note that empty data may be generated!') | |
| def load_state_dict(self, state_dict, custom_infos=None): | |
| self.worker_custom_infos = custom_infos | |
| self._state_dict.update(state_dict) | |
| for ds in self.datasets: | |
| if ds.ds_name in self._state_dict: | |
| ds.load_state_dict(self._state_dict[ds.ds_name]) | |
| logger.info(f'{ds.ds_name=} is resumed.') | |
| else: | |
| logger.warning(f'{ds.ds_name=} is not resumed.') | |
| def _should_log(self): | |
| worker_id = 0 if get_worker_info() is None else get_worker_info().id | |
| num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers | |
| worker_id = num_workers * get_rank() + worker_id | |
| num_workers = num_workers * get_world_size() | |
| return worker_id == 0 | |
| def next_data(self, current_dataset_idx): | |
| while True: | |
| try: | |
| current_sample = next(self.dataset_iter_list[current_dataset_idx]) | |
| break # Exit loop if successful | |
| except StopIteration: | |
| if self.replacement: | |
| # logger.info(f'[Worker id {self.worker_id}] Dataset {self.datasets[current_dataset_idx].ds_name} is exhausted, restart it.') | |
| try: | |
| self.dataset_iter_list[current_dataset_idx] = iter(self.datasets[current_dataset_idx]) | |
| current_sample = next(self.dataset_iter_list[current_dataset_idx]) | |
| break | |
| except: | |
| # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') | |
| self.datasets.pop(current_dataset_idx) | |
| self.dataset_iter_list.pop(current_dataset_idx) | |
| self.dataset_weight.pop(current_dataset_idx) | |
| if len(self.datasets) == 0: | |
| raise StopIteration | |
| current_dataset_idx = np.random.choice(len(self.datasets)) | |
| else: | |
| # logger.error(f'{self.worker_id=} Fail to get any data from {self.datasets[current_dataset_idx].ds_name}! length={len(self.datasets)}') | |
| self.datasets.pop(current_dataset_idx) | |
| self.dataset_iter_list.pop(current_dataset_idx) | |
| self.dataset_weight.pop(current_dataset_idx) | |
| if len(self.datasets) == 0: | |
| raise StopIteration | |
| current_dataset_idx = np.random.choice(len(self.datasets)) | |
| except: | |
| logger.error('Unexpected error!') | |
| if len(self.datasets) == 0: | |
| raise StopIteration | |
| current_dataset_idx = np.random.choice(len(self.datasets)) | |
| current_ds_name = self.datasets[current_dataset_idx].ds_name | |
| current_sample['type_ids'] = torch.zeros_like(current_sample['input_ids']) + current_dataset_idx | |
| if self.worker_state_key not in self._state_dict[current_ds_name]: | |
| self._state_dict[current_ds_name][self.worker_state_key] = {} | |
| meta_info = current_sample.pop('meta_info', {}) | |
| self._state_dict[current_ds_name][self.worker_state_key].update(**meta_info) | |
| self._state_dict['sample_info'][self.datasets[current_dataset_idx].ds_name] += 1 | |
| return current_sample | |
| def find_buffer(self, buffer_list, new_sample): | |
| # NOTE: use `bisect` to search might be faster | |
| find = False | |
| find_idx = -1 | |
| num_images_current = new_sample['pixel_values'].size(0) | |
| for buffer_idx, buffer in enumerate(buffer_list): | |
| num_images_buffer = buffer['pixel_values'].size(0) | |
| if num_images_buffer + num_images_current <= self.num_images_expected: | |
| num_merged_tokens = new_sample['input_ids'].size(0) + buffer['input_ids'].size(0) | |
| if num_merged_tokens <= self.max_packed_tokens: | |
| find = True | |
| find_idx = buffer_idx | |
| break | |
| if self.allow_overflow and len(buffer_list) >= self.max_buffer_size // 2: | |
| find = True | |
| find_idx = buffer_idx | |
| if find: | |
| return buffer_list.pop(find_idx) | |
| return None | |
| def update_buffer(self, buffer, new_sample): | |
| if buffer is None: | |
| new_sample['data_index'] = torch.zeros_like(new_sample['input_ids']) | |
| return new_sample | |
| new_sample['data_index'] = torch.ones_like(new_sample['input_ids']) + buffer['data_index'][-1].item() | |
| assert buffer.keys() == new_sample.keys() | |
| for k in buffer: | |
| buffer[k] = torch.cat([buffer[k], new_sample[k]]) | |
| return buffer | |
| def check_valid(sample_to_check, min_active_tokens_ratio=1/256): | |
| num_ignore_tokens = (sample_to_check['labels'] == IGNORE_TOKEN_ID).sum() | |
| num_tokens = sample_to_check['labels'].numel() | |
| return (1 - num_ignore_tokens / num_tokens) > min_active_tokens_ratio | |
| def split_buffer(buffer, max_tokens, img_start_token_id, img_token_id, img_end_token_id): | |
| if buffer['input_ids'].size(0) <= max_tokens: | |
| return [buffer] | |
| def _image_is_splitted(input_ids, cut_idx): | |
| is_image_start = input_ids[cut_idx].item() == img_start_token_id | |
| is_image_token = input_ids[cut_idx].item() == img_token_id | |
| is_image_end = input_ids[cut_idx].item() == img_end_token_id | |
| return is_image_start or is_image_token or is_image_end | |
| def _split(sample_to_split, left_idx, right_idx, left_img_idx, right_img_idx): | |
| assert (right_idx is None) == (right_img_idx is None) | |
| left_sample = {} | |
| right_sample = {} if right_idx is not None else None | |
| for k in sample_to_split: | |
| if k in ['input_ids', 'labels', 'attention_mask', 'position_ids', 'data_index', 'type_ids']: | |
| left_sample[k] = sample_to_split[k][:left_idx] | |
| if right_sample is not None: | |
| right_sample[k] = sample_to_split[k][right_idx:] | |
| elif k in ['pixel_values', 'image_flags']: | |
| left_sample[k] = sample_to_split[k][:left_img_idx] | |
| if right_sample is not None: | |
| right_sample[k] = sample_to_split[k][right_img_idx:] | |
| else: | |
| raise NotImplementedError(f'find unsupported keys: {k} from {sample_to_split.keys()}') | |
| return left_sample, right_sample | |
| splitted_buffer = [] | |
| while buffer['input_ids'].size(0) > max_tokens: | |
| img_start_idx_list = (buffer['input_ids'] == img_start_token_id).nonzero().squeeze(1).tolist() | |
| img_end_idx_list = (buffer['input_ids'] == img_end_token_id).nonzero().squeeze(1).tolist() | |
| assert len(img_start_idx_list) == len(img_end_idx_list) | |
| if _image_is_splitted(buffer['input_ids'], max_tokens): | |
| cut_idx = bisect.bisect_left(img_start_idx_list, max_tokens) | |
| if buffer['input_ids'][max_tokens] == img_start_token_id: | |
| assert max_tokens == img_start_idx_list[cut_idx] | |
| cut_left_idx = img_start_idx_list[cut_idx] | |
| cut_left_img_idx = cut_idx | |
| else: | |
| cut_left_idx = img_start_idx_list[cut_idx - 1] | |
| cut_left_img_idx = cut_idx - 1 | |
| cut_right_idx = cut_left_idx | |
| cut_right_img_idx = cut_left_img_idx | |
| else: | |
| cut_img_idx = bisect.bisect(img_start_idx_list, max_tokens) | |
| if cut_img_idx < len(img_start_idx_list): | |
| cut_right_idx = img_start_idx_list[cut_img_idx] | |
| cut_right_img_idx = cut_img_idx | |
| else: | |
| cut_right_idx = None | |
| cut_right_img_idx = None | |
| cut_left_idx = max_tokens | |
| cut_left_img_idx = cut_right_img_idx if cut_right_img_idx is not None else buffer['pixel_values'].size(0) | |
| left, right = _split( | |
| sample_to_split=buffer, | |
| left_idx=cut_left_idx, | |
| left_img_idx=cut_left_img_idx, | |
| right_idx=cut_right_idx, | |
| right_img_idx=cut_right_img_idx, | |
| ) | |
| assert (left['input_ids'] == img_end_token_id).sum() == (left['input_ids'] == img_start_token_id).sum() == left['pixel_values'].size(0) | |
| if right is not None: | |
| assert (right['input_ids'] == img_end_token_id).sum() == (right['input_ids'] == img_start_token_id).sum() == right['pixel_values'].size(0) | |
| if left['pixel_values'].size(0) >= 1 and PackedDataset.check_valid(left): | |
| splitted_buffer.append(left) | |
| if right is None or right['pixel_values'].size(0) == 0: | |
| break | |
| buffer = right | |
| if buffer['input_ids'].size(0) <= max_tokens and PackedDataset.check_valid(buffer): | |
| splitted_buffer.append(buffer) | |
| break | |
| logger.debug( | |
| f'split a sample into {len(splitted_buffer)} samples, ' | |
| f'current max_tokens={max_tokens}' | |
| ) | |
| return splitted_buffer | |
| def update_buffer_list(self, buffer_list, buffer_max_len_list, buffer): | |
| # NOTE: in-place operation | |
| splitted_buffer = PackedDataset.split_buffer( | |
| buffer=buffer, | |
| max_tokens=self.max_packed_tokens, | |
| img_start_token_id=self.img_start_token_id, | |
| img_token_id=self.img_token_id, | |
| img_end_token_id=self.img_end_token_id, | |
| ) | |
| for each_buffer in splitted_buffer: | |
| if each_buffer['pixel_values'].size(0) > self.num_images_expected: | |
| logger.error( | |
| f"Find a sample with {each_buffer['pixel_values'].size(0)} images, " | |
| f'which exceeds {self.num_images_expected}' | |
| ) | |
| continue | |
| if each_buffer['input_ids'].size(0) >= self.max_packed_tokens: | |
| assert each_buffer['input_ids'].size(0) == self.max_packed_tokens | |
| buffer_max_len_list.append(each_buffer) | |
| continue | |
| find_idx = len(buffer_list) | |
| num_images_new_sample = each_buffer['pixel_values'].size(0) | |
| for buffer_idx in range(len(buffer_list)): | |
| if buffer_list[buffer_idx]['pixel_values'].size(0) < num_images_new_sample: | |
| find_idx = buffer_idx | |
| break | |
| buffer_list.insert(find_idx, each_buffer) | |
| for i in range(1, len(buffer_list)): | |
| assert buffer_list[i-1]['pixel_values'].size(0) >= buffer_list[i]['pixel_values'].size(0) | |
| return buffer_list, buffer_max_len_list | |
| def pad_buffer(self, buffer): | |
| if buffer['pixel_values'].size(0) == self.num_images_expected: | |
| return buffer | |
| num_pad_images = self.num_images_expected - buffer['pixel_values'].size(0) | |
| pad_images = torch.stack([ | |
| torch.zeros_like(buffer['pixel_values'][0]) | |
| for _ in range(num_pad_images) | |
| ]) | |
| pad_image_flags = torch.tensor([0] * num_pad_images, dtype=torch.long) | |
| buffer['pixel_values'] = torch.cat([buffer['pixel_values'], pad_images]) | |
| buffer['image_flags'] = torch.cat([buffer['image_flags'], pad_image_flags]) | |
| return buffer | |
| def postprocess_buffer(self, buffer, custom_infos=None): | |
| buffer['worker_state_key'] = self.worker_state_key | |
| buffer['worker_state_dict'] = self._state_dict | |
| if custom_infos is not None: | |
| buffer['custom_infos'] = {self.worker_state_key: copy.deepcopy(custom_infos)} | |
| return buffer | |
| def print_log(self, iter_idx, buffer_list): | |
| if iter_idx % self.log_freq != 0: | |
| return | |
| if self._should_log(): | |
| logger.info( | |
| f"{iter_idx=}, {len(buffer_list)=}, {self._state_dict['sample_info']}" | |
| ) | |
| def __iter__(self): | |
| iter_idx = 0 | |
| buffer_list = [] | |
| buffer_max_len_list = [] | |
| if self._should_log(): | |
| logger.info(f'Begin to iter, {len(buffer_list)=}') | |
| worker_id = 0 if get_worker_info() is None else get_worker_info().id | |
| num_workers = 1 if get_worker_info() is None else get_worker_info().num_workers | |
| worker_id = num_workers * self.data_rank + worker_id | |
| num_workers = num_workers * self.data_world_size | |
| rng = np.random.default_rng(seed=worker_id) | |
| # reset states of each dataset | |
| self.worker_id = worker_id | |
| self.worker_state_key = f'work_state_{self.worker_id}' | |
| self.datasets = [d for d in self.datasets_orig] | |
| self.dataset_weight = [w for w in self.dataset_weight_orig] | |
| self.dataset_iter_list = [iter(d) for d in self.datasets] | |
| for ds in self.datasets: | |
| # if not isinstance(ds, (ImageTextPairDataset, InterleavedDataset)): | |
| ds.worker_id = worker_id | |
| ds.worker_state_key = f'work_state_{self.worker_id}' | |
| ds.num_workers = num_workers | |
| if self._should_log() and worker_id == 0: | |
| logger.info(f'set worker_id and num_workers of {ds.__class__.__name__} {ds.ds_name}') | |
| if self.worker_custom_infos is not None and self.worker_state_key in self.worker_custom_infos: | |
| custom_infos = self.worker_custom_infos[self.worker_state_key] | |
| # buffer list | |
| if 'buffer_list' in custom_infos and isinstance(custom_infos['buffer_list'], list): | |
| buffer_list = custom_infos['buffer_list'] | |
| if self._should_log() and worker_id == 0: | |
| logger.info(f'[{self.worker_state_key}] load buffer list --> {len(buffer_list)=}') | |
| # other infos | |
| # reset | |
| self.worker_custom_infos = None | |
| logger.debug( | |
| f'{self.__class__.__name__} Rank {self.data_rank} ' | |
| f'Worker {worker_id} begin to load data' | |
| ) | |
| while True: | |
| self.dataset_weight = [w / sum(self.dataset_weight) for w in self.dataset_weight] | |
| current_dataset_idx = rng.choice(len(self.dataset_iter_list), p=self.dataset_weight) | |
| try: | |
| current_sample = self.next_data(current_dataset_idx) | |
| except: | |
| logger.info(f'All datasets are exhausted, begin to empty the buffer_list ({len(buffer_list)=})') | |
| while len(buffer_list) > 0: | |
| if self.strict_mode: | |
| yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0))) | |
| else: | |
| yield self.postprocess_buffer(buffer_list.pop(0)) | |
| logger.info(f'buffer_list is empty! ({len(buffer_list)=})') | |
| return | |
| buffer = self.find_buffer(buffer_list, current_sample) | |
| buffer = self.update_buffer(buffer, current_sample) | |
| buffer_list, buffer_max_len_list = self.update_buffer_list(buffer_list, buffer_max_len_list, buffer) | |
| while len(buffer_max_len_list) > 0: | |
| if buffer_max_len_list[0]['pixel_values'].size(0) != self.max_packed_tokens: | |
| logger.debug( | |
| f'num tokens of a buffer exceed {self.max_packed_tokens=}, ' | |
| f"yield a sample with {buffer_max_len_list[0]['pixel_values'].size(0)} images" | |
| ) | |
| if self.strict_mode and buffer_max_len_list[0]['pixel_values'].size(0) != self.num_images_expected: | |
| # buffer_max_len_list.pop(0) | |
| yield self.postprocess_buffer(self.pad_buffer(buffer_max_len_list.pop(0)), {'buffer_list': buffer_list}) | |
| else: | |
| yield self.postprocess_buffer(buffer_max_len_list.pop(0), {'buffer_list': buffer_list}) | |
| while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) > self.num_images_expected: | |
| logger.error( | |
| f"num images of a buffer ({buffer_list[0]['pixel_values'].size(0)}) " | |
| f'is larger than num_images_expected({self.num_images_expected})' | |
| ) | |
| buffer_list.pop(0) | |
| while len(buffer_list) > 0 and buffer_list[0]['pixel_values'].size(0) == self.num_images_expected: | |
| if self.debug_mode: | |
| debug_data = self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) | |
| while True: | |
| yield debug_data.copy() | |
| yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) | |
| while len(buffer_list) > self.max_buffer_size: | |
| logger.debug( | |
| f'Failed to pack data to exactly {self.num_images_expected} images, ' | |
| f"yield a data sample with {buffer_list[0]['pixel_values'].size(0)} images." | |
| ) | |
| if self.strict_mode: | |
| yield self.postprocess_buffer(self.pad_buffer(buffer_list.pop(0)), {'buffer_list': buffer_list}) | |
| else: | |
| yield self.postprocess_buffer(buffer_list.pop(0), {'buffer_list': buffer_list}) | |
| self.print_log(iter_idx=iter_idx, buffer_list=buffer_list) | |
| iter_idx += 1 | |
| def get_cu_seqlens_and_indexes( | |
| data_index: torch.LongTensor, # (seq_len,) | |
| input_ids: torch.LongTensor, # (seq_len,) | |
| labels: torch.LongTensor, # (seq_len,) | |
| len2weight: callable, | |
| ): | |
| indexes = [] | |
| cu_seqlens = [0] | |
| loss_weight = [] | |
| start = data_index.min() | |
| end = data_index.max() + 1 | |
| for i in range(start, end): | |
| num_tokens = (data_index == i).sum().item() | |
| indexes.extend(list(range(num_tokens))) | |
| cu_seqlens.append(cu_seqlens[-1] + num_tokens) | |
| assert num_tokens > 0 | |
| curr_data_index = data_index[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] | |
| assert (curr_data_index == i).all(), data_index | |
| curr_labels = labels[cu_seqlens[-2]:cu_seqlens[-2]+num_tokens] | |
| num_effective_tokens = (curr_labels != IGNORE_TOKEN_ID).sum().item() | |
| loss_weight.extend([len2weight(num_effective_tokens)] * num_tokens) | |
| assert len(indexes) == data_index.size(0), f'{len(indexes)=}, {data_index.size(0)=}' | |
| loss_weight = torch.tensor(loss_weight, dtype=torch.float32) | |
| return cu_seqlens, indexes, loss_weight | |
| WARNING_CNT = defaultdict(int) | |
| def packed_collate_fn( | |
| features, | |
| data_collator, | |
| len2weight: callable, | |
| max_item_length: int, | |
| micro_num: int = 1, | |
| loss_reduction_all_gather: bool = False, | |
| pad_id: int = 0, | |
| ): | |
| if not isinstance(features, list): | |
| features = [features] | |
| if len(features) > micro_num: | |
| raise NotImplementedError(f'{len(features)=} > {micro_num=}') | |
| if len(features) < micro_num and WARNING_CNT['micro_num_warning'] < 5: | |
| logger.warning( | |
| f'{len(features)=} > {micro_num=}, ' | |
| f'the features will be padded to satisfy micro_num requirement' | |
| ) | |
| WARNING_CNT['micro_num_warning'] += 1 | |
| # ensure that the len(features) is equal to the required micro_num | |
| num_features = len(features) | |
| while len(features) < micro_num: | |
| features.append(copy.deepcopy(features[0])) | |
| features[-1]['labels'] = torch.full_like(features[-1]['labels'], IGNORE_TOKEN_ID) | |
| indexes = [] | |
| cu_seqlens = [] | |
| cu_num_images_list = [0] | |
| worker_state_key_list = [] | |
| worker_state_dict_list = [] | |
| worker_state_custom_infos_list = [] | |
| batch_lens = [feat['input_ids'].shape for feat in features] | |
| max_item_length = max_item_length or max(batch_lens)[0] | |
| num_samples = 0 | |
| num_padding_tokens = 0 | |
| for feat_idx, feat in enumerate(features): | |
| data_index = feat.pop('data_index') | |
| curr_cu_seqlens, curr_indexes, curr_loss_weight = PackedDataset.get_cu_seqlens_and_indexes( | |
| data_index=data_index, | |
| input_ids=feat['input_ids'], | |
| labels=feat['labels'], | |
| len2weight=len2weight, | |
| ) | |
| feat['loss_weight'] = curr_loss_weight | |
| if feat_idx < num_features: | |
| num_samples += len(curr_cu_seqlens) - 1 | |
| if curr_cu_seqlens[-1] < max_item_length: | |
| curr_cu_seqlens.append(max_item_length) | |
| curr_indexes.extend(list(range(max_item_length - curr_cu_seqlens[-2]))) | |
| indexes.append(torch.tensor(curr_indexes, dtype=torch.long)) | |
| cu_seqlens.append(torch.tensor(curr_cu_seqlens, dtype=torch.int32)) | |
| worker_state_key_list.append(feat.pop('worker_state_key')) | |
| worker_state_dict_list.append(feat.pop('worker_state_dict')) | |
| worker_state_custom_infos_list.append(feat.pop('custom_infos', None)) | |
| num_padding_tokens += (max_item_length - feat['input_ids'].size(0)) | |
| cu_num_images_list.append(cu_num_images_list[-1] + feat['pixel_values'].size(0)) | |
| batch = data_collator(features=features, max_item_length=max_item_length, pad_id=pad_id) | |
| # convert it to list in case it is converted into bf16 | |
| batch['loss_weight'] = torch.where(batch['labels'] == IGNORE_TOKEN_ID, 0, batch['loss_weight']).tolist() | |
| batch['attention_mask'] = torch.stack(cu_seqlens) | |
| batch['loss_reduction_all_gather'] = loss_reduction_all_gather | |
| batch['statistics'] = torch.tensor( | |
| [ | |
| num_samples, | |
| num_padding_tokens, | |
| batch['image_flags'].numel() - batch['image_flags'].sum().item(), | |
| ], | |
| dtype=torch.long, | |
| ) | |
| batch.pop('type_ids') | |
| return batch | |