yeliudev's picture
Add files
f880dff
# Copyright (c) 2025 Ye Liu. Licensed under the BSD-3-Clause License.
import base64
import copy
import math
import os
import warnings
from io import BytesIO
from typing import Optional
import cv2
import decord
import nncore
import numpy as np
import requests
import torch
import torchvision.transforms.functional as T
from PIL import Image
from pycocotools.mask import decode, frPyObjects, merge
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from unipixel.constants import IGNORE_INDEX
from unipixel.conversation import get_conv
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
VIDEO_MIN_PIXELS = 128 * 28 * 28
VIDEO_MAX_PIXELS = 768 * 28 * 28
FRAME_FACTOR = 2
FPS = 2.0
FPS_MIN_FRAMES = 4
FPS_MAX_FRAMES = 768
# Set the maximum number of video token inputs.
# Here, 128K represents the maximum number of input tokens for the VLLM model.
# Remember to adjust it according to your own configuration.
VIDEO_TOTAL_PIXELS = int(float(os.environ.get('VIDEO_MAX_PIXELS', 128000 * 28 * 28 * 0.9)))
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}")
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
# change order here to ensure not exceeding max_pixels
if h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
return h_bar, w_bar
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
if "image" in ele:
image = ele["image"]
else:
image = ele["image_url"]
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
# fix memory leak issue while using BytesIO
with requests.get(image, stream=True) as response:
response.raise_for_status()
with BytesIO(response.content) as bio:
image_obj = copy.deepcopy(Image.open(bio))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
# fix memory leak issue while using BytesIO
with BytesIO(data) as bio:
image_obj = copy.deepcopy(Image.open(bio))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = to_rgb(image_obj)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=size_factor,
)
else:
width, height = image.size
min_pixels = ele.get("min_pixels", MIN_PIXELS)
max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
def smart_nframes(
ele: dict,
total_frames: int,
video_fps: int | float,
) -> int:
"""calculate the number of frames for video used for model inputs.
Args:
ele (dict): a dict contains the configuration of video.
support either `fps` or `nframes`:
- nframes: the number of frames to extract for model inputs.
- fps: the fps to extract frames for model inputs.
- min_frames: the minimum number of frames of the video, only used when fps is provided.
- max_frames: the maximum number of frames of the video, only used when fps is provided.
total_frames (int): the original total number of frames of the video.
video_fps (int | float): the original fps of the video.
Raises:
ValueError: nframes should in interval [FRAME_FACTOR, total_frames].
Returns:
int: the number of frames for video used for model inputs.
"""
assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
if "nframes" in ele:
nframes = round_by_factor(ele["nframes"], FRAME_FACTOR)
else:
fps = ele.get("fps", FPS)
min_frames = ceil_by_factor(ele.get("min_frames", FPS_MIN_FRAMES), FRAME_FACTOR)
max_frames = floor_by_factor(ele.get("max_frames", min(FPS_MAX_FRAMES, total_frames)), FRAME_FACTOR)
nframes = total_frames / video_fps * fps
nframes = min(min(max(nframes, min_frames), max_frames), total_frames)
nframes = floor_by_factor(nframes, FRAME_FACTOR)
if not (FRAME_FACTOR <= nframes and nframes <= total_frames):
raise ValueError(f"nframes should in interval [{FRAME_FACTOR}, {total_frames}], but got {nframes}.")
return nframes
def calculate_video_frame_range(
ele: dict,
total_frames: int,
video_fps: float,
) -> tuple[int, int, int]:
"""
Calculate the start and end frame indices based on the given time range.
Args:
ele (dict): A dictionary containing optional 'video_start' and 'video_end' keys (in seconds).
total_frames (int): Total number of frames in the video.
video_fps (float): Frames per second of the video.
Returns:
tuple: A tuple containing (start_frame, end_frame, frame_count).
Raises:
ValueError: If input parameters are invalid or the time range is inconsistent.
"""
# Validate essential parameters
if video_fps <= 0:
raise ValueError("video_fps must be a positive number")
if total_frames <= 0:
raise ValueError("total_frames must be a positive integer")
# Get start and end time in seconds
video_start = ele.get("video_start", None)
video_end = ele.get("video_end", None)
if video_start is None and video_end is None:
return 0, total_frames - 1, total_frames
max_duration = total_frames / video_fps
# Process start frame
if video_start is not None:
video_start_clamped = max(0.0, min(video_start, max_duration))
start_frame = math.ceil(video_start_clamped * video_fps)
else:
start_frame = 0
# Process end frame
if video_end is not None:
video_end_clamped = max(0.0, min(video_end, max_duration))
end_frame = math.floor(video_end_clamped * video_fps)
end_frame = min(end_frame, total_frames - 1)
else:
end_frame = total_frames - 1
# Validate frame order
if start_frame >= end_frame:
raise ValueError(
f"Invalid time range: Start frame {start_frame} (at {video_start_clamped if video_start is not None else 0}s) "
f"exceeds end frame {end_frame} (at {video_end_clamped if video_end is not None else max_duration}s). "
f"Video duration: {max_duration:.2f}s ({total_frames} frames @ {video_fps}fps)")
return start_frame, end_frame, end_frame - start_frame + 1
def _read_video_decord(ele: dict, ) -> (torch.Tensor, float):
"""read video using decord.VideoReader
Args:
ele (dict): a dict contains the configuration of video.
support keys:
- video: the path of video. support "file://", "http://", "https://" and local path.
- video_start: the start time of video.
- video_end: the end time of video.
Returns:
torch.Tensor: the video tensor with shape (T, C, H, W).
"""
decord.bridge.set_bridge("torch")
video_path = ele["video"]
vr = decord.VideoReader(video_path, num_threads=ele.get('num_threads', 0))
total_frames, video_fps = len(vr), vr.get_avg_fps()
start_frame, end_frame, total_frames = calculate_video_frame_range(
ele,
total_frames,
video_fps,
)
nframes = smart_nframes(ele, total_frames=total_frames, video_fps=video_fps)
idx = torch.linspace(start_frame, end_frame, nframes).round().long().tolist()
video = vr.get_batch(idx).permute(0, 3, 1, 2) # Convert to TCHW format
sample_fps = nframes / max(total_frames, 1e-6) * video_fps
return video, sample_fps
def fetch_video(ele: dict,
image_factor: int = IMAGE_FACTOR,
return_video_sample_fps: bool = False,
sanity_check=False) -> torch.Tensor | list[Image.Image]:
if isinstance(ele["video"], str):
video, sample_fps = _read_video_decord(ele)
nframes, _, height, width = video.shape
min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)
total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)
max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * FRAME_FACTOR), int(min_pixels * 1.05))
max_pixels_supposed = ele.get("max_pixels", max_pixels)
max_pixels = min(max_pixels_supposed, max_pixels)
if "resized_height" in ele and "resized_width" in ele:
resized_height, resized_width = smart_resize(
ele["resized_height"],
ele["resized_width"],
factor=image_factor,
)
else:
resized_height, resized_width = smart_resize(
height,
width,
factor=image_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
video = transforms.functional.resize(
video,
[resized_height, resized_width],
interpolation=InterpolationMode.BICUBIC,
antialias=True,
).float()
if sanity_check and (video == 0).all():
raise ValueError("video '{}' contains all zeros".format(ele["video"]))
if return_video_sample_fps:
return video, sample_fps
return video
else:
assert isinstance(ele["video"], (list, tuple))
process_info = ele.copy()
process_info.pop("type", None)
process_info.pop("video", None)
images = [
fetch_image({
"image": video_element,
**process_info
}, size_factor=image_factor) for video_element in ele["video"]
]
nframes = ceil_by_factor(len(images), FRAME_FACTOR)
if len(images) < nframes:
images.extend([images[-1]] * (nframes - len(images)))
if return_video_sample_fps:
return images, process_info.pop("fps", 2.0)
return images
def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
vision_infos = []
if isinstance(conversations[0], dict):
conversations = [conversations]
for conversation in conversations:
for message in conversation:
if isinstance(message["content"], list):
for ele in message["content"]:
if ("image" in ele or "image_url" in ele or "video" in ele
or ele.get("type", "") in ("image", "image_url", "video")):
vision_infos.append(ele)
return vision_infos
def process_vision_info(
conversations: list[dict] | list[list[dict]],
return_video_kwargs: bool = False,
sanity_check=False
) -> tuple[list[Image.Image] | None, list[torch.Tensor | list[Image.Image]] | None, Optional[dict]]:
vision_infos = extract_vision_info(conversations)
# Read images or videos
image_inputs = []
video_inputs = []
video_sample_fps_list = []
for vision_info in vision_infos:
if "image" in vision_info or "image_url" in vision_info:
image_inputs.append(fetch_image(vision_info))
elif "video" in vision_info:
video_input, video_sample_fps = fetch_video(
vision_info, return_video_sample_fps=True, sanity_check=sanity_check)
video_sample_fps_list.append(video_sample_fps)
video_inputs.append(video_input)
else:
raise ValueError("image, image_url or video should in content.")
if len(image_inputs) == 0:
image_inputs = None
if len(video_inputs) == 0:
video_inputs = None
if return_video_kwargs:
return image_inputs, video_inputs, {'fps': video_sample_fps_list}
return image_inputs, video_inputs
def resize(mask, size):
return T.resize(mask.unsqueeze(0).unsqueeze(0), size)[0, 0]
def process_masks(sample, frame_size, inds):
if sample['mask_type'] == 'image':
# case 1: list of masks or paths to masks
masks = []
for obj_oids in sample['oids']:
obj_masks = []
for i in inds:
label = sample['masks'][i]
if isinstance(label, str):
label = np.array(Image.open(label))
elif label is None:
label = np.full(frame_size, -1)
obj_masks.append(torch.from_numpy(sum([label == oid for oid in obj_oids])).float())
masks.append(obj_masks)
elif sample['mask_type'] == 'image_sep':
# case 2: list of masks or paths to masks (one object per image)
masks = []
for raw_obj_masks in sample['masks']:
obj_masks = []
for i in inds:
label = raw_obj_masks[i]
if isinstance(label, str):
label = np.array(Image.open(label))
elif label is None:
label = np.full(frame_size, -1)
obj_masks.append(torch.from_numpy(label == 255).float())
masks.append(obj_masks)
elif sample['mask_type'] == 'rle':
# case 3: list of lists of multi-region RLE masks
raw_masks = nncore.load(sample['masks']) if isinstance(sample['masks'], str) else sample['masks']
masks = []
for raw_obj_masks in raw_masks:
obj_masks = []
for i in inds:
mask = torch.zeros(frame_size)
for rle in raw_obj_masks[i]:
if isinstance(rle, list):
rles = frPyObjects(rle, sample.get('height', frame_size[0]), sample.get('width', frame_size[1]))
mask += resize(torch.from_numpy(decode(merge(rles))).float(), frame_size)
elif isinstance(rle, dict):
if isinstance(rle['counts'], list):
rle = frPyObjects(rle, *rle['size'])
mask += resize(torch.from_numpy(decode(rle)).float(), frame_size)
elif rle is None:
mask += 0
else:
raise TypeError(f'unknown rle mask: {rle}')
obj_masks.append((mask > 0).float())
masks.append(obj_masks)
elif sample['mask_type'] == 'polygon':
# case 4: list of lists of polygons
masks = []
for raw_obj_masks in sample['masks']:
obj_masks = []
for i in inds:
# step 1: sort shapes
areas = []
for shape in raw_obj_masks[i]:
tmp = np.zeros(frame_size, dtype=np.uint8)
cv2.polylines(tmp, np.array([shape['points']], dtype=np.int32), True, 1, 1)
cv2.fillPoly(tmp, np.array([shape['points']], dtype=np.int32), 1)
areas.append(tmp.sum())
shapes = [raw_obj_masks[i][j] for j in list(np.argsort(areas)[::-1].astype(np.int32))]
# step 2: draw masks
mask = np.zeros(frame_size, dtype=np.uint8)
for shape in shapes:
assert shape['label'] in ('target', 'ignore'), shape
label = 1 if shape['label'] == 'target' else -1 # replacing 255 with -1 here
cv2.polylines(mask, np.array([shape['points']], dtype=np.int32), True, label, 1)
cv2.fillPoly(mask, np.array([shape['points']], dtype=np.int32), label)
obj_masks.append(torch.from_numpy(mask).float())
masks.append(obj_masks)
elif sample['mask_type'] == 'vicas':
# case 5: special case for vicas dataset
masks = []
for obj_rle_path in sample['masks']:
obj_rles, obj_masks = nncore.load(obj_rle_path), []
for i in inds:
mask = torch.zeros(frame_size)
for rle in obj_rles[i]:
mask += 0 if rle is None else resize(torch.from_numpy(decode(rle)).float(), frame_size)
obj_masks.append((mask > 0).float())
masks.append(obj_masks)
elif sample['mask_type'] == 'sav':
# case 6: special case for sav dataset
annos = nncore.load(sample['masks'])['masklet']
masks = [[]]
for i in inds:
mask = resize(torch.from_numpy(decode(annos[i][int(sample['qid'])])).float(), frame_size)
masks[0].append(mask)
else:
raise TypeError(f"unknown mask type: {sample['mask_type']}")
return masks
def build_obj_to_frame_idx(label_mask, batch_mode):
step_t_obj_to_frame_idx = [[]] if batch_mode else [[] for _ in range(label_mask.size(0))]
# t: frame_idx v: video_idx
for t in range(len(step_t_obj_to_frame_idx)):
if batch_mode:
for v in range(label_mask.size(0)):
for _ in range(label_mask.size(1)):
step_t_obj_to_frame_idx[t].append(torch.IntTensor([t, v]))
else:
for _ in range(label_mask.size(1)):
step_t_obj_to_frame_idx[t].append(torch.IntTensor([t, 0]))
label_obj_to_frame_idx = torch.stack([torch.stack(o) for o in step_t_obj_to_frame_idx])
return label_obj_to_frame_idx
def preprocess_chatml(input_ids, text, tokenizer):
conv = get_conv('chatml')
rounds = [m + conv.seps[0] for m in text.split(conv.seps[0])]
assert (len(rounds) % 2 == 0) == (conv.system is not None)
assert rounds[-1] == conv.seps[0]
rounds = rounds[:-1]
if conv.system is None:
rounds = [''.join(rounds[i:i + 2]) for i in range(0, len(rounds), 2)]
else:
rounds = [''.join(rounds[:3])] + [''.join(rounds[i:i + 2]) for i in range(3, len(rounds), 2)]
labels = input_ids.clone()
sep = conv.seps[0] + conv.roles[1]
cur_len = 0
for i, rou in enumerate(rounds):
if len(rou) == 0:
break
ins = sep.join(rou.split(sep)[:-1]) + sep
rou_len = tokenizer(rou, return_length=True).length[0]
ins_len = tokenizer(ins, return_length=True).length[0]
labels[cur_len:cur_len + ins_len] = IGNORE_INDEX
cur_len += rou_len
if labels.size(0) != cur_len:
warnings.warn(f'Tokenization mismatch: {labels.size(0)} and {cur_len}')
return labels
def preprocess(input_ids, text, tokenizer, conv_type):
if conv_type == 'chatml':
return preprocess_chatml(input_ids, text, tokenizer)
else:
raise ValueError(f'unknown conversation type: {conv_type}')