ginipick's picture
Update app.py
c2f47fd verified
raw
history blame
11.6 kB
from typing import Optional
import spaces
import gradio as gr
import numpy as np
import torch
from PIL import Image
import io
import base64, os
from huggingface_hub import snapshot_download
import traceback
# Import μœ ν‹Έλ¦¬ν‹° ν•¨μˆ˜λ“€
from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img
# Download repository (if not already downloaded)
repo_id = "microsoft/OmniParser-v2.0" # HF repository ID
local_dir = "weights" # Local directory for weights
# Check if weights already exist to avoid re-downloading
if not os.path.exists(local_dir):
snapshot_download(repo_id=repo_id, local_dir=local_dir)
print(f"Repository downloaded to: {local_dir}")
else:
print(f"Weights already exist at: {local_dir}")
# Load models with error handling
try:
yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt')
caption_model_processor = get_caption_model_processor(
model_name="florence2",
model_name_or_path="weights/icon_caption"
)
print("Models loaded successfully")
except Exception as e:
print(f"Error loading models: {e}")
raise
# Markdown header text
MARKDOWN = """
# OmniParser V2 ProπŸ”₯
<div style="background-color: #f0f8ff; padding: 15px; border-radius: 10px; margin-bottom: 20px;">
<p style="margin: 0;">🎯 <strong>AI-powered screen understanding tool</strong> that detects UI elements and extracts text with high accuracy.</p>
<p style="margin: 5px 0 0 0;">πŸ“ Supports both PaddleOCR and EasyOCR for flexible text extraction.</p>
</div>
"""
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
# Custom CSS for UI enhancement
custom_css = """
body { background-color: #f0f2f5; }
.gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; }
h1, h2, h3, h4 { color: #283E51; }
button { border-radius: 6px; transition: all 0.3s ease; }
button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); }
.output-image { border: 2px solid #e1e4e8; border-radius: 8px; }
#input_image { border: 2px dashed #4a90e2; border-radius: 8px; }
#input_image:hover { border-color: #2c5aa0; }
.gr-box { border-radius: 8px; }
.gr-padded { padding: 16px; }
"""
@spaces.GPU
@torch.inference_mode()
def process(
image_input,
box_threshold,
iou_threshold,
use_paddleocr,
imgsz
) -> tuple:
"""Process image with error handling and validation"""
# Input validation
if image_input is None:
return None, "⚠️ Please upload an image for processing."
try:
# Log processing parameters
print(f"Processing with parameters: box_threshold={box_threshold}, "
f"iou_threshold={iou_threshold}, use_paddleocr={use_paddleocr}, imgsz={imgsz}")
# Calculate overlay ratio based on input image width
image_width = image_input.size[0]
box_overlay_ratio = max(0.5, min(2.0, image_width / 3200)) # Clamp ratio between 0.5 and 2.0
draw_bbox_config = {
'text_scale': 0.8 * box_overlay_ratio,
'text_thickness': max(int(2 * box_overlay_ratio), 1),
'text_padding': max(int(3 * box_overlay_ratio), 1),
'thickness': max(int(3 * box_overlay_ratio), 1),
}
# Run OCR bounding box detection with error handling
try:
ocr_bbox_rslt, is_goal_filtered = check_ocr_box(
image_input,
display_img=False,
output_bb_format='xyxy',
goal_filtering=None,
easyocr_args={'paragraph': False, 'text_threshold': 0.9},
use_paddleocr=use_paddleocr
)
# Handle None result from OCR
if ocr_bbox_rslt is None:
print("OCR returned None, using empty results")
text, ocr_bbox = [], []
else:
text, ocr_bbox = ocr_bbox_rslt
# Validate OCR results
if text is None:
text = []
if ocr_bbox is None:
ocr_bbox = []
print(f"OCR found {len(text)} text regions")
except Exception as e:
print(f"OCR error: {e}, continuing with empty OCR results")
text, ocr_bbox = [], []
# Get labeled image and parsed content via SOM (YOLO + caption model)
try:
dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(
image_input,
yolo_model,
BOX_TRESHOLD=box_threshold,
output_coord_in_ratio=True,
ocr_bbox=ocr_bbox if ocr_bbox else [], # Ensure it's never None
draw_bbox_config=draw_bbox_config,
caption_model_processor=caption_model_processor,
ocr_text=text if text else [], # Ensure it's never None
iou_threshold=iou_threshold,
imgsz=imgsz
)
if dino_labled_img is None:
raise ValueError("Failed to generate labeled image")
except Exception as e:
print(f"Error in SOM processing: {e}")
# Return original image with error message if SOM fails
return image_input, f"⚠️ Error during element detection: {str(e)}"
# Decode processed image from base64
try:
image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
print('Successfully decoded processed image')
except Exception as e:
print(f"Error decoding image: {e}")
return image_input, f"⚠️ Error decoding processed image: {str(e)}"
# Format parsed content list into a multi-line string
if parsed_content_list and len(parsed_content_list) > 0:
parsed_text = "🎯 **Detected Elements:**\n\n"
for i, v in enumerate(parsed_content_list):
if v: # Only add non-empty content
parsed_text += f"**Icon {i}:** {v}\n"
else:
parsed_text = "ℹ️ No UI elements detected. Try adjusting the detection thresholds."
print(f'Finished processing image. Found {len(parsed_content_list)} elements.')
return image, parsed_text
except Exception as e:
error_msg = f"⚠️ Unexpected error: {str(e)}"
print(f"Error during processing: {e}")
print(traceback.format_exc())
return None, error_msg
# Build Gradio UI with enhanced layout and functionality
with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="OmniParser V2 Pro") as demo:
gr.Markdown(MARKDOWN)
with gr.Row():
# Left sidebar: Upload and settings
with gr.Column(scale=1):
with gr.Accordion("πŸ“€ Upload Image & Settings", open=True):
image_input_component = gr.Image(
type='pil',
label='Upload Screenshot/UI Image',
elem_id="input_image"
)
gr.Markdown("### πŸŽ›οΈ Detection Settings")
with gr.Group():
box_threshold_component = gr.Slider(
label='πŸ“Š Box Threshold',
minimum=0.01,
maximum=1.0,
step=0.01,
value=0.05,
info="Lower values detect more elements (may include false positives)"
)
iou_threshold_component = gr.Slider(
label='πŸ”² IOU Threshold',
minimum=0.01,
maximum=1.0,
step=0.01,
value=0.1,
info="Controls overlap filtering (lower = less filtering)"
)
use_paddleocr_component = gr.Checkbox(
label='πŸ”€ Use PaddleOCR',
value=True,
info="βœ“ PaddleOCR (faster) | βœ— EasyOCR (more languages)"
)
imgsz_component = gr.Slider(
label='πŸ“ Detection Image Size',
minimum=640,
maximum=1920,
step=32,
value=640,
info="Higher = better accuracy but slower (640 recommended)"
)
submit_button_component = gr.Button(
value='πŸš€ Process Image',
variant='primary',
size='lg'
)
# Add examples section
gr.Markdown("### πŸ’‘ Quick Tips")
gr.Markdown("""
- **For mobile apps:** Use default settings
- **For desktop apps:** Try image size 1280
- **For complex UIs:** Lower box threshold to 0.03
- **Too many boxes?** Increase IOU threshold
""")
# Right main area: Results tabs
with gr.Column(scale=2):
with gr.Tabs():
with gr.Tab("πŸ–ΌοΈ Annotated Image"):
image_output_component = gr.Image(
type='pil',
label='Processed Image with Annotations',
elem_classes=["output-image"]
)
with gr.Tab("πŸ“ Extracted Elements"):
text_output_component = gr.Markdown(
value="*Parsed elements will appear here after processing...*",
elem_classes=["parsed-text"]
)
# Add status indicator
status_text = gr.Markdown("", visible=True)
# Button click event with loading spinner
submit_button_component.click(
fn=process,
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component,
use_paddleocr_component,
imgsz_component
],
outputs=[image_output_component, text_output_component],
show_progress=True
)
# Add sample images if available
if os.path.exists("samples"):
gr.Examples(
examples=[
["samples/mobile_app.png", 0.05, 0.1, True, 640],
["samples/desktop_app.png", 0.05, 0.1, True, 1280],
],
inputs=[
image_input_component,
box_threshold_component,
iou_threshold_component,
use_paddleocr_component,
imgsz_component
],
outputs=[image_output_component, text_output_component],
fn=process,
cache_examples=False
)
# Launch with queue support and error handling
if __name__ == "__main__":
try:
demo.queue(max_size=10)
demo.launch(
share=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)
except Exception as e:
print(f"Failed to launch app: {e}")
raise