File size: 1,802 Bytes
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
79ac8fc
56238f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from src.models.conditioner.base import BaseConditioner

from transformers import Qwen3Model, Qwen2Tokenizer


class Qwen3TextEncoder(BaseConditioner):
    def __init__(self, weight_path: str, embed_dim:int=None, max_length=128):
        super().__init__()
        self.tokenizer = Qwen2Tokenizer.from_pretrained(weight_path, max_length=max_length, padding_side="right")
        # self.model = Qwen3Model.from_pretrained(weight_path, attn_implementation="flex_attention").to(torch.bfloat16)
        self.model = Qwen3Model.from_pretrained(weight_path).to(torch.bfloat16)
        # self.model.compile()
        self.uncondition_embedding = None
        self.embed_dim = embed_dim
        self.max_length = max_length
        # torch._dynamo.config.optimize_ddp = False

    def _impl_condition(self, y, metadata:dict={}):
        tokenized = self.tokenizer(y, truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
        input_ids = tokenized.input_ids.cuda()
        attention_mask = tokenized.attention_mask.cuda()
        metadata["valid_length_y"] = torch.sum(attention_mask, dim=-1)
        y = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
        if y.shape[2] < self.embed_dim:
            y = torch.cat([y, torch.zeros(y.shape[0], y.shape[1], self.embed_dim - y.shape[2]).to(y.device, y.dtype)], dim=-1)
        if y.shape[2] > self.embed_dim:
            y = y[:, :, :self.embed_dim]
        return y

    def _impl_uncondition(self, y, metadata:dict=None):
        if self.uncondition_embedding is not None:
            return self.uncondition_embedding.repeat(len(y), 1, 1)
        self.uncondition_embedding = self._impl_condition(["",])
        return self.uncondition_embedding.repeat(len(y), 1, 1)