Matis Despujols commited on
Commit
e9324f7
·
verified ·
1 Parent(s): 072cabd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -198
README.md CHANGED
@@ -1,198 +1,13 @@
1
- import os
2
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
3
- os.environ['GRADIO_DEFAULT_LANG'] = 'en'
4
-
5
- import gradio as gr
6
- import torch
7
- import cv2
8
- import numpy as np
9
- from PIL import Image
10
- from typing import Tuple, List
11
- from rfdetr.detr import RFDETRMedium
12
-
13
- # UI Element classes
14
- CLASSES = ['button', 'field', 'heading', 'iframe', 'image', 'label', 'link', 'text']
15
-
16
- # Single color for all boxes (BGR format for OpenCV)
17
- BOX_COLOR = (0, 255, 0) # Green
18
-
19
- # Global model variable
20
- model = None
21
-
22
- def load_model(model_path: str = "model.pth"):
23
- """Load RF-DETR model"""
24
- global model
25
- if model is None:
26
- print("Loading RF-DETR model...")
27
- model = RFDETRMedium(pretrain_weights=model_path, resolution=1600)
28
- print("Model loaded successfully!")
29
- return model
30
-
31
- def draw_detections(
32
- image: np.ndarray,
33
- boxes: List[Tuple[int, int, int, int]],
34
- scores: List[float],
35
- classes: List[int],
36
- thickness: int = 3,
37
- font_scale: float = 0.6
38
- ) -> np.ndarray:
39
- """Draw detection boxes and labels on image"""
40
- img_with_boxes = image.copy()
41
-
42
- for box, score, cls_id in zip(boxes, scores, classes):
43
- x1, y1, x2, y2 = map(int, box)
44
-
45
- # Draw rectangle
46
- cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), BOX_COLOR, thickness)
47
-
48
- # Prepare label with confidence score only
49
- label = f"{score:.2f}"
50
-
51
- # Calculate label size and position
52
- (label_width, label_height), baseline = cv2.getTextSize(
53
- label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2
54
- )
55
-
56
- # Draw label background
57
- label_y = max(y1 - 10, label_height + 10)
58
- cv2.rectangle(
59
- img_with_boxes,
60
- (x1, label_y - label_height - baseline - 5),
61
- (x1 + label_width + 5, label_y + baseline - 5),
62
- BOX_COLOR,
63
- -1
64
- )
65
-
66
- # Draw label text
67
- cv2.putText(
68
- img_with_boxes,
69
- label,
70
- (x1 + 2, label_y - baseline - 5),
71
- cv2.FONT_HERSHEY_SIMPLEX,
72
- font_scale,
73
- (255, 255, 255),
74
- thickness=2
75
- )
76
-
77
- return img_with_boxes
78
-
79
- @torch.inference_mode()
80
- def detect_ui_elements(
81
- image: Image.Image,
82
- confidence_threshold: float,
83
- line_thickness: int
84
- ) -> Tuple[Image.Image, str]:
85
- """
86
- Detect UI elements in the uploaded image
87
-
88
- Args:
89
- image: Input PIL Image
90
- confidence_threshold: Minimum confidence score for detections
91
- line_thickness: Thickness of bounding box lines
92
-
93
- Returns:
94
- Annotated image and detection summary text
95
- """
96
- try:
97
- if image is None:
98
- return None, "Please upload an image first."
99
-
100
- # Load model
101
- model = load_model()
102
-
103
- # Convert PIL to numpy array (RGB)
104
- img_array = np.array(image)
105
-
106
- # Convert RGB to BGR for OpenCV
107
- img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
108
-
109
- # Run detection (returns supervision Detections object)
110
- detections = model.predict(img_array, threshold=confidence_threshold)
111
-
112
- # Extract detection data
113
- filtered_boxes = detections.xyxy # Bounding boxes in xyxy format
114
- filtered_scores = detections.confidence # Confidence scores
115
- filtered_classes = detections.class_id # Class IDs
116
-
117
- # Draw detections
118
- annotated_img = draw_detections(
119
- img_bgr,
120
- filtered_boxes.tolist(),
121
- filtered_scores.tolist(),
122
- filtered_classes.tolist(),
123
- thickness=line_thickness
124
- )
125
-
126
- # Convert back to RGB for display
127
- annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
128
- annotated_pil = Image.fromarray(annotated_img_rgb)
129
-
130
- # Create summary text
131
- summary_text = f"**Total detections:** {len(filtered_boxes)}"
132
-
133
- return annotated_pil, summary_text
134
-
135
- except Exception as e:
136
- import traceback
137
- error_msg = f"**Error during detection:**\n\n```\n{str(e)}\n\n{traceback.format_exc()}\n```"
138
- print(error_msg) # Also print to logs
139
- return None, error_msg
140
-
141
- # Gradio interface
142
- with gr.Blocks(title="CU-1 UI Element Detector", theme=gr.themes.Soft()) as demo:
143
-
144
- gr.Markdown("""
145
- # CU-1 UI Element Detector
146
-
147
- Upload a screenshot or UI mockup to automatically detect elements.
148
- """)
149
-
150
- with gr.Row():
151
- with gr.Column(scale=1):
152
- input_image = gr.Image(
153
- type="pil",
154
- label="Upload Screenshot",
155
- height=400,
156
- sources=["upload"]
157
- )
158
-
159
- with gr.Accordion("Detection Settings", open=True):
160
- confidence_slider = gr.Slider(
161
- minimum=0.1,
162
- maximum=0.9,
163
- value=0.35,
164
- step=0.05,
165
- label="Confidence Threshold",
166
- info="Higher values = fewer but more confident detections"
167
- )
168
-
169
- thickness_slider = gr.Slider(
170
- minimum=1,
171
- maximum=6,
172
- value=3,
173
- step=1,
174
- label="Box Line Thickness"
175
- )
176
-
177
- detect_button = gr.Button("Detect Elements", variant="primary", size="lg")
178
-
179
- with gr.Column(scale=1):
180
- output_image = gr.Image(
181
- type="pil",
182
- label="Detected Elements",
183
- height=400
184
- )
185
-
186
- summary_output = gr.Markdown(label="Detection Summary")
187
-
188
-
189
- # Connect button
190
- detect_button.click(
191
- fn=detect_ui_elements,
192
- inputs=[input_image, confidence_slider, thickness_slider],
193
- outputs=[output_image, summary_output]
194
- )
195
-
196
- # Launch
197
- if __name__ == "__main__":
198
- demo.queue().launch(share=False)
 
1
+ ---
2
+ title: CU 1
3
+ emoji: 🏢
4
+ colorFrom: green
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.47.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference