Spaces:
Runtime error
Runtime error
Update pipelines/pipeline_seesr.py
Browse files
pipelines/pipeline_seesr.py
CHANGED
|
@@ -99,9 +99,7 @@ EXAMPLE_DOC_STRING = """
|
|
| 99 |
def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
|
| 100 |
# x0: (N, C, H, W) in float32
|
| 101 |
N, C, H, W = x0.shape
|
| 102 |
-
patches = unfold(
|
| 103 |
-
x0, kernel_size=patch_size, stride=patch_size//2
|
| 104 |
-
) # (N, C*ps*ps, M)
|
| 105 |
P, M = patches.shape[1], patches.shape[2]
|
| 106 |
p_i = patches.unsqueeze(1) # (N,1,P,M)
|
| 107 |
p_j = patches.unsqueeze(0) # (1,N,P,M)
|
|
@@ -113,13 +111,15 @@ def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
|
|
| 113 |
num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
|
| 114 |
denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
|
| 115 |
mshift = num / denom # (N,P,M)
|
|
|
|
| 116 |
# fold back
|
| 117 |
grad = fold(
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
)
|
|
|
|
| 123 |
return grad
|
| 124 |
|
| 125 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
|
|
|
| 99 |
def kde_grad(x0: torch.Tensor, patch_size = 16, bandwidth = 0.1):
|
| 100 |
# x0: (N, C, H, W) in float32
|
| 101 |
N, C, H, W = x0.shape
|
| 102 |
+
patches = unfold(x0, kernel_size=patch_size, stride=patch_size) # (N, C*ps*ps, M)
|
|
|
|
|
|
|
| 103 |
P, M = patches.shape[1], patches.shape[2]
|
| 104 |
p_i = patches.unsqueeze(1) # (N,1,P,M)
|
| 105 |
p_j = patches.unsqueeze(0) # (1,N,P,M)
|
|
|
|
| 111 |
num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
|
| 112 |
denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
|
| 113 |
mshift = num / denom # (N,P,M)
|
| 114 |
+
|
| 115 |
# fold back
|
| 116 |
grad = fold(
|
| 117 |
+
mshift / bandwidth**2,
|
| 118 |
+
output_size=(H, W),
|
| 119 |
+
kernel_size=patch_size,
|
| 120 |
+
stride=patch_size
|
| 121 |
+
) # (N, C, H, W)
|
| 122 |
+
|
| 123 |
return grad
|
| 124 |
|
| 125 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|