small fixes
Browse files- annotation_tab/annotation_logic.py +7 -2
- annotation_tab/annotation_setup.py +7 -1
- app.py +7 -1
- inference_tab/inference_logic.py +30 -16
- requirements.txt +1 -0
annotation_tab/annotation_logic.py
CHANGED
|
@@ -7,7 +7,7 @@ from config import OUTPUT_DIR
|
|
| 7 |
# ==== CONFIG ====
|
| 8 |
IMAGE_FOLDER = os.path.join(OUTPUT_DIR,"blobs")
|
| 9 |
os.makedirs(IMAGE_FOLDER, exist_ok=True)
|
| 10 |
-
CSV_FILE = os.path.join(OUTPUT_DIR,"annotations")
|
| 11 |
|
| 12 |
# ==== STATE ====
|
| 13 |
if os.path.exists(CSV_FILE):
|
|
@@ -56,7 +56,7 @@ def save_annotation(user_text):
|
|
| 56 |
if filename in annotated_ids:
|
| 57 |
df_annotations.loc[df_annotations["blob_id"] == filename, "human_ocr"] = text_value
|
| 58 |
else:
|
| 59 |
-
new_row = pd.DataFrame([{"blob_id": filename, "human_ocr": text_value}])
|
| 60 |
df_annotations = pd.concat([df_annotations, new_row], ignore_index=True)
|
| 61 |
annotated_ids.add(filename)
|
| 62 |
|
|
@@ -117,3 +117,8 @@ def save_and_exit(user_text):
|
|
| 117 |
save_annotation(user_text)
|
| 118 |
threading.Timer(1, shutdown).start()
|
| 119 |
return None, "", gr.update(visible=True, value="Session closed."), ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# ==== CONFIG ====
|
| 8 |
IMAGE_FOLDER = os.path.join(OUTPUT_DIR,"blobs")
|
| 9 |
os.makedirs(IMAGE_FOLDER, exist_ok=True)
|
| 10 |
+
CSV_FILE = os.path.join(OUTPUT_DIR,"annotations.csv")
|
| 11 |
|
| 12 |
# ==== STATE ====
|
| 13 |
if os.path.exists(CSV_FILE):
|
|
|
|
| 56 |
if filename in annotated_ids:
|
| 57 |
df_annotations.loc[df_annotations["blob_id"] == filename, "human_ocr"] = text_value
|
| 58 |
else:
|
| 59 |
+
new_row = pd.DataFrame([{"blob_id": os.path.splitext(filename)[0], "human_ocr": text_value}])
|
| 60 |
df_annotations = pd.concat([df_annotations, new_row], ignore_index=True)
|
| 61 |
annotated_ids.add(filename)
|
| 62 |
|
|
|
|
| 117 |
save_annotation(user_text)
|
| 118 |
threading.Timer(1, shutdown).start()
|
| 119 |
return None, "", gr.update(visible=True, value="Session closed."), ""
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_current_annotations_path():
|
| 123 |
+
import os
|
| 124 |
+
return os.path.join(OUTPUT_DIR, "annotations.csv")
|
annotation_tab/annotation_setup.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from .annotation_logic import (
|
| 3 |
save_and_next, previous_image, delete_and_next, save_and_exit,
|
| 4 |
-
get_current_image_path, get_annotation_for_image
|
| 5 |
)
|
| 6 |
|
| 7 |
def get_annotation_widgets():
|
|
@@ -16,10 +16,16 @@ def get_annotation_widgets():
|
|
| 16 |
next_btn = gr.Button("Save & Next")
|
| 17 |
del_btn = gr.Button("Delete & Next", variant="stop")
|
| 18 |
exit_btn = gr.Button("Save & Exit", variant="secondary")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
next_btn.click(save_and_next, inputs=txt, outputs=[img, txt, message, image_path_display])
|
| 21 |
prev_btn.click(previous_image, outputs=[img, txt, message, image_path_display])
|
| 22 |
del_btn.click(delete_and_next, outputs=[img, txt, message, image_path_display])
|
| 23 |
exit_btn.click(save_and_exit, inputs=txt, outputs=[img, txt, message, image_path_display])
|
|
|
|
|
|
|
| 24 |
|
| 25 |
return [message, image_path_display, img, txt, hint, prev_btn, next_btn, del_btn, exit_btn]
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from .annotation_logic import (
|
| 3 |
save_and_next, previous_image, delete_and_next, save_and_exit,
|
| 4 |
+
get_current_image_path, get_annotation_for_image, get_current_annotations_path
|
| 5 |
)
|
| 6 |
|
| 7 |
def get_annotation_widgets():
|
|
|
|
| 16 |
next_btn = gr.Button("Save & Next")
|
| 17 |
del_btn = gr.Button("Delete & Next", variant="stop")
|
| 18 |
exit_btn = gr.Button("Save & Exit", variant="secondary")
|
| 19 |
+
download_btn = gr.Button("Save Annotations")
|
| 20 |
+
with gr.Row():
|
| 21 |
+
download_file = gr.File(label="Download CSV", interactive=False)
|
| 22 |
+
|
| 23 |
|
| 24 |
next_btn.click(save_and_next, inputs=txt, outputs=[img, txt, message, image_path_display])
|
| 25 |
prev_btn.click(previous_image, outputs=[img, txt, message, image_path_display])
|
| 26 |
del_btn.click(delete_and_next, outputs=[img, txt, message, image_path_display])
|
| 27 |
exit_btn.click(save_and_exit, inputs=txt, outputs=[img, txt, message, image_path_display])
|
| 28 |
+
download_btn.click(lambda: get_current_annotations_path(),outputs=download_file)
|
| 29 |
+
|
| 30 |
|
| 31 |
return [message, image_path_display, img, txt, hint, prev_btn, next_btn, del_btn, exit_btn]
|
app.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import logging
|
| 3 |
from inference_tab import get_inference_widgets, run_inference
|
|
@@ -6,14 +9,17 @@ from annotation_tab import get_annotation_widgets
|
|
| 6 |
# setup logging
|
| 7 |
logging.basicConfig(level=logging.DEBUG)
|
| 8 |
|
|
|
|
| 9 |
with gr.Blocks() as demo:
|
| 10 |
with gr.Tab("Inference"):
|
| 11 |
get_inference_widgets(run_inference)
|
| 12 |
with gr.Tab("Annotation"):
|
| 13 |
get_annotation_widgets()
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=False)
|
| 18 |
|
| 19 |
|
|
|
|
| 1 |
+
# [DEBUG]
|
| 2 |
+
#from osgeo import gdal
|
| 3 |
+
|
| 4 |
import gradio as gr
|
| 5 |
import logging
|
| 6 |
from inference_tab import get_inference_widgets, run_inference
|
|
|
|
| 9 |
# setup logging
|
| 10 |
logging.basicConfig(level=logging.DEBUG)
|
| 11 |
|
| 12 |
+
|
| 13 |
with gr.Blocks() as demo:
|
| 14 |
with gr.Tab("Inference"):
|
| 15 |
get_inference_widgets(run_inference)
|
| 16 |
with gr.Tab("Annotation"):
|
| 17 |
get_annotation_widgets()
|
| 18 |
|
| 19 |
+
# [DEBUG]
|
| 20 |
+
#demo.launch(inbrowser=True)
|
| 21 |
|
| 22 |
+
# [PROD]
|
| 23 |
demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=False)
|
| 24 |
|
| 25 |
|
inference_tab/inference_logic.py
CHANGED
|
@@ -46,15 +46,13 @@ def run_inference(image_path, gcp_path, city_name, score_th):
|
|
| 46 |
yield msg, None
|
| 47 |
|
| 48 |
# === POST OCR ===
|
| 49 |
-
for msg in fuzzyMatch():
|
| 50 |
if msg.endswith(".csv"):
|
| 51 |
-
yield f"Finished! CSV saved at {msg}", msg
|
| 52 |
else:
|
| 53 |
yield msg, None
|
| 54 |
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
def getBBoxes(image_path, tile_size=256, overlap=0.3, confidence_threshold=0.25):
|
| 60 |
yield f"DEBUG: Received image_path: {image_path}"
|
|
@@ -542,7 +540,7 @@ def extractSegments(image_path, min_size=500, margin=10):
|
|
| 542 |
|
| 543 |
|
| 544 |
def blobsOCR(image_path):
|
| 545 |
-
|
| 546 |
# Load model + processor
|
| 547 |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str")
|
| 548 |
model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets")
|
|
@@ -558,7 +556,7 @@ def blobsOCR(image_path):
|
|
| 558 |
|
| 559 |
|
| 560 |
# Open output file for writing
|
| 561 |
-
OCR_PATH = os.path.join(OUTPUT_DIR,"ocr")
|
| 562 |
with open(OCR_PATH, "w", encoding="utf-8") as f_out:
|
| 563 |
# Process each image
|
| 564 |
image_folder = os.path.join(OUTPUT_DIR,"blobs")
|
|
@@ -684,15 +682,21 @@ def georefImg(image_path, gcp_path):
|
|
| 684 |
|
| 685 |
|
| 686 |
|
| 687 |
-
yield
|
| 688 |
|
| 689 |
|
| 690 |
def extractStreetNet(city_name):
|
| 691 |
yield f"Extract OSM street network for {city_name}"
|
| 692 |
-
G = ox.graph_from_place(city_name, network_type='
|
| 693 |
G_proj = ox.project_graph(G)
|
| 694 |
-
|
| 695 |
edges_3857 = edges.to_crs(epsg=3857)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
|
| 697 |
edges_3857.to_file(OSM_PATH, driver="GeoJSON")
|
| 698 |
yield "Done."
|
|
@@ -709,11 +713,13 @@ def best_street_match(point, query_name, edges_gdf, max_distance=100):
|
|
| 709 |
best_match = process.extractOne(query_name, candidate_names, scorer=fuzz.ratio)
|
| 710 |
return best_match # (name, score, index)
|
| 711 |
|
| 712 |
-
def fuzzyMatch():
|
| 713 |
COORD_PATH=os.path.join(OUTPUT_DIR,"centroids.csv")
|
| 714 |
OCR_PATH=os.path.join(OUTPUT_DIR,"ocr.csv")
|
| 715 |
coords_df = pd.read_csv(COORD_PATH)
|
| 716 |
-
names_df = pd.read_csv(OCR_PATH,
|
|
|
|
|
|
|
| 717 |
merged_df = coords_df.merge(names_df, on="blob_id")
|
| 718 |
|
| 719 |
gdf = gpd.GeoDataFrame(
|
|
@@ -723,13 +729,12 @@ def fuzzyMatch():
|
|
| 723 |
)
|
| 724 |
|
| 725 |
OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
|
| 726 |
-
osm_gdf = gpd.read_file(OSM_PATH)
|
| 727 |
-
osm_gdf = osm_gdf[osm_gdf['name'].notnull()]
|
| 728 |
|
| 729 |
yield "Process OSM candidates..."
|
| 730 |
results = []
|
| 731 |
for _, row in gdf.iterrows():
|
| 732 |
-
match = best_street_match(row.geometry, row['
|
| 733 |
if match:
|
| 734 |
results.append({
|
| 735 |
"blob_id": row.blob_id,
|
|
@@ -752,4 +757,13 @@ def fuzzyMatch():
|
|
| 752 |
results_df = pd.DataFrame(results)
|
| 753 |
RES_PATH=os.path.join(OUTPUT_DIR,"street_matches.csv")
|
| 754 |
results_df.to_csv(RES_PATH, index=False)
|
| 755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
yield msg, None
|
| 47 |
|
| 48 |
# === POST OCR ===
|
| 49 |
+
for msg in fuzzyMatch(score_th):
|
| 50 |
if msg.endswith(".csv"):
|
| 51 |
+
yield f"Finished! CSV saved at {msg}. Street labels are ready for manual input.", msg
|
| 52 |
else:
|
| 53 |
yield msg, None
|
| 54 |
|
| 55 |
+
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def getBBoxes(image_path, tile_size=256, overlap=0.3, confidence_threshold=0.25):
|
| 58 |
yield f"DEBUG: Received image_path: {image_path}"
|
|
|
|
| 540 |
|
| 541 |
|
| 542 |
def blobsOCR(image_path):
|
| 543 |
+
yield "Load OCR model.."
|
| 544 |
# Load model + processor
|
| 545 |
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str")
|
| 546 |
model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets")
|
|
|
|
| 556 |
|
| 557 |
|
| 558 |
# Open output file for writing
|
| 559 |
+
OCR_PATH = os.path.join(OUTPUT_DIR,"ocr.csv")
|
| 560 |
with open(OCR_PATH, "w", encoding="utf-8") as f_out:
|
| 561 |
# Process each image
|
| 562 |
image_folder = os.path.join(OUTPUT_DIR,"blobs")
|
|
|
|
| 682 |
|
| 683 |
|
| 684 |
|
| 685 |
+
yield "Done."
|
| 686 |
|
| 687 |
|
| 688 |
def extractStreetNet(city_name):
|
| 689 |
yield f"Extract OSM street network for {city_name}"
|
| 690 |
+
G = ox.graph_from_place(city_name, network_type='all')
|
| 691 |
G_proj = ox.project_graph(G)
|
| 692 |
+
edges = ox.graph_to_gdfs(G_proj, nodes=False, edges=True, fill_edge_geometry=True)
|
| 693 |
edges_3857 = edges.to_crs(epsg=3857)
|
| 694 |
+
edges_3857 = edges_3857[['osmid','name', 'geometry']]
|
| 695 |
+
edges_3857 = edges_3857[edges_3857['name'].notnull()]
|
| 696 |
+
|
| 697 |
+
edges_3857['name'] = edges_3857['name'].apply(
|
| 698 |
+
lambda x: x[0] if isinstance(x, list) and len(x) > 0 else x)
|
| 699 |
+
|
| 700 |
OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
|
| 701 |
edges_3857.to_file(OSM_PATH, driver="GeoJSON")
|
| 702 |
yield "Done."
|
|
|
|
| 713 |
best_match = process.extractOne(query_name, candidate_names, scorer=fuzz.ratio)
|
| 714 |
return best_match # (name, score, index)
|
| 715 |
|
| 716 |
+
def fuzzyMatch(score_th):
|
| 717 |
COORD_PATH=os.path.join(OUTPUT_DIR,"centroids.csv")
|
| 718 |
OCR_PATH=os.path.join(OUTPUT_DIR,"ocr.csv")
|
| 719 |
coords_df = pd.read_csv(COORD_PATH)
|
| 720 |
+
names_df = pd.read_csv(OCR_PATH,
|
| 721 |
+
names=['blob_id','pred_text'],
|
| 722 |
+
dtype={"blob_id": "int64", "pred_text": "string"})
|
| 723 |
merged_df = coords_df.merge(names_df, on="blob_id")
|
| 724 |
|
| 725 |
gdf = gpd.GeoDataFrame(
|
|
|
|
| 729 |
)
|
| 730 |
|
| 731 |
OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
|
| 732 |
+
osm_gdf = gpd.read_file(OSM_PATH,dtype={"name": "str"})
|
|
|
|
| 733 |
|
| 734 |
yield "Process OSM candidates..."
|
| 735 |
results = []
|
| 736 |
for _, row in gdf.iterrows():
|
| 737 |
+
match = best_street_match(row.geometry, row['pred_text'], osm_gdf, max_distance=100)
|
| 738 |
if match:
|
| 739 |
results.append({
|
| 740 |
"blob_id": row.blob_id,
|
|
|
|
| 757 |
results_df = pd.DataFrame(results)
|
| 758 |
RES_PATH=os.path.join(OUTPUT_DIR,"street_matches.csv")
|
| 759 |
results_df.to_csv(RES_PATH, index=False)
|
| 760 |
+
|
| 761 |
+
# remove street labels from blobs folder that are more than or equal to score threshold
|
| 762 |
+
manual_df = results_df[results_df['osm_match_score'] >= int(score_th)]
|
| 763 |
+
for blob_id in manual_df['blob_id']:
|
| 764 |
+
file_path = os.path.join(OUTPUT_DIR,"blobs",f"{blob_id}.png")
|
| 765 |
+
|
| 766 |
+
if os.path.exists(file_path):
|
| 767 |
+
os.remove(file_path)
|
| 768 |
+
|
| 769 |
+
yield f"{RES_PATH}"
|
requirements.txt
CHANGED
|
@@ -13,4 +13,5 @@ Shapely==2.1.1
|
|
| 13 |
torch==2.7.1
|
| 14 |
transformers==4.53.2
|
| 15 |
ultralytics==8.3.94
|
|
|
|
| 16 |
GDAL==3.6.2
|
|
|
|
| 13 |
torch==2.7.1
|
| 14 |
transformers==4.53.2
|
| 15 |
ultralytics==8.3.94
|
| 16 |
+
huggingface_hub[hf_xet]
|
| 17 |
GDAL==3.6.2
|