Maximilian Werk
		
	commited on
		
		
					Commit 
							
							·
						
						cf456d3
	
1
								Parent(s):
							
							b7707d5
								
feat: reduced default noise of the model
Browse files
    	
        configuration_jina_embeddings_v4.py
    CHANGED
    
    | @@ -2,6 +2,7 @@ from transformers.models.qwen2_5_vl import Qwen2_5_VLConfig | |
| 2 |  | 
| 3 | 
             
            from typing import Optional
         | 
| 4 |  | 
|  | |
| 5 | 
             
            class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
         | 
| 6 | 
             
                """
         | 
| 7 | 
             
                Configuration for the JinaEmbeddingsV4 model.
         | 
| @@ -12,10 +13,11 @@ class JinaEmbeddingsV4Config(Qwen2_5_VLConfig): | |
| 12 | 
             
                    single_vector_pool_strategy: str = "mean",
         | 
| 13 | 
             
                    multi_vector_projector_dim: int = 128,
         | 
| 14 | 
             
                    pretrained_peft_model_name_or_path: Optional[str] = None,
         | 
|  | |
| 15 | 
             
                    **kwargs,
         | 
| 16 | 
             
                ):
         | 
| 17 | 
             
                    super().__init__(**kwargs)
         | 
| 18 | 
             
                    self.single_vector_pool_strategy = single_vector_pool_strategy
         | 
| 19 | 
             
                    self.multi_vector_projector_dim = multi_vector_projector_dim
         | 
| 20 | 
             
                    self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
         | 
| 21 | 
            -
             | 
|  | |
| 2 |  | 
| 3 | 
             
            from typing import Optional
         | 
| 4 |  | 
| 5 | 
            +
             | 
| 6 | 
             
            class JinaEmbeddingsV4Config(Qwen2_5_VLConfig):
         | 
| 7 | 
             
                """
         | 
| 8 | 
             
                Configuration for the JinaEmbeddingsV4 model.
         | 
|  | |
| 13 | 
             
                    single_vector_pool_strategy: str = "mean",
         | 
| 14 | 
             
                    multi_vector_projector_dim: int = 128,
         | 
| 15 | 
             
                    pretrained_peft_model_name_or_path: Optional[str] = None,
         | 
| 16 | 
            +
                    verbosity: int = 0,
         | 
| 17 | 
             
                    **kwargs,
         | 
| 18 | 
             
                ):
         | 
| 19 | 
             
                    super().__init__(**kwargs)
         | 
| 20 | 
             
                    self.single_vector_pool_strategy = single_vector_pool_strategy
         | 
| 21 | 
             
                    self.multi_vector_projector_dim = multi_vector_projector_dim
         | 
| 22 | 
             
                    self.pretrained_peft_model_name_or_path = pretrained_peft_model_name_or_path
         | 
| 23 | 
            +
                    self.verbosity = verbosity
         | 
    	
        modeling_jina_embeddings_v4.py
    CHANGED
    
    | @@ -146,6 +146,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 146 | 
             
                        self.name_or_path, trust_remote_code=True, use_fast=True
         | 
| 147 | 
             
                    )
         | 
| 148 | 
             
                    self.multi_vector_projector_dim = config.multi_vector_projector_dim
         | 
|  | |
| 149 | 
             
                    self._task = None
         | 
| 150 |  | 
| 151 | 
             
                @property
         | 
| @@ -335,7 +336,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 335 | 
             
                        assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
         | 
| 336 | 
             
                    results = []
         | 
| 337 | 
             
                    self.eval()
         | 
| 338 | 
            -
                    for batch in tqdm(dataloader, desc=desc):
         | 
| 339 | 
             
                        with torch.no_grad():
         | 
| 340 | 
             
                            batch = {k: v.to(self.device) for k, v in batch.items()}
         | 
| 341 | 
             
                            with torch.autocast(
         | 
| @@ -349,7 +350,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 349 | 
             
                                        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
         | 
| 350 | 
             
                                else:
         | 
| 351 | 
             
                                    embeddings = embeddings.multi_vec_emb
         | 
| 352 | 
            -
             | 
| 353 | 
             
                                if return_multivector and not return_numpy:
         | 
| 354 | 
             
                                    valid_tokens = batch["attention_mask"].bool()
         | 
| 355 | 
             
                                    embeddings = [
         | 
| @@ -453,7 +454,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 453 | 
             
                        if return_numpy:
         | 
| 454 | 
             
                            print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
         | 
| 455 | 
             
                        return_numpy = False
         | 
| 456 | 
            -
             | 
| 457 | 
             
                    if isinstance(texts, str):
         | 
| 458 | 
             
                        texts = [texts]
         | 
| 459 |  | 
| @@ -468,7 +469,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 468 | 
             
                        **encode_kwargs,
         | 
| 469 | 
             
                    )
         | 
| 470 |  | 
| 471 | 
            -
                    return embeddings if return_list else embeddings[0] | 
| 472 |  | 
| 473 | 
             
                def _load_images_if_needed(
         | 
| 474 | 
             
                    self, images: List[Union[str, Image.Image]]
         | 
| @@ -515,7 +516,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 515 | 
             
                        )
         | 
| 516 | 
             
                    encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
         | 
| 517 | 
             
                    task = self._validate_task(task)
         | 
| 518 | 
            -
             | 
| 519 | 
             
                    return_list = isinstance(images, list)
         | 
| 520 |  | 
| 521 | 
             
                    # If return_multivector is True and encoding multiple images, ignore return_numpy
         | 
| @@ -527,7 +528,7 @@ class JinaEmbeddingsV4Model(Qwen2_5_VLForConditionalGeneration): | |
| 527 | 
             
                    # Convert single image to list
         | 
| 528 | 
             
                    if isinstance(images, (str, Image.Image)):
         | 
| 529 | 
             
                        images = [images]
         | 
| 530 | 
            -
             | 
| 531 | 
             
                    images = self._load_images_if_needed(images)
         | 
| 532 | 
             
                    embeddings = self._process_batches(
         | 
| 533 | 
             
                        data=images,
         | 
|  | |
| 146 | 
             
                        self.name_or_path, trust_remote_code=True, use_fast=True
         | 
| 147 | 
             
                    )
         | 
| 148 | 
             
                    self.multi_vector_projector_dim = config.multi_vector_projector_dim
         | 
| 149 | 
            +
                    self.verbosity = config.verbosity
         | 
| 150 | 
             
                    self._task = None
         | 
| 151 |  | 
| 152 | 
             
                @property
         | 
|  | |
| 336 | 
             
                        assert not return_numpy, "`return_numpy` is not supported when `return_multivector=True` and more than one data is encoded"
         | 
| 337 | 
             
                    results = []
         | 
| 338 | 
             
                    self.eval()
         | 
| 339 | 
            +
                    for batch in tqdm(dataloader, desc=desc, disable=self.verbosity == 0):
         | 
| 340 | 
             
                        with torch.no_grad():
         | 
| 341 | 
             
                            batch = {k: v.to(self.device) for k, v in batch.items()}
         | 
| 342 | 
             
                            with torch.autocast(
         | 
|  | |
| 350 | 
             
                                        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
         | 
| 351 | 
             
                                else:
         | 
| 352 | 
             
                                    embeddings = embeddings.multi_vec_emb
         | 
| 353 | 
            +
             | 
| 354 | 
             
                                if return_multivector and not return_numpy:
         | 
| 355 | 
             
                                    valid_tokens = batch["attention_mask"].bool()
         | 
| 356 | 
             
                                    embeddings = [
         | 
|  | |
| 454 | 
             
                        if return_numpy:
         | 
| 455 | 
             
                            print("Warning: `return_numpy` is ignored when `return_multivector=True` and `len(texts) > 1`")
         | 
| 456 | 
             
                        return_numpy = False
         | 
| 457 | 
            +
             | 
| 458 | 
             
                    if isinstance(texts, str):
         | 
| 459 | 
             
                        texts = [texts]
         | 
| 460 |  | 
|  | |
| 469 | 
             
                        **encode_kwargs,
         | 
| 470 | 
             
                    )
         | 
| 471 |  | 
| 472 | 
            +
                    return embeddings if return_list else embeddings[0]
         | 
| 473 |  | 
| 474 | 
             
                def _load_images_if_needed(
         | 
| 475 | 
             
                    self, images: List[Union[str, Image.Image]]
         | 
|  | |
| 516 | 
             
                        )
         | 
| 517 | 
             
                    encode_kwargs = self._validate_encoding_params(truncate_dim=truncate_dim)
         | 
| 518 | 
             
                    task = self._validate_task(task)
         | 
| 519 | 
            +
             | 
| 520 | 
             
                    return_list = isinstance(images, list)
         | 
| 521 |  | 
| 522 | 
             
                    # If return_multivector is True and encoding multiple images, ignore return_numpy
         | 
|  | |
| 528 | 
             
                    # Convert single image to list
         | 
| 529 | 
             
                    if isinstance(images, (str, Image.Image)):
         | 
| 530 | 
             
                        images = [images]
         | 
| 531 | 
            +
             | 
| 532 | 
             
                    images = self._load_images_if_needed(images)
         | 
| 533 | 
             
                    embeddings = self._process_batches(
         | 
| 534 | 
             
                        data=images,
         | 
