Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel | |
| from transformers.modeling_outputs import BaseModelOutput | |
| try: | |
| import torch_npu | |
| from torch_npu.contrib import transfer_to_npu | |
| DEVICE_TYPE = "npu" | |
| except ModuleNotFoundError: | |
| DEVICE_TYPE = "cuda" | |
| class TransformersTextEncoderBase(nn.Module): | |
| """ | |
| Base class for text encoding using HuggingFace Transformers models. | |
| """ | |
| def __init__(self, model_name: str): | |
| super().__init__() | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| def forward( | |
| self, | |
| text: list[str], | |
| ): | |
| device = self.model.device | |
| batch = self.tokenizer( | |
| text, | |
| max_length=self.tokenizer.model_max_length, | |
| padding=True, | |
| truncation=True, | |
| return_tensors="pt" | |
| ) | |
| input_ids = batch.input_ids.to(device) | |
| attention_mask = batch.attention_mask.to(device) | |
| output: BaseModelOutput = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| ) | |
| output = output.last_hidden_state | |
| mask = (attention_mask == 1).to(device) | |
| return {"output": output, "mask": mask} | |
| class T5TextEncoder(TransformersTextEncoderBase): | |
| """ | |
| Text encoder using T5 encoder model. | |
| """ | |
| def __init__(self, model_name: str = "/mnt/petrelfs/zhengzihao/cache/google-flan-t5-large"): | |
| nn.Module.__init__(self) | |
| self.tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| self.model = T5EncoderModel.from_pretrained(model_name) | |
| for param in self.model.parameters(): | |
| param.requires_grad = False | |
| self.eval() | |
| def forward( | |
| self, | |
| text: list[str], | |
| ): | |
| with torch.no_grad(), torch.amp.autocast( | |
| device_type=DEVICE_TYPE, enabled=False | |
| ): | |
| return super().forward(text) | |
| if __name__ == '__main__': | |
| text_encoder = T5TextEncoder() | |
| text = ["dog barking and cat moving"] | |
| text_encoder.eval() | |
| with torch.no_grad(): | |
| output = text_encoder(text) | |
| print(output["output"].shape) | |
| #print(output) |