Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| import time | |
| import os | |
| import spaces | |
| try: | |
| from gen2seg_sd_pipeline import gen2segSDPipeline | |
| from gen2seg_mae_pipeline import gen2segMAEInstancePipeline | |
| except ImportError as e: | |
| print(f"Error importing pipeline modules: {e}") | |
| print("Please ensure gen2seg_sd_pipeline.py and gen2seg_mae_pipeline.py are in the same directory.") | |
| # Optionally, raise an error or exit if pipelines are critical at startup | |
| # raise ImportError("Could not import custom pipeline modules. Check file paths.") from e | |
| from transformers import ViTMAEForPreTraining, AutoImageProcessor | |
| # --- Configuration --- | |
| MODEL_IDS = { | |
| "SD": "reachomk/gen2seg-sd", | |
| "MAE-H": "reachomk/gen2seg-mae-h" | |
| } | |
| # Check if a GPU is available and set the device accordingly | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Using device: {DEVICE}") | |
| # --- Global Variables for Caching Pipelines --- | |
| sd_pipe_global = None | |
| mae_pipe_global = None | |
| # --- Model Loading Functions --- | |
| def get_sd_pipeline(): | |
| """Loads and caches the gen2seg Stable Diffusion pipeline.""" | |
| global sd_pipe_global | |
| if sd_pipe_global is None: | |
| model_id_sd = MODEL_IDS["SD"] | |
| print(f"Attempting to load SD pipeline from Hugging Face Hub: {model_id_sd}") | |
| try: | |
| sd_pipe_global = gen2segSDPipeline.from_pretrained( | |
| model_id_sd, | |
| use_safetensors=True, | |
| # torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, # Optional: use float16 on GPU | |
| ).to(DEVICE) | |
| print(f"SD Pipeline loaded successfully from {model_id_sd} on {DEVICE}.") | |
| except Exception as e: | |
| print(f"Error loading SD pipeline from Hugging Face Hub ({model_id_sd}): {e}") | |
| sd_pipe_global = None # Ensure it remains None on failure | |
| # Do not raise gr.Error here; let the main function handle it. | |
| return sd_pipe_global | |
| def get_mae_pipeline(): | |
| """Loads and caches the gen2seg MAE-H pipeline.""" | |
| global mae_pipe_global | |
| if mae_pipe_global is None: | |
| model_id_mae = MODEL_IDS["MAE-H"] | |
| print(f"Loading MAE-H pipeline with model {model_id_mae} on {DEVICE}...") | |
| try: | |
| model = ViTMAEForPreTraining.from_pretrained(model_id_mae) | |
| model.to(DEVICE) | |
| model.eval() # Set to evaluation mode | |
| # Load the official MAE-H image processor | |
| # Using "facebook/vit-mae-huge" as per the original app_mae.py | |
| image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge") | |
| mae_pipe_global = gen2segMAEInstancePipeline(model=model, image_processor=image_processor) | |
| # The custom MAE pipeline's model is already on the DEVICE. | |
| print(f"MAE-H Pipeline with model {model_id_mae} loaded successfully on {DEVICE}.") | |
| except Exception as e: | |
| print(f"Error loading MAE-H model or pipeline from Hugging Face Hub ({model_id_mae}): {e}") | |
| mae_pipe_global = None # Ensure it remains None on failure | |
| # Do not raise gr.Error here; let the main function handle it. | |
| return mae_pipe_global | |
| # --- Unified Prediction Function --- | |
| def segment_image(input_image: Image.Image, model_choice: str) -> Image.Image: | |
| """ | |
| Takes a PIL Image and model choice, performs segmentation, and returns the segmented image. | |
| """ | |
| if input_image is None: | |
| raise gr.Error("No image provided. Please upload an image.") | |
| print(f"Model selected: {model_choice}") | |
| # Ensure image is in RGB format | |
| image_rgb = input_image.convert("RGB") | |
| original_resolution = image_rgb.size # (width, height) | |
| seg_array = None | |
| try: | |
| if model_choice == "SD": | |
| pipe_sd = get_sd_pipeline() | |
| if pipe_sd is None: | |
| raise gr.Error("The SD segmentation pipeline could not be loaded. " | |
| "Please check the Space logs for more details, or try again later.") | |
| print(f"Running SD inference with image size: {image_rgb.size}") | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| # The gen2segSDPipeline expects a single image or a list | |
| # The pipeline's __call__ method handles preprocessing internally | |
| seg_output = pipe_sd(image_rgb, match_input_resolution=False).prediction # Output is before resize | |
| # seg_output is expected to be a numpy array (N,H,W,1) or (N,1,H,W) or tensor | |
| # Based on gen2seg_sd_pipeline.py, if output_type="np" (default), it's [N,H,W,1] | |
| # If output_type="pt", it's [N,1,H,W] | |
| # The original app_sd.py converted tensor to numpy and squeezed. | |
| if isinstance(seg_output, torch.Tensor): | |
| seg_output = seg_output.cpu().numpy() | |
| if seg_output.ndim == 4 and seg_output.shape[0] == 1: # Batch size 1 | |
| if seg_output.shape[1] == 1: # Grayscale, (1, 1, H, W) | |
| seg_array = seg_output.squeeze(0).squeeze(0).astype(np.uint8) | |
| elif seg_output.shape[-1] == 1: # Grayscale, (1, H, W, 1) | |
| seg_array = seg_output.squeeze(0).squeeze(-1).astype(np.uint8) | |
| elif seg_output.shape[1] == 3: # RGB, (1, 3, H, W) -> (H, W, 3) | |
| seg_array = np.transpose(seg_output.squeeze(0), (1, 2, 0)).astype(np.uint8) | |
| elif seg_output.shape[-1] == 3: # RGB, (1, H, W, 3) | |
| seg_array = seg_output.squeeze(0).astype(np.uint8) | |
| else: # Fallback for unexpected shapes | |
| seg_array = seg_output.squeeze().astype(np.uint8) | |
| elif seg_output.ndim == 3: # (H, W, C) or (C, H, W) | |
| seg_array = seg_output.astype(np.uint8) | |
| elif seg_output.ndim == 2: # (H,W) | |
| seg_array = seg_output.astype(np.uint8) | |
| else: | |
| raise TypeError(f"Unexpected SD segmentation output type/shape: {type(seg_output)}, {seg_output.shape}") | |
| end_time = time.time() | |
| print(f"SD Inference completed in {end_time - start_time:.2f} seconds.") | |
| elif model_choice == "MAE-H": | |
| pipe_mae = get_mae_pipeline() | |
| if pipe_mae is None: | |
| raise gr.Error("The MAE-H segmentation pipeline could not be loaded. " | |
| "Please check the Space logs for more details, or try again later.") | |
| print(f"Running MAE-H inference with image size: {image_rgb.size}") | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| # The gen2segMAEInstancePipeline expects a list of images | |
| # output_type="np" returns a NumPy array | |
| pipe_output = pipe_mae([image_rgb], output_type="np") | |
| # Prediction is (batch_size, height, width, 3) for MAE | |
| prediction_np = pipe_output.prediction[0] # Get the first (and only) image prediction | |
| end_time = time.time() | |
| print(f"MAE-H Inference completed in {end_time - start_time:.2f} seconds.") | |
| if not isinstance(prediction_np, np.ndarray): | |
| # This case should ideally not be reached if output_type="np" | |
| prediction_np = prediction_np.cpu().numpy() | |
| # Ensure it's in the expected (H, W, C) format and uint8 | |
| if prediction_np.ndim == 3 and prediction_np.shape[-1] == 3: # Expected (H, W, 3) | |
| seg_array = prediction_np.astype(np.uint8) | |
| else: | |
| # Attempt to handle other shapes if necessary, or raise error | |
| raise gr.Error(f"Unexpected MAE-H prediction shape: {prediction_np.shape}. Expected (H, W, 3).") | |
| # The MAE pipeline already does gamma correction and scaling to 0-255. | |
| # It also ensures 3 channels. | |
| else: | |
| raise gr.Error(f"Invalid model choice: {model_choice}. Please select a valid model.") | |
| if seg_array is None: | |
| raise gr.Error("Segmentation array was not generated. An unknown error occurred.") | |
| print(f"Segmentation array generated with shape: {seg_array.shape}, dtype: {seg_array.dtype}") | |
| # Convert numpy array to PIL Image | |
| # Handle grayscale or RGB based on seg_array channels | |
| if seg_array.ndim == 2: # Grayscale | |
| segmented_image_pil = Image.fromarray(seg_array, mode='L') | |
| elif seg_array.ndim == 3 and seg_array.shape[-1] == 3: # RGB | |
| segmented_image_pil = Image.fromarray(seg_array, mode='RGB') | |
| elif seg_array.ndim == 3 and seg_array.shape[-1] == 1: # Grayscale with channel dim | |
| segmented_image_pil = Image.fromarray(seg_array.squeeze(-1), mode='L') | |
| else: | |
| raise gr.Error(f"Cannot convert seg_array with shape {seg_array.shape} to PIL Image.") | |
| # Resize back to original image resolution using LANCZOS for high quality | |
| segmented_image_pil = segmented_image_pil.resize(original_resolution, Image.Resampling.LANCZOS) | |
| print(f"Segmented image processed. Output size: {segmented_image_pil.size}, mode: {segmented_image_pil.mode}") | |
| return segmented_image_pil | |
| except Exception as e: | |
| print(f"Error during segmentation with {model_choice}: {e}") | |
| # Re-raise as gr.Error for Gradio to display, if not already one | |
| if not isinstance(e, gr.Error): | |
| # It's often helpful to include the type of the original exception | |
| error_type = type(e).__name__ | |
| raise gr.Error(f"An error occurred during segmentation: {error_type} - {str(e)}") | |
| else: | |
| raise e # Re-raise if it's already a gr.Error | |
| # --- Gradio Interface --- | |
| title = "gen2seg: Generative Models Enable Generalizable Instance Segmentation Demo (SD & MAE-H)" | |
| description = f""" | |
| <div style="text-align: center; font-family: 'Arial', sans-serif;"> | |
| <p>Upload an image and choose a model architecture to see the instance segmentation result generated by the respective model. </p> | |
| <p> | |
| BIG THANKS to Huggingface for funding our demo with their Academic GPU Grant! | |
| </p> | |
| <ul> | |
| <li><strong>SD</strong>: Based on Stable Diffusion 2. | |
| <a href="https://huggingface.co/{MODEL_IDS['SD']}" target="_blank">Model Link</a>. | |
| </li> | |
| <li><strong>MAE-H</strong>: Based on Masked Autoencoder (Huge). | |
| <a href="https://huggingface.co/{MODEL_IDS['MAE-H']}" target="_blank">Model Link</a>. | |
| If you experience tokenizer artifacts or very dark images, you can use gamma correction to handle this. | |
| </li> | |
| </ul> | |
| <p> | |
| Paper: <a href="https://arxiv.org/abs/2505.15263">https://arxiv.org/abs/2505.15263</a> | |
| </p> | |
| <p> | |
| For faster inference, please check out our GitHub to run the models locally on a GPU: | |
| <a href="https://github.com/UCDvision/gen2seg" target="_blank">https://github.com/UCDvision/gen2seg</a> or check out our Colab demo <a href="https://colab.research.google.com/drive/10lPBP4figJf7MLp9T1b5cDQeU7MgODw3?usp=sharing" target="_blank">here</a>. | |
| </p> | |
| <p>If the demo experiences issues, please open an issue on our GitHub.</p> | |
| <p> If you have not already, please see our webpage at <a href="https://reachomk.github.io/gen2seg" target="_blank">https://reachomk.github.io/gen2seg</a>. | |
| </div> | |
| """ | |
| article = """ | |
| """ | |
| # Define Gradio inputs | |
| input_image_component = gr.Image(type="pil", label="Input Image") | |
| model_choice_component = gr.Dropdown( | |
| choices=list(MODEL_IDS.keys()), | |
| value="SD", # Default model | |
| label="Choose Segmentation Model Architecture" | |
| ) | |
| # Define Gradio output | |
| output_image_component = gr.Image(type="pil", label="Segmented Image") | |
| # Example images (ensure these paths are correct if you upload examples to your Space) | |
| # For example, if you create an "examples" folder in your Space repo: | |
| # example_paths = [ | |
| # os.path.join("examples", "example1.jpg"), | |
| # os.path.join("examples", "example2.png") | |
| # ] | |
| # Filter out non-existent example files to prevent errors | |
| # example_paths = [ex for ex in example_paths if os.path.exists(ex)] | |
| # Base list of example image paths/URLs | |
| base_example_images = [ | |
| "cats-on-rock-1948.jpg", | |
| "dogs.png", | |
| "000000484893.jpg", | |
| "https://reachomk.github.io/gen2seg/images/comparison/vertical/7.png", | |
| "https://reachomk.github.io/gen2seg/images/comparison/horizontal/11.png", | |
| "https://reachomk.github.io/gen2seg/images/comparison/vertical/2.jpg" | |
| ] | |
| # Generate examples for each image with both model choices | |
| model_choices_for_examples = list(MODEL_IDS.keys()) # ["SD", "MAE-H"] | |
| formatted_examples = [] | |
| for img_path_or_url in base_example_images: | |
| for model_choice in model_choices_for_examples: | |
| formatted_examples.append([img_path_or_url, model_choice]) | |
| iface = gr.Interface( | |
| fn=segment_image, | |
| inputs=[input_image_component, model_choice_component], | |
| outputs=output_image_component, | |
| title=title, | |
| description=description, | |
| article=article, | |
| examples=None, #formatted_examples if formatted_examples else None, | |
| allow_flagging="never", | |
| theme="shivi/calm_seafoam" | |
| ) | |
| if __name__ == "__main__": | |
| # Optional: Pre-load a default model on startup if desired. | |
| # This can make the first inference faster but increases startup time. | |
| # print("Attempting to pre-load default SD model on startup...") | |
| try: | |
| get_sd_pipeline() # Pre-load the default SD model | |
| print("Default SD model pre-loaded successfully or was already cached.") | |
| except Exception as e: | |
| print(f"Could not pre-load default SD model: {e}") | |
| try: | |
| get_mae_pipeline() # Pre-load the default SD model | |
| print("Default mae model pre-loaded successfully or was already cached.") | |
| except Exception as e: | |
| print(f"Could not pre-load default mae model: {e}") | |
| print("Launching Gradio interface...") | |
| iface.launch() | |