Spaces:
Runtime error
Runtime error
Update
Browse files
app.py
CHANGED
|
@@ -76,17 +76,17 @@ def generate_image(model: nn.Module, z: torch.Tensor, truncation_psi: float,
|
|
| 76 |
|
| 77 |
|
| 78 |
@torch.inference_mode()
|
| 79 |
-
def generate_interpolated_images(
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
|
| 84 |
seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
|
| 85 |
|
| 86 |
z0 = generate_z(model.style_dim, seed0, device)
|
| 87 |
if num_intermediate == -1:
|
| 88 |
out = generate_image(model, z0, psi0, randomize_noise)
|
| 89 |
-
return out
|
| 90 |
|
| 91 |
z1 = generate_z(model.style_dim, seed1, device)
|
| 92 |
vec = z1 - z0
|
|
@@ -98,8 +98,8 @@ def generate_interpolated_images(seed0: int, seed1: int, num_intermediate: int,
|
|
| 98 |
for z, psi in zip(zs, psis):
|
| 99 |
out = generate_image(model, z, psi, randomize_noise)
|
| 100 |
res.append(out)
|
| 101 |
-
|
| 102 |
-
return res
|
| 103 |
|
| 104 |
|
| 105 |
def main():
|
|
@@ -129,7 +129,7 @@ def main():
|
|
| 129 |
gr.inputs.Number(default=29703, label='Seed 1'),
|
| 130 |
gr.inputs.Number(default=55376, label='Seed 2'),
|
| 131 |
gr.inputs.Slider(-1,
|
| 132 |
-
|
| 133 |
step=1,
|
| 134 |
default=3,
|
| 135 |
label='Number of Intermediate Frames'),
|
|
@@ -139,7 +139,11 @@ def main():
|
|
| 139 |
0, 2, step=0.05, default=0.7, label='Truncation psi 2'),
|
| 140 |
gr.inputs.Checkbox(default=False, label='Randomize Noise'),
|
| 141 |
],
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
examples=examples,
|
| 144 |
title=TITLE,
|
| 145 |
description=DESCRIPTION,
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
@torch.inference_mode()
|
| 79 |
+
def generate_interpolated_images(
|
| 80 |
+
seed0: int, seed1: int, num_intermediate: int, psi0: float,
|
| 81 |
+
psi1: float, randomize_noise: bool, model: nn.Module,
|
| 82 |
+
device: torch.device) -> tuple[list[np.ndarray], np.ndarray]:
|
| 83 |
seed0 = int(np.clip(seed0, 0, np.iinfo(np.uint32).max))
|
| 84 |
seed1 = int(np.clip(seed1, 0, np.iinfo(np.uint32).max))
|
| 85 |
|
| 86 |
z0 = generate_z(model.style_dim, seed0, device)
|
| 87 |
if num_intermediate == -1:
|
| 88 |
out = generate_image(model, z0, psi0, randomize_noise)
|
| 89 |
+
return [out], None
|
| 90 |
|
| 91 |
z1 = generate_z(model.style_dim, seed1, device)
|
| 92 |
vec = z1 - z0
|
|
|
|
| 98 |
for z, psi in zip(zs, psis):
|
| 99 |
out = generate_image(model, z, psi, randomize_noise)
|
| 100 |
res.append(out)
|
| 101 |
+
concatenated = np.hstack(res)
|
| 102 |
+
return res, concatenated
|
| 103 |
|
| 104 |
|
| 105 |
def main():
|
|
|
|
| 129 |
gr.inputs.Number(default=29703, label='Seed 1'),
|
| 130 |
gr.inputs.Number(default=55376, label='Seed 2'),
|
| 131 |
gr.inputs.Slider(-1,
|
| 132 |
+
21,
|
| 133 |
step=1,
|
| 134 |
default=3,
|
| 135 |
label='Number of Intermediate Frames'),
|
|
|
|
| 139 |
0, 2, step=0.05, default=0.7, label='Truncation psi 2'),
|
| 140 |
gr.inputs.Checkbox(default=False, label='Randomize Noise'),
|
| 141 |
],
|
| 142 |
+
[
|
| 143 |
+
gr.outputs.Carousel(gr.outputs.Image(type='numpy'),
|
| 144 |
+
label='Output Images'),
|
| 145 |
+
gr.outputs.Image(type='numpy', label='Concatenated'),
|
| 146 |
+
],
|
| 147 |
examples=examples,
|
| 148 |
title=TITLE,
|
| 149 |
description=DESCRIPTION,
|