Spaces:
Paused
Paused
| import os | |
| import time | |
| import gradio as gr | |
| from gradio_imageslider import ImageSlider | |
| from comfydeploy import ComfyDeploy | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| import base64 | |
| import glob | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| API_KEY = os.environ.get("COMFY_DEPLOY_API_KEY") | |
| DEPLOYMENT_ID = os.environ.get("COMFY_DEPLOYMENT_ID") | |
| if not API_KEY or not DEPLOYMENT_ID: | |
| raise ValueError( | |
| "Please set COMFY_DEPLOY_API_KEY and COMFY_DEPLOYMENT_ID in your environment variables" | |
| ) | |
| client = ComfyDeploy(bearer_auth=API_KEY) | |
| def get_base64_from_image(image: Image.Image) -> str: | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| def get_profile(profile) -> dict: | |
| return { | |
| "username": profile.username, | |
| "profile": profile.profile, | |
| "name": profile.name, | |
| } | |
| async def process( | |
| image: Image.Image | None = None, | |
| profile: gr.OAuthProfile | None = None, | |
| progress: gr.Progress = gr.Progress(), | |
| ) -> tuple[Image.Image, Image.Image] | None: | |
| if not image: | |
| gr.Info("Please upload an image ") | |
| return None | |
| if profile is None: | |
| gr.Info("Please log in to process the image.") | |
| return None | |
| user_data = get_profile(profile) | |
| print("--------- RUN ----------") | |
| print(user_data) | |
| progress(0, desc="Preparing inputs...") | |
| image_base64 = get_base64_from_image(image) | |
| inputs = { | |
| "image": f"data:image/png;base64,{image_base64}", | |
| **{k: str(v) for k, v in params.items()}, | |
| } | |
| output = await process_image(inputs, progress) | |
| progress(100, desc="Processing completed") | |
| return image, output | |
| async def process_image(inputs: dict, progress: gr.Progress) -> Image.Image | None: | |
| try: | |
| result = client.run.create( | |
| request={"deployment_id": DEPLOYMENT_ID, "inputs": inputs} | |
| ) | |
| if result and result.object: | |
| run_id: str = result.object.run_id | |
| progress(0, desc="Starting processing...") | |
| while True: | |
| run_result = client.run.get(run_id=run_id) | |
| if not run_result.object: | |
| continue | |
| progress_value = run_result.object.progress or 0 | |
| status = run_result.object.live_status or "Cold starting..." | |
| progress(progress_value, desc=f"Status: {status}") | |
| if run_result.object.status == "success": | |
| for output in run_result.object.outputs or []: | |
| if output.data and output.data.images: | |
| image_url: str = output.data.images[0].url | |
| response = requests.get(image_url) | |
| processed_image = Image.open(BytesIO(response.content)) | |
| return processed_image | |
| elif run_result.object.status == "failed": | |
| print("Processing failed") | |
| return None | |
| time.sleep(1) # Wait for 1 second before checking the status again | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return None | |
| def load_preset_images(): | |
| image_files = glob.glob("images/inputs/*") | |
| return [ | |
| {"name": img, "image": Image.open(img)} | |
| for img in image_files | |
| if Image.open(img).format.lower() | |
| in ["png", "jpg", "jpeg", "gif", "bmp", "webp"] | |
| ] | |
| def build_example(input_image_path): | |
| output_image_path = input_image_path.replace("inputs", "outputs") | |
| return [ | |
| input_image_path, | |
| 0.4, | |
| 10, | |
| 1024, | |
| 1, | |
| 4, | |
| 0, | |
| 1, | |
| 0.7, | |
| (input_image_path, output_image_path), | |
| ] | |
| def serialize_params(params: dict) -> dict: | |
| return { | |
| key: {"value": param.value, "label": param.label} | |
| for key, param in params.items() | |
| } | |
| with gr.Blocks() as demo: | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; text-align:center; flex-direction: column;"> | |
| <h1 style="color: #333;">π Creative Image Upscaler</h1> | |
| <div style="max-width: 800px; margin: 0 auto;"> | |
| <p style="font-size: 16px;">Upload an image and adjust the parameters to enhance your image.</p> | |
| <p style="font-size: 16px;">Click on the <b>"Run"</b> button to process the image and compare the original and processed images using the slider.</p> | |
| <p style="font-size: 16px;">β οΈ Note that the images are compressed to reduce the workloads of the demo.</p> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| # The image overflow, fix | |
| input_image = gr.Image(type="pil", label="Input Image", interactive=True) | |
| with gr.Accordion("Avanced parameters", open=False): | |
| params = { | |
| "denoise": gr.Slider(0, 1, value=0.4, label="Denoise"), | |
| "steps": gr.Slider(1, 25, value=10, label="Steps"), | |
| "tile_size": gr.Slider(256, 2048, value=1024, label="Tile Size"), | |
| "downscale": gr.Slider(1, 4, value=1, label="Downscale"), | |
| "upscale": gr.Slider(1, 4, value=4, label="Upscale"), | |
| "color_match": gr.Slider(0, 1, value=0, label="Color Match"), | |
| "controlnet_tile_end": gr.Slider( | |
| 0, 1, value=1, label="ControlNet Tile End" | |
| ), | |
| "controlnet_tile_strength": gr.Slider( | |
| 0, 1, value=0.7, label="ControlNet Tile Strength" | |
| ), | |
| } | |
| with gr.Column(): | |
| image_slider = ImageSlider( | |
| label="Compare Original and Processed", interactive=False | |
| ) | |
| login_button = gr.LoginButton(scale=8) | |
| process_btn = gr.Button("Run", variant="primary", size="lg") | |
| process_btn.click( | |
| fn=lambda _: gr.update(interactive=False, value="Processing..."), | |
| inputs=[], | |
| outputs=[process_btn], | |
| api_name=False, | |
| ).then( | |
| fn=process, | |
| inputs=[ | |
| input_image, | |
| ], | |
| outputs=[image_slider], | |
| api_name=False, | |
| ).then( | |
| fn=lambda _: gr.update(interactive=True, value="Run"), | |
| inputs=[], | |
| outputs=[process_btn], | |
| api_name=False, | |
| ) | |
| examples = [build_example(img) for img in glob.glob("images/inputs/*")] | |
| gr.Examples(examples=examples, inputs=[input_image, *params.values(), image_slider]) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True, share=True) | |