Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # InternVL | |
| # Copyright (c) 2024 OpenGVLab | |
| # Licensed under The MIT License [see LICENSE for details] | |
| # -------------------------------------------------------- | |
| import io | |
| import matplotlib.pyplot as plt | |
| from transformers.trainer_pt_utils import LabelSmoother | |
| IGNORE_TOKEN_ID = LabelSmoother.ignore_index | |
| import os | |
| import random | |
| import re | |
| from collections import Counter | |
| from typing import Dict | |
| import cv2 | |
| import imageio | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| import transformers | |
| from decord import VideoReader | |
| from internvl.conversation import get_conv_template | |
| from PIL import Image | |
| from torch.utils.data import ConcatDataset, WeightedRandomSampler | |
| from torchvision.transforms.functional import InterpolationMode | |
| from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD, | |
| IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, | |
| SIGLIP_MEAN, SIGLIP_STD) | |
| try: | |
| from petrel_client.client import Client | |
| from petrel_client.common.config import Config | |
| except ImportError as E: | |
| print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.') | |
| import sys | |
| def calculate_ngram_repetition(text, n): | |
| words = text.split() | |
| ngrams = [tuple(words[i:i+n]) for i in range(len(words)-n+1)] | |
| ngram_counts = Counter(ngrams) | |
| total_ngrams = len(ngrams) | |
| repeated_ngrams = sum(1 for count in ngram_counts.values() if count > 1) | |
| return repeated_ngrams / total_ngrams if total_ngrams > 0 else 0 | |
| def check_conversations_repetition(conversations, repeat_threshold=0.4, ngram=10): | |
| for conversation in conversations: | |
| if conversation['from'] == 'gpt': | |
| model_answer = conversation['value'] | |
| repeat_ratio = calculate_ngram_repetition(model_answer, ngram) | |
| if repeat_ratio > repeat_threshold: | |
| raise Exception | |
| def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1): | |
| if sample in ['rand', 'middle']: # uniform sampling | |
| acc_samples = min(num_frames, vlen) | |
| # split the video into `acc_samples` intervals, and sample from each interval. | |
| intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int) | |
| ranges = [] | |
| for idx, interv in enumerate(intervals[:-1]): | |
| ranges.append((interv, intervals[idx + 1] - 1)) | |
| if sample == 'rand': | |
| try: | |
| frame_indices = [random.choice(range(x[0], x[1])) for x in ranges] | |
| except: | |
| frame_indices = np.random.permutation(vlen)[:acc_samples] | |
| frame_indices.sort() | |
| frame_indices = list(frame_indices) | |
| elif fix_start is not None: | |
| frame_indices = [x[0] + fix_start for x in ranges] | |
| elif sample == 'middle': | |
| frame_indices = [(x[0] + x[1]) // 2 for x in ranges] | |
| else: | |
| raise NotImplementedError | |
| if len(frame_indices) < num_frames: # padded with last frame | |
| padded_frame_indices = [frame_indices[-1]] * num_frames | |
| padded_frame_indices[:len(frame_indices)] = frame_indices | |
| frame_indices = padded_frame_indices | |
| elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps | |
| output_fps = float(sample[3:]) | |
| duration = float(vlen) / input_fps | |
| delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents | |
| frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) | |
| frame_indices = np.around(frame_seconds * input_fps).astype(int) | |
| frame_indices = [e for e in frame_indices if e < vlen] | |
| if max_num_frames > 0 and len(frame_indices) > max_num_frames: | |
| frame_indices = frame_indices[:max_num_frames] | |
| # frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames) | |
| else: | |
| raise ValueError | |
| return frame_indices | |
| def read_frames_gif( | |
| video_path, num_frames, sample='rand', fix_start=None, | |
| client=None, min_num_frames=4 | |
| ): | |
| if 's3://' in video_path: | |
| video_bytes = client.get(video_path) | |
| gif = imageio.get_reader(io.BytesIO(video_bytes)) | |
| else: | |
| gif = imageio.get_reader(video_path) | |
| vlen = len(gif) | |
| t_num_frames = np.random.randint(min_num_frames, num_frames + 1) | |
| frame_indices = get_frame_indices( | |
| t_num_frames, vlen, sample=sample, fix_start=fix_start | |
| ) | |
| frames = [] | |
| for index, frame in enumerate(gif): | |
| if index in frame_indices: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8) | |
| frame = Image.fromarray(frame) | |
| frames.append(frame) | |
| return frames | |
| def read_frames_decord( | |
| video_path, num_frames, sample='rand', fix_start=None, | |
| client=None, clip=None, min_num_frames=4 | |
| ): | |
| if 's3://' in video_path: | |
| video_bytes = client.get(video_path) | |
| video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1) | |
| else: | |
| video_reader = VideoReader(video_path, num_threads=1) | |
| vlen = len(video_reader) | |
| fps = video_reader.get_avg_fps() | |
| duration = vlen / float(fps) | |
| if clip: | |
| start, end = clip | |
| duration = end - start | |
| vlen = int(duration * fps) | |
| start_index = int(start * fps) | |
| # t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames) | |
| t_num_frames = np.random.randint(min_num_frames, num_frames + 1) | |
| frame_indices = get_frame_indices( | |
| t_num_frames, vlen, sample=sample, fix_start=fix_start, | |
| input_fps=fps | |
| ) | |
| if clip: | |
| frame_indices = [f + start_index for f in frame_indices] | |
| frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8 | |
| frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])] | |
| return frames | |
| def extract_frame_number(filename): | |
| # Extract the numeric part from the filename using regular expressions | |
| match = re.search(r'_(\d+).jpg$', filename) | |
| return int(match.group(1)) if match else -1 | |
| def sort_frames(frame_paths): | |
| # Extract filenames from each path and sort by their numeric part | |
| return sorted(frame_paths, key=lambda x: extract_frame_number(os.path.basename(x))) | |
| def read_frames_folder( | |
| video_path, num_frames, sample='rand', fix_start=None, | |
| client=None, clip=None, min_num_frames=4 | |
| ): | |
| if 's3://' in video_path: | |
| image_list = sort_frames(client.list(video_path)) | |
| frames = [] | |
| for image in image_list: | |
| fp = os.path.join(video_path, image) | |
| frame = Image.open(io.BytesIO(client.get(fp))) | |
| frames.append(frame) | |
| else: | |
| image_list = sort_frames(list(os.listdir(video_path))) | |
| frames = [] | |
| for image in image_list: | |
| fp = os.path.join(video_path, image) | |
| frame = Image.open(fp).convert('RGB') | |
| frames.append(frame) | |
| vlen = len(frames) | |
| t_num_frames = np.random.randint(min_num_frames, num_frames + 1) | |
| if vlen > t_num_frames: | |
| frame_indices = get_frame_indices( | |
| t_num_frames, vlen, sample=sample, fix_start=fix_start | |
| ) | |
| frames = [frames[i] for i in frame_indices] | |
| return frames | |
| class WeightedConcatDataset(ConcatDataset): | |
| def __init__(self, datasets, weights): | |
| super().__init__(datasets) | |
| self.weights = torch.DoubleTensor(weights) | |
| self.total_size = sum(len(d) for d in datasets) | |
| self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) | |
| def __iter__(self): | |
| return iter(self.sampler) | |
| def __len__(self): | |
| return self.total_size | |
| def pil_loader(img_str): | |
| buff = io.BytesIO(img_str) | |
| img = Image.open(buff) | |
| return img.convert('RGB') | |
| class TCSLoader(object): | |
| def __init__(self, conf_path, sc_config_key='sensecore'): | |
| print(f'[TCSLoader] config_path: {conf_path}') | |
| print('--> before Client(conf_path)') | |
| self.client = Client(conf_path) | |
| self.sc_config_key = sc_config_key | |
| print('--> after Client(conf_path)') | |
| def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=8, sample='rand', clip=None): | |
| if image_type == 'image': | |
| img_value_str = self.client.get(fn) | |
| img = pil_loader(img_value_str) | |
| return img | |
| elif image_type == 'video': | |
| if fn.endswith('/'): | |
| frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, | |
| client=self.client, sample=sample) | |
| elif fn.endswith('.gif'): | |
| frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, | |
| client=self.client, sample=sample) | |
| else: | |
| frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames, | |
| client=self.client, sample=sample, clip=clip) | |
| return frames | |
| def expand2square(pil_img, background_color): | |
| width, height = pil_img.size | |
| if width == height: | |
| return pil_img | |
| elif width > height: | |
| result = Image.new(pil_img.mode, (width, width), background_color) | |
| result.paste(pil_img, (0, (width - height) // 2)) | |
| return result | |
| else: | |
| result = Image.new(pil_img.mode, (height, height), background_color) | |
| result.paste(pil_img, ((height - width) // 2, 0)) | |
| return result | |
| def simulate_jpeg_degradation(quality): | |
| def jpeg_degrade(img): | |
| with io.BytesIO() as output: | |
| img.convert('RGB').save(output, format='JPEG', quality=quality) | |
| output.seek(0) # Move the reading cursor to the start of the stream | |
| img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory | |
| return img_jpeg | |
| return jpeg_degrade | |
| # Define the JPEG compression quality range, pre-create all JPEG compression functions | |
| qualities = list(range(75, 101)) | |
| jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} | |
| def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): | |
| if normalize_type == 'imagenet': | |
| MEAN, STD = IMAGENET_MEAN, IMAGENET_STD | |
| elif normalize_type == 'clip': | |
| MEAN, STD = CLIP_MEAN, CLIP_STD | |
| elif normalize_type == 'siglip': | |
| MEAN, STD = SIGLIP_MEAN, SIGLIP_STD | |
| else: | |
| raise NotImplementedError | |
| if is_train: # use data augumentation | |
| transform = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
| T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), | |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=MEAN, std=STD) | |
| ]) | |
| else: | |
| if pad2square is False: # now we use this transform function by default | |
| transform = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=MEAN, std=STD) | |
| ]) | |
| else: | |
| transform = T.Compose([ | |
| T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), | |
| T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), | |
| T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), | |
| T.ToTensor(), | |
| T.Normalize(mean=MEAN, std=STD) | |
| ]) | |
| return transform | |
| def preprocess( | |
| template_name, | |
| sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| num_image_token_list: list, | |
| text_only: bool = False, | |
| group_by_length: bool = False, | |
| use_packed_ds: bool = False, | |
| ds_name: str = None, | |
| num_image: int = 1 | |
| ) -> Dict: | |
| conv = get_conv_template(template_name) | |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]['from']] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence['from']] | |
| assert role == conv.roles[j % 2], f'{i}' | |
| conv.append_message(role, sentence['value']) | |
| conversations.append(conv.get_prompt()) | |
| if not text_only: | |
| new_conversations = [] | |
| for conversation in conversations: | |
| for i in range(num_image): | |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
| conversation = conversation.replace('<image>', image_tokens, 1) | |
| new_conversations.append(conversation) | |
| conversations = new_conversations | |
| # Tokenize conversations | |
| input_ids = tokenizer( | |
| conversations, | |
| return_tensors='pt', | |
| padding=False if group_by_length or use_packed_ds else 'max_length', | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ).input_ids | |
| targets = input_ids.clone() | |
| # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO | |
| # Mask targets. Only compute loss on the assistant outputs. | |
| sep = conv.sep + conv.roles[1] + ': ' | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
| turns = conversation.split(conv.sep2) | |
| cur_len = 1 | |
| target[:cur_len] = IGNORE_TOKEN_ID | |
| for i, turn in enumerate(turns): | |
| if turn == '': | |
| break | |
| turn_len = len(tokenizer(turn).input_ids) | |
| parts = turn.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| # "-2" is hardcoded for the Llama tokenizer to make the offset correct. | |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
| if i != 0 and not tokenizer.legacy: | |
| # The legacy and non-legacy modes handle special tokens differently | |
| instruction_len -= 1 | |
| # Ignore the user instructions | |
| target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID | |
| cur_len += turn_len | |
| if i != 0 and not tokenizer.legacy: | |
| # The legacy and non-legacy modes handle special tokens differently | |
| cur_len -= 1 | |
| target[cur_len:] = IGNORE_TOKEN_ID | |
| if False: # Inspect and check the correctness of masking | |
| z = target.clone() | |
| z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) | |
| logger.info(tokenizer.decode(z)) | |
| exit() | |
| if cur_len < tokenizer.model_max_length: | |
| if cur_len != total_len: | |
| target[:] = IGNORE_TOKEN_ID | |
| print( | |
| f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' | |
| f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' | |
| ) | |
| sys.stdout.flush() | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
| ) | |
| def preprocess_mpt( | |
| template_name, | |
| sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| num_image_token_list: list, | |
| text_only: bool = False, | |
| group_by_length: bool = False, | |
| use_packed_ds: bool = False, | |
| ds_name: str = None, | |
| num_image: int = 1 | |
| ) -> Dict: | |
| conv = get_conv_template(template_name) | |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]['from']] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence['from']] | |
| assert role == conv.roles[j % 2], f'{i}' | |
| conv.append_message(role, sentence['value']) | |
| conversations.append(conv.get_prompt()) | |
| if not text_only: | |
| new_conversations = [] | |
| for conversation in conversations: | |
| for i in range(num_image): | |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
| conversation = conversation.replace('<image>', image_tokens, 1) | |
| new_conversations.append(conversation) | |
| conversations = new_conversations | |
| # Tokenize conversations | |
| input_ids = tokenizer( | |
| conversations, | |
| return_tensors='pt', | |
| padding=False if group_by_length or use_packed_ds else 'max_length', | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ).input_ids | |
| targets = input_ids.clone() | |
| # Mask targets. Only compute loss on the assistant outputs. | |
| sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) | |
| turns = conversation.split(conv.sep) | |
| re_turns = [conv.sep.join(turns[:3])] # system + user + gpt | |
| for conv_idx in range(3, len(turns), 2): | |
| re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt | |
| cur_len = 0 | |
| target[:cur_len] = IGNORE_TOKEN_ID | |
| for i, turn in enumerate(re_turns): | |
| if turn == '': | |
| break | |
| turn_len = len(tokenizer(turn).input_ids) + 1 | |
| parts = turn.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| instruction_len = len(tokenizer(parts[0]).input_ids) | |
| # Ignore the user instructions | |
| target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID | |
| # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) | |
| # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) | |
| # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) | |
| cur_len += turn_len | |
| target[cur_len:] = IGNORE_TOKEN_ID | |
| if cur_len < tokenizer.model_max_length: | |
| if cur_len != total_len: | |
| target[:] = IGNORE_TOKEN_ID | |
| print( | |
| f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' | |
| f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' | |
| ) | |
| sys.stdout.flush() | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
| ) | |
| def preprocess_phi3( | |
| template_name, | |
| sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| num_image_token_list: list, | |
| text_only: bool = False, | |
| group_by_length: bool = False, | |
| use_packed_ds: bool = False, | |
| ds_name: str = None, | |
| num_image: int = 1 | |
| ) -> Dict: | |
| conv = get_conv_template(template_name) | |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]['from']] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence['from']] | |
| assert role == conv.roles[j % 2], f'{i}' | |
| conv.append_message(role, sentence['value']) | |
| conversations.append(conv.get_prompt()) | |
| if not text_only: | |
| new_conversations = [] | |
| for conversation in conversations: | |
| for i in range(num_image): | |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
| conversation = conversation.replace('<image>', image_tokens, 1) | |
| new_conversations.append(conversation) | |
| conversations = new_conversations | |
| # Tokenize conversations | |
| tokenizer.padding_side = 'right' | |
| input_ids = tokenizer( | |
| conversations, | |
| return_tensors='pt', | |
| padding=False if group_by_length or use_packed_ds else 'max_length', | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ).input_ids | |
| targets = input_ids.clone() | |
| # Mask targets. Only compute loss on the assistant outputs. | |
| sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|> | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) | |
| turns = conversation.split(conv.sep) | |
| re_turns = [conv.sep.join(turns[:3])] # system + user + gpt | |
| for conv_idx in range(3, len(turns), 2): | |
| re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt | |
| cur_len = 1 | |
| target[:cur_len] = IGNORE_TOKEN_ID | |
| endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') | |
| target[target == endoftext_id] = IGNORE_TOKEN_ID | |
| for i, turn in enumerate(re_turns): | |
| if turn == '': | |
| break | |
| if i == 0: | |
| turn_len = len(tokenizer(turn).input_ids) | |
| else: | |
| turn_len = len(tokenizer(turn).input_ids) - 1 | |
| parts = turn.split(sep) | |
| if len(parts) != 2: | |
| break | |
| parts[0] += sep | |
| if i == 0: | |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 1 | |
| else: | |
| instruction_len = len(tokenizer(parts[0]).input_ids) - 2 | |
| # Ignore the user instructions | |
| target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID | |
| # print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0])) | |
| # print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0])) | |
| # print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len]) | |
| cur_len += turn_len | |
| target[cur_len:] = IGNORE_TOKEN_ID | |
| if False: # Inspect and check the correctness of masking | |
| z = target.clone() | |
| z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) | |
| print(repr(tokenizer.decode(z))) | |
| if cur_len < tokenizer.model_max_length: | |
| if cur_len != total_len: | |
| target[:] = IGNORE_TOKEN_ID | |
| print( | |
| f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' | |
| f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' | |
| ) | |
| sys.stdout.flush() | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
| ) | |
| def preprocess_internlm( | |
| template_name, | |
| sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| num_image_token_list: list, | |
| text_only: bool = False, | |
| group_by_length: bool = False, | |
| use_packed_ds: bool = False, | |
| ds_name: str = None, | |
| num_image: int = 1 | |
| ) -> Dict: | |
| conv = get_conv_template(template_name) | |
| roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} | |
| # Apply prompt templates | |
| conversations = [] | |
| for i, source in enumerate(sources): | |
| if roles[source[0]['from']] != conv.roles[0]: | |
| # Skip the first one if it is not from human | |
| source = source[1:] | |
| conv.messages = [] | |
| for j, sentence in enumerate(source): | |
| role = roles[sentence['from']] | |
| assert role == conv.roles[j % 2], f'{i}' | |
| sentence['value'] = sentence['value'].strip() | |
| conv.append_message(role, sentence['value']) | |
| conversations.append(conv.get_prompt()) | |
| if not text_only: | |
| new_conversations = [] | |
| for conversation in conversations: | |
| for i in range(num_image): | |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' | |
| conversation = conversation.replace('<image>', image_tokens, 1) | |
| new_conversations.append(conversation) | |
| conversations = new_conversations | |
| # Tokenize conversations | |
| input_ids = tokenizer( | |
| conversations, | |
| return_tensors='pt', | |
| padding=False if group_by_length or use_packed_ds else 'max_length', | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ).input_ids | |
| targets = input_ids.clone() | |
| for conversation, target in zip(conversations, targets): | |
| total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id | |
| cur_len = 1 | |
| target[:cur_len] = IGNORE_TOKEN_ID # <s> | |
| parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n | |
| info = parts[0] + conv.roles[1] | |
| temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s> | |
| target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID | |
| cur_len = cur_len + temp_len | |
| for index in range(1, len(parts) - 1): | |
| info = parts[index] | |
| part1, part2 = info.split(conv.roles[0]) | |
| temp_len = len(tokenizer(part1).input_ids) - 1 | |
| cur_len = cur_len + temp_len | |
| part = conv.roles[0] + part2 + conv.roles[1] | |
| temp_len = len(tokenizer(part).input_ids) - 1 | |
| target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID | |
| cur_len = cur_len + temp_len | |
| last_info = parts[-1] | |
| temp_len = len(tokenizer(last_info).input_ids) - 1 | |
| cur_len = cur_len + temp_len | |
| target[cur_len:] = IGNORE_TOKEN_ID | |
| if False: # Inspect and check the correctness of masking | |
| z = target.clone() | |
| z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) | |
| print(repr(tokenizer.decode(z))) | |
| if cur_len < tokenizer.model_max_length: | |
| if cur_len != total_len: | |
| target[:] = IGNORE_TOKEN_ID | |
| print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') | |
| sys.stdout.flush() | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
| ) | |
| def preprocess_internvl2_5( | |
| template_name, | |
| sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| num_image_token_list: list, | |
| text_only: bool = False, | |
| group_by_length: bool = False, | |
| use_packed_ds: bool = False, | |
| ds_name: str = None, | |
| num_image: int = 1 | |
| ) -> Dict: | |
| assert len(sources) == 1, 'process only the first conversations' | |
| conversations = sources[0] | |
| if conversations[0]['from'] == 'system': | |
| system_prompt = conversations[0]['value'] | |
| conversations = conversations[1:] # remove system prompt | |
| else: | |
| conv = get_conv_template(template_name) | |
| system_prompt = conv.system_message | |
| # system_prompt = None | |
| if not text_only: | |
| new_conversations = [] | |
| current_image_idx = 0 | |
| for conversation in conversations: | |
| if conversation['from'] == 'human': | |
| image_cnt = conversation['value'].count('<image>') | |
| for i in range(image_cnt): | |
| if current_image_idx == num_image: | |
| break | |
| image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[current_image_idx]}{IMG_END_TOKEN}' | |
| conversation['value'] = conversation['value'].replace('<image>', image_tokens, 1) | |
| current_image_idx += 1 | |
| new_conversations.append(conversation) | |
| conversations = new_conversations | |
| assert current_image_idx == num_image, f'{current_image_idx} != {num_image}' | |
| batches, roles = [], [] | |
| if system_prompt is not None: | |
| batches.append(f'<|im_start|>system\n{system_prompt}<|im_end|>\n') | |
| roles.append('system') | |
| for conversation in conversations: | |
| if conversation['from'] == 'human': | |
| batches.append(f'<|im_start|>user\n{conversation["value"]}<|im_end|>\n') | |
| roles.append('human') | |
| elif conversation['from'] == 'gpt': | |
| batches.append(f'<|im_start|>assistant\n{conversation["value"]}<|im_end|>\n') | |
| roles.append('gpt') | |
| else: | |
| raise NotImplementedError | |
| add_bos_token = getattr(tokenizer, 'add_bos_token', False) | |
| if add_bos_token: # for InternLM series | |
| batches[0] = tokenizer.bos_token + batches[0] | |
| # Tokenize conversations | |
| input_ids = tokenizer( | |
| batches, | |
| return_tensors='np', | |
| padding=False, | |
| max_length=tokenizer.model_max_length, | |
| truncation=False, | |
| ).input_ids | |
| if add_bos_token: # for InternLM series | |
| input_ids = [item[1:] for item in input_ids] | |
| final_input_ids, final_targets = [], [] | |
| ignore_ids = tokenizer('<|im_start|>assistant\n', return_tensors='np').input_ids[0] | |
| ignore_len = ignore_ids.shape[0] - 1 if add_bos_token else ignore_ids.shape[0] | |
| for role, input_id in zip(roles, input_ids): | |
| final_input_ids.append(input_id) | |
| if role == 'system' or role == 'human': | |
| final_targets.append(np.full(input_id.shape, IGNORE_TOKEN_ID)) # ignore | |
| elif role == 'gpt': | |
| target = input_id.copy() | |
| target[:ignore_len] = IGNORE_TOKEN_ID # ignore loss for `<|im_start|>assistant\n` | |
| target[-1:] = IGNORE_TOKEN_ID # ignore loss for `\n` | |
| final_targets.append(target) | |
| else: | |
| raise NotImplementedError | |
| input_ids = torch.tensor(np.concatenate(final_input_ids))[:tokenizer.model_max_length] | |
| targets = torch.tensor(np.concatenate(final_targets))[:tokenizer.model_max_length] | |
| padding = False if group_by_length or use_packed_ds else True | |
| if padding: | |
| current_length = input_ids.size(0) | |
| padding_length = tokenizer.model_max_length - current_length | |
| input_ids = F.pad(input_ids, (0, padding_length), value=tokenizer.pad_token_id) | |
| targets = F.pad(targets, (0, padding_length), value=IGNORE_TOKEN_ID) | |
| input_ids = input_ids.unsqueeze(0) | |
| targets = targets.unsqueeze(0) | |
| return dict( | |
| input_ids=input_ids, | |
| labels=targets, | |
| attention_mask=input_ids.ne(tokenizer.pad_token_id), | |
| ) | |
| def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): | |
| best_ratio_diff = float('inf') | |
| best_ratio = (1, 1) | |
| area = width * height | |
| for ratio in target_ratios: | |
| target_aspect_ratio = ratio[0] / ratio[1] | |
| ratio_diff = abs(aspect_ratio - target_aspect_ratio) | |
| if ratio_diff < best_ratio_diff: | |
| best_ratio_diff = ratio_diff | |
| best_ratio = ratio | |
| elif ratio_diff == best_ratio_diff: | |
| if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: | |
| best_ratio = ratio | |
| # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}') | |
| return best_ratio | |
| def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, return_ratio=False): | |
| orig_width, orig_height = image.size | |
| aspect_ratio = orig_width / orig_height | |
| # calculate the existing image aspect ratio | |
| target_ratios = set( | |
| (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
| i * j <= max_num and i * j >= min_num) | |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| # find the closest aspect ratio to the target | |
| target_aspect_ratio = find_closest_aspect_ratio( | |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
| # calculate the target width and height | |
| target_width = image_size * target_aspect_ratio[0] | |
| target_height = image_size * target_aspect_ratio[1] | |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
| # resize the image | |
| resized_img = image.resize((target_width, target_height)) | |
| processed_images = [] | |
| for i in range(blocks): | |
| box = ( | |
| (i % (target_width // image_size)) * image_size, | |
| (i // (target_width // image_size)) * image_size, | |
| ((i % (target_width // image_size)) + 1) * image_size, | |
| ((i // (target_width // image_size)) + 1) * image_size | |
| ) | |
| # split the image | |
| split_img = resized_img.crop(box) | |
| processed_images.append(split_img) | |
| assert len(processed_images) == blocks | |
| if use_thumbnail and len(processed_images) != 1: | |
| thumbnail_img = image.resize((image_size, image_size)) | |
| processed_images.append(thumbnail_img) | |
| if return_ratio: | |
| return processed_images, target_aspect_ratio | |
| return processed_images | |
| def dynamic_preprocess_mask(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): | |
| # import pdb | |
| length, orig_height, orig_width = image.shape | |
| aspect_ratio = orig_width / orig_height | |
| # calculate the existing image aspect ratio | |
| target_ratios = set( | |
| (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if | |
| i * j <= max_num and i * j >= min_num) | |
| target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) | |
| # find the closest aspect ratio to the target | |
| target_aspect_ratio = find_closest_aspect_ratio( | |
| aspect_ratio, target_ratios, orig_width, orig_height, image_size) | |
| # print(target_aspect_ratio) | |
| # calculate the target width and height | |
| target_width = image_size * target_aspect_ratio[0] | |
| target_height = image_size * target_aspect_ratio[1] | |
| blocks = target_aspect_ratio[0] * target_aspect_ratio[1] | |
| # resize the image | |
| tensor_images = image.unsqueeze(1) # 添加一个维度作为单通道 | |
| # pdb.set_trace() | |
| resized_images = F.interpolate(tensor_images, size=(target_height, target_width), mode='bilinear', align_corners=False) #(1792,1344) | |
| resized_images = resized_images > 0 | |
| # print(resized_images.shape) | |
| # 然后像 PIL 那样裁剪图像块 | |
| processed_images = [] | |
| for i in range(blocks): | |
| top = (i // (target_width // image_size)) * image_size | |
| left = (i % (target_width // image_size)) * image_size | |
| bottom = top + image_size | |
| right = left + image_size | |
| # 使用 tensor 切片进行裁剪 | |
| split_img = resized_images[..., top:bottom, left:right] # 这里使用...来保持通道这一维度 | |
| processed_images.append(split_img) | |
| # plt.imshow(split_img.sum(0).squeeze()) | |
| # plt.savefig(f'/workdir/guantongkun/12490719/eef5a3b245897c9f4335463fb12fed35/work_dirs/{i}_mask.jpg', dpi=600) | |
| # pdb.set_trace() | |
| # 最后,如果您需要,可以对处理过的图像list进行任何后续操作 | |
| # 例如,convert回通道为最后维度的形式,如果是单通道的话 | |
| processed_images = [img.squeeze(1) for img in processed_images] | |
| assert len(processed_images) == blocks | |
| if use_thumbnail and len(processed_images) != 1: | |
| thumbnail_img = F.interpolate(tensor_images, size=(image_size, image_size), mode='bilinear', align_corners=False).squeeze(1) | |
| thumbnail_img = thumbnail_img > 0 | |
| # Image.fromarray(thumbnail_img.cpu().numpy().astype(np.uint8)) | |
| processed_images.append(thumbnail_img) | |
| return processed_images | |