Spaces:
Runtime error
Runtime error
Modified s_dict generation
Browse files- app.py +3 -2
- model/sg2_model.py +30 -1
- styleclip/styleclip_global.py +5 -2
app.py
CHANGED
|
@@ -368,8 +368,9 @@ with blocks:
|
|
| 368 |
vid_button = gr.Button("Generate Video")
|
| 369 |
loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
|
| 370 |
with gr.Row():
|
| 371 |
-
gr.
|
| 372 |
-
|
|
|
|
| 373 |
with gr.Column():
|
| 374 |
vid_output = gr.outputs.Video(label="Output Video")
|
| 375 |
|
|
|
|
| 368 |
vid_button = gr.Button("Generate Video")
|
| 369 |
loop_styles = gr.inputs.Checkbox(default=True, label="Loop video back to the initial style?")
|
| 370 |
with gr.Row():
|
| 371 |
+
with gr.Column():
|
| 372 |
+
gr.Markdown("Warning: Videos generation requires the synthesis of hundreds of frames and is expected to take several minutes.")
|
| 373 |
+
gr.Markdown("To reduce queue times, we significantly reduced the number of video frames. Using more than 3 styles will further reduce the frames per style, leading to quicker transitions. For better control, we reccomend cloning the gradio app, adjusting `num_alphas` in `generate_videos`, and running the code locally.")
|
| 374 |
with gr.Column():
|
| 375 |
vid_output = gr.outputs.Video(label="Output Video")
|
| 376 |
|
model/sg2_model.py
CHANGED
|
@@ -526,7 +526,36 @@ class Generator(nn.Module):
|
|
| 526 |
if not input_is_latent:
|
| 527 |
styles = [self.style(s) for s in styles]
|
| 528 |
|
| 529 |
-
s_codes =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 530 |
|
| 531 |
return s_codes
|
| 532 |
|
|
|
|
| 526 |
if not input_is_latent:
|
| 527 |
styles = [self.style(s) for s in styles]
|
| 528 |
|
| 529 |
+
s_codes = {# const block
|
| 530 |
+
self.modulation_layers[0]: self.modulation_layers[0](styles[0]),
|
| 531 |
+
self.modulation_layers[1]: self.modulation_layers[1](styles[1]),
|
| 532 |
+
# conv layers
|
| 533 |
+
self.modulation_layers[2]: self.modulation_layers[2](styles[2]),
|
| 534 |
+
self.modulation_layers[3]: self.modulation_layers[3](styles[3]),
|
| 535 |
+
self.modulation_layers[5]: self.modulation_layers[5](styles[4]),
|
| 536 |
+
self.modulation_layers[6]: self.modulation_layers[6](styles[5]),
|
| 537 |
+
self.modulation_layers[8]: self.modulation_layers[8](styles[6]),
|
| 538 |
+
self.modulation_layers[9]: self.modulation_layers[9](styles[7]),
|
| 539 |
+
self.modulation_layers[11]: self.modulation_layers[11](styles[8]),
|
| 540 |
+
self.modulation_layers[12]: self.modulation_layers[12](styles[9]),
|
| 541 |
+
self.modulation_layers[14]: self.modulation_layers[14](styles[10]),
|
| 542 |
+
self.modulation_layers[15]: self.modulation_layers[15](styles[11]),
|
| 543 |
+
self.modulation_layers[17]: self.modulation_layers[17](styles[12]),
|
| 544 |
+
self.modulation_layers[18]: self.modulation_layers[18](styles[13]),
|
| 545 |
+
self.modulation_layers[20]: self.modulation_layers[20](styles[14]),
|
| 546 |
+
self.modulation_layers[21]: self.modulation_layers[21](styles[15]),
|
| 547 |
+
self.modulation_layers[23]: self.modulation_layers[23](styles[16]),
|
| 548 |
+
self.modulation_layers[24]: self.modulation_layers[24](styles[17]),
|
| 549 |
+
# toRGB layers
|
| 550 |
+
self.modulation_layers[4]: self.modulation_layers[4](styles[3]),
|
| 551 |
+
self.modulation_layers[7]: self.modulation_layers[7](styles[5]),
|
| 552 |
+
self.modulation_layers[10]: self.modulation_layers[10](styles[7]),
|
| 553 |
+
self.modulation_layers[13]: self.modulation_layers[13](styles[9]),
|
| 554 |
+
self.modulation_layers[16]: self.modulation_layers[16](styles[11]),
|
| 555 |
+
self.modulation_layers[19]: self.modulation_layers[19](styles[13]),
|
| 556 |
+
self.modulation_layers[22]: self.modulation_layers[22](styles[15]),
|
| 557 |
+
self.modulation_layers[25]: self.modulation_layers[25](styles[17]),
|
| 558 |
+
}
|
| 559 |
|
| 560 |
return s_codes
|
| 561 |
|
styleclip/styleclip_global.py
CHANGED
|
@@ -120,7 +120,10 @@ def get_direction(neutral_class, target_class, beta, di, clip_model=None):
|
|
| 120 |
|
| 121 |
dt = class_weights[:, 1] - class_weights[:, 0]
|
| 122 |
dt = dt / dt.norm()
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
| 124 |
relevance = di @ dt
|
| 125 |
mask = relevance.abs() > beta
|
| 126 |
direction = relevance * mask
|
|
@@ -144,7 +147,7 @@ def style_tensor_to_style_dict(style_tensor, refernce_generator):
|
|
| 144 |
def style_dict_to_style_tensor(style_dict, reference_generator):
|
| 145 |
style_layers = reference_generator.modulation_layers
|
| 146 |
|
| 147 |
-
style_tensor = torch.zeros(
|
| 148 |
for layer in style_dict:
|
| 149 |
layer_idx = style_layers.index(layer)
|
| 150 |
style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]
|
|
|
|
| 120 |
|
| 121 |
dt = class_weights[:, 1] - class_weights[:, 0]
|
| 122 |
dt = dt / dt.norm()
|
| 123 |
+
|
| 124 |
+
dt = dt.float()
|
| 125 |
+
di = di.float()
|
| 126 |
+
|
| 127 |
relevance = di @ dt
|
| 128 |
mask = relevance.abs() > beta
|
| 129 |
direction = relevance * mask
|
|
|
|
| 147 |
def style_dict_to_style_tensor(style_dict, reference_generator):
|
| 148 |
style_layers = reference_generator.modulation_layers
|
| 149 |
|
| 150 |
+
style_tensor = torch.zeros(size=(1, 9088))
|
| 151 |
for layer in style_dict:
|
| 152 |
layer_idx = style_layers.index(layer)
|
| 153 |
style_tensor[:, FFHQ_CODE_INDICES[layer_idx][0]:FFHQ_CODE_INDICES[layer_idx][1]] = style_dict[layer]
|