muk42 commited on
Commit
2be323a
·
1 Parent(s): 74d134c

models moved to huggingface repo

Browse files
annotation_tab/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .annotation_setup import get_annotation_widgets
2
+
3
+ __all__ = ["get_annotation_widgets"]
annotation_tab/annotation_logic.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import threading
4
+ import gradio as gr
5
+
6
+
7
+ # ==== CONFIG ====
8
+ IMAGE_FOLDER = "output/blobs"
9
+ CSV_FILE = "output/manual_annotations.csv"
10
+
11
+ # ==== STATE ====
12
+ if os.path.exists(CSV_FILE):
13
+ df_annotations = pd.read_csv(CSV_FILE)
14
+ annotated_ids = set(df_annotations["blob_id"].astype(str).tolist())
15
+ else:
16
+ df_annotations = pd.DataFrame(columns=["blob_id", "human_ocr"])
17
+ df_annotations.to_csv(CSV_FILE, index=False)
18
+ annotated_ids = set()
19
+
20
+ all_images = [f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
21
+ all_images_paths = [os.path.join(IMAGE_FOLDER, f) for f in all_images]
22
+ current_index = 0
23
+
24
+ def get_current_image_path():
25
+ if 0 <= current_index < len(all_images_paths):
26
+ return all_images_paths[current_index]
27
+ return None
28
+
29
+ def is_annotated(image_path):
30
+ return os.path.basename(image_path) in annotated_ids
31
+
32
+ def get_annotation_for_image(image_path):
33
+ filename = os.path.basename(image_path)
34
+ row = df_annotations[df_annotations["blob_id"] == filename]
35
+ if not row.empty:
36
+ return row["human_ocr"].values[0]
37
+ return ""
38
+
39
+ def find_next_unannotated_index(start):
40
+ n = len(all_images_paths)
41
+ idx = start
42
+ for _ in range(n):
43
+ idx = (idx + 1) % n
44
+ if not is_annotated(all_images_paths[idx]):
45
+ return idx
46
+ return None
47
+
48
+ def save_annotation(user_text):
49
+ global df_annotations, annotated_ids
50
+ img_path = get_current_image_path()
51
+ if img_path:
52
+ filename = os.path.basename(img_path)
53
+ text_value = user_text.strip() if user_text and user_text.strip() else ""
54
+
55
+ if filename in annotated_ids:
56
+ df_annotations.loc[df_annotations["blob_id"] == filename, "human_ocr"] = text_value
57
+ else:
58
+ new_row = pd.DataFrame([{"blob_id": filename, "human_ocr": text_value}])
59
+ df_annotations = pd.concat([df_annotations, new_row], ignore_index=True)
60
+ annotated_ids.add(filename)
61
+
62
+ df_annotations.to_csv(CSV_FILE, index=False)
63
+
64
+ def save_and_next(user_text):
65
+ global current_index
66
+ if get_current_image_path() is None:
67
+ return None, "", gr.update(visible=True, value="No images available."), "No image loaded"
68
+
69
+ save_annotation(user_text)
70
+ next_idx = find_next_unannotated_index(current_index)
71
+ if next_idx is None:
72
+ return None, "", gr.update(visible=True, value="All images annotated."), ""
73
+
74
+ current_index = next_idx
75
+ img_path = get_current_image_path()
76
+ annotation = get_annotation_for_image(img_path)
77
+ return img_path, annotation, gr.update(visible=False), img_path
78
+
79
+ def previous_image():
80
+ global current_index
81
+ if len(all_images_paths) == 0:
82
+ return None, "", gr.update(visible=True, value="No images available."), "No image loaded"
83
+
84
+ current_index = (current_index - 1) % len(all_images_paths)
85
+ img_path = get_current_image_path()
86
+ annotation = get_annotation_for_image(img_path)
87
+ return img_path, annotation, gr.update(visible=False), img_path
88
+
89
+ def delete_and_next():
90
+ global current_index, all_images_paths, annotated_ids, df_annotations
91
+ img_path = get_current_image_path()
92
+ if img_path and os.path.exists(img_path):
93
+ os.remove(img_path)
94
+
95
+ filename = os.path.basename(img_path)
96
+ if filename in annotated_ids:
97
+ annotated_ids.remove(filename)
98
+ df_annotations = df_annotations[df_annotations["blob_id"] != filename]
99
+ df_annotations.to_csv(CSV_FILE, index=False)
100
+
101
+ del all_images_paths[current_index]
102
+
103
+ if len(all_images_paths) == 0:
104
+ return None, "", gr.update(visible=True, value="No images left."), "No image loaded"
105
+
106
+ current_index = min(current_index, len(all_images_paths) - 1)
107
+ img_path = get_current_image_path()
108
+ annotation = get_annotation_for_image(img_path)
109
+ return img_path, annotation, gr.update(visible=False), img_path
110
+
111
+ def shutdown():
112
+ os._exit(0)
113
+
114
+ def save_and_exit(user_text):
115
+ if get_current_image_path() is not None:
116
+ save_annotation(user_text)
117
+ threading.Timer(1, shutdown).start()
118
+ return None, "", gr.update(visible=True, value="Session closed."), ""
annotation_tab/annotation_setup.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from .annotation_logic import (
3
+ save_and_next, previous_image, delete_and_next, save_and_exit,
4
+ get_current_image_path, get_annotation_for_image
5
+ )
6
+
7
+ def get_annotation_widgets():
8
+ message = gr.Markdown("", visible=False)
9
+ image_path_display = gr.Markdown(value=get_current_image_path() or "No image loaded", elem_id="image_path")
10
+ img = gr.Image(type="filepath", value=get_current_image_path(), label="Blob")
11
+ txt = gr.Textbox(label="Transcription")
12
+ hint = gr.Markdown("*If there are multiple street names in the image, please separate them with commas.*")
13
+
14
+ with gr.Row():
15
+ prev_btn = gr.Button("Previous")
16
+ next_btn = gr.Button("Save & Next")
17
+ del_btn = gr.Button("Delete & Next", variant="stop")
18
+ exit_btn = gr.Button("Save & Exit", variant="secondary")
19
+
20
+ next_btn.click(save_and_next, inputs=txt, outputs=[img, txt, message, image_path_display])
21
+ prev_btn.click(previous_image, outputs=[img, txt, message, image_path_display])
22
+ del_btn.click(delete_and_next, outputs=[img, txt, message, image_path_display])
23
+ exit_btn.click(save_and_exit, inputs=txt, outputs=[img, txt, message, image_path_display])
24
+
25
+ return [message, image_path_display, img, txt, hint, prev_btn, next_btn, del_btn, exit_btn]
app.py CHANGED
@@ -1,23 +1,20 @@
1
  import gradio as gr
2
  import logging
 
 
3
 
4
  # setup logging
5
  logging.basicConfig(level=logging.DEBUG)
6
 
7
- import cv2
8
- import numpy as np
 
 
 
9
 
10
- def process_image_file(img_file):
11
- # img_file.name is the path
12
- img = cv2.imread(img_file.name)
13
- return f"Shape: {img.shape}"
14
 
15
- demo = gr.Interface(
16
- fn=process_image_file,
17
- inputs=gr.File(label="Select Image File"),
18
- outputs="text"
19
- )
20
 
21
 
22
 
23
- demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=False)
 
1
  import gradio as gr
2
  import logging
3
+ from inference_tab import get_inference_widgets, run_inference
4
+ from annotation_tab import get_annotation_widgets
5
 
6
  # setup logging
7
  logging.basicConfig(level=logging.DEBUG)
8
 
9
+ with gr.Blocks() as demo:
10
+ with gr.Tab("Inference"):
11
+ get_inference_widgets(run_inference)
12
+ with gr.Tab("Annotation"):
13
+ get_annotation_widgets()
14
 
 
 
 
 
15
 
16
+
17
+ demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=False)
 
 
 
18
 
19
 
20
 
 
inference_tab/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .inference_setup import get_inference_widgets
2
+ from .inference_logic import run_inference
3
+
4
+ __all__ = ["get_inference_widgets", "run_inference"]
inference_tab/inference_logic.py ADDED
@@ -0,0 +1,748 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from ultralytics import YOLO
3
+ import os
4
+ import json
5
+ from PIL import Image
6
+ from ultralytics import SAM
7
+ import cv2
8
+ import torch
9
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
10
+ import rasterio
11
+ import rasterio.features
12
+ from shapely.geometry import shape
13
+ import pandas as pd
14
+ import osmnx as ox
15
+ from osgeo import gdal
16
+ import geopandas as gpd
17
+ from rapidfuzz import process, fuzz
18
+ from huggingface_hub import hf_hub_download
19
+
20
+
21
+ yolo_weights = hf_hub_download(
22
+ repo_id="muk42/yolov9_streets",
23
+ filename="yolov9c_finetuned.pt"
24
+ )
25
+
26
+ def run_inference(image_path, gcp_path, city_name, score_th):
27
+ # ==== TEXT DETECTION ====
28
+ yield from getBBoxes(image_path)
29
+ yield from getSegments(image_path)
30
+ yield from extractSegments(image_path)
31
+
32
+ # === TEXT RECOGNITION ===
33
+ yield from blobsOCR(image_path)
34
+
35
+ # === ADD GEO DATA ===
36
+ yield from georefImg("output/mask.tif", gcp_path)
37
+ yield from extractCentroids(image_path)
38
+ yield from extractStreetNet(city_name)
39
+
40
+ # === POST OCR ===
41
+ for msg in fuzzyMatch():
42
+ if msg.endswith(".csv"):
43
+ yield f"Finished! CSV saved at {msg}", msg
44
+ else:
45
+ yield msg, None
46
+
47
+ return f"Street labels are ready for manual input.\nImage: {image_path}", None
48
+
49
+
50
+
51
+ def getBBoxes(image_path, tile_size=256, overlap=0.3, confidence_threshold=0.25):
52
+ yield f"DEBUG: Received image_path: {image_path}"
53
+ image = cv2.imread(image_path)
54
+ H, W, _ = image.shape
55
+ model = YOLO(yolo_weights)
56
+
57
+ step = int(tile_size * (1 - overlap))
58
+ all_detections = []
59
+
60
+ total_tiles = 0
61
+ # Calculate total tiles for progress reporting
62
+ for y in range(0, H, step):
63
+ for x in range(0, W, step):
64
+ # Skip small tiles at the edges
65
+ if y + tile_size > H or x + tile_size > W:
66
+ continue
67
+ total_tiles += 1
68
+
69
+ processed_tiles = 0
70
+
71
+ # Tile the image and run prediction
72
+ for y in range(0, H, step):
73
+ for x in range(0, W, step):
74
+ tile = image[y:y+tile_size, x:x+tile_size]
75
+
76
+ if tile.shape[0] < tile_size or tile.shape[1] < tile_size:
77
+ continue
78
+
79
+ results = model.predict(source=tile, imgsz=tile_size, conf=confidence_threshold, verbose=False)
80
+
81
+ for result in results:
82
+ boxes = result.boxes.xyxy.cpu().numpy()
83
+ scores = result.boxes.conf.cpu().numpy()
84
+ classes = result.boxes.cls.cpu().numpy()
85
+
86
+ for box, score, cls in zip(boxes, scores, classes):
87
+ x1, y1, x2, y2 = box
88
+ # Shift box coordinates relative to full image
89
+ x1 += x
90
+ x2 += x
91
+ y1 += y
92
+ y2 += y
93
+ all_detections.append([x1, y1, x2, y2, score, int(cls)])
94
+
95
+ processed_tiles += 1
96
+ yield f"Processed tile {processed_tiles} of {total_tiles}"
97
+
98
+ # After all tiles are processed, save detections to JSON
99
+ boxes_to_save = [
100
+ {
101
+ "bbox": [float(x1), float(y1), float(x2), float(y2)],
102
+ "score": float(conf),
103
+ "class": int(cls)
104
+ }
105
+ for x1, y1, x2, y2, conf, cls in all_detections
106
+ ]
107
+
108
+
109
+ output_path = f"output/boxes.json"
110
+ os.makedirs("output", exist_ok=True)
111
+ with open(output_path, "w") as f:
112
+ json.dump(boxes_to_save, f, indent=4)
113
+
114
+ yield f"Inference complete. Results saved to {output_path}"
115
+
116
+
117
+ def box_inside_global(box, global_box):
118
+ x1, y1, x2, y2 = box
119
+ gx1, gy1, gx2, gy2 = global_box
120
+ return (x1 >= gx1 and y1 >= gy1 and x2 <= gx2 and y2 <= gy2)
121
+
122
+ def nms_iou(box1, box2):
123
+ x1 = max(box1[0], box2[0])
124
+ y1 = max(box1[1], box2[1])
125
+ x2 = min(box1[2], box2[2])
126
+ y2 = min(box1[3], box2[3])
127
+
128
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
129
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
130
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
131
+ union_area = box1_area + box2_area - inter_area
132
+
133
+ return inter_area / union_area if union_area > 0 else 0
134
+
135
+ def non_max_suppression(boxes, scores, iou_threshold=0.5):
136
+ idxs = np.argsort(scores)[::-1]
137
+ keep = []
138
+
139
+ while len(idxs) > 0:
140
+ current = idxs[0]
141
+ keep.append(current)
142
+ idxs = idxs[1:]
143
+ idxs = np.array([i for i in idxs if nms_iou(boxes[current], boxes[i]) < iou_threshold])
144
+
145
+ return keep
146
+
147
+
148
+
149
+ def tile_image_with_overlap(image_path, tile_size=1024, overlap=256):
150
+ """Tile PDF image into overlapping RGB tiles."""
151
+ image = cv2.imread(image_path)
152
+ height, width, _ = image.shape
153
+
154
+ step = tile_size - overlap
155
+ tile_list = []
156
+
157
+ for y in range(0, height, step):
158
+ for x in range(0, width, step):
159
+ x_end = min(x + tile_size, width)
160
+ y_end = min(y + tile_size, height)
161
+ x_start = max(0, x_end - tile_size)
162
+ y_start = max(0, y_end - tile_size)
163
+
164
+ tile = image[y_start:y_end, x_start:x_end, :]
165
+ tile_list.append((tile, (x_start, y_start)))
166
+
167
+ return tile_list, image.shape
168
+
169
+
170
+ def compute_iou(box1, box2):
171
+ """Compute Intersection over Union for two boxes."""
172
+ x1 = max(box1[0], box2[0])
173
+ y1 = max(box1[1], box2[1])
174
+ x2 = min(box1[2], box2[2])
175
+ y2 = min(box1[3], box2[3])
176
+
177
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
178
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
179
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
180
+ union_area = area1 + area2 - inter_area
181
+
182
+ return inter_area / union_area if union_area > 0 else 0
183
+
184
+
185
+ def merge_boxes(boxes, iou_threshold=0.8):
186
+ """Merge overlapping boxes based on IoU."""
187
+ merged = []
188
+ used = [False] * len(boxes)
189
+
190
+ for i, box in enumerate(boxes):
191
+ if used[i]:
192
+ continue
193
+ group = [box]
194
+ used[i] = True
195
+ for j in range(i + 1, len(boxes)):
196
+ if used[j]:
197
+ continue
198
+ if compute_iou(box, boxes[j]) > iou_threshold:
199
+ group.append(boxes[j])
200
+ used[j] = True
201
+
202
+ # Merge group into one bounding box
203
+ x1 = min(b[0] for b in group)
204
+ y1 = min(b[1] for b in group)
205
+ x2 = max(b[2] for b in group)
206
+ y2 = max(b[3] for b in group)
207
+ merged.append([x1, y1, x2, y2])
208
+
209
+ return merged
210
+
211
+
212
+ def box_area(box):
213
+ return max(0, box[2] - box[0]) * max(0, box[3] - box[1])
214
+
215
+ def is_contained(box1, box2, containment_threshold=0.9):
216
+ # Check if box1 is mostly inside box2
217
+ x1 = max(box1[0], box2[0])
218
+ y1 = max(box1[1], box2[1])
219
+ x2 = min(box1[2], box2[2])
220
+ y2 = min(box1[3], box2[3])
221
+
222
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
223
+ area1 = box_area(box1)
224
+ area2 = box_area(box2)
225
+
226
+ # If intersection covers most of smaller box area, consider contained
227
+ smaller_area = min(area1, area2)
228
+ if smaller_area == 0:
229
+ return False
230
+ return (inter_area / smaller_area) >= containment_threshold
231
+
232
+ def merge_boxes_iterative(boxes, iou_threshold=0.25, containment_threshold=0.75):
233
+ boxes = boxes.copy()
234
+ changed = True
235
+
236
+ while changed:
237
+ changed = False
238
+ merged = []
239
+ used = [False] * len(boxes)
240
+
241
+ for i, box in enumerate(boxes):
242
+ if used[i]:
243
+ continue
244
+ group = [box]
245
+ used[i] = True
246
+ for j in range(i + 1, len(boxes)):
247
+ if used[j]:
248
+ continue
249
+ iou = compute_iou(box, boxes[j])
250
+ contained = is_contained(box, boxes[j], containment_threshold)
251
+ if iou > iou_threshold or contained:
252
+ group.append(boxes[j])
253
+ used[j] = True
254
+
255
+ # Merge group into one bounding box
256
+ x1 = min(b[0] for b in group)
257
+ y1 = min(b[1] for b in group)
258
+ x2 = max(b[2] for b in group)
259
+ y2 = max(b[3] for b in group)
260
+ merged.append([x1, y1, x2, y2])
261
+
262
+ if len(merged) < len(boxes):
263
+ changed = True
264
+ boxes = merged
265
+
266
+ return boxes
267
+
268
+
269
+ def get_corner_points(box):
270
+ x1, y1, x2, y2 = box
271
+ return [
272
+ [x1, y1], # top-left
273
+ [x2, y1], # top-right
274
+ [x1, y2], # bottom-left
275
+ [x2, y2], # bottom-right
276
+ ]
277
+
278
+
279
+ def sample_negative_points_outside_boxes(mask, num_points):
280
+ points = []
281
+ tries = 0
282
+ max_tries = num_points * 20 # fail-safe to avoid infinite loops
283
+ while len(points) < num_points and tries < max_tries:
284
+ x = np.random.randint(0, mask.shape[1])
285
+ y = np.random.randint(0, mask.shape[0])
286
+ if not mask[y, x]:
287
+ points.append([x, y])
288
+ tries += 1
289
+ return np.array(points)
290
+
291
+ def get_inset_corner_points(box, margin=5):
292
+ x1, y1, x2, y2 = box
293
+
294
+ # Ensure box is large enough for the margin
295
+ x1i = min(x1 + margin, x2)
296
+ y1i = min(y1 + margin, y2)
297
+ x2i = max(x2 - margin, x1)
298
+ y2i = max(y2 - margin, y1)
299
+
300
+ return [
301
+ [x1i, y1i], # top-left (inset)
302
+ [x2i, y1i], # top-right
303
+ [x1i, y2i], # bottom-left
304
+ [x2i, y2i], # bottom-right
305
+ ]
306
+
307
+
308
+ def getSegments(image_path,iou=0.5,c_th=0.75,edge_margin=10):
309
+ """
310
+ iou for combining bounding boxes
311
+ c_th defined share of the smaller box contained in the larger box for merge
312
+ edge_margin pixel margin for tiles
313
+
314
+
315
+ TBD as user input
316
+ # define global bounding box to filter out boxes outside of the main map
317
+ # [COL_MIN, ROW_MIN, COL_MAX, ROW_MAX]
318
+ #GLOBAL_BOX = [211,470,6198,4723]
319
+ """
320
+
321
+
322
+ yield f"Loading SAM model and data..."
323
+
324
+ # Load Ultralytics SAM2.1 model
325
+ model = SAM("sam2.1_l.pt")
326
+
327
+ # Load YOLO-predicted boxes
328
+ with open(f"output/boxes.json", "r") as f:
329
+ box_data = json.load(f)
330
+
331
+
332
+ # ==== PREPARE BOXES =====
333
+ yield f"Prepare bounding boxes..."
334
+ # Non-max suppression
335
+ boxes = np.array([item["bbox"] for item in box_data])
336
+ scores = np.array([item["score"] for item in box_data])
337
+ # Run NMS
338
+ keep_indices = non_max_suppression(boxes, scores, iou)
339
+ # Filter data
340
+ box_data = [box_data[i] for i in keep_indices]
341
+ # Filter boxes inside global bbox (TBD)
342
+ #box_data = [entry for entry in box_data if box_inside_global(entry["bbox"], GLOBAL_BOX)]
343
+ boxes_full = [b["bbox"] for b in box_data] # Format: [x1, y1, x2, y2]
344
+
345
+
346
+ # Tile the image
347
+ yield f"Tile the image..."
348
+ tiles, (full_height, full_width, _) = tile_image_with_overlap(image_path, tile_size=1024, overlap=50)
349
+
350
+ # Prepare full-size mask
351
+ full_mask = np.zeros((full_height, full_width), dtype=np.uint16)
352
+ instance_id = 1
353
+
354
+
355
+ yield f"Running predictions..."
356
+ for tile_idx, (tile_array, (x_offset, y_offset)) in enumerate(tiles, desc="Processing Tiles"):
357
+
358
+ tile_height, tile_width, _ = tile_array.shape
359
+
360
+ # Select boxes overlapping this tile
361
+ candidate_boxes = []
362
+ for x1, y1, x2, y2 in boxes_full:
363
+ if (x2 > x_offset) and (x1 < x_offset + tile_width) and (y2 > y_offset) and (y1 < y_offset + tile_height):
364
+ candidate_boxes.append([x1, y1, x2, y2])
365
+
366
+ if not candidate_boxes:
367
+ continue
368
+
369
+ # Merge overlapping boxes
370
+ merged_boxes = merge_boxes_iterative(candidate_boxes, iou_threshold=iou, containment_threshold=c_th)
371
+
372
+
373
+ # Adjust boxes to tile-local coordinates
374
+ local_boxes = []
375
+ for x1, y1, x2, y2 in merged_boxes:
376
+ new_x1 = max(0, x1 - x_offset)
377
+ new_y1 = max(0, y1 - y_offset)
378
+ new_x2 = min(tile_width, x2 - x_offset)
379
+ new_y2 = min(tile_height, y2 - y_offset)
380
+ local_boxes.append([new_x1, new_y1, new_x2, new_y2])
381
+
382
+
383
+ tile_h, tile_w, _ = tile_array.shape
384
+ # Filter local_boxes to remove those too close to the tile edges
385
+ filtered_local_boxes = []
386
+ for box in local_boxes:
387
+ x1, y1, x2, y2 = box
388
+ if (x1 > edge_margin and y1 > edge_margin and (tile_w - x2) > edge_margin and (tile_h - y2) > edge_margin):
389
+ filtered_local_boxes.append(box)
390
+
391
+ local_boxes = filtered_local_boxes
392
+
393
+
394
+ if not local_boxes:
395
+ continue
396
+
397
+
398
+
399
+ # centroids will be positive point prompts as they align well with the text
400
+ centroids = [((bx1 + bx2) / 2, (by1 + by2) / 2) for bx1, by1, bx2, by2 in local_boxes]
401
+
402
+
403
+
404
+ # [STRATEGY 2] Negative points are within box at the corners
405
+ #negative_points_per_box = [get_corner_points(box) for box in local_boxes]
406
+ # [STRATEGY 3] Negative points are within box at the corners with a bit of a margin
407
+ negative_points_per_box = [get_inset_corner_points(box, margin=2) for box in local_boxes]
408
+
409
+
410
+
411
+ point_coords = []
412
+ point_labels = []
413
+
414
+ for centroid, neg_points in zip(centroids, negative_points_per_box):
415
+ if not isinstance(neg_points, list):
416
+ neg_points = neg_points.tolist()
417
+ all_points = [centroid] + neg_points
418
+ all_labels = [1] + [0] * len(neg_points)
419
+
420
+ assert len(all_points) == len(all_labels), f"Point-label mismatch: {len(all_points)} vs {len(all_labels)}"
421
+
422
+
423
+ point_coords.append(all_points)
424
+ point_labels.append(all_labels)
425
+
426
+
427
+ results = model(tile_array,
428
+ bboxes=local_boxes,
429
+ points=point_coords,
430
+ labels=point_labels)
431
+
432
+
433
+
434
+
435
+ yield f"Merging segmentation masks..."
436
+ for result in results:
437
+ if result.masks is None or result.masks.data is None:
438
+ continue
439
+
440
+ # Create a copy of the tile image to overlay masks on
441
+ tile_with_masks = tile_array.copy()
442
+
443
+
444
+ for mask in result.masks.data: # each mask: (H, W)
445
+ mask_np = mask.cpu().numpy().astype(bool)
446
+
447
+ # Create a red overlay for the mask
448
+ red_overlay = np.zeros_like(tile_with_masks, dtype=np.uint8)
449
+ red_overlay[..., 0] = 255 # Red channel
450
+
451
+ alpha = 0.5 # Transparency factor
452
+
453
+ # Blend the overlay on the tile where mask is True
454
+ tile_with_masks = np.where(
455
+ mask_np[..., None],
456
+ (1 - alpha) * tile_with_masks + alpha * red_overlay,
457
+ tile_with_masks
458
+ ).astype(np.uint8)
459
+
460
+
461
+ # Paste into full-size canvas
462
+ y1 = y_offset
463
+ y2 = min(y_offset + tile_height, full_height)
464
+ x1 = x_offset
465
+ x2 = min(x_offset + tile_width, full_width)
466
+
467
+ cropped_mask = mask_np[:y2 - y1, :x2 - x1]
468
+ region = full_mask[y1:y2, x1:x2]
469
+
470
+ region[(cropped_mask) & (region == 0)] = instance_id
471
+ instance_id += 1
472
+
473
+
474
+
475
+
476
+ final_mask = Image.fromarray(full_mask)
477
+ final_mask.save(f"output/mask.tif")
478
+
479
+ yield f"Saved mask with {instance_id - 1} instances"
480
+
481
+
482
+
483
+ def extractSegments(image_path, min_size=500, margin=10):
484
+
485
+ image = cv2.imread(image_path)
486
+ mask = cv2.imread(f"output/mask.tif", cv2.IMREAD_UNCHANGED)
487
+
488
+ height, width = mask.shape[:2]
489
+
490
+ # Get unique labels (excluding background label 0)
491
+ blob_ids = np.unique(mask)
492
+ blob_ids = blob_ids[blob_ids != 0]
493
+
494
+ yield f"Found {len(blob_ids)} blobs"
495
+
496
+ for blob_id in blob_ids:
497
+ yield f"Processing blob {blob_id}..."
498
+ # Create a binary mask for the current blob
499
+ blob_mask = (mask == blob_id).astype(np.uint8)
500
+
501
+ # Skip small blobs (WxH)
502
+ if np.sum(blob_mask) < min_size:
503
+ continue
504
+
505
+ # Find bounding box of the blob
506
+ ys, xs = np.where(blob_mask)
507
+ y_min, y_max = ys.min(), ys.max() + 1
508
+ x_min, x_max = xs.min(), xs.max() + 1
509
+
510
+ # Add margin to bounding box while keeping inside image bounds
511
+ x_min = max(0, x_min - margin)
512
+ y_min = max(0, y_min - margin)
513
+ x_max = min(width, x_max + margin)
514
+ y_max = min(height, y_max + margin)
515
+
516
+ # Crop the region from original image
517
+ cropped_image = image[y_min:y_max, x_min:x_max]
518
+ cropped_mask = blob_mask[y_min:y_max, x_min:x_max]
519
+
520
+ # Apply mask to original image
521
+ if image.ndim == 3:
522
+ masked_image = cv2.bitwise_and(cropped_image, cropped_image, mask=cropped_mask)
523
+ else:
524
+ masked_image = cv2.bitwise_and(cropped_image, cropped_image, mask=cropped_mask)
525
+
526
+ # Save the masked image
527
+ output_path = os.path.join('output/blobs', f"{blob_id}.png")
528
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
529
+ cv2.imwrite(output_path, masked_image)
530
+
531
+ yield f"Done."
532
+
533
+
534
+ def blobsOCR(image_path):
535
+
536
+ # Load model + processor
537
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str")
538
+ model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets")
539
+ image_extensions = (".png")
540
+
541
+
542
+ # Device setup
543
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
544
+ model.to(device)
545
+ yield f"Running on {device}..."
546
+
547
+
548
+
549
+
550
+ # Open output file for writing
551
+ with open(f"output/ocr", "w", encoding="utf-8") as f_out:
552
+ # Process each image
553
+ image_folder = "output/blobs"
554
+ for filename in os.listdir(image_folder):
555
+ if filename.lower().endswith(image_extensions):
556
+ image_path = os.path.join(image_folder, filename)
557
+
558
+ try:
559
+ image = Image.open(image_path).convert("RGB")
560
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
561
+
562
+ generated_ids = model.generate(pixel_values)
563
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
564
+
565
+
566
+ # Write to file
567
+ name = os.path.splitext(os.path.basename(filename))[0]
568
+ f_out.write(f'{name},"{generated_text}"\n')
569
+ yield f"{filename} → {generated_text}"
570
+
571
+ except Exception as e:
572
+ yield f"Error processing {filename}: {e}"
573
+ f_out.write(f"{filename}\tERROR: {e}\n")
574
+
575
+
576
+
577
+ def extractCentroids(image_path):
578
+
579
+ with rasterio.open(f"output/georeferenced.tif") as src:
580
+ mask = src.read(1)
581
+ transform = src.transform
582
+
583
+ labels = np.unique(mask)
584
+ labels = labels[labels != 0]
585
+
586
+ data = []
587
+
588
+ # Generate polygons and their values
589
+ shapes_gen = rasterio.features.shapes(mask, mask=(mask != 0), transform=transform)
590
+
591
+ # Create a dict to collect polygons by label
592
+ polygons_by_label = {}
593
+
594
+ for geom, val in shapes_gen:
595
+ if val == 0:
596
+ continue
597
+ polygons_by_label.setdefault(val, []).append(shape(geom))
598
+
599
+ # For each label, merge polygons and get centroid
600
+ for idx, label in enumerate(labels):
601
+ yield f"Processing {idx+1} out of {len(labels)}"
602
+ polygons = polygons_by_label.get(label)
603
+ if not polygons:
604
+ continue
605
+
606
+ # Merge polygons of the same label (if multiple parts)
607
+ multi_poly = polygons[0]
608
+ for poly in polygons[1:]:
609
+ multi_poly = multi_poly.union(poly)
610
+
611
+ centroid = multi_poly.centroid
612
+ data.append({"blob_id": label, "x": centroid.x, "y": centroid.y})
613
+
614
+ df = pd.DataFrame(data)
615
+ df.to_csv(f"output/centroids.csv", index=False)
616
+ yield f"Saved centroid coordinates of {len(labels)} blobs."
617
+
618
+
619
+
620
+ def collectBlobs(image_path):
621
+ filename = os.path.splitext(os.path.basename(image_path))[0]
622
+ box_dir = "output/blobs"
623
+ # Get all filenames in the folder (only files, not subfolders)
624
+ file_names = [f for f in os.listdir(box_dir) if os.path.isfile(os.path.join(box_dir, f))]
625
+
626
+ # Save to text file
627
+ with open(f"output/{filename}_blobs.txt", "w") as f:
628
+ for name in file_names:
629
+ yield f"Writing {name}..."
630
+ f.write(name + "\n")
631
+
632
+ def img_shape(image_path):
633
+ img = cv2.imread(image_path)
634
+ return img.shape
635
+
636
+ def georefImg(image_path, gcp_path):
637
+ yield "Reading GCP CSV..."
638
+ df = pd.read_csv(gcp_path)
639
+
640
+ H,W,_ = img_shape(image_path)
641
+
642
+
643
+ # Build GCPs
644
+ gcps = []
645
+ for _, r in df.iterrows():
646
+ gcps.append(
647
+ gdal.GCP(
648
+ float(r['mapX']),
649
+ float(r['mapY']),
650
+ 0,
651
+ float(r['sourceX']),
652
+ #H-float(r['sourceY'])
653
+ abs(float(r['sourceY']))
654
+ )
655
+ )
656
+
657
+
658
+
659
+ tmp_file = "output/tmp.tif"
660
+
661
+ gdal.Translate(
662
+ tmp_file,
663
+ image_path,
664
+ format="GTiff",
665
+ GCPs=gcps,
666
+ outputSRS="EPSG:3857"
667
+ )
668
+
669
+
670
+
671
+ geo_file = "output/georeferenced.tif"
672
+ yield "Running gdalwarp..."
673
+
674
+ gdal.Warp(
675
+ geo_file,
676
+ tmp_file,
677
+ dstSRS="EPSG:3857",
678
+ resampleAlg="near",
679
+ polynomialOrder=1,
680
+ creationOptions=["COMPRESS=LZW"]
681
+ )
682
+
683
+
684
+
685
+ yield f"Done."
686
+
687
+
688
+ def extractStreetNet(city_name):
689
+ yield f"Extract OSM street network for {city_name}"
690
+ G = ox.graph_from_place(city_name, network_type='drive')
691
+ G_proj = ox.project_graph(G)
692
+ nodes, edges = ox.graph_to_gdfs(G_proj)
693
+ edges_3857 = edges.to_crs(epsg=3857)
694
+ edges_3857.to_file("output/osm_extract.geojson", driver="GeoJSON")
695
+ yield "Done."
696
+
697
+
698
+ def best_street_match(point, query_name, edges_gdf, max_distance=100):
699
+ buffer = point.buffer(max_distance)
700
+ nearby_edges = edges_gdf[edges_gdf.intersects(buffer)]
701
+
702
+ if nearby_edges.empty:
703
+ return None, 0
704
+
705
+ candidate_names = nearby_edges['name'].tolist()
706
+ best_match = process.extractOne(query_name, candidate_names, scorer=fuzz.ratio)
707
+ return best_match # (name, score, index)
708
+
709
+ def fuzzyMatch():
710
+ coords_df = pd.read_csv("output/centroids.csv")
711
+ names_df = pd.read_csv("output/ocr.csv",sep="\t",columns=[['blob_id','pred_text']])
712
+ merged_df = coords_df.merge(names_df, on="blob_id")
713
+
714
+ gdf = gpd.GeoDataFrame(
715
+ merged_df,
716
+ geometry=gpd.points_from_xy(merged_df.x, merged_df.y),
717
+ crs="EPSG:3857"
718
+ )
719
+
720
+ osm_gdf = gpd.read_file("output/osm_extract.geojson")
721
+ osm_gdf = osm_gdf[osm_gdf['name'].notnull()]
722
+
723
+ yield "Process OSM candidates..."
724
+ results = []
725
+ for _, row in gdf.iterrows():
726
+ match = best_street_match(row.geometry, row['name'], osm_gdf, max_distance=100)
727
+ if match:
728
+ results.append({
729
+ "blob_id": row.blob_id,
730
+ "x": row.x,
731
+ "y": row.y,
732
+ "blob_name": row.pred_text,
733
+ "best_osm_match": match[0],
734
+ "osm_match_score": match[1]
735
+ })
736
+ else:
737
+ results.append({
738
+ "blob_id": row.blob_id,
739
+ "x": row.x,
740
+ "y": row.y,
741
+ "blob_name": row.pred_text,
742
+ "best_osm_match": None,
743
+ "osm_match_score": 0
744
+ })
745
+
746
+ results_df = pd.DataFrame(results)
747
+ results_df.to_csv("output/street_matches.csv", index=False)
748
+ yield "output/street_matches.csv"
inference_tab/inference_setup.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def get_inference_widgets(run_inference):
4
+ image_input = gr.File(label="Select Image File")
5
+ gcp_input = gr.File(label="Select GCP Points File", file_types=[".points"])
6
+ city_name = gr.Textbox(label="Enter city name")
7
+ score_th = gr.Textbox(label="Enter a score threshold")
8
+ run_button = gr.Button("Run Inference")
9
+ output = gr.Textbox(label="Progress", lines=10, interactive=False)
10
+ download_file = gr.File(label="Download CSV")
11
+
12
+
13
+ run_button.click(
14
+ run_inference,
15
+ inputs=[image_input, gcp_input, city_name, score_th],
16
+ outputs=[output, download_file]
17
+ )
18
+
19
+ return image_input, gcp_input, city_name, score_th, run_button, output, download_file
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libgdal-dev
2
+ gdal-bin
requirements.txt CHANGED
@@ -1,3 +1,18 @@
1
- numpy
2
- gradio
3
- opencv-python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ geopandas==1.0.1
2
+ gradio==5.42.0
3
+ numpy==2.3.2
4
+ opencv_contrib_python==4.10.0.84
5
+ opencv_python==4.10.0.84
6
+ opencv_python_headless==4.10.0.84
7
+ osgeo==0.0.1
8
+ osmnx==2.0.6
9
+ pandas==2.3.1
10
+ Pillow==10.0.0
11
+ Pillow==11.3.0
12
+ rapidfuzz==3.13.0
13
+ rasterio==1.4.3
14
+ Shapely==2.1.1
15
+ torch==2.7.1+cu128
16
+ transformers==4.53.2
17
+ ultralytics==8.3.94
18
+ GDAL==3.7.0