zhiyucheng's picture
add files
9ed36d5 unverified
from typing import List, Optional, Union, Any, Dict
from PIL import Image
import torch
from transformers.image_processing_base import BatchFeature
from transformers.image_processing_utils_fast import BaseImageProcessorFast, divide_to_patches
from transformers.image_utils import (make_list_of_images, get_image_size,
get_image_type, ImageInput, ImageType, ChannelDimension)
from transformers.utils import TensorType
import torchvision.transforms as T
def get_internvl_target_ratios(
min_num: int,
max_num: int,
) -> list[tuple[int, int]]:
target_ratios = {(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1) if min_num <= i * j <= max_num}
return sorted(target_ratios, key=lambda x: x[0] * x[1])
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_factor = float('-inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
factor_based_on_area_n_ratio = min(
(ratio[0]*ratio[1]*image_size*image_size)/ area, 0.6
)* min(
target_aspect_ratio/aspect_ratio, aspect_ratio/target_aspect_ratio)
if factor_based_on_area_n_ratio > best_factor:
best_factor = factor_based_on_area_n_ratio
best_ratio = ratio
return best_ratio
def calculate_targets(
orig_width: int,
orig_height: int,
target_ratios: list[tuple[int, int]],
image_size: int,
) -> tuple[int, int, int]:
aspect_ratio = orig_width / orig_height
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio,
target_ratios,
width=orig_width,
height=orig_height,
image_size=image_size,
)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
return blocks, target_width, target_height
def dynamic_preprocess(image, image_size=512, max_num_tiles=12, use_thumbnail=True):
orig_height, orig_width = get_image_size(image, channel_dim=ChannelDimension.FIRST)
target_ratios = get_internvl_target_ratios(1, max_num_tiles)
blocks, target_width, target_height = calculate_targets(
orig_width,
orig_height,
target_ratios,
image_size
)
# resize the image
resized_img = T.Resize((target_width, target_height), interpolation=T.InterpolationMode.BICUBIC)(image)
patches = divide_to_patches(resized_img, image_size)
assert len(patches) == blocks
if use_thumbnail and len(patches) != 1:
thumbnail_img = T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC)(image)
patches.append(thumbnail_img)
return patches