Spaces:
Runtime error
Runtime error
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +27 -1
pipelines/pipeline_seesr.py
CHANGED
|
@@ -95,7 +95,31 @@ EXAMPLE_DOC_STRING = """
|
|
| 95 |
... ).images[0]
|
| 96 |
```
|
| 97 |
"""
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
| 101 |
r"""
|
|
@@ -807,6 +831,8 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 807 |
ram_encoder_hidden_states=None,
|
| 808 |
latent_tiled_size=320,
|
| 809 |
latent_tiled_overlap=4,
|
|
|
|
|
|
|
| 810 |
use_KDS=True,
|
| 811 |
args=None,
|
| 812 |
):
|
|
|
|
| 95 |
... ).images[0]
|
| 96 |
```
|
| 97 |
"""
|
| 98 |
+
def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
|
| 99 |
+
# x0: (N, C, H, W) in float32
|
| 100 |
+
N, C, H, W = x0.shape
|
| 101 |
+
patches = unfold(
|
| 102 |
+
x0, kernel_size=patch_size, stride=patch_size//2
|
| 103 |
+
) # (N, C*ps*ps, M)
|
| 104 |
+
P, M = patches.shape[1], patches.shape[2]
|
| 105 |
+
p_i = patches.unsqueeze(1) # (N,1,P,M)
|
| 106 |
+
p_j = patches.unsqueeze(0) # (1,N,P,M)
|
| 107 |
+
diff = p_j - p_i # (N,N,P,M)
|
| 108 |
+
# Gaussian weights
|
| 109 |
+
w = torch.exp((-0.5 / bandwidth**2) *
|
| 110 |
+
(diff.square().sum(dim=2))) # (N,N,M)
|
| 111 |
+
# mean-shift numerator & normalizer
|
| 112 |
+
num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
|
| 113 |
+
denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
|
| 114 |
+
mshift = num / denom # (N,P,M)
|
| 115 |
+
# fold back
|
| 116 |
+
grad = fold(
|
| 117 |
+
mshift / bandwidth**2,
|
| 118 |
+
output_size=(H, W),
|
| 119 |
+
kernel_size=patch_size,
|
| 120 |
+
stride=patch_size//2
|
| 121 |
+
) # (N, C, H, W)
|
| 122 |
+
return grad
|
| 123 |
|
| 124 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
| 125 |
r"""
|
|
|
|
| 831 |
ram_encoder_hidden_states=None,
|
| 832 |
latent_tiled_size=320,
|
| 833 |
latent_tiled_overlap=4,
|
| 834 |
+
num_particles: Optional[int] = 4,
|
| 835 |
+
gamma_0: Optional[float] = 0.1, # base steering strength
|
| 836 |
use_KDS=True,
|
| 837 |
args=None,
|
| 838 |
):
|