primerz commited on
Commit
1f2b17f
·
verified ·
1 Parent(s): 75805c0

Delete cappella.py

Browse files
Files changed (1) hide show
  1. cappella.py +0 -186
cappella.py DELETED
@@ -1,186 +0,0 @@
1
- import torch
2
- from dataclasses import dataclass
3
- from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
4
- from typing import Tuple
5
-
6
- @dataclass
7
- class CappellaResult:
8
- """
9
- Holds the 4 tensors required by the SDXL pipeline,
10
- all guaranteed to have the correct, matching sequence length.
11
- """
12
- embeds: torch.Tensor
13
- pooled_embeds: torch.Tensor
14
- negative_embeds: torch.Tensor
15
- negative_pooled_embeds: torch.Tensor
16
-
17
- class Cappella:
18
- """
19
- A minimal, custom-built prompt encoder for our SDXL pipeline.
20
- It replaces the 'compel' dependency and is tailored for our exact use case.
21
-
22
- It correctly:
23
- 1. Uses both SDXL tokenizers and text encoders.
24
- 2. Truncates prompts that are too long (fixes "78 vs 77" error).
25
- 3. Pads prompts (by using max_length) to ensure they are all 77 tokens.
26
- 4. Returns all 4 required embedding tensors.
27
- """
28
- def __init__(self, pipe, device):
29
- self.tokenizer: CLIPTokenizer = pipe.tokenizer
30
- self.tokenizer_2: CLIPTokenizer = pipe.tokenizer_2
31
- self.text_encoder: CLIPTextModel = pipe.text_encoder
32
- self.text_encoder_2: CLIPTextModelWithProjection = pipe.text_encoder_2
33
- self.device = device
34
-
35
-
36
- # In cappella.py
37
- @torch.no_grad()
38
- def __call__(self, prompt: str, negative_prompt: str) -> CappellaResult:
39
- """
40
- Encodes the positive and negative prompts.
41
- Ensures both embedding tensors have the same sequence length.
42
- """
43
- # Encode the positive prompt
44
- pos_embeds, pos_pooled = self._encode_one(prompt)
45
-
46
- # Encode the negative prompt
47
- neg_embeds, neg_pooled = self._encode_one(negative_prompt)
48
-
49
- # --- START FIX: Pad shorter embeds ---
50
- # Ensure embeds and negative_embeds have the same sequence length
51
- seq_len_pos = pos_embeds.shape[1]
52
- seq_len_neg = neg_embeds.shape[1]
53
-
54
- if seq_len_pos > seq_len_neg:
55
- # Pad negative embeds
56
- pad_len = seq_len_pos - seq_len_neg
57
- padding = torch.zeros(
58
- (neg_embeds.shape[0], pad_len, neg_embeds.shape[2]),
59
- device=self.device, dtype=neg_embeds.dtype
60
- )
61
- neg_embeds = torch.cat([neg_embeds, padding], dim=1)
62
-
63
- elif seq_len_neg > seq_len_pos:
64
- # Pad positive embeds
65
- pad_len = seq_len_neg - seq_len_pos
66
- padding = torch.zeros(
67
- (pos_embeds.shape[0], pad_len, pos_embeds.shape[2]),
68
- device=self.device, dtype=pos_embeds.dtype
69
- )
70
- pos_embeds = torch.cat([pos_embeds, padding], dim=1)
71
-
72
- # Now seq_len_pos and seq_len_neg are guaranteed to be equal
73
- # --- END FIX ---
74
-
75
- return CappellaResult(
76
- embeds=pos_embeds,
77
- pooled_embeds=pos_pooled,
78
- negative_embeds=neg_embeds,
79
- negative_pooled_embeds=neg_pooled
80
- )
81
-
82
- def _encode_one(self, prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
83
- """
84
- Runs a single prompt string through both text encoders.
85
- Handles prompts longer than 77 tokens by chunking.
86
- """
87
-
88
- # --- Get Tokenizers and Encoders ---
89
- tokenizers = [self.tokenizer, self.tokenizer_2]
90
- text_encoders = [self.text_encoder, self.text_encoder_2]
91
-
92
- prompt_embeds_list = []
93
- pooled_prompt_embeds = None
94
-
95
- for tokenizer, text_encoder in zip(tokenizers, text_encoders):
96
- # --- Tokenize ---
97
- # Tokenize without padding or truncation first
98
- text_inputs = tokenizer(
99
- prompt,
100
- padding=False,
101
- truncation=False,
102
- return_tensors="pt"
103
- )
104
- input_ids = text_inputs.input_ids.to(self.device)
105
-
106
- # --- Chunking ---
107
- # Manually chunk the input_ids
108
- max_length = tokenizer.model_max_length
109
- bos = tokenizer.bos_token_id
110
- eos = tokenizer.eos_token_id
111
-
112
- # We subtract 2 for BOS and EOS
113
- chunk_length = max_length - 2
114
-
115
- # Get all token IDs *except* BOS and EOS
116
- clean_input_ids = input_ids[0, 1:-1]
117
-
118
- # Split into chunks
119
- chunks = [clean_input_ids[i:i + chunk_length] for i in range(0, len(clean_input_ids), chunk_length)]
120
-
121
- # --- Prepare Batches ---
122
- batch_input_ids = []
123
- for chunk in chunks:
124
- # Add BOS and EOS
125
- chunk_with_bos_eos = torch.cat([
126
- torch.tensor([bos], dtype=torch.long, device=self.device),
127
- chunk.to(torch.long),
128
- torch.tensor([eos], dtype=torch.long, device=self.device)
129
- ])
130
-
131
- # Pad to max_length
132
- pad_len = max_length - len(chunk_with_bos_eos)
133
- if pad_len > 0:
134
- padding = torch.full((pad_len,), tokenizer.pad_token_id, dtype=torch.long, device=self.device)
135
- chunk_with_bos_eos = torch.cat([chunk_with_bos_eos, padding])
136
-
137
- batch_input_ids.append(chunk_with_bos_eos)
138
-
139
- if not batch_input_ids:
140
- # Handle empty prompt
141
- batch_input_ids.append(
142
- torch.full((max_length,), tokenizer.pad_token_id, dtype=torch.long, device=self.device)
143
- )
144
-
145
- batch_input_ids = torch.stack(batch_input_ids)
146
-
147
- # --- Encode ---
148
- if text_encoder == self.text_encoder:
149
- # Text Encoder 1 (CLIP-L)
150
- # We only need the last_hidden_state
151
- encoder_output = text_encoder(
152
- batch_input_ids,
153
- output_hidden_states=False
154
- )
155
- # [num_chunks, 77, 768]
156
- prompt_embeds = encoder_output.last_hidden_state
157
- prompt_embeds_list.append(prompt_embeds)
158
-
159
- elif text_encoder == self.text_encoder_2:
160
- # Text Encoder 2 (OpenCLIP-G)
161
- # We need hidden_states[-2] and the pooled output from the FIRST chunk
162
- encoder_output = text_encoder(
163
- batch_input_ids,
164
- output_hidden_states=True
165
- )
166
- # [num_chunks, 77, 1280]
167
- prompt_embeds = encoder_output.hidden_states[-2]
168
- prompt_embeds_list.append(prompt_embeds)
169
-
170
- # Pooled output comes from the FIRST chunk
171
- # We use .text_embeds which is the pooled output
172
- # [num_chunks, 1280]
173
- all_pooled = encoder_output.text_embeds
174
- pooled_prompt_embeds = all_pooled[0:1] # Keep as [1, 1280]
175
-
176
- # --- Concatenate Chunks ---
177
- # Reshape from [num_chunks, 77, dim] to [1, num_chunks*77, dim]
178
- # and then concatenate along the dim=-1
179
-
180
- embeds_1 = prompt_embeds_list[0].reshape(1, -1, prompt_embeds_list[0].shape[-1])
181
- embeds_2 = prompt_embeds_list[1].reshape(1, -1, prompt_embeds_list[1].shape[-1])
182
-
183
- prompt_embeds = torch.cat([embeds_1, embeds_2], dim=-1)
184
-
185
- # pooled_prompt_embeds is already [1, 1280] from Encoder 2's first chunk
186
- return prompt_embeds, pooled_prompt_embeds