rookie9's picture
Upload 77 files
f582ec6 verified
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel, T5Tokenizer, T5EncoderModel
from transformers.modeling_outputs import BaseModelOutput
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
DEVICE_TYPE = "npu"
except ModuleNotFoundError:
DEVICE_TYPE = "cuda"
class TransformersTextEncoderBase(nn.Module):
"""
Base class for text encoding using HuggingFace Transformers models.
"""
def __init__(self, model_name: str):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def forward(
self,
text: list[str],
):
device = self.model.device
batch = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding=True,
truncation=True,
return_tensors="pt"
)
input_ids = batch.input_ids.to(device)
attention_mask = batch.attention_mask.to(device)
output: BaseModelOutput = self.model(
input_ids=input_ids, attention_mask=attention_mask
)
output = output.last_hidden_state
mask = (attention_mask == 1).to(device)
return {"output": output, "mask": mask}
class T5TextEncoder(TransformersTextEncoderBase):
"""
Text encoder using T5 encoder model.
"""
def __init__(self, model_name: str = "/mnt/petrelfs/zhengzihao/cache/google-flan-t5-large"):
nn.Module.__init__(self)
self.tokenizer = T5Tokenizer.from_pretrained(model_name)
self.model = T5EncoderModel.from_pretrained(model_name)
for param in self.model.parameters():
param.requires_grad = False
self.eval()
def forward(
self,
text: list[str],
):
with torch.no_grad(), torch.amp.autocast(
device_type=DEVICE_TYPE, enabled=False
):
return super().forward(text)
if __name__ == '__main__':
text_encoder = T5TextEncoder()
text = ["dog barking and cat moving"]
text_encoder.eval()
with torch.no_grad():
output = text_encoder(text)
print(output["output"].shape)
#print(output)