Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -13,6 +13,8 @@
|
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
import gradio as gr
|
|
|
|
|
|
|
| 16 |
import argparse
|
| 17 |
import inspect
|
| 18 |
import os
|
|
@@ -1616,22 +1618,20 @@ if __name__ == "__main__":
|
|
| 1616 |
parser.add_argument('--experiment_name', default="AccDiffusion")
|
| 1617 |
|
| 1618 |
args = parser.parse_args()
|
| 1619 |
-
|
|
|
|
|
|
|
|
|
|
| 1620 |
# GRADIO MODE
|
| 1621 |
|
| 1622 |
-
|
|
|
|
| 1623 |
set_seed(args.seed)
|
| 1624 |
width,height = list(map(int, args.resolution.split(',')))
|
| 1625 |
-
pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
|
| 1626 |
-
generator = torch.Generator(device='cuda')
|
| 1627 |
-
generator = generator.manual_seed(args.seed)
|
| 1628 |
cross_attention_kwargs = {"edit_type": "visualize",
|
| 1629 |
"n_self_replace": 0.4,
|
| 1630 |
"n_cross_replace": {"default_": 1.0, "confetti": 0.8},
|
| 1631 |
}
|
| 1632 |
-
|
| 1633 |
-
|
| 1634 |
-
|
| 1635 |
seed = args.seed
|
| 1636 |
generator = generator.manual_seed(seed)
|
| 1637 |
|
|
@@ -1644,8 +1644,8 @@ if __name__ == "__main__":
|
|
| 1644 |
view_batch_size=args.view_batch_size,
|
| 1645 |
stride=args.stride,
|
| 1646 |
cross_attention_kwargs=cross_attention_kwargs,
|
| 1647 |
-
num_inference_steps=
|
| 1648 |
-
guidance_scale =
|
| 1649 |
multi_guidance_scale = args.multi_guidance_scale,
|
| 1650 |
cosine_scale_1=args.cosine_scale_1,
|
| 1651 |
cosine_scale_2=args.cosine_scale_2,
|
|
@@ -1680,7 +1680,7 @@ if __name__ == "__main__":
|
|
| 1680 |
<img src='https://img.shields.io/badge/Project-Page-blue'>
|
| 1681 |
</a>
|
| 1682 |
<a href='https://github.com/lzhxmu/AccDiffusion'>
|
| 1683 |
-
<img src='https://img.shields.io/badge/Code-blue'>
|
| 1684 |
</a>
|
| 1685 |
<a href='https://arxiv.org/abs/2407.10738v1'>
|
| 1686 |
<img src='https://img.shields.io/badge/Paper-Arxiv-red'>
|
|
@@ -1688,6 +1688,9 @@ if __name__ == "__main__":
|
|
| 1688 |
</div>
|
| 1689 |
""")
|
| 1690 |
prompt = gr.Textbox(label="Prompt")
|
|
|
|
|
|
|
|
|
|
| 1691 |
submit_btn = gr.Button("Submit")
|
| 1692 |
output_images = gr.Image(format="png")
|
| 1693 |
gr.Examples(
|
|
@@ -1700,7 +1703,7 @@ if __name__ == "__main__":
|
|
| 1700 |
)
|
| 1701 |
submit_btn.click(
|
| 1702 |
fn = infer,
|
| 1703 |
-
inputs = [prompt],
|
| 1704 |
outputs = [output_images],
|
| 1705 |
show_api=False
|
| 1706 |
)
|
|
|
|
| 13 |
# limitations under the License.
|
| 14 |
|
| 15 |
import gradio as gr
|
| 16 |
+
import spaces
|
| 17 |
+
|
| 18 |
import argparse
|
| 19 |
import inspect
|
| 20 |
import os
|
|
|
|
| 1618 |
parser.add_argument('--experiment_name', default="AccDiffusion")
|
| 1619 |
|
| 1620 |
args = parser.parse_args()
|
| 1621 |
+
|
| 1622 |
+
pipe = AccDiffusionSDXLPipeline.from_pretrained(args.model_ckpt, torch_dtype=torch.float16).to("cuda")
|
| 1623 |
+
generator = torch.Generator(device='cuda')
|
| 1624 |
+
|
| 1625 |
# GRADIO MODE
|
| 1626 |
|
| 1627 |
+
@spaces.GPU()
|
| 1628 |
+
def infer(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
| 1629 |
set_seed(args.seed)
|
| 1630 |
width,height = list(map(int, args.resolution.split(',')))
|
|
|
|
|
|
|
|
|
|
| 1631 |
cross_attention_kwargs = {"edit_type": "visualize",
|
| 1632 |
"n_self_replace": 0.4,
|
| 1633 |
"n_cross_replace": {"default_": 1.0, "confetti": 0.8},
|
| 1634 |
}
|
|
|
|
|
|
|
|
|
|
| 1635 |
seed = args.seed
|
| 1636 |
generator = generator.manual_seed(seed)
|
| 1637 |
|
|
|
|
| 1644 |
view_batch_size=args.view_batch_size,
|
| 1645 |
stride=args.stride,
|
| 1646 |
cross_attention_kwargs=cross_attention_kwargs,
|
| 1647 |
+
num_inference_steps=num_inference_steps,
|
| 1648 |
+
guidance_scale = guidance_scale,
|
| 1649 |
multi_guidance_scale = args.multi_guidance_scale,
|
| 1650 |
cosine_scale_1=args.cosine_scale_1,
|
| 1651 |
cosine_scale_2=args.cosine_scale_2,
|
|
|
|
| 1680 |
<img src='https://img.shields.io/badge/Project-Page-blue'>
|
| 1681 |
</a>
|
| 1682 |
<a href='https://github.com/lzhxmu/AccDiffusion'>
|
| 1683 |
+
<img src='https://img.shields.io/badge/Code-github-blue'>
|
| 1684 |
</a>
|
| 1685 |
<a href='https://arxiv.org/abs/2407.10738v1'>
|
| 1686 |
<img src='https://img.shields.io/badge/Paper-Arxiv-red'>
|
|
|
|
| 1688 |
</div>
|
| 1689 |
""")
|
| 1690 |
prompt = gr.Textbox(label="Prompt")
|
| 1691 |
+
with gr.Accordion("Advanced settings", open=False):
|
| 1692 |
+
num_inference_steps = gr.Slider(label="Inference Steps", minimum=2, maximum=50, step=1, value=50)
|
| 1693 |
+
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=510, step=0.1, value=7.5)
|
| 1694 |
submit_btn = gr.Button("Submit")
|
| 1695 |
output_images = gr.Image(format="png")
|
| 1696 |
gr.Examples(
|
|
|
|
| 1703 |
)
|
| 1704 |
submit_btn.click(
|
| 1705 |
fn = infer,
|
| 1706 |
+
inputs = [prompt, num_inference_steps, guidance_scale],
|
| 1707 |
outputs = [output_images],
|
| 1708 |
show_api=False
|
| 1709 |
)
|