fix app torch grad
Browse files
app.py
CHANGED
|
@@ -83,12 +83,11 @@ def text_to_image_generation(input_text, guidance_scale=1.75, generation_timeste
|
|
| 83 |
config=config,
|
| 84 |
)
|
| 85 |
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
| 92 |
|
| 93 |
return images[0]
|
| 94 |
|
|
@@ -158,12 +157,12 @@ def text_guided_inpainting(input_text, inpainting_image, inpainting_mask, guidan
|
|
| 158 |
config=config,
|
| 159 |
)
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
|
| 168 |
return images[0]
|
| 169 |
|
|
@@ -283,11 +282,12 @@ def text_guided_extrapolation(input_img, input_text, left_ext, right_ext, guidan
|
|
| 283 |
|
| 284 |
_, h, w = gen_token_ids.shape
|
| 285 |
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
| 286 |
-
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
| 287 |
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
| 291 |
|
| 292 |
return images[0]
|
| 293 |
|
|
|
|
| 83 |
config=config,
|
| 84 |
)
|
| 85 |
|
| 86 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
| 87 |
+
images = vq_model.decode_code(gen_token_ids)
|
| 88 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
| 89 |
+
images *= 255.0
|
| 90 |
+
images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
|
|
|
|
| 91 |
|
| 92 |
return images[0]
|
| 93 |
|
|
|
|
| 157 |
config=config,
|
| 158 |
)
|
| 159 |
|
| 160 |
+
gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0)
|
| 161 |
+
images = vq_model.decode_code(gen_token_ids)
|
| 162 |
|
| 163 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
| 164 |
+
images *= 255.0
|
| 165 |
+
images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
|
| 166 |
|
| 167 |
return images[0]
|
| 168 |
|
|
|
|
| 282 |
|
| 283 |
_, h, w = gen_token_ids.shape
|
| 284 |
gen_token_ids = gen_token_ids.reshape(config.training.batch_size, -1)
|
|
|
|
| 285 |
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
images = vq_model.decode_code(gen_token_ids, shape=(h, w))
|
| 288 |
+
images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0)
|
| 289 |
+
images *= 255.0
|
| 290 |
+
images = images.permute(0, 2, 3, 1).cpu().detach().numpy().astype(np.uint8)
|
| 291 |
|
| 292 |
return images[0]
|
| 293 |
|