Create pipeline.py
Browse files- pipeline.py +284 -0
pipeline.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import safetensors.torch as st
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import StableDiffusionXLPipeline
|
| 4 |
+
from transformers import T5TokenizerFast, T5EncoderModel
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from tqdm.auto import tqdm
|
| 12 |
+
|
| 13 |
+
# ─────────────────────────────────────────────────────────────
|
| 14 |
+
# ░ Two-Stream Shunt Adapter
|
| 15 |
+
# ─────────────────────────────────────────────────────────────
|
| 16 |
+
class TwoStreamShuntAdapter(nn.Module):
|
| 17 |
+
"""
|
| 18 |
+
Cross-attentive adapter that aligns T5 and CLIP token streams.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
anchor : (B, Lc, clip_dim)
|
| 22 |
+
delta : (B, Lc, clip_dim)
|
| 23 |
+
log_sigma : (B, Lc, clip_dim) – log σ, always finite
|
| 24 |
+
attn_t2c : (B, heads, Lt, Lc)
|
| 25 |
+
attn_c2t : (B, heads, Lc, Lt)
|
| 26 |
+
tau : (heads, 1, 1) – per-head threshold param
|
| 27 |
+
g_pred : (B, 1) – guidance-scale prediction
|
| 28 |
+
gate : (B, Lc, 1) – per-token gate ∈ (0,1)
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
t5_dim: int = 512,
|
| 34 |
+
clip_dim: int = 768,
|
| 35 |
+
bottleneck: int = 256,
|
| 36 |
+
heads: int = 8,
|
| 37 |
+
tau_init: float = 0.1,
|
| 38 |
+
max_guidance: float = 10.0,
|
| 39 |
+
):
|
| 40 |
+
super().__init__()
|
| 41 |
+
print("TwoStreamShuntAdapter init")
|
| 42 |
+
self.heads = heads
|
| 43 |
+
self.bneck = bottleneck
|
| 44 |
+
self.max_guidance = max_guidance
|
| 45 |
+
|
| 46 |
+
# projections
|
| 47 |
+
self.proj_t5 = nn.Linear(t5_dim, bottleneck)
|
| 48 |
+
self.proj_clip = nn.Linear(clip_dim, bottleneck)
|
| 49 |
+
|
| 50 |
+
# cross-attention
|
| 51 |
+
self.cross_t2c = nn.MultiheadAttention(
|
| 52 |
+
bottleneck, heads, batch_first=True, dropout=0.1
|
| 53 |
+
)
|
| 54 |
+
self.cross_c2t = nn.MultiheadAttention(
|
| 55 |
+
bottleneck, heads, batch_first=True, dropout=0.1
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# head-wise τ
|
| 59 |
+
self.tau = nn.Parameter(torch.full((heads, 1, 1), tau_init))
|
| 60 |
+
|
| 61 |
+
# convolutional pocket residual (depth-wise)
|
| 62 |
+
self.res1 = nn.Conv1d(
|
| 63 |
+
bottleneck, bottleneck, 3, padding=1, groups=bottleneck
|
| 64 |
+
)
|
| 65 |
+
self.res2 = nn.Conv1d(
|
| 66 |
+
bottleneck, bottleneck, 3, padding=1, groups=bottleneck
|
| 67 |
+
)
|
| 68 |
+
self.norm_res = nn.LayerNorm(bottleneck)
|
| 69 |
+
|
| 70 |
+
# fusion + projections
|
| 71 |
+
self.fuse = nn.Linear(2 * bottleneck, bottleneck)
|
| 72 |
+
|
| 73 |
+
self.anchor_proj = nn.Sequential(
|
| 74 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
| 75 |
+
nn.Linear(bottleneck, clip_dim)
|
| 76 |
+
)
|
| 77 |
+
self.delta_proj = nn.Sequential(
|
| 78 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
| 79 |
+
nn.Linear(bottleneck, clip_dim)
|
| 80 |
+
)
|
| 81 |
+
self.logsig_proj = nn.Sequential(
|
| 82 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
| 83 |
+
nn.Linear(bottleneck, clip_dim)
|
| 84 |
+
)
|
| 85 |
+
self.gate_proj = nn.Sequential(
|
| 86 |
+
nn.Linear(bottleneck, bottleneck), nn.GELU(),
|
| 87 |
+
nn.Linear(bottleneck, 1), nn.Sigmoid()
|
| 88 |
+
)
|
| 89 |
+
self.guidance_proj = nn.Sequential(
|
| 90 |
+
nn.LayerNorm(bottleneck), nn.Linear(bottleneck, 1), nn.Sigmoid()
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def load_state_dict(self, args, **kwargs):
|
| 94 |
+
# remove _orig_mod from state dict before applying.
|
| 95 |
+
state_dict = {k.replace("_orig_mod.", ""): v for k, v in args.items()}
|
| 96 |
+
super().load_state_dict(state_dict, **kwargs)
|
| 97 |
+
|
| 98 |
+
def forward(self, t5_seq: torch.Tensor, clip_seq: torch.Tensor):
|
| 99 |
+
print("📣 SHUNT FORWARD CALLED")
|
| 100 |
+
|
| 101 |
+
B, Lt, _ = t5_seq.size()
|
| 102 |
+
_, Lc, _ = clip_seq.size()
|
| 103 |
+
|
| 104 |
+
# 1) project into bottleneck
|
| 105 |
+
t5_b = self.proj_t5(t5_seq) # (B, Lt, b)
|
| 106 |
+
clip_b = self.proj_clip(clip_seq) # (B, Lc, b)
|
| 107 |
+
|
| 108 |
+
# 2) cross-attention
|
| 109 |
+
t2c, attn_t2c = self.cross_t2c(
|
| 110 |
+
t5_b, clip_b, clip_b, need_weights=True, average_attn_weights=False
|
| 111 |
+
)
|
| 112 |
+
c2t, attn_c2t = self.cross_c2t(
|
| 113 |
+
clip_b, t5_b, t5_b, need_weights=True, average_attn_weights=False
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 3) convolutional pocket on T5→CLIP
|
| 117 |
+
x = t2c.transpose(1, 2) # (B, b, Lt)
|
| 118 |
+
x = F.gelu(self.res1(x))
|
| 119 |
+
x = F.gelu(self.res2(x)).transpose(1, 2) # (B, Lt, b)
|
| 120 |
+
pocket = self.norm_res(t2c + x) # (B, Lt, b)
|
| 121 |
+
|
| 122 |
+
# 4) fuse pocket avg with C2T
|
| 123 |
+
pocket_mean = pocket.mean(1, keepdim=True).expand(-1, Lc, -1)
|
| 124 |
+
h = F.gelu(self.fuse(torch.cat([pocket_mean, c2t], -1))) # (B, Lc, b)
|
| 125 |
+
|
| 126 |
+
# 5) outputs
|
| 127 |
+
anchor = self.anchor_proj(h) # (B,Lc,768)
|
| 128 |
+
delta_mean = self.delta_proj(h) # (B,Lc,768)
|
| 129 |
+
log_sigma = self.logsig_proj(h) # (B,Lc,768)
|
| 130 |
+
gate = self.gate_proj(h) # (B,Lc,1)
|
| 131 |
+
delta = delta_mean * gate # (B,Lc,768)
|
| 132 |
+
|
| 133 |
+
g_tok = self.guidance_proj(h).squeeze(-1) # (B,Lc)
|
| 134 |
+
g_pred = g_tok.mean(1, keepdim=True) * self.max_guidance
|
| 135 |
+
|
| 136 |
+
#print(anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate)
|
| 137 |
+
|
| 138 |
+
return anchor, delta, log_sigma, attn_t2c, attn_c2t, self.tau, g_pred, gate
|
| 139 |
+
|
| 140 |
+
# --- 1. load pipeline -------------------------------------------------
|
| 141 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(
|
| 142 |
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
| 143 |
+
torch_dtype=torch.float16).to("cuda")
|
| 144 |
+
|
| 145 |
+
# --- 2. load tiny-T5 & shunt (fp32) -----------------------------------
|
| 146 |
+
t5_tok = T5TokenizerFast.from_pretrained("t5-small")
|
| 147 |
+
t5_mod = T5EncoderModel.from_pretrained("t5-small").eval().to("cuda")
|
| 148 |
+
shunt = TwoStreamShuntAdapter().float().eval().to("cuda")
|
| 149 |
+
shunt.load_state_dict( st.load_file("/content/drive/MyDrive/t5-clip-l-shunts/vitl14_t5small_shunt_vanilla_final.safetensors") )
|
| 150 |
+
|
| 151 |
+
# --- 3. wrap encode_prompt once ---------------------------------------
|
| 152 |
+
orig_encode = pipe.encode_prompt
|
| 153 |
+
|
| 154 |
+
config = {
|
| 155 |
+
"strength": 1.0,
|
| 156 |
+
"gate_gamma": 1.0,
|
| 157 |
+
"tau_scale": 1.0,
|
| 158 |
+
"guidance_gain": 1.0,
|
| 159 |
+
"guidance_bias": 0.0
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
gen = torch.Generator(device="cuda").manual_seed(420)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
strength = 0
|
| 168 |
+
|
| 169 |
+
# the working version that can't be omitted,
|
| 170 |
+
def stable_encode_prompt_shunted(self, *args, **kw):
|
| 171 |
+
pe, ne, pool, npool = orig_encode(*args, **kw) # regular call
|
| 172 |
+
|
| 173 |
+
# 👉 split: first 768 dims are CLIP-L, rest 1280 are CLIP-G
|
| 174 |
+
clipL, clipG = pe[..., :768], pe[..., 768:]
|
| 175 |
+
|
| 176 |
+
# build T5 batch (handles CFG dup automatically because
|
| 177 |
+
# encode_prompt already concatenated negative & positive if needed)
|
| 178 |
+
bsz = clipL.shape[0]
|
| 179 |
+
texts = ["tmp"] * bsz # dummy, we only care about hidden states
|
| 180 |
+
t5_ids = t5_tok(texts, return_tensors="pt").input_ids.to("cuda")
|
| 181 |
+
t5_seq = t5_mod(t5_ids).last_hidden_state # (B,L,512)
|
| 182 |
+
|
| 183 |
+
# run adapter in fp32
|
| 184 |
+
delta = shunt(t5_seq.float(), clipL.float())[1] # second output is Δ
|
| 185 |
+
delta = delta * strength # << your strength knob
|
| 186 |
+
clipL_shift = (clipL.float() + delta).to(clipL.dtype)
|
| 187 |
+
|
| 188 |
+
pe_shifted = torch.cat([clipL_shift, clipG], dim=-1)
|
| 189 |
+
return pe_shifted, ne, pool, npool
|
| 190 |
+
#-----------------------------------------------------------------------------------------
|
| 191 |
+
|
| 192 |
+
def encode_prompt_shunted(self, *a, **k):
|
| 193 |
+
# 1) run the normal encoder with “style” & “context” already split
|
| 194 |
+
pe, ne, pool, npool = orig_encode(*a, **k) # (B,77,2048)
|
| 195 |
+
|
| 196 |
+
# 2) split CLIP-L / CLIP-G
|
| 197 |
+
clipL, clipG = pe[..., :768], pe[..., 768:]
|
| 198 |
+
|
| 199 |
+
# 3) build T5 on the *context* text (it’s in k['prompt_2'])
|
| 200 |
+
t5_ids = t5_tok([k.get("prompt_2")], return_tensors="pt").input_ids.to(pe.device)
|
| 201 |
+
t5_seq = t5_mod(t5_ids).last_hidden_state.float()
|
| 202 |
+
|
| 203 |
+
# 4) shunt → Δ (FP32 → back-cast)
|
| 204 |
+
Δ = shunt(t5_seq, clipL.float())[1].to(clipL.dtype)
|
| 205 |
+
clipL_shift = clipL + Δ * strength
|
| 206 |
+
|
| 207 |
+
# 5) concatenate back
|
| 208 |
+
pe_shift = torch.cat([clipL_shift, clipG], dim=-1)
|
| 209 |
+
return pe_shift, ne, pool, npool
|
| 210 |
+
|
| 211 |
+
pipe.encode_prompt = encode_prompt_shunted.__get__(pipe, type(pipe))
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
PROMPT = "a naturally lit and beautiful room with a photorealistic depiction of a woman"
|
| 217 |
+
PROMPT_2 = "a realistic depiction of a woman sitting on a chair at a coffee shop sipping coffee, the environment is beautiful"
|
| 218 |
+
NEG = "blurry, distorted, monochrome, greyscale, watermark"
|
| 219 |
+
STEPS = 50
|
| 220 |
+
base_strength = 0.5
|
| 221 |
+
base_cfg = 7.5
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
for i in range(0, 4):
|
| 225 |
+
strength = base_strength + (i * 0.25)
|
| 226 |
+
cfg = base_cfg - (i * 0.25)
|
| 227 |
+
img = pipe(
|
| 228 |
+
PROMPT,
|
| 229 |
+
prompt_2=PROMPT_2,
|
| 230 |
+
negative_prompt=NEG,
|
| 231 |
+
num_inference_steps=STEPS,
|
| 232 |
+
cfg_scale=cfg,
|
| 233 |
+
generator=torch.Generator(device="cuda").manual_seed(420)
|
| 234 |
+
).images[0]
|
| 235 |
+
img.save(f"woman_cfg_{int(cfg*100)}_{int(strength*100)}.png")
|
| 236 |
+
|
| 237 |
+
# --- 4. generate -------------------------------------------------------
|
| 238 |
+
#img = pipe(
|
| 239 |
+
# PROMPT,
|
| 240 |
+
# negative_prompt=NEG,
|
| 241 |
+
# num_inference_steps=STEPS,
|
| 242 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
| 243 |
+
# ).images[0]
|
| 244 |
+
#img.save("majestic_baseline.png")#
|
| 245 |
+
#
|
| 246 |
+
|
| 247 |
+
#strength = 0.25
|
| 248 |
+
## --- 4. generate -------------------------------------------------------
|
| 249 |
+
#img = pipe(
|
| 250 |
+
# PROMPT,
|
| 251 |
+
# negative_prompt=NEG,
|
| 252 |
+
# num_inference_steps=STEPS,
|
| 253 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
| 254 |
+
# ).images[0]
|
| 255 |
+
#img.save("majestic_02.png")#
|
| 256 |
+
|
| 257 |
+
#strength = 0.5
|
| 258 |
+
## --- 4. generate -------------------------------------------------------
|
| 259 |
+
#img = pipe(
|
| 260 |
+
# PROMPT,
|
| 261 |
+
# negative_prompt=NEG,
|
| 262 |
+
# num_inference_steps=STEPS,
|
| 263 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
| 264 |
+
# ).images[0]
|
| 265 |
+
#img.save("majestic_05.png")#
|
| 266 |
+
|
| 267 |
+
#strength = 0.75
|
| 268 |
+
## --- 4. generate -------------------------------------------------------
|
| 269 |
+
#img = pipe(
|
| 270 |
+
# PROMPT,
|
| 271 |
+
# negative_prompt=NEG,
|
| 272 |
+
# num_inference_steps=STEPS,
|
| 273 |
+
# generator=torch.Generator(device="cuda").manual_seed(420)
|
| 274 |
+
# ).images[0]
|
| 275 |
+
#img.save("majestic_075.png")
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
|