refactor-image-processing (#16)
Browse files- refactor: support urls, fast processor, flash attn check (9624180b21e596eb410896719e1c1708aaed343c)
- refactor: image loading in st wrapper (9ef2e43d97b27bc27da6b71bc68d6160d317da20)
- custom_st.py +80 -42
- modeling_jina_embeddings_v4.py +41 -15
- tokenizer_config.json +1 -1
    	
        custom_st.py
    CHANGED
    
    | @@ -1,32 +1,34 @@ | |
|  | |
|  | |
| 1 | 
             
            from typing import Any, Dict, List, Literal, Optional, Union
         | 
| 2 |  | 
|  | |
| 3 | 
             
            import torch
         | 
| 4 | 
             
            from PIL import Image
         | 
| 5 | 
             
            from torch import nn
         | 
| 6 | 
            -
            from transformers import AutoConfig,  | 
| 7 |  | 
| 8 |  | 
| 9 | 
             
            class Transformer(nn.Module):
         | 
| 10 |  | 
| 11 | 
             
                save_in_root: bool = True
         | 
| 12 | 
            -
             | 
| 13 | 
             
                def __init__(
         | 
| 14 | 
             
                    self,
         | 
| 15 | 
            -
                    model_name_or_path: str =  | 
| 16 | 
             
                    max_seq_length: Optional[int] = None,
         | 
| 17 | 
             
                    config_args: Optional[Dict[str, Any]] = None,
         | 
| 18 | 
             
                    model_args: Optional[Dict[str, Any]] = None,
         | 
| 19 | 
             
                    tokenizer_args: Optional[Dict[str, Any]] = None,
         | 
| 20 | 
             
                    cache_dir: Optional[str] = None,
         | 
| 21 | 
            -
                    backend: Literal[ | 
| 22 | 
             
                    **kwargs,
         | 
| 23 | 
             
                ) -> None:
         | 
| 24 | 
             
                    super(Transformer, self).__init__()
         | 
| 25 | 
            -
                    if backend !=  | 
| 26 | 
             
                        raise ValueError(
         | 
| 27 | 
            -
                            f | 
| 28 | 
             
                        )
         | 
| 29 | 
            -
             | 
| 30 | 
             
                    config_kwargs = config_args or {}
         | 
| 31 | 
             
                    model_kwargs = model_args or {}
         | 
| 32 | 
             
                    tokenizer_kwargs = tokenizer_args or {}
         | 
| @@ -34,9 +36,11 @@ class Transformer(nn.Module): | |
| 34 | 
             
                    self.config = AutoConfig.from_pretrained(
         | 
| 35 | 
             
                        model_name_or_path, cache_dir=cache_dir, **config_kwargs
         | 
| 36 | 
             
                    )
         | 
| 37 | 
            -
                    self.default_task = model_args.pop( | 
| 38 | 
             
                    if self.default_task and self.default_task not in self.config.task_names:
         | 
| 39 | 
            -
                        raise ValueError( | 
|  | |
|  | |
| 40 |  | 
| 41 | 
             
                    self.model = AutoModel.from_pretrained(
         | 
| 42 | 
             
                        model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
         | 
| @@ -45,6 +49,7 @@ class Transformer(nn.Module): | |
| 45 | 
             
                    self.processor = AutoProcessor.from_pretrained(
         | 
| 46 | 
             
                        model_name_or_path,
         | 
| 47 | 
             
                        cache_dir=cache_dir,
         | 
|  | |
| 48 | 
             
                        **tokenizer_kwargs,
         | 
| 49 | 
             
                    )
         | 
| 50 | 
             
                    self.max_seq_length = max_seq_length or 8192
         | 
| @@ -55,33 +60,52 @@ class Transformer(nn.Module): | |
| 55 | 
             
                    encoding = {}
         | 
| 56 | 
             
                    text_indices = []
         | 
| 57 | 
             
                    image_indices = []
         | 
| 58 | 
            -
                    
         | 
| 59 | 
             
                    for i, text in enumerate(texts):
         | 
| 60 | 
             
                        if isinstance(text, str):
         | 
| 61 | 
            -
                             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 62 | 
             
                        elif isinstance(text, Image.Image):
         | 
| 63 | 
             
                            image_indices.append(i)
         | 
| 64 | 
             
                        else:
         | 
| 65 | 
            -
                            raise ValueError(f | 
| 66 | 
            -
                    
         | 
| 67 | 
             
                    if text_indices:
         | 
| 68 | 
             
                        _texts = [texts[i] for i in text_indices]
         | 
| 69 | 
            -
                        text_features = self.processor.process_texts( | 
|  | |
|  | |
| 70 | 
             
                        for key, value in text_features.items():
         | 
| 71 | 
            -
                            encoding[f | 
| 72 | 
            -
                        encoding[ | 
| 73 | 
            -
             | 
| 74 | 
             
                    if image_indices:
         | 
| 75 | 
             
                        _images = [texts[i] for i in image_indices]
         | 
| 76 | 
             
                        img_features = self.processor.process_images(_images)
         | 
| 77 | 
             
                        for key, value in img_features.items():
         | 
| 78 | 
            -
                            encoding[f | 
| 79 | 
            -
                        encoding[ | 
| 80 | 
            -
             | 
| 81 | 
             
                    return encoding
         | 
| 82 | 
            -
                
         | 
| 83 |  | 
| 84 | 
            -
                def forward( | 
|  | |
|  | |
| 85 | 
             
                    self.model.eval()
         | 
| 86 |  | 
| 87 | 
             
                    if task is None:
         | 
| @@ -94,41 +118,55 @@ class Transformer(nn.Module): | |
| 94 | 
             
                        task = self.default_task
         | 
| 95 | 
             
                    else:
         | 
| 96 | 
             
                        if task not in self.config.task_names:
         | 
| 97 | 
            -
                            raise ValueError( | 
|  | |
|  | |
| 98 |  | 
| 99 | 
             
                    device = self.model.device.type
         | 
| 100 | 
             
                    all_embeddings = []
         | 
| 101 | 
            -
             | 
| 102 | 
             
                    with torch.no_grad():
         | 
| 103 | 
            -
                        if any(k.startswith( | 
| 104 | 
            -
                            text_batch = { | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 107 | 
             
                            with torch.autocast(device_type=device):
         | 
| 108 | 
            -
                                text_embeddings = self.model( | 
|  | |
|  | |
| 109 | 
             
                                if self.config.truncate_dim:
         | 
| 110 | 
            -
                                    text_embeddings = text_embeddings[:, :self.config.truncate_dim]
         | 
| 111 | 
            -
             | 
| 112 | 
             
                            for i, embedding in enumerate(text_embeddings):
         | 
| 113 | 
             
                                all_embeddings.append((text_indices[i], embedding))
         | 
| 114 | 
            -
             | 
| 115 | 
            -
                        if any(k.startswith( | 
| 116 | 
            -
                            image_batch = { | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
| 119 | 
             
                            with torch.autocast(device_type=device):
         | 
| 120 | 
            -
                                img_embeddings = self.model( | 
|  | |
|  | |
| 121 | 
             
                                if self.config.truncate_dim:
         | 
| 122 | 
            -
                                    img_embeddings = img_embeddings[:, :self.config.truncate_dim]
         | 
| 123 | 
            -
             | 
| 124 | 
             
                            for i, embedding in enumerate(img_embeddings):
         | 
| 125 | 
             
                                all_embeddings.append((image_indices[i], embedding))
         | 
| 126 |  | 
| 127 | 
             
                    if not all_embeddings:
         | 
| 128 | 
            -
                        raise RuntimeError( | 
| 129 |  | 
| 130 | 
             
                    all_embeddings.sort(key=lambda x: x[0])  # sort by original index
         | 
| 131 | 
             
                    combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
         | 
| 132 | 
            -
                    features[ | 
| 133 | 
            -
             | 
| 134 | 
             
                    return features
         | 
|  | |
| 1 | 
            +
            from io import BytesIO
         | 
| 2 | 
            +
            from pathlib import Path
         | 
| 3 | 
             
            from typing import Any, Dict, List, Literal, Optional, Union
         | 
| 4 |  | 
| 5 | 
            +
            import requests
         | 
| 6 | 
             
            import torch
         | 
| 7 | 
             
            from PIL import Image
         | 
| 8 | 
             
            from torch import nn
         | 
| 9 | 
            +
            from transformers import AutoConfig, AutoModel, AutoProcessor
         | 
| 10 |  | 
| 11 |  | 
| 12 | 
             
            class Transformer(nn.Module):
         | 
| 13 |  | 
| 14 | 
             
                save_in_root: bool = True
         | 
| 15 | 
            +
             | 
| 16 | 
             
                def __init__(
         | 
| 17 | 
             
                    self,
         | 
| 18 | 
            +
                    model_name_or_path: str = "jinaai/jina-embeddings-v4",
         | 
| 19 | 
             
                    max_seq_length: Optional[int] = None,
         | 
| 20 | 
             
                    config_args: Optional[Dict[str, Any]] = None,
         | 
| 21 | 
             
                    model_args: Optional[Dict[str, Any]] = None,
         | 
| 22 | 
             
                    tokenizer_args: Optional[Dict[str, Any]] = None,
         | 
| 23 | 
             
                    cache_dir: Optional[str] = None,
         | 
| 24 | 
            +
                    backend: Literal["torch", "onnx", "openvino"] = "torch",
         | 
| 25 | 
             
                    **kwargs,
         | 
| 26 | 
             
                ) -> None:
         | 
| 27 | 
             
                    super(Transformer, self).__init__()
         | 
| 28 | 
            +
                    if backend != "torch":
         | 
| 29 | 
             
                        raise ValueError(
         | 
| 30 | 
            +
                            f"Backend '{backend}' is not supported, please use 'torch' instead"
         | 
| 31 | 
             
                        )
         | 
|  | |
| 32 | 
             
                    config_kwargs = config_args or {}
         | 
| 33 | 
             
                    model_kwargs = model_args or {}
         | 
| 34 | 
             
                    tokenizer_kwargs = tokenizer_args or {}
         | 
|  | |
| 36 | 
             
                    self.config = AutoConfig.from_pretrained(
         | 
| 37 | 
             
                        model_name_or_path, cache_dir=cache_dir, **config_kwargs
         | 
| 38 | 
             
                    )
         | 
| 39 | 
            +
                    self.default_task = model_args.pop("default_task", None)
         | 
| 40 | 
             
                    if self.default_task and self.default_task not in self.config.task_names:
         | 
| 41 | 
            +
                        raise ValueError(
         | 
| 42 | 
            +
                            f"Invalid task: {self.default_task}. Must be one of {self.config.task_names}."
         | 
| 43 | 
            +
                        )
         | 
| 44 |  | 
| 45 | 
             
                    self.model = AutoModel.from_pretrained(
         | 
| 46 | 
             
                        model_name_or_path, config=self.config, cache_dir=cache_dir, **model_kwargs
         | 
|  | |
| 49 | 
             
                    self.processor = AutoProcessor.from_pretrained(
         | 
| 50 | 
             
                        model_name_or_path,
         | 
| 51 | 
             
                        cache_dir=cache_dir,
         | 
| 52 | 
            +
                        use_fast=True,
         | 
| 53 | 
             
                        **tokenizer_kwargs,
         | 
| 54 | 
             
                    )
         | 
| 55 | 
             
                    self.max_seq_length = max_seq_length or 8192
         | 
|  | |
| 60 | 
             
                    encoding = {}
         | 
| 61 | 
             
                    text_indices = []
         | 
| 62 | 
             
                    image_indices = []
         | 
|  | |
| 63 | 
             
                    for i, text in enumerate(texts):
         | 
| 64 | 
             
                        if isinstance(text, str):
         | 
| 65 | 
            +
                            # Remove Query: or Passage: prefixes when checking for URLs or file paths
         | 
| 66 | 
            +
                            clean_text = text
         | 
| 67 | 
            +
                            if text.startswith("Query: "):
         | 
| 68 | 
            +
                                clean_text = text[len("Query: ") :]
         | 
| 69 | 
            +
                            elif text.startswith("Passage: "):
         | 
| 70 | 
            +
                                clean_text = text[len("Passage: ") :]
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                            if clean_text.startswith("http"):
         | 
| 73 | 
            +
                                response = requests.get(clean_text)
         | 
| 74 | 
            +
                                texts[i] = Image.open(BytesIO(response.content)).convert("RGB")
         | 
| 75 | 
            +
                                image_indices.append(i)
         | 
| 76 | 
            +
                            elif Path(clean_text).is_file():
         | 
| 77 | 
            +
                                try:
         | 
| 78 | 
            +
                                    texts[i] = Image.open(clean_text).convert("RGB")
         | 
| 79 | 
            +
                                    image_indices.append(i)
         | 
| 80 | 
            +
                                except Exception as e:
         | 
| 81 | 
            +
                                    text_indices.append(i)
         | 
| 82 | 
            +
                            else:
         | 
| 83 | 
            +
                                text_indices.append(i)
         | 
| 84 | 
             
                        elif isinstance(text, Image.Image):
         | 
| 85 | 
             
                            image_indices.append(i)
         | 
| 86 | 
             
                        else:
         | 
| 87 | 
            +
                            raise ValueError(f"Invalid input type: {type(text)}")
         | 
|  | |
| 88 | 
             
                    if text_indices:
         | 
| 89 | 
             
                        _texts = [texts[i] for i in text_indices]
         | 
| 90 | 
            +
                        text_features = self.processor.process_texts(
         | 
| 91 | 
            +
                            _texts, max_length=self.max_seq_length
         | 
| 92 | 
            +
                        )
         | 
| 93 | 
             
                        for key, value in text_features.items():
         | 
| 94 | 
            +
                            encoding[f"text_{key}"] = value
         | 
| 95 | 
            +
                        encoding["text_indices"] = text_indices
         | 
| 96 | 
            +
             | 
| 97 | 
             
                    if image_indices:
         | 
| 98 | 
             
                        _images = [texts[i] for i in image_indices]
         | 
| 99 | 
             
                        img_features = self.processor.process_images(_images)
         | 
| 100 | 
             
                        for key, value in img_features.items():
         | 
| 101 | 
            +
                            encoding[f"image_{key}"] = value
         | 
| 102 | 
            +
                        encoding["image_indices"] = image_indices
         | 
| 103 | 
            +
             | 
| 104 | 
             
                    return encoding
         | 
|  | |
| 105 |  | 
| 106 | 
            +
                def forward(
         | 
| 107 | 
            +
                    self, features: Dict[str, torch.Tensor], task: Optional[str] = None
         | 
| 108 | 
            +
                ) -> Dict[str, torch.Tensor]:
         | 
| 109 | 
             
                    self.model.eval()
         | 
| 110 |  | 
| 111 | 
             
                    if task is None:
         | 
|  | |
| 118 | 
             
                        task = self.default_task
         | 
| 119 | 
             
                    else:
         | 
| 120 | 
             
                        if task not in self.config.task_names:
         | 
| 121 | 
            +
                            raise ValueError(
         | 
| 122 | 
            +
                                f"Invalid task: {task}. Must be one of {self.config.task_names}."
         | 
| 123 | 
            +
                            )
         | 
| 124 |  | 
| 125 | 
             
                    device = self.model.device.type
         | 
| 126 | 
             
                    all_embeddings = []
         | 
| 127 | 
            +
             | 
| 128 | 
             
                    with torch.no_grad():
         | 
| 129 | 
            +
                        if any(k.startswith("text_") for k in features.keys()):
         | 
| 130 | 
            +
                            text_batch = {
         | 
| 131 | 
            +
                                k[len("text_") :]: v.to(device)
         | 
| 132 | 
            +
                                for k, v in features.items()
         | 
| 133 | 
            +
                                if k.startswith("text_") and k != "text_indices"
         | 
| 134 | 
            +
                            }
         | 
| 135 | 
            +
                            text_indices = features.get("text_indices", [])
         | 
| 136 | 
            +
             | 
| 137 | 
             
                            with torch.autocast(device_type=device):
         | 
| 138 | 
            +
                                text_embeddings = self.model(
         | 
| 139 | 
            +
                                    **text_batch, task_label=task
         | 
| 140 | 
            +
                                ).single_vec_emb
         | 
| 141 | 
             
                                if self.config.truncate_dim:
         | 
| 142 | 
            +
                                    text_embeddings = text_embeddings[:, : self.config.truncate_dim]
         | 
| 143 | 
            +
             | 
| 144 | 
             
                            for i, embedding in enumerate(text_embeddings):
         | 
| 145 | 
             
                                all_embeddings.append((text_indices[i], embedding))
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        if any(k.startswith("image_") for k in features.keys()):
         | 
| 148 | 
            +
                            image_batch = {
         | 
| 149 | 
            +
                                k[len("image_") :]: v.to(device)
         | 
| 150 | 
            +
                                for k, v in features.items()
         | 
| 151 | 
            +
                                if k.startswith("image_") and k != "image_indices"
         | 
| 152 | 
            +
                            }
         | 
| 153 | 
            +
                            image_indices = features.get("image_indices", [])
         | 
| 154 | 
            +
             | 
| 155 | 
             
                            with torch.autocast(device_type=device):
         | 
| 156 | 
            +
                                img_embeddings = self.model(
         | 
| 157 | 
            +
                                    **image_batch, task_label=task
         | 
| 158 | 
            +
                                ).single_vec_emb
         | 
| 159 | 
             
                                if self.config.truncate_dim:
         | 
| 160 | 
            +
                                    img_embeddings = img_embeddings[:, : self.config.truncate_dim]
         | 
| 161 | 
            +
             | 
| 162 | 
             
                            for i, embedding in enumerate(img_embeddings):
         | 
| 163 | 
             
                                all_embeddings.append((image_indices[i], embedding))
         | 
| 164 |  | 
| 165 | 
             
                    if not all_embeddings:
         | 
| 166 | 
            +
                        raise RuntimeError("No embeddings were generated")
         | 
| 167 |  | 
| 168 | 
             
                    all_embeddings.sort(key=lambda x: x[0])  # sort by original index
         | 
| 169 | 
             
                    combined_embeddings = torch.stack([emb for _, emb in all_embeddings])
         | 
| 170 | 
            +
                    features["sentence_embedding"] = combined_embeddings
         | 
| 171 | 
            +
             | 
| 172 | 
             
                    return features
         | 
    	
        modeling_jina_embeddings_v4.py
    CHANGED
    
    | @@ -5,20 +5,24 @@ import os | |
| 5 | 
             
            from dataclasses import dataclass
         | 
| 6 | 
             
            from enum import Enum
         | 
| 7 | 
             
            from functools import partial
         | 
|  | |
| 8 | 
             
            from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
         | 
| 9 |  | 
| 10 | 
             
            import numpy as np
         | 
|  | |
| 11 | 
             
            import torch
         | 
| 12 | 
             
            from huggingface_hub import snapshot_download
         | 
| 13 | 
            -
            from peft import  | 
| 14 | 
             
            from PIL import Image
         | 
| 15 | 
             
            from torch import nn
         | 
| 16 | 
             
            from torch.utils.data import DataLoader
         | 
| 17 | 
             
            from tqdm import tqdm
         | 
| 18 | 
             
            from transformers import BatchFeature
         | 
| 19 | 
            -
            from . | 
|  | |
| 20 | 
             
            from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
         | 
| 21 | 
             
            from .custom_lora_module import MultiAdapterLinear
         | 
|  | |
| 22 |  | 
| 23 |  | 
| 24 | 
             
            class PromptType(str, Enum):
         | 
| @@ -140,7 +144,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 140 | 
             
                    self._init_projection_layers(config)
         | 
| 141 | 
             
                    self.post_init()
         | 
| 142 | 
             
                    self.processor = JinaEmbeddingsV4Processor.from_pretrained(
         | 
| 143 | 
            -
                        self.name_or_path, trust_remote_code=True
         | 
| 144 | 
             
                    )
         | 
| 145 | 
             
                    self.single_vector_projector_dim = config.single_vector_projector_dim
         | 
| 146 | 
             
                    self.multi_vector_projector_dim = config.multi_vector_projector_dim
         | 
| @@ -160,7 +164,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 160 | 
             
                        task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
         | 
| 161 | 
             
                    """
         | 
| 162 | 
             
                    if task not in self.config.task_names:
         | 
| 163 | 
            -
                        raise ValueError( | 
|  | |
|  | |
| 164 | 
             
                    self._task = task
         | 
| 165 |  | 
| 166 | 
             
                def get_last_hidden_states(
         | 
| @@ -342,7 +348,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 342 | 
             
                    for batch in tqdm(dataloader, desc=desc):
         | 
| 343 | 
             
                        with torch.no_grad():
         | 
| 344 | 
             
                            batch = {k: v.to(self.device) for k, v in batch.items()}
         | 
| 345 | 
            -
                            with torch.autocast( | 
|  | |
|  | |
| 346 | 
             
                                embeddings = self(**batch, task_label=task_label)
         | 
| 347 | 
             
                                if vector_type == "single_vector":
         | 
| 348 | 
             
                                    embeddings = embeddings.single_vec_emb
         | 
| @@ -395,7 +403,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 395 | 
             
                        encode_kwargs["truncate_dim"] = truncate_dim
         | 
| 396 |  | 
| 397 | 
             
                    return encode_kwargs
         | 
| 398 | 
            -
             | 
| 399 | 
             
                def _validate_task(self, task: Optional[str] = None) -> str:
         | 
| 400 | 
             
                    if task is None:
         | 
| 401 | 
             
                        if self.task is None:
         | 
| @@ -406,7 +414,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 406 | 
             
                        task = self.task
         | 
| 407 | 
             
                    else:
         | 
| 408 | 
             
                        if task not in self.config.task_names:
         | 
| 409 | 
            -
                            raise ValueError( | 
|  | |
|  | |
| 410 | 
             
                    return task
         | 
| 411 |  | 
| 412 | 
             
                def encode_texts(
         | 
| @@ -460,9 +470,23 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 460 |  | 
| 461 | 
             
                    return embeddings
         | 
| 462 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 463 | 
             
                def encode_images(
         | 
| 464 | 
             
                    self,
         | 
| 465 | 
            -
                    images: List[Image.Image],
         | 
| 466 | 
             
                    task: Optional[str] = None,
         | 
| 467 | 
             
                    batch_size: int = 8,
         | 
| 468 | 
             
                    vector_type: Optional[str] = None,
         | 
| @@ -474,7 +498,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 474 | 
             
                    Encodes a list of images into embeddings.
         | 
| 475 |  | 
| 476 | 
             
                    Args:
         | 
| 477 | 
            -
                        images: List of PIL images to encode
         | 
| 478 | 
             
                        batch_size: Number of images to process at once
         | 
| 479 | 
             
                        vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
         | 
| 480 | 
             
                        return_numpy: Whether to return numpy arrays instead of torch tensors
         | 
| @@ -489,9 +513,9 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 489 | 
             
                        self.processor.image_processor.max_pixels = (
         | 
| 490 | 
             
                            max_pixels  # change during encoding
         | 
| 491 | 
             
                        )
         | 
| 492 | 
            -
             | 
| 493 | 
             
                    encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
         | 
| 494 | 
             
                    task = self._validate_task(task)
         | 
|  | |
| 495 | 
             
                    embeddings = self._process_batches(
         | 
| 496 | 
             
                        data=images,
         | 
| 497 | 
             
                        processor_fn=self.processor.process_images,
         | 
| @@ -519,8 +543,10 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 519 | 
             
                    """
         | 
| 520 | 
             
                    if "torch_dtype" not in kwargs:
         | 
| 521 | 
             
                        kwargs["torch_dtype"] = "auto"
         | 
| 522 | 
            -
             | 
| 523 | 
             
                    kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
         | 
|  | |
|  | |
| 524 |  | 
| 525 | 
             
                    base_model = super().from_pretrained(
         | 
| 526 | 
             
                        pretrained_model_name_or_path, *args, **kwargs
         | 
| @@ -547,19 +573,19 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 547 | 
             
                        model_id=adapter_dir,
         | 
| 548 | 
             
                        config=lora_config,
         | 
| 549 | 
             
                    )
         | 
| 550 | 
            -
             | 
| 551 | 
             
                    @property
         | 
| 552 | 
             
                    def task(self):
         | 
| 553 | 
             
                        return self.model.task
         | 
| 554 | 
            -
             | 
| 555 | 
             
                    @task.setter
         | 
| 556 | 
             
                    def task(self, value):
         | 
| 557 | 
             
                        self.model.task = value
         | 
| 558 | 
            -
             | 
| 559 | 
             
                    peft_model.task = property(task.fget, task.fset)
         | 
| 560 | 
             
                    peft_model.__class__.task = property(
         | 
| 561 | 
             
                        lambda self: self.model.task,
         | 
| 562 | 
            -
                        lambda self, value: setattr(self.model,  | 
| 563 | 
             
                    )
         | 
| 564 |  | 
| 565 | 
             
                    return peft_model
         | 
|  | |
| 5 | 
             
            from dataclasses import dataclass
         | 
| 6 | 
             
            from enum import Enum
         | 
| 7 | 
             
            from functools import partial
         | 
| 8 | 
            +
            from io import BytesIO
         | 
| 9 | 
             
            from typing import Any, Callable, ClassVar, Dict, List, Optional, Union, cast
         | 
| 10 |  | 
| 11 | 
             
            import numpy as np
         | 
| 12 | 
            +
            import requests
         | 
| 13 | 
             
            import torch
         | 
| 14 | 
             
            from huggingface_hub import snapshot_download
         | 
| 15 | 
            +
            from peft import LoraConfig, PeftModel
         | 
| 16 | 
             
            from PIL import Image
         | 
| 17 | 
             
            from torch import nn
         | 
| 18 | 
             
            from torch.utils.data import DataLoader
         | 
| 19 | 
             
            from tqdm import tqdm
         | 
| 20 | 
             
            from transformers import BatchFeature
         | 
| 21 | 
            +
            from transformers.utils import is_flash_attn_2_available
         | 
| 22 | 
            +
             | 
| 23 | 
             
            from .configuration_jina_embeddings_v4 import JinaEmbeddingsV4Config
         | 
| 24 | 
             
            from .custom_lora_module import MultiAdapterLinear
         | 
| 25 | 
            +
            from .qwen2_5_vl import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
         | 
| 26 |  | 
| 27 |  | 
| 28 | 
             
            class PromptType(str, Enum):
         | 
|  | |
| 144 | 
             
                    self._init_projection_layers(config)
         | 
| 145 | 
             
                    self.post_init()
         | 
| 146 | 
             
                    self.processor = JinaEmbeddingsV4Processor.from_pretrained(
         | 
| 147 | 
            +
                        self.name_or_path, trust_remote_code=True, use_fast=True
         | 
| 148 | 
             
                    )
         | 
| 149 | 
             
                    self.single_vector_projector_dim = config.single_vector_projector_dim
         | 
| 150 | 
             
                    self.multi_vector_projector_dim = config.multi_vector_projector_dim
         | 
|  | |
| 164 | 
             
                        task (str): The task name. Must be one of ['retrieval', 'text-matching', 'code']
         | 
| 165 | 
             
                    """
         | 
| 166 | 
             
                    if task not in self.config.task_names:
         | 
| 167 | 
            +
                        raise ValueError(
         | 
| 168 | 
            +
                            f"Invalid task: {task}. Must be one of {self.config.task_names}."
         | 
| 169 | 
            +
                        )
         | 
| 170 | 
             
                    self._task = task
         | 
| 171 |  | 
| 172 | 
             
                def get_last_hidden_states(
         | 
|  | |
| 348 | 
             
                    for batch in tqdm(dataloader, desc=desc):
         | 
| 349 | 
             
                        with torch.no_grad():
         | 
| 350 | 
             
                            batch = {k: v.to(self.device) for k, v in batch.items()}
         | 
| 351 | 
            +
                            with torch.autocast(
         | 
| 352 | 
            +
                                device_type=torch.device(self.device).type, dtype=torch.bfloat16
         | 
| 353 | 
            +
                            ):
         | 
| 354 | 
             
                                embeddings = self(**batch, task_label=task_label)
         | 
| 355 | 
             
                                if vector_type == "single_vector":
         | 
| 356 | 
             
                                    embeddings = embeddings.single_vec_emb
         | 
|  | |
| 403 | 
             
                        encode_kwargs["truncate_dim"] = truncate_dim
         | 
| 404 |  | 
| 405 | 
             
                    return encode_kwargs
         | 
| 406 | 
            +
             | 
| 407 | 
             
                def _validate_task(self, task: Optional[str] = None) -> str:
         | 
| 408 | 
             
                    if task is None:
         | 
| 409 | 
             
                        if self.task is None:
         | 
|  | |
| 414 | 
             
                        task = self.task
         | 
| 415 | 
             
                    else:
         | 
| 416 | 
             
                        if task not in self.config.task_names:
         | 
| 417 | 
            +
                            raise ValueError(
         | 
| 418 | 
            +
                                f"Invalid task: {task}. Must be one of {self.config.task_names}."
         | 
| 419 | 
            +
                            )
         | 
| 420 | 
             
                    return task
         | 
| 421 |  | 
| 422 | 
             
                def encode_texts(
         | 
|  | |
| 470 |  | 
| 471 | 
             
                    return embeddings
         | 
| 472 |  | 
| 473 | 
            +
                def _load_images_if_needed(
         | 
| 474 | 
            +
                    self, images: List[Union[str, Image.Image]]
         | 
| 475 | 
            +
                ) -> List[Image.Image]:
         | 
| 476 | 
            +
                    loaded_images = []
         | 
| 477 | 
            +
                    for image in images:
         | 
| 478 | 
            +
                        if isinstance(image, str):
         | 
| 479 | 
            +
                            if image.startswith("http"):
         | 
| 480 | 
            +
                                response = requests.get(image)
         | 
| 481 | 
            +
                                image = Image.open(BytesIO(response.content)).convert("RGB")
         | 
| 482 | 
            +
                            else:
         | 
| 483 | 
            +
                                image = Image.open(image).convert("RGB")
         | 
| 484 | 
            +
                        loaded_images.append(image)
         | 
| 485 | 
            +
                    return loaded_images
         | 
| 486 | 
            +
             | 
| 487 | 
             
                def encode_images(
         | 
| 488 | 
             
                    self,
         | 
| 489 | 
            +
                    images: List[Union[str, Image.Image]],
         | 
| 490 | 
             
                    task: Optional[str] = None,
         | 
| 491 | 
             
                    batch_size: int = 8,
         | 
| 492 | 
             
                    vector_type: Optional[str] = None,
         | 
|  | |
| 498 | 
             
                    Encodes a list of images into embeddings.
         | 
| 499 |  | 
| 500 | 
             
                    Args:
         | 
| 501 | 
            +
                        images: List of PIL images, URLs, or local file paths to encode
         | 
| 502 | 
             
                        batch_size: Number of images to process at once
         | 
| 503 | 
             
                        vector_type: Type of embedding vector to generate ('single_vector' or 'multi_vector')
         | 
| 504 | 
             
                        return_numpy: Whether to return numpy arrays instead of torch tensors
         | 
|  | |
| 513 | 
             
                        self.processor.image_processor.max_pixels = (
         | 
| 514 | 
             
                            max_pixels  # change during encoding
         | 
| 515 | 
             
                        )
         | 
|  | |
| 516 | 
             
                    encode_kwargs = self._validate_encoding_params(vector_type, truncate_dim)
         | 
| 517 | 
             
                    task = self._validate_task(task)
         | 
| 518 | 
            +
                    images = self._load_images_if_needed(images)
         | 
| 519 | 
             
                    embeddings = self._process_batches(
         | 
| 520 | 
             
                        data=images,
         | 
| 521 | 
             
                        processor_fn=self.processor.process_images,
         | 
|  | |
| 543 | 
             
                    """
         | 
| 544 | 
             
                    if "torch_dtype" not in kwargs:
         | 
| 545 | 
             
                        kwargs["torch_dtype"] = "auto"
         | 
| 546 | 
            +
             | 
| 547 | 
             
                    kwargs["key_mapping"] = super()._checkpoint_conversion_mapping
         | 
| 548 | 
            +
                    if not is_flash_attn_2_available():
         | 
| 549 | 
            +
                        kwargs["attn_implementation"] = "sdpa"
         | 
| 550 |  | 
| 551 | 
             
                    base_model = super().from_pretrained(
         | 
| 552 | 
             
                        pretrained_model_name_or_path, *args, **kwargs
         | 
|  | |
| 573 | 
             
                        model_id=adapter_dir,
         | 
| 574 | 
             
                        config=lora_config,
         | 
| 575 | 
             
                    )
         | 
| 576 | 
            +
             | 
| 577 | 
             
                    @property
         | 
| 578 | 
             
                    def task(self):
         | 
| 579 | 
             
                        return self.model.task
         | 
| 580 | 
            +
             | 
| 581 | 
             
                    @task.setter
         | 
| 582 | 
             
                    def task(self, value):
         | 
| 583 | 
             
                        self.model.task = value
         | 
| 584 | 
            +
             | 
| 585 | 
             
                    peft_model.task = property(task.fget, task.fset)
         | 
| 586 | 
             
                    peft_model.__class__.task = property(
         | 
| 587 | 
             
                        lambda self: self.model.task,
         | 
| 588 | 
            +
                        lambda self, value: setattr(self.model, "task", value),
         | 
| 589 | 
             
                    )
         | 
| 590 |  | 
| 591 | 
             
                    return peft_model
         | 
    	
        tokenizer_config.json
    CHANGED
    
    | @@ -202,7 +202,7 @@ | |
| 202 | 
             
              "extra_special_tokens": {},
         | 
| 203 | 
             
              "model_max_length": 131072,
         | 
| 204 | 
             
              "pad_token": "<|endoftext|>",
         | 
| 205 | 
            -
              "processor_class": " | 
| 206 | 
             
              "split_special_tokens": false,
         | 
| 207 | 
             
              "tokenizer_class": "Qwen2Tokenizer",
         | 
| 208 | 
             
              "unk_token": null
         | 
|  | |
| 202 | 
             
              "extra_special_tokens": {},
         | 
| 203 | 
             
              "model_max_length": 131072,
         | 
| 204 | 
             
              "pad_token": "<|endoftext|>",
         | 
| 205 | 
            +
              "processor_class": "JinaEmbeddingsV4Processor",
         | 
| 206 | 
             
              "split_special_tokens": false,
         | 
| 207 | 
             
              "tokenizer_class": "Qwen2Tokenizer",
         | 
| 208 | 
             
              "unk_token": null
         | 

