Spaces:
Runtime error
Runtime error
| import random | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from .helpers import GatedCrossAttentionBlock | |
| from .utils import getattr_recursive, setattr_recursive | |
| class FlamingoLayer(nn.Module): | |
| def __init__(self, decoder_layer): | |
| super().__init__() | |
| self.decoder_layer = decoder_layer | |
| self.vis_x = None | |
| self.image_nums = None | |
| self.image_start_index_list = None | |
| self.media_locations = None | |
| self.add_visual_token = False | |
| self.input_ids = None | |
| def is_conditioned(self) -> bool: | |
| """Check whether the layer is conditioned.""" | |
| return self.vis_x is not None | |
| # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) | |
| def condition_vis_x(self, vis_x, image_nums=None, image_start_index_list=None, num_beams=None, visual_tokens=None, data_list=None): | |
| self.vis_x = vis_x | |
| self.image_nums = image_nums | |
| self.image_start_index_list = image_start_index_list | |
| self.num_beams = num_beams | |
| self.visual_tokens = visual_tokens | |
| self.data_list = data_list | |
| self.input_ids = None | |
| def condition_media_locations(self, media_locations): | |
| self.media_locations = media_locations | |
| def condition_attend_previous(self, attend_previous): | |
| self.attend_previous = attend_previous | |
| def forward( | |
| self, | |
| hidden_states, # alignment with hugging face name | |
| attention_mask=None, | |
| **decoder_layer_kwargs, | |
| ): | |
| if self.media_locations is None: | |
| raise ValueError("media_locations must be conditioned before forward pass") | |
| if self.vis_x is not None: | |
| if self.training: | |
| single_length = self.vis_x.shape[-2] | |
| image_nums = self.image_nums | |
| image_start_index_list = self.image_start_index_list | |
| image_nums = [0] + np.cumsum(image_nums).tolist() | |
| for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)): | |
| for index in start_indices: | |
| if image_num_begin < image_num_end: | |
| hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin] | |
| image_num_begin += 1 | |
| if self.visual_tokens is not None and len(self.visual_tokens) != 0: | |
| for i, (x, y) in enumerate(self.data_list): | |
| if len(self.visual_tokens[i].shape) > 1: | |
| # print(self.visual_tokens[i].shape[0], "embedding") | |
| hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i] | |
| else: | |
| # print(self.visual_tokens[i].shape[0], "embedding") | |
| hidden_states[x, y] = self.visual_tokens[i] | |
| elif not self.training: | |
| if ( | |
| ("past_key_value" in decoder_layer_kwargs and decoder_layer_kwargs["past_key_value"] is None) or | |
| ("layer_past" in decoder_layer_kwargs and decoder_layer_kwargs["layer_past"] is None) | |
| ): | |
| single_length = self.vis_x.shape[-2] | |
| image_nums = self.image_nums | |
| image_start_index_list = self.image_start_index_list | |
| image_nums = [0] + np.cumsum(image_nums).tolist() | |
| for i, (image_num_begin, image_num_end, start_indices) in enumerate(zip(image_nums[:-1], image_nums[1:], image_start_index_list)): | |
| for index in start_indices: | |
| if image_num_begin < image_num_end: | |
| hidden_states[i, index:index+single_length] = self.vis_x[image_num_begin] | |
| image_num_begin += 1 | |
| if self.visual_tokens is not None and len(self.visual_tokens) != 0: | |
| for i, (x, y) in enumerate(self.data_list): | |
| # import pdb; pdb.set_trace() | |
| # print(x, y, self.visual_tokens[i].shape) | |
| if len(self.visual_tokens[i].shape) > 1: | |
| # print(self.visual_tokens[i].shape[0], "embedding") | |
| hidden_states[x, y+1-self.visual_tokens[i].shape[0]:y+1] = self.visual_tokens[i] | |
| else: | |
| # print(self.visual_tokens[i].shape[0], "embedding") | |
| hidden_states[x, y] = self.visual_tokens[i] | |
| hidden_states = self.decoder_layer( | |
| hidden_states, attention_mask=attention_mask, **decoder_layer_kwargs | |
| ) | |
| return hidden_states | |
| class FlamingoLMMixin(nn.Module): | |
| """ | |
| Mixin to add cross-attention layers to a language model. | |
| """ | |
| def set_decoder_layers_attr_name(self, decoder_layers_attr_name): | |
| self.decoder_layers_attr_name = decoder_layers_attr_name | |
| def _get_decoder_layers(self): | |
| return getattr_recursive(self, self.decoder_layers_attr_name) | |
| def _set_decoder_layers(self, value): | |
| setattr_recursive(self, self.decoder_layers_attr_name, value) | |
| def init_flamingo( | |
| self, | |
| media_token_id, | |
| use_media_placement_augmentation, | |
| ): | |
| """ | |
| Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. | |
| """ | |
| self._set_decoder_layers( | |
| nn.ModuleList( | |
| [FlamingoLayer(decoder_layer) for decoder_layer in self._get_decoder_layers()] | |
| ) | |
| ) | |
| self.media_token_id = media_token_id | |
| self.use_media_placement_augmentation = use_media_placement_augmentation | |
| self.initialized_flamingo = True | |
| def forward(self, *input, **kwargs): | |
| """Condition the Flamingo layers on the media locations before forward()""" | |
| if not self.initialized_flamingo: | |
| raise ValueError( | |
| "Flamingo layers are not initialized. Please call `init_flamingo` first." | |
| ) | |
| input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0] | |
| media_locations = input_ids == self.media_token_id | |
| attend_previous = ( | |
| (random.random() < 0.5) if self.use_media_placement_augmentation else True | |
| ) | |
| if ( | |
| "gpt2" in self.__class__.__name__.lower() | |
| or "codegen" in self.__class__.__name__.lower() | |
| ): | |
| for layer in self.transformer.h: | |
| layer.condition_media_locations(media_locations) | |
| layer.condition_attend_previous(attend_previous) | |
| elif "gptneox" in self.__class__.__name__.lower(): | |
| for layer in self.gpt_neox.layers: | |
| layer.condition_media_locations(media_locations) | |
| layer.condition_attend_previous(attend_previous) | |
| else: | |
| for layer in self.get_decoder().layers: | |
| layer.condition_media_locations(media_locations) | |
| layer.condition_attend_previous(attend_previous) | |
| return super().forward( | |
| *input, **kwargs | |
| ) # Call the other parent's forward method | |
| def is_conditioned(self) -> bool: | |
| """Check whether all decoder layers are already conditioned.""" | |
| return all(l.is_conditioned() for l in self._get_decoder_layers()) | |
| def clear_conditioned_layers(self): | |
| for layer in self._get_decoder_layers(): | |
| layer.condition_vis_x(None) | |
| layer.condition_media_locations(None) | |
| layer.condition_attend_previous(None) | |