Spaces:
Runtime error
Runtime error
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +19 -19
pipelines/pipeline_seesr.py
CHANGED
|
@@ -1226,37 +1226,37 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1226 |
|
| 1227 |
if use_KDS:
|
| 1228 |
|
| 1229 |
-
# 2) Compute x₀ prediction
|
| 1230 |
beta_t = 1 - self.scheduler.alphas_cumprod[t]
|
| 1231 |
alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
|
| 1232 |
sigma_t = beta_t.sqrt()
|
| 1233 |
-
x0_pred = (latents - sigma_t * noise_pred) / alpha_t
|
| 1234 |
-
|
| 1235 |
-
# 3) Apply KDE steering
|
| 1236 |
-
m_shift = kde_grad(x0_pred, patch_size=patch_size, bandwidth=bandwidth)
|
| 1237 |
-
delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
|
| 1238 |
-
x0_steer = x0_pred + delta_t * m_shift
|
| 1239 |
-
# frac = i / (len(timesteps) - 1)
|
| 1240 |
-
# delta_t = 0.0 if frac < 0.3 else 0.3
|
| 1241 |
-
# x0_steer = x0_pred + delta_t * gamma_0 * m_shift
|
| 1242 |
|
|
|
|
|
|
|
| 1243 |
|
| 1244 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1245 |
noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
|
| 1246 |
|
| 1247 |
-
#
|
| 1248 |
if i < len(timesteps) - 1:
|
| 1249 |
-
|
| 1250 |
-
|
| 1251 |
else:
|
| 1252 |
-
|
| 1253 |
-
|
| 1254 |
sigma_prev = (1 - alpha_prev**2).sqrt()
|
| 1255 |
|
| 1256 |
-
# 6) Form next latent per DDIM
|
| 1257 |
latents = (
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
).detach().requires_grad_(True)
|
| 1261 |
else:
|
| 1262 |
|
|
|
|
| 1226 |
|
| 1227 |
if use_KDS:
|
| 1228 |
|
| 1229 |
+
# 2) Compute x₀ prediction for all particles
|
| 1230 |
beta_t = 1 - self.scheduler.alphas_cumprod[t]
|
| 1231 |
alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
|
| 1232 |
sigma_t = beta_t.sqrt()
|
| 1233 |
+
x0_pred = (latents - sigma_t * noise_pred) / alpha_t # shape [2N, C, H, W]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1234 |
|
| 1235 |
+
# — split into unconditional vs. conditional
|
| 1236 |
+
x0_uncond, x0_cond = x0_pred.chunk(2, dim=0) # each [N, C, H, W]
|
| 1237 |
|
| 1238 |
+
# 3) Apply KDE steering *only* on the conditional batch
|
| 1239 |
+
m_shift_cond = kde_grad(x0_cond, bandwidth=bandwidth) # [N, C, H, W]
|
| 1240 |
+
delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
|
| 1241 |
+
x0_cond_steer = x0_cond + delta_t * m_shift_cond # steered conditional
|
| 1242 |
+
|
| 1243 |
+
# 4) Recombine the latents: leave uncond untouched, use steered cond
|
| 1244 |
+
x0_steer = torch.cat([x0_uncond, x0_cond_steer], dim=0) # [2N, C, H, W]
|
| 1245 |
+
|
| 1246 |
+
# 5) Recompute “noise” for DDIM step
|
| 1247 |
noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
|
| 1248 |
|
| 1249 |
+
# 6) Determine prev alphas and form next latent per DDIM
|
| 1250 |
if i < len(timesteps) - 1:
|
| 1251 |
+
next_t = timesteps[i + 1]
|
| 1252 |
+
alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
|
| 1253 |
else:
|
| 1254 |
+
alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
|
|
|
|
| 1255 |
sigma_prev = (1 - alpha_prev**2).sqrt()
|
| 1256 |
|
|
|
|
| 1257 |
latents = (
|
| 1258 |
+
alpha_prev * x0_steer
|
| 1259 |
+
+ sigma_prev * noise_pred_kds
|
| 1260 |
).detach().requires_grad_(True)
|
| 1261 |
else:
|
| 1262 |
|