Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| class CLIPTextModel_(torch.nn.Module): | |
| """#### The CLIPTextModel_ module.""" | |
| def __init__( | |
| self, | |
| config_dict: dict, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| operations: object, | |
| ): | |
| """#### Initialize the CLIPTextModel_ module. | |
| #### Args: | |
| - `config_dict` (dict): The configuration dictionary. | |
| - `dtype` (torch.dtype): The data type. | |
| - `device` (torch.device): The device to use. | |
| - `operations` (object): The operations object. | |
| """ | |
| num_layers = config_dict["num_hidden_layers"] | |
| embed_dim = config_dict["hidden_size"] | |
| heads = config_dict["num_attention_heads"] | |
| intermediate_size = config_dict["intermediate_size"] | |
| intermediate_activation = config_dict["hidden_act"] | |
| num_positions = config_dict["max_position_embeddings"] | |
| self.eos_token_id = config_dict["eos_token_id"] | |
| super().__init__() | |
| from modules.clip.Clip import CLIPEmbeddings, CLIPEncoder | |
| self.embeddings = CLIPEmbeddings( | |
| embed_dim, | |
| num_positions=num_positions, | |
| dtype=dtype, | |
| device=device, | |
| operations=operations, | |
| ) | |
| self.encoder = CLIPEncoder( | |
| num_layers, | |
| embed_dim, | |
| heads, | |
| intermediate_size, | |
| intermediate_activation, | |
| dtype, | |
| device, | |
| operations, | |
| ) | |
| self.final_layer_norm = operations.LayerNorm( | |
| embed_dim, dtype=dtype, device=device | |
| ) | |
| def forward( | |
| self, | |
| input_tokens: torch.Tensor, | |
| attention_mask: torch.Tensor = None, | |
| intermediate_output: int = None, | |
| final_layer_norm_intermediate: bool = True, | |
| dtype: torch.dtype = torch.float32, | |
| ) -> tuple: | |
| """#### Forward pass for the CLIPTextModel_ module. | |
| #### Args: | |
| - `input_tokens` (torch.Tensor): The input tokens. | |
| - `attention_mask` (torch.Tensor, optional): The attention mask. Defaults to None. | |
| - `intermediate_output` (int, optional): The intermediate output layer. Defaults to None. | |
| - `final_layer_norm_intermediate` (bool, optional): Whether to apply final layer normalization to the intermediate output. Defaults to True. | |
| #### Returns: | |
| - `tuple`: The output tensor, the intermediate output tensor, and the pooled output tensor. | |
| """ | |
| x = self.embeddings(input_tokens, dtype=dtype) | |
| mask = None | |
| if attention_mask is not None: | |
| mask = 1.0 - attention_mask.to(x.dtype).reshape( | |
| (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) | |
| ).expand( | |
| attention_mask.shape[0], | |
| 1, | |
| attention_mask.shape[-1], | |
| attention_mask.shape[-1], | |
| ) | |
| mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) | |
| causal_mask = ( | |
| torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) | |
| .fill_(float("-inf")) | |
| .triu_(1) | |
| ) | |
| if mask is not None: | |
| mask += causal_mask | |
| else: | |
| mask = causal_mask | |
| x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) | |
| x = self.final_layer_norm(x) | |
| if i is not None and final_layer_norm_intermediate: | |
| i = self.final_layer_norm(i) | |
| pooled_output = x[ | |
| torch.arange(x.shape[0], device=x.device), | |
| ( | |
| torch.round(input_tokens).to(dtype=torch.int, device=x.device) | |
| == self.eos_token_id | |
| ) | |
| .int() | |
| .argmax(dim=-1), | |
| ] | |
| return x, i, pooled_output | |
| class CLIPTextModel(torch.nn.Module): | |
| """#### The CLIPTextModel module.""" | |
| def __init__( | |
| self, | |
| config_dict: dict, | |
| dtype: torch.dtype, | |
| device: torch.device, | |
| operations: object, | |
| ): | |
| """#### Initialize the CLIPTextModel module. | |
| #### Args: | |
| - `config_dict` (dict): The configuration dictionary. | |
| - `dtype` (torch.dtype): The data type. | |
| - `device` (torch.device): The device to use. | |
| - `operations` (object): The operations object. | |
| """ | |
| super().__init__() | |
| self.num_layers = config_dict["num_hidden_layers"] | |
| self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) | |
| embed_dim = config_dict["hidden_size"] | |
| self.text_projection = operations.Linear( | |
| embed_dim, embed_dim, bias=False, dtype=dtype, device=device | |
| ) | |
| self.dtype = dtype | |
| def get_input_embeddings(self) -> torch.nn.Embedding: | |
| """#### Get the input embeddings. | |
| #### Returns: | |
| - `torch.nn.Embedding`: The input embeddings. | |
| """ | |
| return self.text_model.embeddings.token_embedding | |
| def set_input_embeddings(self, embeddings: torch.nn.Embedding) -> None: | |
| """#### Set the input embeddings. | |
| #### Args: | |
| - `embeddings` (torch.nn.Embedding): The input embeddings. | |
| """ | |
| self.text_model.embeddings.token_embedding = embeddings | |
| def forward(self, *args, **kwargs) -> tuple: | |
| """#### Forward pass for the CLIPTextModel module. | |
| #### Args: | |
| - `*args`: Variable length argument list. | |
| - `**kwargs`: Arbitrary keyword arguments. | |
| #### Returns: | |
| - `tuple`: The output tensors. | |
| """ | |
| x = self.text_model(*args, **kwargs) | |
| out = self.text_projection(x[2]) | |
| return (x[0], x[1], out, x[2]) |