Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,7 @@ from src.flux.xflux_pipeline import XFluxSampler
|
|
| 14 |
args = OmegaConf.load("inference_configs/inference.yaml")
|
| 15 |
# is_schnell = args.model_name == "flux-schnell"
|
| 16 |
# sampler = None
|
| 17 |
-
|
| 18 |
# dtype = torch.bfloat16
|
| 19 |
# dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
|
| 20 |
# vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
|
|
@@ -91,6 +91,7 @@ def generate(image: Image.Image, edit_prompt: str):
|
|
| 91 |
img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
|
| 92 |
|
| 93 |
result = sampler(
|
|
|
|
| 94 |
prompt=edit_prompt,
|
| 95 |
width=args.sample_width,
|
| 96 |
height=args.sample_height,
|
|
|
|
| 14 |
args = OmegaConf.load("inference_configs/inference.yaml")
|
| 15 |
# is_schnell = args.model_name == "flux-schnell"
|
| 16 |
# sampler = None
|
| 17 |
+
device = torch.device("cuda")
|
| 18 |
# dtype = torch.bfloat16
|
| 19 |
# dit = load_flow_model2(args.model_name, device="cpu").to(device, dtype=dtype)
|
| 20 |
# vae = load_ae(args.model_name, device="cpu").to(device, dtype=dtype)
|
|
|
|
| 91 |
img = img.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
|
| 92 |
|
| 93 |
result = sampler(
|
| 94 |
+
device='cuda',
|
| 95 |
prompt=edit_prompt,
|
| 96 |
width=args.sample_width,
|
| 97 |
height=args.sample_height,
|