Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -43,15 +43,14 @@ if torch.cuda.is_available():
|
|
| 43 |
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
|
| 44 |
|
| 45 |
if PREVIEW_IMAGES:
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# previewer.eval().requires_grad_(False).to(device).to(dtype)
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
|
| 56 |
else:
|
| 57 |
previewer = None
|
|
@@ -97,12 +96,12 @@ def generate(
|
|
| 97 |
callback=callback_prior,
|
| 98 |
)
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
|
| 107 |
decoder_output = decoder_pipeline(
|
| 108 |
image_embeddings=prior_output.image_embeddings,
|
|
|
|
| 43 |
decoder_pipeline.decoder = torch.compile(decoder_pipeline.decoder, mode="max-autotune", fullgraph=True)
|
| 44 |
|
| 45 |
if PREVIEW_IMAGES:
|
| 46 |
+
previewer = Previewer()
|
| 47 |
+
previewer.load_state_dict(torch.load("previewer/previewer_v1_100k.pt")["state_dict"])
|
| 48 |
+
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
|
|
|
| 49 |
|
| 50 |
+
def callback_prior(i, t, latents):
|
| 51 |
+
output = previewer(latents)
|
| 52 |
+
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).cpu().numpy())
|
| 53 |
+
return output
|
| 54 |
|
| 55 |
else:
|
| 56 |
previewer = None
|
|
|
|
| 96 |
callback=callback_prior,
|
| 97 |
)
|
| 98 |
|
| 99 |
+
if PREVIEW_IMAGES:
|
| 100 |
+
for _ in range(len(DEFAULT_STAGE_C_TIMESTEPS)):
|
| 101 |
+
r = next(prior_output)
|
| 102 |
+
if isinstance(r, list):
|
| 103 |
+
yield r[0]
|
| 104 |
+
prior_output = r
|
| 105 |
|
| 106 |
decoder_output = decoder_pipeline(
|
| 107 |
image_embeddings=prior_output.image_embeddings,
|