Spaces:
Running
Running
| from typing import Optional, Tuple, Union | |
| import torch | |
| from transformers.modeling_outputs import BaseModelOutputWithPooling | |
| from transformers.models.clip.configuration_clip import CLIPConfig | |
| from transformers.models.clip.modeling_clip import CLIPModel, CLIPTextTransformer, _make_causal_mask, _expand_mask, clip_loss, CLIPOutput | |
| class CLIPTextTransformerCanReceiveEmbed(CLIPTextTransformer): | |
| def forward(self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| input_embeds: Optional[torch.Tensor] = None, # NOTE | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None,) -> Union[Tuple, BaseModelOutputWithPooling]: | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if input_embeds is None: | |
| if input_ids is None: | |
| raise ValueError("You have to specify input_ids") | |
| input_shape = input_ids.size() | |
| input_ids = input_ids.view(-1, input_shape[-1]) | |
| hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) | |
| else: | |
| hidden_states = input_embeds | |
| input_shape = torch.Size([hidden_states.size(0), hidden_states.size(1)]) | |
| # CLIP's text model uses causal mask, prepare it here. | |
| # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 | |
| # print(input_shape) | |
| causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) | |
| # expand attention_mask | |
| if attention_mask is not None: | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| attention_mask = _expand_mask(attention_mask, hidden_states.dtype) | |
| encoder_outputs = self.encoder( | |
| inputs_embeds=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| last_hidden_state = encoder_outputs[0] | |
| last_hidden_state = self.final_layer_norm(last_hidden_state) | |
| # text_embeds.shape = [batch_size, sequence_length, transformer.width] | |
| # take features from the eot embedding (eot_token is the highest number in each sequence) | |
| # eot embedding pos: input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1) | |
| # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 | |
| if input_ids is not None: | |
| eos_embedding_pos = input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1) | |
| # print(input_ids, eos_embedding_pos) | |
| else: | |
| # pass | |
| # TODO: is there any exception? | |
| eos_embedding_pos = torch.tensor([input_embeds.size(1) - 1] * input_embeds.size(0), device=last_hidden_state.device) | |
| pooled_output = last_hidden_state[ | |
| torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), | |
| eos_embedding_pos | |
| ] | |
| if not return_dict: | |
| return (last_hidden_state, pooled_output) + encoder_outputs[1:] | |
| return BaseModelOutputWithPooling( | |
| last_hidden_state=last_hidden_state, | |
| pooler_output=pooled_output, | |
| hidden_states=encoder_outputs.hidden_states, | |
| attentions=encoder_outputs.attentions, | |
| ) | |
| class CLIPModelCanReceiveTextEmbeds(CLIPModel): | |
| def __init__(self, config: CLIPConfig): | |
| super().__init__(config) | |
| self.text_model = CLIPTextTransformerCanReceiveEmbed(config.text_config) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| input_embeds: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| return_loss: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| only_return_logits_per_text = False, | |
| no_grad_text = False | |
| ) -> Union[Tuple, CLIPOutput]: | |
| # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| vision_outputs = self.vision_model( | |
| pixel_values=pixel_values, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| if no_grad_text: | |
| with torch.no_grad(): | |
| text_outputs = self.text_model( | |
| input_ids=input_ids, | |
| input_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| else: | |
| text_outputs = self.text_model( | |
| input_ids=input_ids, | |
| input_embeds=input_embeds, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| image_embeds = vision_outputs[1] | |
| image_embeds = self.visual_projection(image_embeds) | |
| text_embeds = text_outputs[1] | |
| text_embeds = self.text_projection(text_embeds) | |
| # normalized features | |
| image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | |
| text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) | |
| # cosine similarity as logits | |
| logit_scale = self.logit_scale.exp() | |
| logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale | |
| logits_per_image = logits_per_text.t() | |
| if only_return_logits_per_text: | |
| return logits_per_text | |
| loss = None | |
| if return_loss: | |
| loss = clip_loss(logits_per_text) | |
| if not return_dict: | |
| output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) | |
| return ((loss,) + output) if loss is not None else output | |
| return CLIPOutput( | |
| loss=loss, | |
| logits_per_image=logits_per_image, | |
| logits_per_text=logits_per_text, | |
| text_embeds=text_embeds, | |
| image_embeds=image_embeds, | |
| text_model_output=text_outputs, | |
| vision_model_output=vision_outputs, | |
| ) |