Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -125,7 +125,8 @@ models_rbm = core.Models(
|
|
| 125 |
text_model=models.text_model,
|
| 126 |
tokenizer=models.tokenizer,
|
| 127 |
generator=generator_rbm,
|
| 128 |
-
previewer=models.previewer
|
|
|
|
| 129 |
)
|
| 130 |
|
| 131 |
def reset_inference_state():
|
|
@@ -160,8 +161,10 @@ def reset_inference_state():
|
|
| 160 |
|
| 161 |
models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
|
| 162 |
|
| 163 |
-
# Ensure effnet
|
| 164 |
models_rbm.effnet.to(device)
|
|
|
|
|
|
|
| 165 |
|
| 166 |
# Reset model states
|
| 167 |
models_rbm.generator.eval().requires_grad_(False)
|
|
@@ -204,8 +207,11 @@ def infer(style_description, ref_style_file, caption):
|
|
| 204 |
|
| 205 |
models_b.generator.to(device)
|
| 206 |
|
| 207 |
-
# Ensure effnet
|
| 208 |
models_rbm.effnet.to(device)
|
|
|
|
|
|
|
|
|
|
| 209 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
| 210 |
|
| 211 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|
|
|
|
| 125 |
text_model=models.text_model,
|
| 126 |
tokenizer=models.tokenizer,
|
| 127 |
generator=generator_rbm,
|
| 128 |
+
previewer=models.previewer,
|
| 129 |
+
image_model=models.image_model # Add this line
|
| 130 |
)
|
| 131 |
|
| 132 |
def reset_inference_state():
|
|
|
|
| 161 |
|
| 162 |
models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
|
| 163 |
|
| 164 |
+
# Ensure effnet and image_model are on the correct device
|
| 165 |
models_rbm.effnet.to(device)
|
| 166 |
+
if models_rbm.image_model is not None:
|
| 167 |
+
models_rbm.image_model.to(device)
|
| 168 |
|
| 169 |
# Reset model states
|
| 170 |
models_rbm.generator.eval().requires_grad_(False)
|
|
|
|
| 207 |
|
| 208 |
models_b.generator.to(device)
|
| 209 |
|
| 210 |
+
# Ensure effnet and image_model are on the correct device
|
| 211 |
models_rbm.effnet.to(device)
|
| 212 |
+
if models_rbm.image_model is not None:
|
| 213 |
+
models_rbm.image_model.to(device)
|
| 214 |
+
|
| 215 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
| 216 |
|
| 217 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|