Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ Processor class for Emu3. """ | |
| import re | |
| from typing import List, Optional, Sequence, Union | |
| from functools import partial | |
| from PIL import Image | |
| import torch | |
| from transformers.feature_extraction_utils import BatchFeature | |
| from transformers.image_utils import ImageInput, get_image_size, to_numpy_array | |
| from transformers.processing_utils import ProcessingKwargs, ProcessorMixin | |
| from transformers.tokenization_utils_base import TextInput, PreTokenizedInput | |
| from transformers.utils import logging | |
| from .utils_emu3 import Emu3PrefixConstrainedLogitsHelper | |
| logger = logging.get_logger(__name__) | |
| class Emu3Processor(ProcessorMixin): | |
| r""" | |
| Constructs an Emu3 processor which wraps an Emu3 image processor and an Emu3 vision vq model and an Emu3 tokenizer into a single processor. | |
| [`Emu3Processor`] offers all the functionalities of [`Emu3VisionVQModel`] and [`Emu3Tokenizer`]. See the | |
| [`~Emu3Processor.__call__`], [`~Emu3Processor.decode`], [`~Emu3Processor.vision_encode`], [`~Emu3Processor.vision_decode`] | |
| for more information. | |
| Args: | |
| image_processor ([`Emu3VisionVQImageProcessor`]): | |
| The image processor is a required input. | |
| vision_tokenizer ([`Emu3VisionVQModel`]): | |
| The vision tokenizer is a required input. | |
| tokenizer ([`Emu3Tokenizer`]): | |
| The tokenizer is a required input. | |
| prefix_template(`str`, *optional*): | |
| The prefix template for image tokens | |
| visual_template(`Tuple[str, ...]`, *optional*): | |
| The visual token template for image tokens | |
| """ | |
| attributes = ["image_processor", "tokenizer"] | |
| valid_kwargs = ["vision_tokenizer", "prefix_template", "visual_template"] | |
| image_processor_class = "AutoImageProcessor" | |
| tokenizer_class = "AutoTokenizer" | |
| def __init__( | |
| self, | |
| image_processor=None, | |
| vision_tokenizer=None, | |
| tokenizer=None, | |
| chat_template="You are a helpful assistant. USER: {image_prompt}{text_prompt}. ASSISTANT:", | |
| prefix_template="{H}*{W}", | |
| visual_template=("<|visual token {token_id:0>6d}|>", r"<\|visual token (\d+)\|>"), | |
| **kwargs, | |
| ): | |
| assert vision_tokenizer is not None, "image tokenizer can not be None" | |
| self.vision_tokenizer = vision_tokenizer | |
| self.prefix_template = prefix_template | |
| self.visual_template = visual_template | |
| super().__init__(image_processor, tokenizer, chat_template=chat_template) | |
| self.const_helper = self.build_const_helper() | |
| def __call__( | |
| self, | |
| text: Optional[TextInput | PreTokenizedInput] = None, | |
| image: Optional[Image.Image | List[Image.Image]] = None, | |
| *, | |
| mode: str = "G", | |
| ratio: str = "1:1", | |
| image_area: int = 518400, | |
| **kwargs, | |
| ) -> BatchFeature: | |
| """ | |
| Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` | |
| and `kwargs` arguments to Emu3Tokenizer's [`~Emu3Tokenizer.__call__`] to encode the text. | |
| To prepare the image(s), this method forwards the `image` argument to | |
| Emu3VisionVQImageProcessor's [`~Emu3VisionVQImageProcessor.__call__`] and Emu3VisionVQModel's [`~EmuVideoVQModel.encode`] | |
| if `image` is not `None`. Please refer to the doctsring of the above two methods for more information. | |
| Args: | |
| text (`str` or `List[str]`): | |
| The sequence or a batch of sequence to be encoded. A sequence is a string. | |
| image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*): | |
| The image or a batch of images to be prepared. An image is a PIL image. | |
| mode (`str`, *optional*, in `G` or `U`): | |
| task mode, `G` for generation and `U` for understanding | |
| ratio (`str`, *optional*): | |
| the image width-height ratio for generation | |
| image_area (`int`, *optional*): | |
| image area used to calcualte the generated image height and width | |
| return_tensors (`str` or [`~utils.TensorType`], *optional*): | |
| If set, will return tensors of a particular framework. Acceptable values are: | |
| - `'pt'`: Return PyTorch `torch.Tensor` objects. | |
| - `'np'`: Return NumPy `np.ndarray` objects. | |
| Returns: | |
| [`BatchFeature`]: A [`BatchFeature`] with the following fields: | |
| - **input_ids** -- List of token ids to be fed to a model. | |
| - **image_size** -- List of image size of input images or generated images. | |
| """ | |
| assert mode in ('G', 'U'), "mode must be 'G' or 'U'." | |
| if isinstance(text, str): | |
| text = [text] | |
| if not isinstance(text[0], str): | |
| raise ValueError("`text` must be string or list of string") | |
| image_inputs = None | |
| if mode == 'G': | |
| if image is not None: | |
| raise ValueError("You have to specify only `text` in generation mode") | |
| if len(text) > 1: | |
| raise ValueError("`text` can only be `str` in generation mode") | |
| else: | |
| if image is None: | |
| raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.") | |
| if not isinstance(image, Sequence) and not isinstance(image, Image.Image): | |
| raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].") | |
| if isinstance(image, Sequence) and not isinstance(image[0], Image.Image): | |
| raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].") | |
| image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"] | |
| print(image_inputs.shape) | |
| image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype) | |
| image_tokens = self.vision_tokenizer.encode(image_inputs) | |
| if len(text) != len(image_tokens): | |
| raise ValueError("number of image must match number of text prompt") | |
| prompt_list, size_list = [], [] | |
| for idx, text_prompt in enumerate(text): | |
| prompt = self.tokenizer.bos_token | |
| if mode == 'U': | |
| h, w = image_tokens[idx].shape | |
| imgstr = self.to_imgstr(image_tokens[idx]) | |
| image_prompt = ( | |
| self.tokenizer.boi_token + | |
| self.prefix_template.format(H=h, W=w) + | |
| self.tokenizer.img_token + | |
| imgstr + | |
| self.tokenizer.eol_token + | |
| self.tokenizer.eof_token + | |
| self.tokenizer.eoi_token | |
| ) | |
| prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt) | |
| else: | |
| h, w = self.calculate_generate_size(ratio, image_area, self.vision_tokenizer.spatial_scale_factor) | |
| image_prompt = ( | |
| self.tokenizer.boi_token + | |
| self.prefix_template.format(H=h, W=w) + | |
| self.tokenizer.img_token | |
| ) | |
| prompt += (text_prompt + image_prompt) | |
| prompt_list.append(prompt) | |
| size_list.append([h, w]) | |
| text_inputs = self.tokenizer(prompt_list, **kwargs) | |
| return BatchFeature(data={**text_inputs, "image_size": size_list}, tensor_type=kwargs.get("return_tensors")) | |
| def batch_decode(self, *args, **kwargs): | |
| docs = self.tokenizer.batch_decode(*args, **kwargs) | |
| return [self.multimodal_decode(d) for d in docs] | |
| def decode(self, *args, **kwargs): | |
| doc = self.tokenizer.decode(*args, **kwargs) | |
| return self.multimodal_decode(doc) | |
| def vision_encode(self, *args, **kwargs): | |
| return self.vision_tokenizer.encode(*args, **kwargs) | |
| def vision_decode(self, *args, **kwargs): | |
| return self.vision_tokenizer.decode(*args, **kwargs) | |
| def multimodal_decode(self, doc): | |
| multimodal_output = [] | |
| pattern = rf'({re.escape(self.tokenizer.boi_token)}.*?{re.escape(self.tokenizer.eoi_token)})' | |
| chunks = re.split(pattern, doc) | |
| for c in chunks: | |
| if len(c) == 0: | |
| continue | |
| if self.tokenizer.boi_token in c: | |
| image = [] | |
| image_rows = re.split(re.escape(self.tokenizer.eol_token), c) | |
| for r in image_rows: | |
| token_ids = re.findall(self.visual_template[1], r) | |
| if len(token_ids) > 0: | |
| row_token = [int(m) for m in token_ids] | |
| image.append(row_token) | |
| image = torch.tensor(image, dtype=torch.long, device=self.vision_tokenizer.device) | |
| image = self.vision_tokenizer.decode(image[None]).float() | |
| image = self.image_processor.postprocess(image)["pixel_values"][0] | |
| multimodal_output.append(image) | |
| else: | |
| multimodal_output.append(c) | |
| return multimodal_output if len(multimodal_output) > 1 else multimodal_output[0] | |
| def model_input_names(self): | |
| tokenizer_input_names = self.tokenizer.model_input_names | |
| image_processor_input_names = self.image_processor.model_input_names | |
| return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) | |
| def to_imgstr(self, image_tokens): | |
| image_tokens = image_tokens.cpu().numpy().tolist() | |
| image_token_str = [ | |
| [ | |
| self.visual_template[0].format(token_id=token_id) | |
| for token_id in token_row | |
| ] | |
| for token_row in image_tokens | |
| ] | |
| image_row_str = ["".join(token_row) for token_row in image_token_str] | |
| imgstr = self.tokenizer.eol_token.join(image_row_str) | |
| return imgstr | |
| def calculate_generate_size(self, ratio, image_area, spatial_scale_factor): | |
| w, h = map(int, ratio.split(":")) | |
| current_area = h * w | |
| target_ratio = (image_area / current_area) ** 0.5 | |
| th = int(round(h * target_ratio / spatial_scale_factor)) | |
| tw = int(round(w * target_ratio / spatial_scale_factor)) | |
| return th, tw | |
| def build_const_helper(self): | |
| ( | |
| img_token, | |
| eoi_token, | |
| eos_token, | |
| eol_token, | |
| eof_token, | |
| pad_token, | |
| vis_start, | |
| vis_end, | |
| ) = self.tokenizer.encode([ | |
| self.tokenizer.img_token, | |
| self.tokenizer.eoi_token, | |
| self.tokenizer.eos_token, | |
| self.tokenizer.eol_token, | |
| self.tokenizer.eof_token, | |
| self.tokenizer.pad_token, | |
| self.visual_template[0].format(token_id=0), | |
| self.visual_template[0].format(token_id=self.vision_tokenizer.config.codebook_size - 1), | |
| ]) | |
| const_helper = partial( | |
| Emu3PrefixConstrainedLogitsHelper, | |
| img_token=img_token, | |
| eoi_token=eoi_token, | |
| eos_token=eos_token, | |
| eol_token=eol_token, | |
| eof_token=eof_token, | |
| pad_token=pad_token, | |
| visual_tokens=list(range(vis_start, vis_end + 1)), | |
| ) | |
| return const_helper | |
| def build_prefix_constrained_fn(self, height, width): | |
| helper = self.const_helper(height=height, width=width) | |
| return helper | |