muk42 commited on
Commit
a0f7b41
·
1 Parent(s): 78c781f

small fixes

Browse files
annotation_tab/annotation_logic.py CHANGED
@@ -7,7 +7,7 @@ from config import OUTPUT_DIR
7
  # ==== CONFIG ====
8
  IMAGE_FOLDER = os.path.join(OUTPUT_DIR,"blobs")
9
  os.makedirs(IMAGE_FOLDER, exist_ok=True)
10
- CSV_FILE = os.path.join(OUTPUT_DIR,"annotations")
11
 
12
  # ==== STATE ====
13
  if os.path.exists(CSV_FILE):
@@ -56,7 +56,7 @@ def save_annotation(user_text):
56
  if filename in annotated_ids:
57
  df_annotations.loc[df_annotations["blob_id"] == filename, "human_ocr"] = text_value
58
  else:
59
- new_row = pd.DataFrame([{"blob_id": filename, "human_ocr": text_value}])
60
  df_annotations = pd.concat([df_annotations, new_row], ignore_index=True)
61
  annotated_ids.add(filename)
62
 
@@ -117,3 +117,8 @@ def save_and_exit(user_text):
117
  save_annotation(user_text)
118
  threading.Timer(1, shutdown).start()
119
  return None, "", gr.update(visible=True, value="Session closed."), ""
 
 
 
 
 
 
7
  # ==== CONFIG ====
8
  IMAGE_FOLDER = os.path.join(OUTPUT_DIR,"blobs")
9
  os.makedirs(IMAGE_FOLDER, exist_ok=True)
10
+ CSV_FILE = os.path.join(OUTPUT_DIR,"annotations.csv")
11
 
12
  # ==== STATE ====
13
  if os.path.exists(CSV_FILE):
 
56
  if filename in annotated_ids:
57
  df_annotations.loc[df_annotations["blob_id"] == filename, "human_ocr"] = text_value
58
  else:
59
+ new_row = pd.DataFrame([{"blob_id": os.path.splitext(filename)[0], "human_ocr": text_value}])
60
  df_annotations = pd.concat([df_annotations, new_row], ignore_index=True)
61
  annotated_ids.add(filename)
62
 
 
117
  save_annotation(user_text)
118
  threading.Timer(1, shutdown).start()
119
  return None, "", gr.update(visible=True, value="Session closed."), ""
120
+
121
+
122
+ def get_current_annotations_path():
123
+ import os
124
+ return os.path.join(OUTPUT_DIR, "annotations.csv")
annotation_tab/annotation_setup.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  from .annotation_logic import (
3
  save_and_next, previous_image, delete_and_next, save_and_exit,
4
- get_current_image_path, get_annotation_for_image
5
  )
6
 
7
  def get_annotation_widgets():
@@ -16,10 +16,16 @@ def get_annotation_widgets():
16
  next_btn = gr.Button("Save & Next")
17
  del_btn = gr.Button("Delete & Next", variant="stop")
18
  exit_btn = gr.Button("Save & Exit", variant="secondary")
 
 
 
 
19
 
20
  next_btn.click(save_and_next, inputs=txt, outputs=[img, txt, message, image_path_display])
21
  prev_btn.click(previous_image, outputs=[img, txt, message, image_path_display])
22
  del_btn.click(delete_and_next, outputs=[img, txt, message, image_path_display])
23
  exit_btn.click(save_and_exit, inputs=txt, outputs=[img, txt, message, image_path_display])
 
 
24
 
25
  return [message, image_path_display, img, txt, hint, prev_btn, next_btn, del_btn, exit_btn]
 
1
  import gradio as gr
2
  from .annotation_logic import (
3
  save_and_next, previous_image, delete_and_next, save_and_exit,
4
+ get_current_image_path, get_annotation_for_image, get_current_annotations_path
5
  )
6
 
7
  def get_annotation_widgets():
 
16
  next_btn = gr.Button("Save & Next")
17
  del_btn = gr.Button("Delete & Next", variant="stop")
18
  exit_btn = gr.Button("Save & Exit", variant="secondary")
19
+ download_btn = gr.Button("Save Annotations")
20
+ with gr.Row():
21
+ download_file = gr.File(label="Download CSV", interactive=False)
22
+
23
 
24
  next_btn.click(save_and_next, inputs=txt, outputs=[img, txt, message, image_path_display])
25
  prev_btn.click(previous_image, outputs=[img, txt, message, image_path_display])
26
  del_btn.click(delete_and_next, outputs=[img, txt, message, image_path_display])
27
  exit_btn.click(save_and_exit, inputs=txt, outputs=[img, txt, message, image_path_display])
28
+ download_btn.click(lambda: get_current_annotations_path(),outputs=download_file)
29
+
30
 
31
  return [message, image_path_display, img, txt, hint, prev_btn, next_btn, del_btn, exit_btn]
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import gradio as gr
2
  import logging
3
  from inference_tab import get_inference_widgets, run_inference
@@ -6,14 +9,17 @@ from annotation_tab import get_annotation_widgets
6
  # setup logging
7
  logging.basicConfig(level=logging.DEBUG)
8
 
 
9
  with gr.Blocks() as demo:
10
  with gr.Tab("Inference"):
11
  get_inference_widgets(run_inference)
12
  with gr.Tab("Annotation"):
13
  get_annotation_widgets()
14
 
 
 
15
 
16
-
17
  demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=False)
18
 
19
 
 
1
+ # [DEBUG]
2
+ #from osgeo import gdal
3
+
4
  import gradio as gr
5
  import logging
6
  from inference_tab import get_inference_widgets, run_inference
 
9
  # setup logging
10
  logging.basicConfig(level=logging.DEBUG)
11
 
12
+
13
  with gr.Blocks() as demo:
14
  with gr.Tab("Inference"):
15
  get_inference_widgets(run_inference)
16
  with gr.Tab("Annotation"):
17
  get_annotation_widgets()
18
 
19
+ # [DEBUG]
20
+ #demo.launch(inbrowser=True)
21
 
22
+ # [PROD]
23
  demo.launch(server_name="0.0.0.0", server_port=7860, inbrowser=False)
24
 
25
 
inference_tab/inference_logic.py CHANGED
@@ -46,15 +46,13 @@ def run_inference(image_path, gcp_path, city_name, score_th):
46
  yield msg, None
47
 
48
  # === POST OCR ===
49
- for msg in fuzzyMatch():
50
  if msg.endswith(".csv"):
51
- yield f"Finished! CSV saved at {msg}", msg
52
  else:
53
  yield msg, None
54
 
55
- return f"Street labels are ready for manual input.\nImage: {image_path}", None
56
-
57
-
58
 
59
  def getBBoxes(image_path, tile_size=256, overlap=0.3, confidence_threshold=0.25):
60
  yield f"DEBUG: Received image_path: {image_path}"
@@ -542,7 +540,7 @@ def extractSegments(image_path, min_size=500, margin=10):
542
 
543
 
544
  def blobsOCR(image_path):
545
-
546
  # Load model + processor
547
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str")
548
  model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets")
@@ -558,7 +556,7 @@ def blobsOCR(image_path):
558
 
559
 
560
  # Open output file for writing
561
- OCR_PATH = os.path.join(OUTPUT_DIR,"ocr")
562
  with open(OCR_PATH, "w", encoding="utf-8") as f_out:
563
  # Process each image
564
  image_folder = os.path.join(OUTPUT_DIR,"blobs")
@@ -684,15 +682,21 @@ def georefImg(image_path, gcp_path):
684
 
685
 
686
 
687
- yield f"Done."
688
 
689
 
690
  def extractStreetNet(city_name):
691
  yield f"Extract OSM street network for {city_name}"
692
- G = ox.graph_from_place(city_name, network_type='drive')
693
  G_proj = ox.project_graph(G)
694
- nodes, edges = ox.graph_to_gdfs(G_proj)
695
  edges_3857 = edges.to_crs(epsg=3857)
 
 
 
 
 
 
696
  OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
697
  edges_3857.to_file(OSM_PATH, driver="GeoJSON")
698
  yield "Done."
@@ -709,11 +713,13 @@ def best_street_match(point, query_name, edges_gdf, max_distance=100):
709
  best_match = process.extractOne(query_name, candidate_names, scorer=fuzz.ratio)
710
  return best_match # (name, score, index)
711
 
712
- def fuzzyMatch():
713
  COORD_PATH=os.path.join(OUTPUT_DIR,"centroids.csv")
714
  OCR_PATH=os.path.join(OUTPUT_DIR,"ocr.csv")
715
  coords_df = pd.read_csv(COORD_PATH)
716
- names_df = pd.read_csv(OCR_PATH,sep="\t",columns=[['blob_id','pred_text']])
 
 
717
  merged_df = coords_df.merge(names_df, on="blob_id")
718
 
719
  gdf = gpd.GeoDataFrame(
@@ -723,13 +729,12 @@ def fuzzyMatch():
723
  )
724
 
725
  OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
726
- osm_gdf = gpd.read_file(OSM_PATH)
727
- osm_gdf = osm_gdf[osm_gdf['name'].notnull()]
728
 
729
  yield "Process OSM candidates..."
730
  results = []
731
  for _, row in gdf.iterrows():
732
- match = best_street_match(row.geometry, row['name'], osm_gdf, max_distance=100)
733
  if match:
734
  results.append({
735
  "blob_id": row.blob_id,
@@ -752,4 +757,13 @@ def fuzzyMatch():
752
  results_df = pd.DataFrame(results)
753
  RES_PATH=os.path.join(OUTPUT_DIR,"street_matches.csv")
754
  results_df.to_csv(RES_PATH, index=False)
755
- yield f"{RES_PATH}/street_matches.csv"
 
 
 
 
 
 
 
 
 
 
46
  yield msg, None
47
 
48
  # === POST OCR ===
49
+ for msg in fuzzyMatch(score_th):
50
  if msg.endswith(".csv"):
51
+ yield f"Finished! CSV saved at {msg}. Street labels are ready for manual input.", msg
52
  else:
53
  yield msg, None
54
 
55
+
 
 
56
 
57
  def getBBoxes(image_path, tile_size=256, overlap=0.3, confidence_threshold=0.25):
58
  yield f"DEBUG: Received image_path: {image_path}"
 
540
 
541
 
542
  def blobsOCR(image_path):
543
+ yield "Load OCR model.."
544
  # Load model + processor
545
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-str")
546
  model = VisionEncoderDecoderModel.from_pretrained("muk42/trocr_streets")
 
556
 
557
 
558
  # Open output file for writing
559
+ OCR_PATH = os.path.join(OUTPUT_DIR,"ocr.csv")
560
  with open(OCR_PATH, "w", encoding="utf-8") as f_out:
561
  # Process each image
562
  image_folder = os.path.join(OUTPUT_DIR,"blobs")
 
682
 
683
 
684
 
685
+ yield "Done."
686
 
687
 
688
  def extractStreetNet(city_name):
689
  yield f"Extract OSM street network for {city_name}"
690
+ G = ox.graph_from_place(city_name, network_type='all')
691
  G_proj = ox.project_graph(G)
692
+ edges = ox.graph_to_gdfs(G_proj, nodes=False, edges=True, fill_edge_geometry=True)
693
  edges_3857 = edges.to_crs(epsg=3857)
694
+ edges_3857 = edges_3857[['osmid','name', 'geometry']]
695
+ edges_3857 = edges_3857[edges_3857['name'].notnull()]
696
+
697
+ edges_3857['name'] = edges_3857['name'].apply(
698
+ lambda x: x[0] if isinstance(x, list) and len(x) > 0 else x)
699
+
700
  OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
701
  edges_3857.to_file(OSM_PATH, driver="GeoJSON")
702
  yield "Done."
 
713
  best_match = process.extractOne(query_name, candidate_names, scorer=fuzz.ratio)
714
  return best_match # (name, score, index)
715
 
716
+ def fuzzyMatch(score_th):
717
  COORD_PATH=os.path.join(OUTPUT_DIR,"centroids.csv")
718
  OCR_PATH=os.path.join(OUTPUT_DIR,"ocr.csv")
719
  coords_df = pd.read_csv(COORD_PATH)
720
+ names_df = pd.read_csv(OCR_PATH,
721
+ names=['blob_id','pred_text'],
722
+ dtype={"blob_id": "int64", "pred_text": "string"})
723
  merged_df = coords_df.merge(names_df, on="blob_id")
724
 
725
  gdf = gpd.GeoDataFrame(
 
729
  )
730
 
731
  OSM_PATH=os.path.join(OUTPUT_DIR,"osm_extract.geojson")
732
+ osm_gdf = gpd.read_file(OSM_PATH,dtype={"name": "str"})
 
733
 
734
  yield "Process OSM candidates..."
735
  results = []
736
  for _, row in gdf.iterrows():
737
+ match = best_street_match(row.geometry, row['pred_text'], osm_gdf, max_distance=100)
738
  if match:
739
  results.append({
740
  "blob_id": row.blob_id,
 
757
  results_df = pd.DataFrame(results)
758
  RES_PATH=os.path.join(OUTPUT_DIR,"street_matches.csv")
759
  results_df.to_csv(RES_PATH, index=False)
760
+
761
+ # remove street labels from blobs folder that are more than or equal to score threshold
762
+ manual_df = results_df[results_df['osm_match_score'] >= int(score_th)]
763
+ for blob_id in manual_df['blob_id']:
764
+ file_path = os.path.join(OUTPUT_DIR,"blobs",f"{blob_id}.png")
765
+
766
+ if os.path.exists(file_path):
767
+ os.remove(file_path)
768
+
769
+ yield f"{RES_PATH}"
requirements.txt CHANGED
@@ -13,4 +13,5 @@ Shapely==2.1.1
13
  torch==2.7.1
14
  transformers==4.53.2
15
  ultralytics==8.3.94
 
16
  GDAL==3.6.2
 
13
  torch==2.7.1
14
  transformers==4.53.2
15
  ultralytics==8.3.94
16
+ huggingface_hub[hf_xet]
17
  GDAL==3.6.2