Spaces:
Runtime error
Runtime error
Update pipelines/pipeline_seesr.py
Browse files- pipelines/pipeline_seesr.py +199 -295
pipelines/pipeline_seesr.py
CHANGED
|
@@ -22,7 +22,6 @@ import numpy as np
|
|
| 22 |
import PIL.Image
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
| 25 |
-
from torch.nn.functional import unfold, fold
|
| 26 |
from torchvision.utils import save_image
|
| 27 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 28 |
|
|
@@ -96,31 +95,7 @@ EXAMPLE_DOC_STRING = """
|
|
| 96 |
... ).images[0]
|
| 97 |
```
|
| 98 |
"""
|
| 99 |
-
|
| 100 |
-
# x0: (N, C, H, W) in float32
|
| 101 |
-
N, C, H, W = x0.shape
|
| 102 |
-
patches = unfold(
|
| 103 |
-
x0, kernel_size=patch_size, stride=patch_size//2
|
| 104 |
-
) # (N, C*ps*ps, M)
|
| 105 |
-
P, M = patches.shape[1], patches.shape[2]
|
| 106 |
-
p_i = patches.unsqueeze(1) # (N,1,P,M)
|
| 107 |
-
p_j = patches.unsqueeze(0) # (1,N,P,M)
|
| 108 |
-
diff = p_j - p_i # (N,N,P,M)
|
| 109 |
-
# Gaussian weights
|
| 110 |
-
w = torch.exp((-0.5 / bandwidth**2) *
|
| 111 |
-
(diff.square().sum(dim=2))) # (N,N,M)
|
| 112 |
-
# mean-shift numerator & normalizer
|
| 113 |
-
num = (w.unsqueeze(2) * diff).sum(dim=1) # (N,P,M)
|
| 114 |
-
denom = w.sum(dim=1, keepdim=True) + 1e-8 # (N,1,M)
|
| 115 |
-
mshift = num / denom # (N,P,M)
|
| 116 |
-
# fold back
|
| 117 |
-
grad = fold(
|
| 118 |
-
mshift / bandwidth**2,
|
| 119 |
-
output_size=(H, W),
|
| 120 |
-
kernel_size=patch_size,
|
| 121 |
-
stride=patch_size//2
|
| 122 |
-
) # (N, C, H, W)
|
| 123 |
-
return grad
|
| 124 |
|
| 125 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
| 126 |
r"""
|
|
@@ -803,6 +778,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 803 |
return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
|
| 804 |
|
| 805 |
@perfcount
|
|
|
|
| 806 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 807 |
def __call__(
|
| 808 |
self,
|
|
@@ -832,12 +808,7 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 832 |
ram_encoder_hidden_states=None,
|
| 833 |
latent_tiled_size=320,
|
| 834 |
latent_tiled_overlap=4,
|
| 835 |
-
|
| 836 |
-
gamma_0: Optional[float] = 0.1, # base steering strength
|
| 837 |
-
use_KDS = True,
|
| 838 |
-
patch_size = 16,
|
| 839 |
-
bandwidth = 0.1,
|
| 840 |
-
args=None,
|
| 841 |
):
|
| 842 |
r"""
|
| 843 |
Function invoked when calling the pipeline for generation.
|
|
@@ -1025,17 +996,6 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1025 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1026 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1027 |
|
| 1028 |
-
if use_KDS:
|
| 1029 |
-
# 1) update batch_size to account for the new particles
|
| 1030 |
-
batch_size = batch_size * num_particles
|
| 1031 |
-
|
| 1032 |
-
# 2) now repeat latents/images/prompts
|
| 1033 |
-
latents = latents.repeat_interleave(num_particles, dim=0)
|
| 1034 |
-
image = image.repeat_interleave(num_particles, dim=0)
|
| 1035 |
-
ram_encoder_hidden_states = ram_encoder_hidden_states.repeat_interleave(num_particles, dim=0)
|
| 1036 |
-
prompt_embeds = prompt_embeds.repeat_interleave(num_particles, dim=0)
|
| 1037 |
-
latents.requires_grad_(True)
|
| 1038 |
-
|
| 1039 |
# 8. Denoising loop
|
| 1040 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1041 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
@@ -1048,220 +1008,184 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1048 |
print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
|
| 1049 |
|
| 1050 |
for i, t in enumerate(timesteps):
|
| 1051 |
-
|
| 1052 |
-
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1056 |
|
| 1057 |
-
# expand the latents if we are doing classifier free guidance
|
| 1058 |
-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1059 |
-
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1060 |
|
| 1061 |
-
# controlnet(s) inference
|
| 1062 |
if guess_mode and do_classifier_free_guidance:
|
| 1063 |
-
#
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
|
| 1069 |
-
|
| 1070 |
-
|
| 1071 |
-
|
| 1072 |
-
|
| 1073 |
-
|
| 1074 |
-
|
| 1075 |
-
|
| 1076 |
-
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
|
| 1092 |
-
|
| 1093 |
-
|
| 1094 |
-
|
| 1095 |
-
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
-
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
|
| 1108 |
-
|
| 1109 |
-
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
|
| 1131 |
-
|
| 1132 |
-
|
| 1133 |
-
|
| 1134 |
-
|
| 1135 |
-
|
| 1136 |
-
|
| 1137 |
-
|
| 1138 |
-
|
| 1139 |
-
|
| 1140 |
-
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
|
| 1157 |
-
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
|
| 1161 |
-
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
| 1172 |
-
|
| 1173 |
-
|
| 1174 |
-
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
-
|
| 1194 |
-
|
| 1195 |
-
|
| 1196 |
-
|
| 1197 |
-
|
| 1198 |
-
|
| 1199 |
-
|
| 1200 |
-
|
| 1201 |
-
|
| 1202 |
-
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
| 1203 |
-
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
| 1204 |
-
# input tile area on total image
|
| 1205 |
-
if row == grid_rows-1:
|
| 1206 |
-
ofs_x = w - tile_size
|
| 1207 |
-
if col == grid_cols-1:
|
| 1208 |
-
ofs_y = h - tile_size
|
| 1209 |
-
|
| 1210 |
-
input_start_x = ofs_x
|
| 1211 |
-
input_end_x = ofs_x + tile_size
|
| 1212 |
-
input_start_y = ofs_y
|
| 1213 |
-
input_end_y = ofs_y + tile_size
|
| 1214 |
-
|
| 1215 |
-
noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
|
| 1216 |
-
contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
|
| 1217 |
-
# Average overlapping areas with more than 1 contributor
|
| 1218 |
-
noise_pred /= contributors
|
| 1219 |
-
|
| 1220 |
-
|
| 1221 |
-
# perform guidance
|
| 1222 |
-
if do_classifier_free_guidance:
|
| 1223 |
-
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1224 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 1225 |
|
| 1226 |
|
| 1227 |
-
if use_KDS:
|
| 1228 |
|
| 1229 |
-
|
| 1230 |
-
|
| 1231 |
-
alpha_t = self.scheduler.alphas_cumprod[t].sqrt()
|
| 1232 |
-
sigma_t = beta_t.sqrt()
|
| 1233 |
-
x0_pred = (latents - sigma_t * noise_pred) / alpha_t # shape [2N, C, H, W]
|
| 1234 |
-
|
| 1235 |
-
# — split into unconditional vs. conditional
|
| 1236 |
-
x0_uncond, x0_cond = x0_pred.chunk(2, dim=0) # each [N, C, H, W]
|
| 1237 |
-
|
| 1238 |
-
# 3) Apply KDE steering *only* on the conditional batch
|
| 1239 |
-
m_shift_cond = kde_grad(x0_cond, bandwidth=bandwidth) # [N, C, H, W]
|
| 1240 |
-
delta_t = gamma_0 * (1 - i / (len(timesteps) - 1))
|
| 1241 |
-
x0_cond_steer = x0_cond + delta_t * m_shift_cond # steered conditional
|
| 1242 |
-
|
| 1243 |
-
# 4) Recombine the latents: leave uncond untouched, use steered cond
|
| 1244 |
-
x0_steer = torch.cat([x0_uncond, x0_cond_steer], dim=0) # [2N, C, H, W]
|
| 1245 |
-
|
| 1246 |
-
# 5) Recompute “noise” for DDIM step
|
| 1247 |
-
noise_pred_kds = (latents - alpha_t * x0_steer) / sigma_t
|
| 1248 |
-
|
| 1249 |
-
# 6) Determine prev alphas and form next latent per DDIM
|
| 1250 |
-
if i < len(timesteps) - 1:
|
| 1251 |
-
next_t = timesteps[i + 1]
|
| 1252 |
-
alpha_prev = self.scheduler.alphas_cumprod[next_t].sqrt()
|
| 1253 |
-
else:
|
| 1254 |
-
alpha_prev = self.scheduler.final_alpha_cumprod.sqrt()
|
| 1255 |
-
sigma_prev = (1 - alpha_prev**2).sqrt()
|
| 1256 |
-
|
| 1257 |
-
latents = (
|
| 1258 |
-
alpha_prev * x0_steer
|
| 1259 |
-
+ sigma_prev * noise_pred_kds
|
| 1260 |
-
).detach().requires_grad_(True)
|
| 1261 |
-
else:
|
| 1262 |
-
|
| 1263 |
-
# compute the previous noisy sample x_t -> x_t-1
|
| 1264 |
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 1265 |
|
| 1266 |
# call the callback, if provided
|
| 1267 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
@@ -1269,53 +1193,33 @@ class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoade
|
|
| 1269 |
if callback is not None and i % callback_steps == 0:
|
| 1270 |
callback(i, t, latents)
|
| 1271 |
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
| 1275 |
-
|
| 1276 |
-
|
| 1277 |
-
|
| 1278 |
-
|
| 1279 |
-
|
| 1280 |
-
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
# 3) best index
|
| 1286 |
-
best_idx = dists.argmin().item()
|
| 1287 |
-
# 4) select that latent (and its uncond pair)
|
| 1288 |
-
best_uncond = uncond_latents[best_idx:best_idx+1]
|
| 1289 |
-
best_cond = cond_latents [best_idx:best_idx+1]
|
| 1290 |
-
latents = torch.cat([best_uncond, best_cond], dim=0) # [2, C, H, W]
|
| 1291 |
-
|
| 1292 |
-
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1293 |
-
# manually for max memory savings
|
| 1294 |
-
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 1295 |
-
self.unet.to("cpu")
|
| 1296 |
-
self.controlnet.to("cpu")
|
| 1297 |
-
torch.cuda.empty_cache()
|
| 1298 |
-
|
| 1299 |
has_nsfw_concept = None
|
| 1300 |
-
if not output_type == "latent":
|
| 1301 |
-
image = self.vae.decode(latents.detach() / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
|
| 1302 |
-
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 1303 |
-
else:
|
| 1304 |
-
image = latents.detach()
|
| 1305 |
-
has_nsfw_concept = None
|
| 1306 |
|
| 1307 |
-
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
|
| 1312 |
-
|
| 1313 |
|
| 1314 |
-
|
| 1315 |
-
|
| 1316 |
-
|
| 1317 |
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
|
| 1321 |
-
|
|
|
|
| 22 |
import PIL.Image
|
| 23 |
import torch
|
| 24 |
import torch.nn.functional as F
|
|
|
|
| 25 |
from torchvision.utils import save_image
|
| 26 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
| 27 |
|
|
|
|
| 95 |
... ).images[0]
|
| 96 |
```
|
| 97 |
"""
|
| 98 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
class StableDiffusionControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
|
| 101 |
r"""
|
|
|
|
| 778 |
return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
|
| 779 |
|
| 780 |
@perfcount
|
| 781 |
+
@torch.no_grad()
|
| 782 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 783 |
def __call__(
|
| 784 |
self,
|
|
|
|
| 808 |
ram_encoder_hidden_states=None,
|
| 809 |
latent_tiled_size=320,
|
| 810 |
latent_tiled_overlap=4,
|
| 811 |
+
args=None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
):
|
| 813 |
r"""
|
| 814 |
Function invoked when calling the pipeline for generation.
|
|
|
|
| 996 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 997 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 998 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 999 |
# 8. Denoising loop
|
| 1000 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1001 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
|
|
| 1008 |
print(f"[Tiled Latent]: the input size is {image.shape[-2]}x{image.shape[-1]}, need to tiled")
|
| 1009 |
|
| 1010 |
for i, t in enumerate(timesteps):
|
| 1011 |
+
# pass, if the timestep is larger than start_steps
|
| 1012 |
+
if t > start_steps:
|
| 1013 |
+
print(f'pass {t} steps.')
|
| 1014 |
+
continue
|
| 1015 |
+
|
| 1016 |
+
# expand the latents if we are doing classifier free guidance
|
| 1017 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 1018 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 1019 |
+
|
| 1020 |
+
# controlnet(s) inference
|
| 1021 |
+
if guess_mode and do_classifier_free_guidance:
|
| 1022 |
+
# Infer ControlNet only for the conditional batch.
|
| 1023 |
+
controlnet_latent_model_input = latents
|
| 1024 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
| 1025 |
+
|
| 1026 |
+
else:
|
| 1027 |
+
controlnet_latent_model_input = latent_model_input
|
| 1028 |
+
controlnet_prompt_embeds = prompt_embeds
|
| 1029 |
+
|
| 1030 |
+
if h*w<=tile_size*tile_size: # tiled latent input
|
| 1031 |
+
down_block_res_samples, mid_block_res_sample = [None]*10, None
|
| 1032 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1033 |
+
controlnet_latent_model_input,
|
| 1034 |
+
t,
|
| 1035 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 1036 |
+
controlnet_cond=image,
|
| 1037 |
+
conditioning_scale=conditioning_scale,
|
| 1038 |
+
guess_mode=guess_mode,
|
| 1039 |
+
return_dict=False,
|
| 1040 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
| 1041 |
+
)
|
| 1042 |
|
|
|
|
|
|
|
|
|
|
| 1043 |
|
|
|
|
| 1044 |
if guess_mode and do_classifier_free_guidance:
|
| 1045 |
+
# Infered ControlNet only for the conditional batch.
|
| 1046 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 1047 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 1048 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 1049 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 1050 |
+
|
| 1051 |
+
# predict the noise residual
|
| 1052 |
+
noise_pred = self.unet(
|
| 1053 |
+
latent_model_input,
|
| 1054 |
+
t,
|
| 1055 |
+
encoder_hidden_states=prompt_embeds,
|
| 1056 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1057 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 1058 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 1059 |
+
return_dict=False,
|
| 1060 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
| 1061 |
+
)[0]
|
| 1062 |
+
else:
|
| 1063 |
+
tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
|
| 1064 |
+
tile_size = min(tile_size, min(h, w))
|
| 1065 |
+
tile_weights = self._gaussian_weights(tile_size, tile_size, 1)
|
| 1066 |
+
|
| 1067 |
+
grid_rows = 0
|
| 1068 |
+
cur_x = 0
|
| 1069 |
+
while cur_x < latent_model_input.size(-1):
|
| 1070 |
+
cur_x = max(grid_rows * tile_size-tile_overlap * grid_rows, 0)+tile_size
|
| 1071 |
+
grid_rows += 1
|
| 1072 |
+
|
| 1073 |
+
grid_cols = 0
|
| 1074 |
+
cur_y = 0
|
| 1075 |
+
while cur_y < latent_model_input.size(-2):
|
| 1076 |
+
cur_y = max(grid_cols * tile_size-tile_overlap * grid_cols, 0)+tile_size
|
| 1077 |
+
grid_cols += 1
|
| 1078 |
+
|
| 1079 |
+
input_list = []
|
| 1080 |
+
cond_list = []
|
| 1081 |
+
img_list = []
|
| 1082 |
+
noise_preds = []
|
| 1083 |
+
for row in range(grid_rows):
|
| 1084 |
+
noise_preds_row = []
|
| 1085 |
+
for col in range(grid_cols):
|
| 1086 |
+
if col < grid_cols-1 or row < grid_rows-1:
|
| 1087 |
+
# extract tile from input image
|
| 1088 |
+
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
| 1089 |
+
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
| 1090 |
+
# input tile area on total image
|
| 1091 |
+
if row == grid_rows-1:
|
| 1092 |
+
ofs_x = w - tile_size
|
| 1093 |
+
if col == grid_cols-1:
|
| 1094 |
+
ofs_y = h - tile_size
|
| 1095 |
+
|
| 1096 |
+
input_start_x = ofs_x
|
| 1097 |
+
input_end_x = ofs_x + tile_size
|
| 1098 |
+
input_start_y = ofs_y
|
| 1099 |
+
input_end_y = ofs_y + tile_size
|
| 1100 |
+
|
| 1101 |
+
# input tile dimensions
|
| 1102 |
+
input_tile = latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
|
| 1103 |
+
input_list.append(input_tile)
|
| 1104 |
+
cond_tile = controlnet_latent_model_input[:, :, input_start_y:input_end_y, input_start_x:input_end_x]
|
| 1105 |
+
cond_list.append(cond_tile)
|
| 1106 |
+
img_tile = image[:, :, input_start_y*8:input_end_y*8, input_start_x*8:input_end_x*8]
|
| 1107 |
+
img_list.append(img_tile)
|
| 1108 |
+
|
| 1109 |
+
if len(input_list) == batch_size or col == grid_cols-1:
|
| 1110 |
+
input_list_t = torch.cat(input_list, dim=0)
|
| 1111 |
+
cond_list_t = torch.cat(cond_list, dim=0)
|
| 1112 |
+
img_list_t = torch.cat(img_list, dim=0)
|
| 1113 |
+
#print(input_list_t.shape, cond_list_t.shape, img_list_t.shape, fg_mask_list_t.shape)
|
| 1114 |
+
|
| 1115 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
| 1116 |
+
cond_list_t,
|
| 1117 |
+
t,
|
| 1118 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
| 1119 |
+
controlnet_cond=img_list_t,
|
| 1120 |
+
conditioning_scale=conditioning_scale,
|
| 1121 |
+
guess_mode=guess_mode,
|
| 1122 |
+
return_dict=False,
|
| 1123 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
if guess_mode and do_classifier_free_guidance:
|
| 1127 |
+
# Infered ControlNet only for the conditional batch.
|
| 1128 |
+
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
| 1129 |
+
# add 0 to the unconditional batch to keep it unchanged.
|
| 1130 |
+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
| 1131 |
+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
| 1132 |
+
|
| 1133 |
+
# predict the noise residual
|
| 1134 |
+
model_out = self.unet(
|
| 1135 |
+
input_list_t,
|
| 1136 |
+
t,
|
| 1137 |
+
encoder_hidden_states=prompt_embeds,
|
| 1138 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1139 |
+
down_block_additional_residuals=down_block_res_samples,
|
| 1140 |
+
mid_block_additional_residual=mid_block_res_sample,
|
| 1141 |
+
return_dict=False,
|
| 1142 |
+
image_encoder_hidden_states = ram_encoder_hidden_states,
|
| 1143 |
+
)[0]
|
| 1144 |
+
|
| 1145 |
+
#for sample_i in range(model_out.size(0)):
|
| 1146 |
+
# noise_preds_row.append(model_out[sample_i].unsqueeze(0))
|
| 1147 |
+
input_list = []
|
| 1148 |
+
cond_list = []
|
| 1149 |
+
img_list = []
|
| 1150 |
+
|
| 1151 |
+
noise_preds.append(model_out)
|
| 1152 |
+
|
| 1153 |
+
# Stitch noise predictions for all tiles
|
| 1154 |
+
noise_pred = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
| 1155 |
+
contributors = torch.zeros(latent_model_input.shape, device=latent_model_input.device)
|
| 1156 |
+
# Add each tile contribution to overall latents
|
| 1157 |
+
for row in range(grid_rows):
|
| 1158 |
+
for col in range(grid_cols):
|
| 1159 |
+
if col < grid_cols-1 or row < grid_rows-1:
|
| 1160 |
+
# extract tile from input image
|
| 1161 |
+
ofs_x = max(row * tile_size-tile_overlap * row, 0)
|
| 1162 |
+
ofs_y = max(col * tile_size-tile_overlap * col, 0)
|
| 1163 |
+
# input tile area on total image
|
| 1164 |
+
if row == grid_rows-1:
|
| 1165 |
+
ofs_x = w - tile_size
|
| 1166 |
+
if col == grid_cols-1:
|
| 1167 |
+
ofs_y = h - tile_size
|
| 1168 |
+
|
| 1169 |
+
input_start_x = ofs_x
|
| 1170 |
+
input_end_x = ofs_x + tile_size
|
| 1171 |
+
input_start_y = ofs_y
|
| 1172 |
+
input_end_y = ofs_y + tile_size
|
| 1173 |
+
|
| 1174 |
+
noise_pred[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += noise_preds[row*grid_cols + col] * tile_weights
|
| 1175 |
+
contributors[:, :, input_start_y:input_end_y, input_start_x:input_end_x] += tile_weights
|
| 1176 |
+
# Average overlapping areas with more than 1 contributor
|
| 1177 |
+
noise_pred /= contributors
|
| 1178 |
+
|
| 1179 |
+
|
| 1180 |
+
# perform guidance
|
| 1181 |
+
if do_classifier_free_guidance:
|
| 1182 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1183 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1184 |
|
| 1185 |
|
|
|
|
| 1186 |
|
| 1187 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1188 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1189 |
|
| 1190 |
# call the callback, if provided
|
| 1191 |
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
|
| 1193 |
if callback is not None and i % callback_steps == 0:
|
| 1194 |
callback(i, t, latents)
|
| 1195 |
|
| 1196 |
+
# If we do sequential model offloading, let's offload unet and controlnet
|
| 1197 |
+
# manually for max memory savings
|
| 1198 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 1199 |
+
self.unet.to("cpu")
|
| 1200 |
+
self.controlnet.to("cpu")
|
| 1201 |
+
torch.cuda.empty_cache()
|
| 1202 |
+
|
| 1203 |
+
has_nsfw_concept = None
|
| 1204 |
+
if not output_type == "latent":
|
| 1205 |
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]#.flip(1)
|
| 1206 |
+
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
| 1207 |
+
else:
|
| 1208 |
+
image = latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1209 |
has_nsfw_concept = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1210 |
|
| 1211 |
+
if has_nsfw_concept is None:
|
| 1212 |
+
do_denormalize = [True] * image.shape[0]
|
| 1213 |
+
else:
|
| 1214 |
+
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
| 1215 |
|
| 1216 |
+
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
| 1217 |
|
| 1218 |
+
# Offload last model to CPU
|
| 1219 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 1220 |
+
self.final_offload_hook.offload()
|
| 1221 |
|
| 1222 |
+
if not return_dict:
|
| 1223 |
+
return (image, has_nsfw_concept)
|
| 1224 |
|
| 1225 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|