Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import requests | |
| import tempfile | |
| import gradio as gr | |
| from PIL import Image, ImageDraw | |
| from config import config, theme | |
| from public.data.images.loras.flux1 import loras as flux1_loras | |
| # os.makedirs(config.get("HF_HOME"), exist_ok=True) | |
| # UI | |
| with gr.Blocks( | |
| theme=theme, | |
| fill_width=True, | |
| css_paths=[os.path.join("static/css", f) for f in os.listdir("static/css")], | |
| ) as demo: | |
| # States | |
| data_state = gr.State() | |
| local_state = gr.BrowserState( | |
| { | |
| "selected_loras": [], | |
| } | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Label("AllFlux", show_label=False) | |
| with gr.Accordion("Settings", open=True): | |
| with gr.Group(): | |
| height_slider = gr.Slider( | |
| minimum=64, | |
| maximum=2048, | |
| value=1024, | |
| step=64, | |
| label="Height", | |
| interactive=True, | |
| ) | |
| width_slider = gr.Slider( | |
| minimum=64, | |
| maximum=2048, | |
| value=1024, | |
| step=64, | |
| label="Width", | |
| interactive=True, | |
| ) | |
| with gr.Group(): | |
| num_images_slider = gr.Slider( | |
| minimum=1, | |
| maximum=4, | |
| value=1, | |
| step=1, | |
| label="Number of Images", | |
| interactive=True, | |
| ) | |
| toggles = gr.CheckboxGroup( | |
| choices=["Realtime", "Randomize Seed"], | |
| value=["Randomize Seed"], | |
| show_label=False, | |
| interactive=True, | |
| ) | |
| with gr.Accordion("Advanced", open=False): | |
| num_steps_slider = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=20, | |
| step=1, | |
| label="Steps", | |
| interactive=True, | |
| ) | |
| guidance_scale_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=3.5, | |
| step=0.1, | |
| label="Guidance Scale", | |
| interactive=True, | |
| ) | |
| seed_slider = gr.Slider( | |
| minimum=0, | |
| maximum=4294967295, | |
| value=42, | |
| step=1, | |
| label="Seed", | |
| interactive=True, | |
| ) | |
| upscale_slider = gr.Slider( | |
| minimum=2, | |
| maximum=4, | |
| value=2, | |
| step=2, | |
| label="Upscale", | |
| interactive=True, | |
| ) | |
| scheduler_dropdown = gr.Dropdown( | |
| label="Scheduler", | |
| choices=[ | |
| "Euler a", | |
| "Euler", | |
| "LMS", | |
| "Heun", | |
| "DPM++ 2", | |
| "DPM++ 2 a", | |
| "DPM++ SDE", | |
| "DPM++ SDE Karras", | |
| "DDIM", | |
| "PLMS", | |
| ], | |
| value="Euler a", | |
| interactive=True, | |
| ) | |
| gr.LoginButton() | |
| gr.Markdown( | |
| """ | |
| Yurrrrrrrrrrrr, WIP | |
| """ | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter your prompt here...", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| submit_btn = gr.Button("Submit") | |
| with gr.Column(scale=1): | |
| ai_improve_btn = gr.Button("💡", link="#improve-prompt") | |
| with gr.Group(): | |
| output_gallery = gr.Gallery( | |
| label="Outputs", interactive=False, height=500 | |
| ) | |
| with gr.Row(): | |
| upscale_selected_btn = gr.Button("Upscale Selected", size="sm") | |
| upscale_all_btn = gr.Button("Upscale All", size="sm") | |
| create_similar_btn = gr.Button("Create Similar", size="sm") | |
| with gr.Accordion("Output History", open=False): | |
| with gr.Group(): | |
| output_history_gallery = gr.Gallery( | |
| show_label=False, interactive=False, height=500 | |
| ) | |
| with gr.Row(): | |
| clear_history_btn = gr.Button("Clear All", size="sm") | |
| download_history_btn = gr.Button("Download All", size="sm") | |
| with gr.Accordion("Image Playground", open=True): | |
| def show_info(content: str | None = None): | |
| info_checkbox = gr.Checkbox( | |
| value=False, label="Show Info", interactive=True | |
| ) | |
| def show_info(info_checkbox): | |
| return ( | |
| gr.Markdown( | |
| f"""Sup, need some help here, please check the community tab. {content}""" | |
| ) | |
| if info_checkbox | |
| else None | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Img 2 Img"): | |
| with gr.Group(): | |
| img2img_img = gr.Image(show_label=False, interactive=True) | |
| img2img_strength_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=1.0, | |
| step=0.1, | |
| label="Strength", | |
| interactive=True, | |
| ) | |
| show_info() | |
| with gr.Tab("Inpaint"): | |
| with gr.Group(): | |
| inpaint_img = gr.ImageMask( | |
| show_label=False, interactive=True, type="pil" | |
| ) | |
| generate_mask_btn = gr.Button( | |
| "Remove Background", size="sm" | |
| ) | |
| use_fill_pipe_inpaint = gr.Checkbox( | |
| value=True, | |
| label="Use Fill Pipeline 🧪", | |
| interactive=True, | |
| ) | |
| show_info() | |
| inpaint_img.upload( | |
| fn=lambda x: ( | |
| gr.update(height=x["layers"][0].height + 96) | |
| if x is not None | |
| else None | |
| ), | |
| inputs=inpaint_img, | |
| outputs=inpaint_img, | |
| ) | |
| with gr.Tab("Outpaint"): | |
| outpaint_img = gr.Image( | |
| show_label=False, interactive=True, type="pil" | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| ratio_9_16 = gr.Radio( | |
| label="Image Ratio", | |
| choices=["9:16", "16:9", "1:1", "Height & Width"], | |
| value="9:16", | |
| container=True, | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| mask_position = gr.Dropdown( | |
| choices=[ | |
| "Middle", | |
| "Left", | |
| "Right", | |
| "Top", | |
| "Bottom", | |
| ], | |
| value="Middle", | |
| label="Alignment", | |
| interactive=True, | |
| ) | |
| with gr.Group(): | |
| resize_options = gr.Radio( | |
| choices=["Full", "75%", "50%", "33%", "25%", "Custom"], | |
| value="Full", | |
| label="Resize", | |
| interactive=True, | |
| ) | |
| resize_option_custom = gr.State() | |
| def resize_options_render(resize_option): | |
| if resize_option == "Custom": | |
| resize_option_custom = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Custom Size %", | |
| interactive=True, | |
| ) | |
| with gr.Accordion("Advanced settings", open=False): | |
| with gr.Group(): | |
| mask_overlap_slider = gr.Slider( | |
| label="Mask Overlap %", | |
| minimum=1, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| overlap_top = gr.Checkbox( | |
| value=True, | |
| label="Overlap Top", | |
| interactive=True, | |
| ) | |
| overlap_right = gr.Checkbox( | |
| value=True, | |
| label="Overlap Right", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| overlap_left = gr.Checkbox( | |
| value=True, | |
| label="Overlap Left", | |
| interactive=True, | |
| ) | |
| overlap_bottom = gr.Checkbox( | |
| value=True, | |
| label="Overlap Bottom", | |
| interactive=True, | |
| ) | |
| mask_preview_btn = gr.Button( | |
| "Preview", interactive=True | |
| ) | |
| mask_preview_img = gr.Image( | |
| show_label=False, visible=False, interactive=True | |
| ) | |
| def prepare_image_and_mask( | |
| image, | |
| width, | |
| height, | |
| overlap_percentage, | |
| resize_option, | |
| custom_resize_percentage, | |
| alignment, | |
| overlap_left, | |
| overlap_right, | |
| overlap_top, | |
| overlap_bottom, | |
| ): | |
| target_size = (width, height) | |
| scale_factor = min( | |
| target_size[0] / image.width, | |
| target_size[1] / image.height, | |
| ) | |
| new_width = int(image.width * scale_factor) | |
| new_height = int(image.height * scale_factor) | |
| source = image.resize( | |
| (new_width, new_height), Image.LANCZOS | |
| ) | |
| if resize_option == "Full": | |
| resize_percentage = 100 | |
| elif resize_option == "75%": | |
| resize_percentage = 75 | |
| elif resize_option == "50%": | |
| resize_percentage = 50 | |
| elif resize_option == "33%": | |
| resize_percentage = 33 | |
| elif resize_option == "25%": | |
| resize_percentage = 25 | |
| else: # Custom | |
| resize_percentage = custom_resize_percentage | |
| # Calculate new dimensions based on percentage | |
| resize_factor = resize_percentage / 100 | |
| new_width = int(source.width * resize_factor) | |
| new_height = int(source.height * resize_factor) | |
| # Ensure minimum size of 64 pixels | |
| new_width = max(new_width, 64) | |
| new_height = max(new_height, 64) | |
| # Resize the image | |
| source = source.resize( | |
| (new_width, new_height), Image.LANCZOS | |
| ) | |
| # Calculate the overlap in pixels based on the percentage | |
| overlap_x = int(new_width * (overlap_percentage / 100)) | |
| overlap_y = int(new_height * (overlap_percentage / 100)) | |
| # Ensure minimum overlap of 1 pixel | |
| overlap_x = max(overlap_x, 1) | |
| overlap_y = max(overlap_y, 1) | |
| # Calculate margins based on alignment | |
| if alignment == "Middle": | |
| margin_x = (target_size[0] - new_width) // 2 | |
| margin_y = (target_size[1] - new_height) // 2 | |
| elif alignment == "Left": | |
| margin_x = 0 | |
| margin_y = (target_size[1] - new_height) // 2 | |
| elif alignment == "Right": | |
| margin_x = target_size[0] - new_width | |
| margin_y = (target_size[1] - new_height) // 2 | |
| elif alignment == "Top": | |
| margin_x = (target_size[0] - new_width) // 2 | |
| margin_y = 0 | |
| elif alignment == "Bottom": | |
| margin_x = (target_size[0] - new_width) // 2 | |
| margin_y = target_size[1] - new_height | |
| # Adjust margins to eliminate gaps | |
| margin_x = max( | |
| 0, min(margin_x, target_size[0] - new_width) | |
| ) | |
| margin_y = max( | |
| 0, min(margin_y, target_size[1] - new_height) | |
| ) | |
| # Create a new background image and paste the resized source image | |
| background = Image.new( | |
| "RGB", target_size, (255, 255, 255) | |
| ) | |
| background.paste(source, (margin_x, margin_y)) | |
| # Create the mask | |
| mask = Image.new("L", target_size, 255) | |
| mask_draw = ImageDraw.Draw(mask) | |
| # Calculate overlap areas | |
| white_gaps_patch = 2 | |
| left_overlap = ( | |
| margin_x + overlap_x | |
| if overlap_left | |
| else margin_x + white_gaps_patch | |
| ) | |
| right_overlap = ( | |
| margin_x + new_width - overlap_x | |
| if overlap_right | |
| else margin_x + new_width - white_gaps_patch | |
| ) | |
| top_overlap = ( | |
| margin_y + overlap_y | |
| if overlap_top | |
| else margin_y + white_gaps_patch | |
| ) | |
| bottom_overlap = ( | |
| margin_y + new_height - overlap_y | |
| if overlap_bottom | |
| else margin_y + new_height - white_gaps_patch | |
| ) | |
| if alignment == "Left": | |
| left_overlap = ( | |
| margin_x + overlap_x | |
| if overlap_left | |
| else margin_x | |
| ) | |
| elif alignment == "Right": | |
| right_overlap = ( | |
| margin_x + new_width - overlap_x | |
| if overlap_right | |
| else margin_x + new_width | |
| ) | |
| elif alignment == "Top": | |
| top_overlap = ( | |
| margin_y + overlap_y | |
| if overlap_top | |
| else margin_y | |
| ) | |
| elif alignment == "Bottom": | |
| bottom_overlap = ( | |
| margin_y + new_height - overlap_y | |
| if overlap_bottom | |
| else margin_y + new_height | |
| ) | |
| # Draw the mask | |
| mask_draw.rectangle( | |
| [ | |
| (left_overlap, top_overlap), | |
| (right_overlap, bottom_overlap), | |
| ], | |
| fill=0, | |
| ) | |
| return background, mask | |
| mask_preview_btn.click( | |
| fn=prepare_image_and_mask, | |
| inputs=[ | |
| outpaint_img, | |
| width_slider, | |
| height_slider, | |
| mask_overlap_slider, | |
| resize_options, | |
| resize_option_custom, | |
| mask_position, | |
| overlap_left, | |
| overlap_right, | |
| overlap_top, | |
| overlap_bottom, | |
| ], | |
| outputs=[mask_preview_img, outpaint_img], | |
| ) | |
| mask_preview_img.clear( | |
| fn=lambda: gr.update(visible=False), | |
| outputs=mask_preview_img, | |
| ) | |
| use_fill_pipe_outpaint = gr.Checkbox( | |
| value=True, | |
| label="Use Fill Pipeline 🧪", | |
| interactive=True, | |
| ) | |
| show_info() | |
| with gr.Tab("In-Context"): | |
| with gr.Group(): | |
| incontext_img = gr.Image(show_label=False, interactive=True) | |
| # https://huggingface.co/spaces/Yuanshi/OminiControl | |
| show_info(content="1024 res is in beta") | |
| with gr.Tab("IP-Adapter"): | |
| with gr.Group(): | |
| ip_adapter_img = gr.Image( | |
| show_label=False, interactive=True | |
| ) | |
| ip_adapter_img_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.7, | |
| step=0.1, | |
| label="Scale", | |
| interactive=True, | |
| ) | |
| # https://huggingface.co/InstantX/FLUX.1-dev-IP-Adapter | |
| show_info(content="1024 res is in beta") | |
| with gr.Tab("Canny"): | |
| with gr.Group(): | |
| canny_img = gr.Image(show_label=False, interactive=True) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| canny_controlnet_conditioning_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.65, | |
| step=0.05, | |
| label="ControlNet Conditioning Scale", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| canny_img_is_preprocessed = gr.Checkbox( | |
| value=True, | |
| label="Preprocessed", | |
| interactive=True, | |
| ) | |
| with gr.Tab("Tile"): | |
| with gr.Group(): | |
| tile_img = gr.Image(show_label=False, interactive=True) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| tile_controlnet_conditioning_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.45, | |
| step=0.05, | |
| label="ControlNet Conditioning Scale", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| tile_img_is_preprocessed = gr.Checkbox( | |
| value=True, | |
| label="Preprocessed", | |
| interactive=True, | |
| ) | |
| with gr.Tab("Depth"): | |
| with gr.Group(): | |
| depth_img = gr.Image(show_label=False, interactive=True) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| depth_controlnet_conditioning_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.55, | |
| step=0.05, | |
| label="ControlNet Conditioning Scale", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| depth_img_is_preprocessed = gr.Checkbox( | |
| value=True, | |
| label="Preprocessed", | |
| interactive=True, | |
| ) | |
| with gr.Tab("Blur"): | |
| with gr.Group(): | |
| blur_img = gr.Image(show_label=False, interactive=True) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| blur_controlnet_conditioning_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.45, | |
| step=0.05, | |
| label="ControlNet Conditioning Scale", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| blur_img_is_preprocessed = gr.Checkbox( | |
| value=True, | |
| label="Preprocessed", | |
| interactive=True, | |
| ) | |
| with gr.Tab("Pose"): | |
| with gr.Group(): | |
| pose_img = gr.Image(show_label=False, interactive=True) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| pose_controlnet_conditioning_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.55, | |
| step=0.05, | |
| label="ControlNet Conditioning Scale", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| pose_img_is_preprocessed = gr.Checkbox( | |
| value=True, | |
| label="Preprocessed", | |
| interactive=True, | |
| ) | |
| with gr.Tab("Gray"): | |
| with gr.Group(): | |
| gray_img = gr.Image(show_label=False, interactive=True) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| gray_controlnet_conditioning_scale = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.45, | |
| step=0.05, | |
| label="ControlNet Conditioning Scale", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| gray_img_is_preprocessed = gr.Checkbox( | |
| value=True, | |
| label="Preprocessed", | |
| interactive=True, | |
| ) | |
| with gr.Tab("Low Quality"): | |
| with gr.Group(): | |
| low_quality_img = gr.Image( | |
| show_label=False, interactive=True | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=3): | |
| low_quality_controlnet_conditioning_scale = ( | |
| gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.4, | |
| step=0.05, | |
| label="ControlNet Conditioning Scale", | |
| interactive=True, | |
| ) | |
| ) | |
| with gr.Column(scale=1): | |
| low_quality_img_is_preprocessed = gr.Checkbox( | |
| value=True, | |
| label="Preprocessed", | |
| interactive=True, | |
| ) | |
| # with gr.Tab("Official Canny"): | |
| # with gr.Group(): | |
| # gr.HTML( | |
| # """ | |
| # <script | |
| # type="module" | |
| # src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js" | |
| # ></script> | |
| # <gradio-app src="https://black-forest-labs-flux-1-canny-dev.hf.space"></gradio-app> | |
| # """ | |
| # ) | |
| # with gr.Tab("Official Depth"): | |
| # with gr.Group(): | |
| # gr.HTML( | |
| # """ | |
| # <script | |
| # type="module" | |
| # src="https://gradio.s3-us-west-2.amazonaws.com/5.6.0/gradio.js" | |
| # ></script> | |
| # <gradio-app src="https://black-forest-labs-flux-1-depth-dev.hf.space"></gradio-app> | |
| # """ | |
| # ) | |
| with gr.Tab("Auto Trainer"): | |
| gr.HTML( | |
| """ | |
| <script | |
| type="module" | |
| src="https://gradio.s3-us-west-2.amazonaws.com/4.42.0/gradio.js" | |
| ></script> | |
| <gradio-app src="https://autotrain-projects-train-flux-lora-ease.hf.space"></gradio-app> | |
| """ | |
| ) | |
| resize_mode_radio = gr.Radio( | |
| label="Resize Mode", | |
| choices=["Crop & Resize", "Resize Only", "Resize & Fill"], | |
| value="Resize & Fill", | |
| interactive=True, | |
| ) | |
| with gr.Accordion("Prompt Generator", open=False): | |
| gr.HTML( | |
| """ | |
| <gradio-app src="https://gokaygokay-flux-prompt-generator.hf.space"></gradio-app> | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| # Loras | |
| with gr.Accordion("Loras", open=True): | |
| selected_loras = gr.State([]) | |
| lora_selector = gr.Gallery( | |
| show_label=False, | |
| value=[(l["image"], l["title"]) for l in flux1_loras], | |
| container=False, | |
| columns=3, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| allow_preview=False, | |
| ) | |
| with gr.Group(): | |
| lora_selected = gr.Textbox( | |
| show_label=False, | |
| placeholder="Select a Lora to apply...", | |
| container=False, | |
| ) | |
| add_lora_btn = gr.Button("Add Lora", size="sm") | |
| gr.Markdown( | |
| "*You can add a Lora by entering a URL or a Hugging Face repo path." | |
| ) | |
| # update the selected_loras state with the new lora | |
| def add_lora(lora_selected): | |
| title = None | |
| weights = None | |
| info = None | |
| if isinstance(lora_selected, int): | |
| # Add from lora selector | |
| title = lora_selector[lora_selected]["title"] | |
| weights = lora_selector[lora_selected]["weights"] | |
| info = lora_selector[lora_selected]["trigger_word"] | |
| elif isinstance(lora_selected, str): | |
| # check if url | |
| if lora_selected.startswith("http"): | |
| # Check if it's a CivitAI URL | |
| if "civitai.com/models/" in lora_selected: | |
| try: | |
| # Extract model ID and version ID from URL | |
| model_id = re.search( | |
| r"/models/(\d+)", lora_selected | |
| ).group(1) | |
| version_id = re.search( | |
| r"modelVersionId=(\d+)", lora_selected | |
| ) | |
| version_id = ( | |
| version_id.group(1) if version_id else None | |
| ) | |
| # Get API token from config | |
| api_token = config.get("CIVITAI_TOKEN") | |
| headers = ( | |
| {"Authorization": f"Bearer {api_token}"} | |
| if api_token | |
| else {} | |
| ) | |
| # Get model version info | |
| if version_id: | |
| url = f"https://civitai.com/api/v1/model-versions/{version_id}" | |
| else: | |
| # Get latest version if no specific version | |
| url = f"https://civitai.com/api/v1/models/{model_id}" | |
| response = requests.get(url, headers=headers) | |
| data = response.json() | |
| # For models endpoint, get first version | |
| if "modelVersions" in data: | |
| version_data = data["modelVersions"][0] | |
| else: | |
| version_data = data | |
| # Verify it's a LoRA for Flux | |
| if ( | |
| "flux" not in version_data["baseModel"].lower() | |
| and "1" not in version_data["baseModel"].lower() | |
| ): | |
| raise ValueError( | |
| "This LoRA is not compatible with Flux base model" | |
| ) | |
| # Find .safetensor file | |
| safetensor_file = next( | |
| ( | |
| f | |
| for f in version_data["files"] | |
| if f["name"].endswith(".safetensors") | |
| ), | |
| None, | |
| ) | |
| if not safetensor_file: | |
| raise ValueError("No .safetensor file found") | |
| # Download file to temp location | |
| temp_dir = tempfile.gettempdir() | |
| file_path = os.path.join( | |
| temp_dir, safetensor_file["name"] | |
| ) | |
| download_url = safetensor_file["downloadUrl"] | |
| if api_token: | |
| download_url += f"?token={api_token}" | |
| response = requests.get( | |
| download_url, headers=headers | |
| ) | |
| with open(file_path, "wb") as f: | |
| f.write(response.content) | |
| # Set info from model data | |
| title = data["name"] | |
| weights = file_path | |
| # Check usage tips for default weight | |
| if "description" in version_data: | |
| strength_match = re.search( | |
| r"strength[:\s]+(\d*\.?\d+)", | |
| version_data["description"], | |
| re.IGNORECASE, | |
| ) | |
| if strength_match: | |
| weight = float(strength_match.group(1)) | |
| info = ", ".join( | |
| version_data.get("trainedWords", []) | |
| ) | |
| except Exception as e: | |
| gr.Error(f"Error processing CivitAI URL: {str(e)}") | |
| else: | |
| # check if a hugging face repo (user/repo) | |
| if re.match( | |
| r"^[a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+$", lora_selected | |
| ): | |
| try: | |
| # Get API token from config | |
| api_token = config.get("HF_TOKEN") | |
| headers = ( | |
| {"Authorization": f"Bearer {api_token}"} | |
| if api_token | |
| else {} | |
| ) | |
| # Get model info | |
| url = f"https://huggingface.co/api/models/{lora_selected}" | |
| response = requests.get(url, headers=headers) | |
| data = response.json() | |
| # Verify it's a LoRA for Flux | |
| if ( | |
| "tags" in data | |
| and "flux-lora" not in data["tags"] | |
| ): | |
| raise ValueError( | |
| "This model is not tagged as a Flux LoRA" | |
| ) | |
| # Find .safetensor file | |
| files_url = f"https://huggingface.co/api/models/{lora_selected}/tree" | |
| response = requests.get(files_url, headers=headers) | |
| files = response.json() | |
| safetensor_file = next( | |
| ( | |
| f | |
| for f in files | |
| if f.get("path", "").endswith( | |
| ".safetensors" | |
| ) | |
| ), | |
| None, | |
| ) | |
| if not safetensor_file: | |
| raise ValueError("No .safetensor file found") | |
| # Download file to temp location | |
| temp_dir = tempfile.gettempdir() | |
| file_name = os.path.basename( | |
| safetensor_file["path"] | |
| ) | |
| file_path = os.path.join(temp_dir, file_name) | |
| download_url = ( | |
| f"https://huggingface.co/{lora_selected}" | |
| f"/resolve/main/{safetensor_file['path']}" | |
| ) | |
| response = requests.get( | |
| download_url, headers=headers | |
| ) | |
| with open(file_path, "wb") as f: | |
| f.write(response.content) | |
| # Set info from model data | |
| title = data.get( | |
| "name", lora_selected.split("/")[-1] | |
| ) | |
| weights = file_path | |
| # Check model card for weight recommendations | |
| if ( | |
| "cardData" in data | |
| and "weight" in data["cardData"] | |
| ): | |
| try: | |
| weight = float(data["cardData"]["weight"]) | |
| except (ValueError, TypeError): | |
| weight = 1.0 | |
| # Get trigger words from tags or model card | |
| trigger_words = [] | |
| if ( | |
| "cardData" in data | |
| and "trigger_words" in data["cardData"] | |
| ): | |
| trigger_words.extend( | |
| data["cardData"]["trigger_words"] | |
| ) | |
| if "tags" in data: | |
| trigger_words.extend( | |
| t | |
| for t in data["tags"] | |
| if not t.startswith("flux-") | |
| ) | |
| info = ( | |
| ", ".join(trigger_words) | |
| if trigger_words | |
| else None | |
| ) | |
| except Exception as e: | |
| gr.Error( | |
| f"Error processing Hugging Face repo: {str(e)}" | |
| ) | |
| # add lora to selected_loras | |
| selected_loras.append( | |
| { | |
| "title": title, | |
| "weights": weights, # i.e safetensors file path | |
| "info": info, | |
| } | |
| ) | |
| # render the selected_loras state as sliders | |
| def render_selected_loras(selected_loras): | |
| def update_lora_weight(lora_slider, selected_loras): | |
| for i, lora in enumerate(selected_loras): | |
| if lora["title"] == lora_slider.label: | |
| lora["weight"] = lora_slider.value | |
| for i, lora in enumerate(selected_loras): | |
| lora_slider = gr.Slider( | |
| label=lora["title"], | |
| value=0.8, | |
| interactive=True, | |
| info=lora["info"], | |
| ) | |
| lora_slider.change( | |
| fn=update_lora_weight, | |
| inputs=[lora_slider, selected_loras], | |
| outputs=selected_loras, | |
| ) | |
| demo.launch() | |