Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,7 +21,7 @@ if not torch.cuda.is_available():
|
|
| 21 |
|
| 22 |
MAX_SEED = np.iinfo(np.int32).max
|
| 23 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 24 |
-
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "
|
| 25 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 26 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 27 |
|
|
@@ -105,7 +105,7 @@ if torch.cuda.is_available():
|
|
| 105 |
print("Using DALL-E 3 Consistency Decoder")
|
| 106 |
pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
| 107 |
|
| 108 |
-
|
| 109 |
pipe.enable_model_cpu_offload()
|
| 110 |
else:
|
| 111 |
pipe.to(device)
|
|
@@ -118,33 +118,35 @@ if torch.cuda.is_available():
|
|
| 118 |
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
| 119 |
print("Model Compiled!")
|
| 120 |
|
|
|
|
| 121 |
def save_image(img):
|
| 122 |
unique_name = str(uuid.uuid4()) + ".png"
|
| 123 |
img.save(unique_name)
|
| 124 |
return unique_name
|
| 125 |
|
|
|
|
| 126 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 127 |
if randomize_seed:
|
| 128 |
seed = random.randint(0, MAX_SEED)
|
| 129 |
return seed
|
| 130 |
|
|
|
|
| 131 |
def generate(
|
| 132 |
prompt: str,
|
| 133 |
negative_prompt: str = "",
|
| 134 |
style: str = DEFAULT_STYLE_NAME,
|
| 135 |
use_negative_prompt: bool = False,
|
| 136 |
-
num_imgs: int = 1,
|
| 137 |
seed: int = 0,
|
| 138 |
width: int = 1024,
|
| 139 |
height: int = 1024,
|
| 140 |
-
|
| 141 |
randomize_seed: bool = False,
|
| 142 |
use_resolution_binning: bool = True,
|
| 143 |
progress=gr.Progress(track_tqdm=True),
|
| 144 |
):
|
| 145 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 146 |
generator = torch.Generator().manual_seed(seed)
|
| 147 |
-
|
| 148 |
if not use_negative_prompt:
|
| 149 |
negative_prompt = None # type: ignore
|
| 150 |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
|
@@ -290,6 +292,7 @@ with gr.Blocks() as demo:
|
|
| 290 |
negative_prompt,
|
| 291 |
style_selection,
|
| 292 |
use_negative_prompt,
|
|
|
|
| 293 |
seed,
|
| 294 |
width,
|
| 295 |
height,
|
|
@@ -300,6 +303,6 @@ with gr.Blocks() as demo:
|
|
| 300 |
api_name="run",
|
| 301 |
)
|
| 302 |
|
| 303 |
-
|
| 304 |
if __name__ == "__main__":
|
| 305 |
-
demo.queue(max_size=20).launch()
|
|
|
|
|
|
| 21 |
|
| 22 |
MAX_SEED = np.iinfo(np.int32).max
|
| 23 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
| 24 |
+
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
|
| 25 |
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
| 26 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
| 27 |
|
|
|
|
| 105 |
print("Using DALL-E 3 Consistency Decoder")
|
| 106 |
pipe.vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
|
| 107 |
|
| 108 |
+
if ENABLE_CPU_OFFLOAD:
|
| 109 |
pipe.enable_model_cpu_offload()
|
| 110 |
else:
|
| 111 |
pipe.to(device)
|
|
|
|
| 118 |
pipe.transformer = torch.compile(pipe.transformer, mode="reduce-overhead", fullgraph=True)
|
| 119 |
print("Model Compiled!")
|
| 120 |
|
| 121 |
+
|
| 122 |
def save_image(img):
|
| 123 |
unique_name = str(uuid.uuid4()) + ".png"
|
| 124 |
img.save(unique_name)
|
| 125 |
return unique_name
|
| 126 |
|
| 127 |
+
|
| 128 |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
|
| 129 |
if randomize_seed:
|
| 130 |
seed = random.randint(0, MAX_SEED)
|
| 131 |
return seed
|
| 132 |
|
| 133 |
+
|
| 134 |
def generate(
|
| 135 |
prompt: str,
|
| 136 |
negative_prompt: str = "",
|
| 137 |
style: str = DEFAULT_STYLE_NAME,
|
| 138 |
use_negative_prompt: bool = False,
|
|
|
|
| 139 |
seed: int = 0,
|
| 140 |
width: int = 1024,
|
| 141 |
height: int = 1024,
|
| 142 |
+
inference_steps: int = 4,
|
| 143 |
randomize_seed: bool = False,
|
| 144 |
use_resolution_binning: bool = True,
|
| 145 |
progress=gr.Progress(track_tqdm=True),
|
| 146 |
):
|
| 147 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
| 148 |
generator = torch.Generator().manual_seed(seed)
|
| 149 |
+
|
| 150 |
if not use_negative_prompt:
|
| 151 |
negative_prompt = None # type: ignore
|
| 152 |
prompt, negative_prompt = apply_style(style, prompt, negative_prompt)
|
|
|
|
| 292 |
negative_prompt,
|
| 293 |
style_selection,
|
| 294 |
use_negative_prompt,
|
| 295 |
+
num_imgs,
|
| 296 |
seed,
|
| 297 |
width,
|
| 298 |
height,
|
|
|
|
| 303 |
api_name="run",
|
| 304 |
)
|
| 305 |
|
|
|
|
| 306 |
if __name__ == "__main__":
|
| 307 |
+
demo.queue(max_size=20).launch()
|
| 308 |
+
# demo.queue(max_size=20).launch(server_name="0.0.0.0", server_port=11900, debug=True)
|