Mountchicken commited on
Commit
60f587b
·
verified ·
1 Parent(s): a8932c9

Update rex_omni/wrapper.py

Browse files
Files changed (1) hide show
  1. rex_omni/wrapper.py +4 -175
rex_omni/wrapper.py CHANGED
@@ -4,194 +4,24 @@
4
  """
5
  Main wrapper class for Rex Omni
6
  """
7
- import spaces
8
- import base64
9
  import json
10
- import math
11
  import time
12
- from io import BytesIO
13
  from typing import Any, Dict, List, Optional, Tuple, Union
14
 
15
- import requests
16
  import torch
17
  from PIL import Image
 
18
 
19
  from .parser import convert_boxes_to_normalized_bins, parse_prediction
20
  from .tasks import TASK_CONFIGS, TaskType, get_keypoint_config, get_task_config
21
 
22
- IMAGE_FACTOR = 28
23
- MIN_PIXELS = 4 * 28 * 28
24
- MAX_PIXELS = 16384 * 28 * 28
25
- MAX_RATIO = 200
26
-
27
- VIDEO_MIN_PIXELS = 128 * 28 * 28
28
- VIDEO_MAX_PIXELS = 768 * 28 * 28
29
- FRAME_FACTOR = 2
30
- FPS = 2.0
31
- FPS_MIN_FRAMES = 4
32
- FPS_MAX_FRAMES = 768
33
-
34
-
35
- def round_by_factor(number: int, factor: int) -> int:
36
- """Returns the closest integer to 'number' that is divisible by 'factor'."""
37
- return round(number / factor) * factor
38
-
39
-
40
- def ceil_by_factor(number: int, factor: int) -> int:
41
- """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
42
- return math.ceil(number / factor) * factor
43
-
44
-
45
- def floor_by_factor(number: int, factor: int) -> int:
46
- """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
47
- return math.floor(number / factor) * factor
48
-
49
-
50
- def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[dict]:
51
- vision_infos = []
52
- if isinstance(conversations[0], dict):
53
- conversations = [conversations]
54
- for conversation in conversations:
55
- for message in conversation:
56
- if isinstance(message["content"], list):
57
- for ele in message["content"]:
58
- if (
59
- "image" in ele
60
- or "image_url" in ele
61
- or "video" in ele
62
- or ele["type"] in ("image", "image_url", "video")
63
- ):
64
- vision_infos.append(ele)
65
- return vision_infos
66
-
67
-
68
- def to_rgb(pil_image: Image.Image) -> Image.Image:
69
- if pil_image.mode == "RGBA":
70
- white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
71
- white_background.paste(
72
- pil_image, mask=pil_image.split()[3]
73
- ) # Use alpha channel as mask
74
- return white_background
75
- else:
76
- return pil_image.convert("RGB")
77
-
78
-
79
- def fetch_image(
80
- ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR
81
- ) -> Image.Image:
82
- if "image" in ele:
83
- image = ele["image"]
84
- else:
85
- image = ele["image_url"]
86
- image_obj = None
87
- if isinstance(image, Image.Image):
88
- image_obj = image
89
- elif image.startswith("http://") or image.startswith("https://"):
90
- response = requests.get(image, stream=True)
91
- image_obj = Image.open(BytesIO(response.content))
92
- elif image.startswith("file://"):
93
- image_obj = Image.open(image[7:])
94
- elif image.startswith("data:image"):
95
- if "base64," in image:
96
- _, base64_data = image.split("base64,", 1)
97
- data = base64.b64decode(base64_data)
98
- image_obj = Image.open(BytesIO(data))
99
- else:
100
- image_obj = Image.open(image)
101
- if image_obj is None:
102
- raise ValueError(
103
- f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}"
104
- )
105
- image = to_rgb(image_obj)
106
- ## resize
107
- if "resized_height" in ele and "resized_width" in ele:
108
- resized_height, resized_width = smart_resize(
109
- ele["resized_height"],
110
- ele["resized_width"],
111
- factor=size_factor,
112
- )
113
- else:
114
- width, height = image.size
115
- min_pixels = ele.get("min_pixels", MIN_PIXELS)
116
- max_pixels = ele.get("max_pixels", MAX_PIXELS)
117
- resized_height, resized_width = smart_resize(
118
- height,
119
- width,
120
- factor=size_factor,
121
- min_pixels=min_pixels,
122
- max_pixels=max_pixels,
123
- )
124
- image = image.resize((resized_width, resized_height))
125
-
126
- return image
127
-
128
-
129
- def process_vision_info(
130
- conversations: list[dict] | list[list[dict]],
131
- return_video_kwargs: bool = False,
132
- ) -> tuple[
133
- list[Image.Image] | None,
134
- list[torch.Tensor | list[Image.Image]] | None,
135
- Optional[dict],
136
- ]:
137
-
138
- vision_infos = extract_vision_info(conversations)
139
- ## Read images or videos
140
- image_inputs = []
141
- video_inputs = []
142
- video_sample_fps_list = []
143
- for vision_info in vision_infos:
144
- if "image" in vision_info or "image_url" in vision_info:
145
- image_inputs.append(fetch_image(vision_info))
146
- else:
147
- raise ValueError("image, image_url or video should in content.")
148
- if len(image_inputs) == 0:
149
- image_inputs = None
150
- if len(video_inputs) == 0:
151
- video_inputs = None
152
- if return_video_kwargs:
153
- return image_inputs, video_inputs, {"fps": video_sample_fps_list}
154
- return image_inputs, video_inputs
155
-
156
-
157
- def smart_resize(
158
- height: int,
159
- width: int,
160
- factor: int = IMAGE_FACTOR,
161
- min_pixels: int = MIN_PIXELS,
162
- max_pixels: int = MAX_PIXELS,
163
- ) -> tuple[int, int]:
164
- """
165
- Rescales the image so that the following conditions are met:
166
-
167
- 1. Both dimensions (height and width) are divisible by 'factor'.
168
-
169
- 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
170
-
171
- 3. The aspect ratio of the image is maintained as closely as possible.
172
- """
173
- if max(height, width) / min(height, width) > MAX_RATIO:
174
- raise ValueError(
175
- f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
176
- )
177
- h_bar = max(factor, round_by_factor(height, factor))
178
- w_bar = max(factor, round_by_factor(width, factor))
179
- if h_bar * w_bar > max_pixels:
180
- beta = math.sqrt((height * width) / max_pixels)
181
- h_bar = floor_by_factor(height / beta, factor)
182
- w_bar = floor_by_factor(width / beta, factor)
183
- elif h_bar * w_bar < min_pixels:
184
- beta = math.sqrt(min_pixels / (height * width))
185
- h_bar = ceil_by_factor(height * beta, factor)
186
- w_bar = ceil_by_factor(width * beta, factor)
187
- return h_bar, w_bar
188
-
189
 
190
  class RexOmniWrapper:
191
  """
192
  High-level wrapper for Rex-Omni
193
  """
194
- @spaces.GPU
195
  def __init__(
196
  self,
197
  model_path: str,
@@ -304,8 +134,7 @@ class RexOmniWrapper:
304
 
305
  elif self.backend == "transformers":
306
  import torch
307
- from transformers import (AutoProcessor,
308
- Qwen2_5_VLForConditionalGeneration)
309
 
310
  # Initialize transformers model
311
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 
4
  """
5
  Main wrapper class for Rex Omni
6
  """
7
+
 
8
  import json
 
9
  import time
 
10
  from typing import Any, Dict, List, Optional, Tuple, Union
11
 
 
12
  import torch
13
  from PIL import Image
14
+ from qwen_vl_utils import process_vision_info, smart_resize
15
 
16
  from .parser import convert_boxes_to_normalized_bins, parse_prediction
17
  from .tasks import TASK_CONFIGS, TaskType, get_keypoint_config, get_task_config
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class RexOmniWrapper:
21
  """
22
  High-level wrapper for Rex-Omni
23
  """
24
+
25
  def __init__(
26
  self,
27
  model_path: str,
 
134
 
135
  elif self.backend == "transformers":
136
  import torch
137
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
 
138
 
139
  # Initialize transformers model
140
  self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(