Spaces:
Runtime error
Runtime error
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +38 -57
pipelines/pipeline_seesr.py
CHANGED
|
@@ -99,7 +99,9 @@ EXAMPLE_DOC_STRING = """
|
|
| 99 |
def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
|
| 100 |
# x0: (N, C, H, W) in float32
|
| 101 |
N, C, H, W = x0.shape
|
| 102 |
-
patches = unfold(
|
|
|
|
|
|
|
| 103 |
P, M = patches.shape[1], patches.shape[2]
|
| 104 |
p_i = patches.unsqueeze(1) # (N,1,P,M)
|
| 105 |
p_j = patches.unsqueeze(0) # (1,N,P,M)
|
|
@@ -111,15 +113,13 @@ def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
|
|
| 111 |
num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
|
| 112 |
denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
|
| 113 |
mshift = num / denom # (N,P,M)
|
| 114 |
-
|
| 115 |
# fold back
|
| 116 |
grad = fold(
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
)
|
| 122 |
-
|
| 123 |
return grad
|
| 124 |
|
| 125 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
|
@@ -835,8 +835,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 835 |
num_particles: Optional[int] = 4,
|
| 836 |
gamma_0: Optional[float] = 0.1, # base steering strength
|
| 837 |
use_KDS = True,
|
| 838 |
-
bandwidth = 0.1,
|
| 839 |
patch_size = 16,
|
|
|
|
| 840 |
args=None,
|
| 841 |
):
|
| 842 |
r"""
|
|
@@ -1050,9 +1050,9 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1050 |
for i, t in enumerate(timesteps):
|
| 1051 |
with torch.no_grad():
|
| 1052 |
# pass, if the timestep is larger than start_steps
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
|
| 1057 |
# expand the latents if we are doing classifier free guidance
|
| 1058 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
@@ -1189,7 +1189,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1189 |
cond_list = []
|
| 1190 |
img_list = []
|
| 1191 |
|
| 1192 |
-
|
| 1193 |
|
| 1194 |
# Stitch noise predictions for all tiles
|
| 1195 |
noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
|
@@ -1226,69 +1226,50 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1226 |
|
| 1227 |
if use_KDS:
|
| 1228 |
|
| 1229 |
-
# 2) Compute x₀ prediction
|
| 1230 |
beta_t = 1 - self.scheduler.alphas_cumprod[t]
|
| 1231 |
alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
|
| 1232 |
sigma_t = beta_t.sqrt()
|
| 1233 |
-
x0_pred = (latents - sigma_t * noise_pred) / alpha_t
|
| 1234 |
-
|
| 1235 |
-
#
|
| 1236 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1237 |
|
| 1238 |
-
# 3) Apply KDE steering *only* on the conditional batch
|
| 1239 |
-
m_shift_cond = kde_grad(x0_cond, patch_size=patch_size, bandwidth=bandwidth) # [N, C, H, W]
|
| 1240 |
-
delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
|
| 1241 |
-
x0_cond_steer = x0_cond + delta_t * m_shift_cond # steered conditional
|
| 1242 |
|
| 1243 |
-
# 4)
|
| 1244 |
-
x0_steer = torch.cat([x0_uncond, x0_cond_steer], dim=0) # [2N, C, H, W]
|
| 1245 |
-
|
| 1246 |
-
# 5) Recompute “noise” for DDIM step
|
| 1247 |
noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
|
| 1248 |
|
| 1249 |
-
#
|
| 1250 |
if i < len(timesteps) - 1:
|
| 1251 |
-
|
| 1252 |
-
|
| 1253 |
else:
|
| 1254 |
-
|
|
|
|
| 1255 |
sigma_prev = (1 - alpha_prev**2).sqrt()
|
| 1256 |
|
|
|
|
| 1257 |
latents = (
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
).detach().requires_grad_(True)
|
| 1261 |
else:
|
| 1262 |
|
| 1263 |
# compute the previous noisy sample x_t -> x_t-1
|
| 1264 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1265 |
|
| 1266 |
-
|
| 1267 |
-
|
| 1268 |
-
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
# 1) ensemble mean
|
| 1272 |
-
mean_cond = cond_latents.mean(dim=0, keepdim=True) # [1, C, H, W]
|
| 1273 |
-
# 2) distances
|
| 1274 |
-
dists = ((cond_latents - mean_cond)
|
| 1275 |
-
.view(cond_latents.size(0), -1)
|
| 1276 |
-
.pow(2)
|
| 1277 |
-
.sum(dim=1)) # [N]
|
| 1278 |
-
# 3) best index
|
| 1279 |
-
best_idx = dists.argmin().item()
|
| 1280 |
-
# 4) select that latent (and its uncond pair)
|
| 1281 |
-
best_uncond = uncond_latents[best_idx:best_idx+1]
|
| 1282 |
-
best_cond = cond_latents [best_idx:best_idx+1]
|
| 1283 |
-
latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
|
| 1284 |
-
|
| 1285 |
-
# call the callback, if provided
|
| 1286 |
-
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1287 |
-
progress_bar.update()
|
| 1288 |
-
if callback is not None and i % callback_steps == 0:
|
| 1289 |
-
callback(i, t, latents)
|
| 1290 |
|
| 1291 |
-
|
| 1292 |
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1293 |
# manually for max memory savings
|
| 1294 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
|
|
| 99 |
def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
|
| 100 |
# x0: (N, C, H, W) in float32
|
| 101 |
N, C, H, W = x0.shape
|
| 102 |
+
patches = unfold(
|
| 103 |
+
x0, kernel_size=patch_size, stride=patch_size//2
|
| 104 |
+
) # (N, C*ps*ps, M)
|
| 105 |
P, M = patches.shape[1], patches.shape[2]
|
| 106 |
p_i = patches.unsqueeze(1) # (N,1,P,M)
|
| 107 |
p_j = patches.unsqueeze(0) # (1,N,P,M)
|
|
|
|
| 113 |
num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
|
| 114 |
denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
|
| 115 |
mshift = num / denom # (N,P,M)
|
|
|
|
| 116 |
# fold back
|
| 117 |
grad = fold(
|
| 118 |
+
mshift / bandwidth**2,
|
| 119 |
+
output_size=(H, W),
|
| 120 |
+
kernel_size=patch_size,
|
| 121 |
+
stride=patch_size//2
|
| 122 |
+
) # (N, C, H, W)
|
|
|
|
| 123 |
return grad
|
| 124 |
|
| 125 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
|
|
|
| 835 |
num_particles: Optional[int] = 4,
|
| 836 |
gamma_0: Optional[float] = 0.1, # base steering strength
|
| 837 |
use_KDS = True,
|
|
|
|
| 838 |
patch_size = 16,
|
| 839 |
+
bandwidth = 0.1,
|
| 840 |
args=None,
|
| 841 |
):
|
| 842 |
r"""
|
|
|
|
| 1050 |
for i, t in enumerate(timesteps):
|
| 1051 |
with torch.no_grad():
|
| 1052 |
# pass, if the timestep is larger than start_steps
|
| 1053 |
+
if t > start_steps:
|
| 1054 |
+
print(f'pass {t} steps.')
|
| 1055 |
+
continue
|
| 1056 |
|
| 1057 |
# expand the latents if we are doing classifier free guidance
|
| 1058 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
|
|
| 1189 |
cond_list = []
|
| 1190 |
img_list = []
|
| 1191 |
|
| 1192 |
+
noise_preds.append(model_out)
|
| 1193 |
|
| 1194 |
# Stitch noise predictions for all tiles
|
| 1195 |
noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
|
|
|
| 1226 |
|
| 1227 |
if use_KDS:
|
| 1228 |
|
| 1229 |
+
# 2) Compute x₀ prediction
|
| 1230 |
beta_t = 1 - self.scheduler.alphas_cumprod[t]
|
| 1231 |
alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
|
| 1232 |
sigma_t = beta_t.sqrt()
|
| 1233 |
+
x0_pred = (latents - sigma_t * noise_pred) / alpha_t
|
| 1234 |
+
|
| 1235 |
+
# 3) Apply KDE steering
|
| 1236 |
+
m_shift = kde_grad(x0_pred, patch_size=patch_size, bandwidth=bandwidth)
|
| 1237 |
+
delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
|
| 1238 |
+
x0_steer = x0_pred + delta_t * m_shift
|
| 1239 |
+
# frac = i / (len(timesteps) - 1)
|
| 1240 |
+
# delta_t = 0.0 if frac < 0.3 else 0.3
|
| 1241 |
+
# x0_steer = x0_pred + delta_t * gamma_0 * m_shift
|
| 1242 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1243 |
|
| 1244 |
+
# 4) Recompute “noise” for DDIM step
|
|
|
|
|
|
|
|
|
|
| 1245 |
noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
|
| 1246 |
|
| 1247 |
+
# 5) Determine prev alphas
|
| 1248 |
if i < len(timesteps) - 1:
|
| 1249 |
+
next_t = timesteps[i + 1]
|
| 1250 |
+
alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
|
| 1251 |
else:
|
| 1252 |
+
alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
|
| 1253 |
+
|
| 1254 |
sigma_prev = (1 - alpha_prev**2).sqrt()
|
| 1255 |
|
| 1256 |
+
# 6) Form next latent per DDIM
|
| 1257 |
latents = (
|
| 1258 |
+
alpha_prev * x0_steer
|
| 1259 |
+
+ sigma_prev * noise_pred_kds
|
| 1260 |
).detach().requires_grad_(True)
|
| 1261 |
else:
|
| 1262 |
|
| 1263 |
# compute the previous noisy sample x_t -> x_t-1
|
| 1264 |
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1265 |
|
| 1266 |
+
# call the callback, if provided
|
| 1267 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 1268 |
+
progress_bar.update()
|
| 1269 |
+
if callback is not None and i % callback_steps == 0:
|
| 1270 |
+
callback(i, t, latents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1271 |
|
| 1272 |
+
with torch.no_grad():
|
| 1273 |
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1274 |
# manually for max memory savings
|
| 1275 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|