Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Optional | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64, os | |
| from huggingface_hub import snapshot_download | |
| import traceback | |
| import warnings | |
| import sys | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", message=".*_supports_sdpa.*") | |
| # Simple monkey patch for transformers - avoid recursion | |
| def simple_patch_transformers(): | |
| """Simple patch to fix _supports_sdpa issue""" | |
| try: | |
| import transformers.modeling_utils as modeling_utils | |
| # Store original method | |
| original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation | |
| def patched_check(self, *args, **kwargs): | |
| # Simply set the attribute if it doesn't exist | |
| if not hasattr(self, '_supports_sdpa'): | |
| object.__setattr__(self, '_supports_sdpa', False) | |
| try: | |
| return original_check(self, *args, **kwargs) | |
| except AttributeError as e: | |
| if '_supports_sdpa' in str(e): | |
| # Return default attention implementation | |
| return "eager" | |
| raise | |
| modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check | |
| print("Applied simple transformers patch") | |
| except Exception as e: | |
| print(f"Warning: Could not patch transformers: {e}") | |
| # Apply the patch BEFORE importing utils | |
| simple_patch_transformers() | |
| # Now import the utils | |
| from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img | |
| # Download repository | |
| repo_id = "microsoft/OmniParser-v2.0" | |
| local_dir = "weights" | |
| if not os.path.exists(local_dir): | |
| snapshot_download(repo_id=repo_id, local_dir=local_dir) | |
| print(f"Repository downloaded to: {local_dir}") | |
| else: | |
| print(f"Weights already exist at: {local_dir}") | |
| # Custom function to load caption model | |
| def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"): | |
| """Safely load caption model""" | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Method 1: Try original function | |
| try: | |
| return get_caption_model_processor(model_name, model_name_or_path) | |
| except Exception as e: | |
| print(f"Original loading failed: {e}, trying alternative...") | |
| # Method 2: Load with specific configs | |
| try: | |
| from transformers import AutoProcessor, AutoModelForCausalLM | |
| print(f"Loading caption model from {model_name_or_path}...") | |
| processor = AutoProcessor.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True | |
| ) | |
| # Load model with safer config | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True, | |
| attn_implementation="eager", # Use eager attention | |
| low_cpu_mem_usage=True | |
| ) | |
| # Ensure attribute exists (using object.__setattr__ to avoid recursion) | |
| if not hasattr(model, '_supports_sdpa'): | |
| object.__setattr__(model, '_supports_sdpa', False) | |
| if device.type == 'cuda': | |
| model = model.to(device) | |
| print("Model loaded successfully with alternative method") | |
| return {'model': model, 'processor': processor} | |
| except Exception as e: | |
| print(f"Alternative loading also failed: {e}") | |
| # Method 3: Manual loading as last resort | |
| try: | |
| print("Attempting manual model loading...") | |
| # Import required modules | |
| from transformers import AutoProcessor, AutoConfig | |
| import importlib.util | |
| # Load processor | |
| processor = AutoProcessor.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True | |
| ) | |
| # Load config | |
| config = AutoConfig.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True | |
| ) | |
| # Manually import and instantiate model | |
| model_file = os.path.join(model_name_or_path, "modeling_florence2.py") | |
| if os.path.exists(model_file): | |
| spec = importlib.util.spec_from_file_location("modeling_florence2_custom", model_file) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| # Get model class | |
| if hasattr(module, 'Florence2ForConditionalGeneration'): | |
| model_class = module.Florence2ForConditionalGeneration | |
| # Create model instance | |
| model = model_class(config) | |
| # Set the attribute before loading weights | |
| object.__setattr__(model, '_supports_sdpa', False) | |
| # Load weights | |
| weight_file = os.path.join(model_name_or_path, "model.safetensors") | |
| if os.path.exists(weight_file): | |
| from safetensors.torch import load_file | |
| state_dict = load_file(weight_file) | |
| model.load_state_dict(state_dict, strict=False) | |
| if device.type == 'cuda': | |
| model = model.to(device) | |
| model = model.half() # Use half precision | |
| print("Model loaded successfully with manual method") | |
| return {'model': model, 'processor': processor} | |
| except Exception as e: | |
| print(f"Manual loading failed: {e}") | |
| raise RuntimeError(f"Could not load model with any method: {e}") | |
| # Load models | |
| try: | |
| print("Loading YOLO model...") | |
| yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') | |
| print("YOLO model loaded successfully") | |
| print("Loading caption model...") | |
| caption_model_processor = load_caption_model_safe() | |
| print("Caption model loaded successfully") | |
| except Exception as e: | |
| print(f"Critical error loading models: {e}") | |
| print(traceback.format_exc()) | |
| caption_model_processor = None | |
| yolo_model = None | |
| # UI Configuration | |
| MARKDOWN = """ | |
| # OmniParser V2 Pro🔥 | |
| <div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-bottom: 20px;"> | |
| <p style="margin: 0;">🎯 <strong>AI-powered screen understanding tool</strong> that detects UI elements and extracts text with high accuracy.</p> | |
| <p style="margin: 5px 0 0 0;">📝 Supports both PaddleOCR and EasyOCR for flexible text extraction.</p> | |
| </div> | |
| """ | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {DEVICE}") | |
| custom_css = """ | |
| body { background-color: #f0f2f5; } | |
| .gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; } | |
| h1, h2, h3, h4 { color: #283E51; } | |
| button { border-radius: 6px; transition: all 0.3s ease; } | |
| button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); } | |
| .output-image { border: 2px solid #e1e4e8; border-radius: 8px; } | |
| #input_image { border: 2px dashed #4a90e2; border-radius: 8px; } | |
| #input_image:hover { border-color: #2c5aa0; } | |
| """ | |
| def process( | |
| image_input, | |
| box_threshold, | |
| iou_threshold, | |
| use_paddleocr, | |
| imgsz | |
| ) -> tuple: | |
| """Process image with error handling""" | |
| if image_input is None: | |
| return None, "⚠️ Please upload an image for processing." | |
| if caption_model_processor is None or yolo_model is None: | |
| return None, "⚠️ Models not loaded properly. Please restart the application." | |
| try: | |
| print(f"Processing: box_threshold={box_threshold}, iou_threshold={iou_threshold}, " | |
| f"use_paddleocr={use_paddleocr}, imgsz={imgsz}") | |
| # Calculate overlay ratio | |
| image_width = image_input.size[0] | |
| box_overlay_ratio = max(0.5, min(2.0, image_width / 3200)) | |
| draw_bbox_config = { | |
| 'text_scale': 0.8 * box_overlay_ratio, | |
| 'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
| 'text_padding': max(int(3 * box_overlay_ratio), 1), | |
| 'thickness': max(int(3 * box_overlay_ratio), 1), | |
| } | |
| # OCR processing | |
| try: | |
| ocr_bbox_rslt, is_goal_filtered = check_ocr_box( | |
| image_input, | |
| display_img=False, | |
| output_bb_format='xyxy', | |
| goal_filtering=None, | |
| easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
| use_paddleocr=use_paddleocr | |
| ) | |
| if ocr_bbox_rslt is None: | |
| text, ocr_bbox = [], [] | |
| else: | |
| text, ocr_bbox = ocr_bbox_rslt | |
| text = text if text is not None else [] | |
| ocr_bbox = ocr_bbox if ocr_bbox is not None else [] | |
| print(f"OCR found {len(text)} text regions") | |
| except Exception as e: | |
| print(f"OCR error: {e}") | |
| text, ocr_bbox = [], [] | |
| # Object detection and captioning | |
| try: | |
| # Ensure model has _supports_sdpa attribute | |
| if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor: | |
| model = caption_model_processor['model'] | |
| if not hasattr(model, '_supports_sdpa'): | |
| object.__setattr__(model, '_supports_sdpa', False) | |
| dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
| image_input, | |
| yolo_model, | |
| BOX_TRESHOLD=box_threshold, | |
| output_coord_in_ratio=True, | |
| ocr_bbox=ocr_bbox, | |
| draw_bbox_config=draw_bbox_config, | |
| caption_model_processor=caption_model_processor, | |
| ocr_text=text, | |
| iou_threshold=iou_threshold, | |
| imgsz=imgsz | |
| ) | |
| if dino_labled_img is None: | |
| raise ValueError("Failed to generate labeled image") | |
| except Exception as e: | |
| print(f"Detection error: {e}") | |
| return image_input, f"⚠️ Error during detection: {str(e)}" | |
| # Decode image | |
| try: | |
| image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) | |
| except Exception as e: | |
| print(f"Image decode error: {e}") | |
| return image_input, f"⚠️ Error decoding image: {str(e)}" | |
| # Format results | |
| if parsed_content_list and len(parsed_content_list) > 0: | |
| parsed_text = "🎯 **Detected Elements:**\n\n" | |
| for i, v in enumerate(parsed_content_list): | |
| if v: | |
| parsed_text += f"**Element {i}:** {v}\n" | |
| else: | |
| parsed_text = "ℹ️ No UI elements detected. Try adjusting the thresholds." | |
| print(f'Processing complete. Found {len(parsed_content_list)} elements.') | |
| return image, parsed_text | |
| except Exception as e: | |
| print(f"Processing error: {e}") | |
| print(traceback.format_exc()) | |
| return None, f"⚠️ Error: {str(e)}" | |
| # Build UI | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(MARKDOWN) | |
| if caption_model_processor is None or yolo_model is None: | |
| gr.Markdown("### ⚠️ Warning: Models failed to load. Please check logs.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| with gr.Accordion("📤 Upload & Settings", open=True): | |
| image_input_component = gr.Image( | |
| type='pil', | |
| label='Upload Screenshot', | |
| elem_id="input_image" | |
| ) | |
| gr.Markdown("### 🎛️ Detection Settings") | |
| box_threshold_component = gr.Slider( | |
| label='Box Threshold', | |
| minimum=0.01, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.05, | |
| info="Lower = more detections" | |
| ) | |
| iou_threshold_component = gr.Slider( | |
| label='IOU Threshold', | |
| minimum=0.01, | |
| maximum=1.0, | |
| step=0.01, | |
| value=0.1, | |
| info="Overlap filtering" | |
| ) | |
| use_paddleocr_component = gr.Checkbox( | |
| label='Use PaddleOCR', | |
| value=True | |
| ) | |
| imgsz_component = gr.Slider( | |
| label='Image Size', | |
| minimum=640, | |
| maximum=1920, | |
| step=32, | |
| value=640 | |
| ) | |
| submit_button_component = gr.Button( | |
| value='🚀 Process', | |
| variant='primary' | |
| ) | |
| with gr.Column(scale=2): | |
| with gr.Tabs(): | |
| with gr.Tab("🖼️ Result"): | |
| image_output_component = gr.Image( | |
| type='pil', | |
| label='Annotated Image' | |
| ) | |
| with gr.Tab("📝 Elements"): | |
| text_output_component = gr.Markdown( | |
| value="*Results will appear here...*" | |
| ) | |
| submit_button_component.click( | |
| fn=process, | |
| inputs=[ | |
| image_input_component, | |
| box_threshold_component, | |
| iou_threshold_component, | |
| use_paddleocr_component, | |
| imgsz_component | |
| ], | |
| outputs=[image_output_component, text_output_component], | |
| show_progress=True | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| try: | |
| demo.queue(max_size=10) | |
| demo.launch( | |
| share=False, | |
| show_error=True, | |
| server_name="0.0.0.0", | |
| server_port=7860 | |
| ) | |
| except Exception as e: | |
| print(f"Launch failed: {e}") |