AnjaliSarawgi commited on
Commit
47e7fc8
·
1 Parent(s): e940076
Files changed (1) hide show
  1. app.py +216 -136
app.py CHANGED
@@ -1,11 +1,13 @@
1
  """
2
  Gradio application for performing OCR on scanned Old Nepali documents.
 
3
  This script is a Gradio port of a Streamlit application originally built
4
  to visualize and edit OCR output. It loads a pre‑trained model for
5
  sequence decoding, accepts an input image (and optional segmentation
6
  XML in ALTO format), performs OCR on segmented lines, highlights tokens
7
  with low confidence and offers downloads of both the raw text and per
8
  token scores.
 
9
  The heavy lifting functions (model loading, pre‑processing, inference
10
  and highlighting) are adapted directly from the Streamlit version. The
11
  UI has been simplified for Gradio: users upload an image and optional
@@ -13,9 +15,12 @@ XML file, choose preprocessing steps and a highlight metric, then run
13
  OCR. The results are displayed alongside the overlaid segmentation
14
  boxes and a table of token scores. An editable textbox lets users
15
  modify the predicted text before downloading it.
 
16
  To run this app locally, install gradio (`pip install gradio`) and
17
  execute this script with Python:
 
18
  python gradio_app.py
 
19
  """
20
 
21
  import io
@@ -88,6 +93,7 @@ FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")
88
  @lru_cache(maxsize=1)
89
  def load_model():
90
  """Load the OCR model, tokenizer and feature extractor.
 
91
  Returns
92
  -------
93
  model : VisionEncoderDecoderModel
@@ -116,6 +122,20 @@ def load_model():
116
  #
117
 
118
  def clean_text(text: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  text = unicodedata.normalize("NFC", text)
120
  text = CLEANUP.sub("", text)
121
  return re.sub(r"\s+", "", text)
@@ -123,12 +143,14 @@ def clean_text(text: str) -> str:
123
 
124
  def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
125
  """Resize the image so that its longest side equals max_side.
 
126
  Parameters
127
  ----------
128
  image : PIL.Image
129
  Input image.
130
  max_side : int, optional
131
  Maximum allowed size for the longest side of the image.
 
132
  Returns
133
  -------
134
  PIL.Image
@@ -151,6 +173,7 @@ def get_amp_ctx():
151
  #
152
  def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
153
  """Parse ALTO or PAGE XML to extract bounding boxes.
 
154
  Parameters
155
  ----------
156
  xml_bytes : bytes
@@ -161,6 +184,7 @@ def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tupl
161
  If provided, image_size=(width, height) allows rescaling
162
  coordinates to match the actual image. ALTO files often store
163
  absolute page sizes that differ from the image dimensions.
 
164
  Returns
165
  -------
166
  list of dict
@@ -308,12 +332,14 @@ def sort_boxes_reading_order(boxes, y_tol: int = 10):
308
 
309
  def draw_boxes(img: Image.Image, boxes):
310
  """Overlay semi‑transparent red polygons or rectangles on an image.
 
311
  Parameters
312
  ----------
313
  img : PIL.Image
314
  The base image.
315
  boxes : list of dict
316
  Segmentation boxes with either 'points' or 'bbox' keys.
 
317
  Returns
318
  -------
319
  PIL.Image
@@ -343,19 +369,12 @@ def draw_boxes(img: Image.Image, boxes):
343
  # ----------------------------------------------------------------------
344
  # OCR inference per line
345
  #
346
- # def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
347
- def predict_and_score_once(
348
- image: Image.Image,
349
- model,
350
- tokenizer,
351
- feature_extractor,
352
- device,
353
- line_id: int = 1,
354
- topk: int = TOPK,
355
- ):
356
  """Run the model on a single cropped line and return predictions and scores.
 
357
  This helper wraps the model.generate call to obtain per‑token
358
  probabilities and derives a DataFrame summarizing each decoding step.
 
359
  Parameters
360
  ----------
361
  image : PIL.Image
@@ -364,6 +383,7 @@ def predict_and_score_once(
364
  Identifier used in the output DataFrame.
365
  topk : int, optional
366
  Number of alternative tokens to keep for each decoding position.
 
367
  Returns
368
  -------
369
  decoded_text : str
@@ -373,7 +393,7 @@ def predict_and_score_once(
373
  columns: line_id, seq_pos, token_id, token, confidence,
374
  rel_prob, entropy, gap12, alt_tokens, alt_probs.
375
  """
376
- # model, tokenizer, feature_extractor, device = load_model()
377
  img = prepare_image(image)
378
  pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
379
  amp_ctx = get_amp_ctx()
@@ -381,7 +401,7 @@ def predict_and_score_once(
381
  try:
382
  out = model.generate(
383
  pixel_values,
384
- max_length=100,
385
  num_beams=5,
386
  do_sample=False,
387
  return_dict_in_generate=True,
@@ -394,7 +414,7 @@ def predict_and_score_once(
394
  if "out of memory" in str(e).lower():
395
  out = model.generate(
396
  pixel_values,
397
- max_length=100,
398
  num_beams=1,
399
  do_sample=False,
400
  return_dict_in_generate=True,
@@ -510,15 +530,16 @@ def parse_alt_tokens(s: str):
510
  return [(t if t is not None else "") for t in (s or "").split("|")]
511
 
512
 
513
-
514
  def highlight_tokens_with_tooltips(
515
  line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
516
  ) -> str:
517
  """Insert HTML spans around tokens whose chosen metric exceeds threshold.
 
518
  The metric column can be "rel_prob" (relative probability) or
519
  "entropy". Tokens with a value strictly greater than red_threshold
520
  will be wrapped in a span with a tooltip listing alternative
521
  predictions and their probabilities.
 
522
  Parameters
523
  ----------
524
  line_text : str
@@ -529,6 +550,7 @@ def highlight_tokens_with_tooltips(
529
  Values above this threshold will be highlighted.
530
  metric_column : str
531
  Column name in df_tok used for thresholding.
 
532
  Returns
533
  -------
534
  str
@@ -610,174 +632,232 @@ def _html_escape(s: str) -> str:
610
  # ----------------------------------------------------------------------
611
  # Main OCR wrapper for Gradio
612
  #
613
- import tempfile
 
 
 
 
 
 
 
614
 
615
- def run_ocr(image, xml_file, apply_gray, apply_bin, highlight_metric):
616
- if image is None:
617
- return None, "", None, "", None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  pil_img = Image.fromarray(image).convert("RGB")
 
620
  if apply_gray:
621
  pil_img = pil_img.convert("L").convert("RGB")
622
  if apply_bin:
623
  img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
624
  _, bin_img = cv2.threshold(img_cv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
625
  pil_img = Image.fromarray(bin_img).convert("RGB")
626
-
627
  boxes = []
628
- if xml_file:
629
- if isinstance(xml_file, (bytes, bytearray)):
630
- xml_bytes = bytes(xml_file)
631
- elif isinstance(xml_file, str):
632
- with open(xml_file, "rb") as f:
633
- xml_bytes = f.read()
634
- elif hasattr(xml_file, "read"):
635
- xml_bytes = xml_file.read()
636
- elif isinstance(xml_file, dict) and "data" in xml_file:
637
- xml_bytes = xml_file.get("data")
638
- else:
639
- xml_bytes = None
640
-
641
- if xml_bytes:
642
  boxes = parse_boxes_from_xml(xml_bytes, level="line", image_size=pil_img.size)
643
  boxes = sort_boxes_reading_order(boxes)[:MAX_LINES]
644
-
 
 
645
  dfs = []
646
- parts = []
647
- plain_lines = []
648
- model, tokenizer, feature_extractor, device = load_model()
649
-
650
-
651
  if boxes:
 
652
  for idx, b in enumerate(boxes, 1):
653
- x1, y1, x2, y2 = b["bbox"]
654
- crop = pil_img.crop((x1, y1, x2, y2))
655
- # seg_text, df_tok = predict_and_score_once(crop, line_id=idx, topk=TOPK)
656
- seg_text, df_tok = predict_and_score_once(
657
- crop, model, tokenizer, feature_extractor, device, line_id=idx, topk=TOPK
658
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  seg_text = clean_text(seg_text)
660
- plain_lines.append(seg_text)
661
  if highlight_metric == "Relative Probability":
662
- seg_html = highlight_tokens_with_tooltips(seg_text, df_tok, REL_PROB_TH, "rel_prob")
 
663
  else:
664
- seg_html = highlight_tokens_with_tooltips(seg_text, df_tok, 0.10, "entropy")
665
- parts.append(seg_html)
 
 
 
 
666
  dfs.append(df_tok)
667
- predicted_html = "<br>".join(parts)
 
668
  df_all = pd.concat(dfs, ignore_index=True)
669
  else:
670
- # seg_text, df_all = predict_and_score_once(pil_img, line_id=1, topk=TOPK)
671
- seg_text, df_all = predict_and_score_once(
672
- pil_img, model, tokenizer, feature_extractor, device, line_id=1, topk=TOPK
673
- )
674
  seg_text = clean_text(seg_text)
675
  if highlight_metric == "Relative Probability":
676
- seg_html = highlight_tokens_with_tooltips(seg_text, df_all, REL_PROB_TH, "rel_prob")
 
677
  else:
678
- seg_html = highlight_tokens_with_tooltips(seg_text, df_all, 0.10, "entropy")
679
- predicted_html = seg_html
680
-
 
 
 
681
  overlay_img = draw_boxes(pil_img, boxes) if boxes else pil_img
682
-
683
- # Clean text for editing (strip HTML)
684
- # clean_pred_text = re.sub(r"<[^>]+>", "", predicted_html)
685
-
686
- # clean_pred_text = re.sub(r"<[^>]+>", "", predicted_html)
687
- # clean_pred_text = clean_pred_text.replace("<br>", "\n").strip()
688
- clean_pred_text = "\n".join(plain_lines)
689
-
690
-
691
- # Save outputs to temporary files
692
- tmp_dir = tempfile.mkdtemp()
693
- txt_path = os.path.join(tmp_dir, "ocr_prediction.txt")
694
- # csv_path = os.path.join(tmp_dir, "token_scores.csv")
695
-
696
- with open(txt_path, "w", encoding="utf-8") as f:
697
- f.write(clean_pred_text)
698
-
699
- # if df_all is not None and not df_all.empty:
700
- # df_all.to_csv(csv_path, index=False, encoding="utf-8")
701
-
702
-
703
- # return overlay_img, predicted_html, df_all, clean_pred_text, txt_path, csv_path
704
- return overlay_img, predicted_html, clean_pred_text, txt_path
 
 
 
 
 
 
 
 
705
 
706
 
707
  # ----------------------------------------------------------------------
708
  # Build Gradio Interface
709
  #
710
  def create_gradio_interface():
 
711
  with gr.Blocks(title="Old Nepali HTR") as demo:
712
- gr.Markdown("""
713
- # Old Nepali HTR (Gradio)
714
- Upload a scanned image and (optionally) a segmentation XML file.
715
- Choose preprocessing steps and a highlight metric, then click **Run OCR**.
716
- """)
717
-
718
  with gr.Row():
719
  image_input = gr.Image(type="numpy", label="Upload Image")
720
- xml_input = gr.File(label="Upload XML (optional)", type="binary")
721
-
722
  with gr.Row():
723
  apply_gray_checkbox = gr.Checkbox(label="Convert to Grayscale", value=False)
724
  apply_bin_checkbox = gr.Checkbox(label="Binarize", value=False)
725
-
726
-
 
 
727
  run_btn = gr.Button("Run OCR")
 
 
 
 
 
 
 
 
 
 
 
728
 
729
- with gr.Row():
730
- overlay_output = gr.Image(label="Detected Regions")
731
- predictions_output = gr.HTML(
732
- label="Predictions",
733
- container=True,
734
- elem_classes=["predictions-box"]
735
- )
736
-
737
- # Add subtle border to Predictions
738
- gr.HTML("""
739
- <style>
740
- .predictions-box {
741
- border: 1px solid #d0d0d0;
742
- border-radius: 8px;
743
- padding: 12px;
744
- background-color: #fafafa;
745
- min-height: 200px;
746
- overflow-y: auto;
747
- }
748
- </style>
749
- """)
750
-
751
- editable_text = gr.Textbox(label="Edit Recognized Text", lines=8, interactive=True)
752
-
753
- with gr.Row():
754
- download_text = gr.File(label="Download Raw Text (.txt)")
755
- download_edited = gr.File(label="Download Edited Text (.txt)")
756
 
757
- # Run OCR and populate results
758
  run_btn.click(
759
- fn=run_ocr,
760
- inputs=[image_input, xml_input, apply_gray_checkbox, apply_bin_checkbox],
761
- outputs=[
762
- overlay_output,
763
- predictions_output,
764
- editable_text,
765
- download_text,
766
- ],
767
  )
768
 
769
- # Save edited text dynamically
770
- def save_edited_text(text):
771
- import tempfile, os
772
- tmp_dir = tempfile.mkdtemp()
773
- path = os.path.join(tmp_dir, "edited_ocr_text.txt")
774
- with open(path, "w", encoding="utf-8") as f:
775
- f.write(text)
776
- return path
777
 
778
- editable_text.change(fn=save_edited_text, inputs=editable_text, outputs=download_edited)
 
 
 
 
779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
780
 
 
 
 
 
 
781
  return demo
782
 
783
 
 
1
  """
2
  Gradio application for performing OCR on scanned Old Nepali documents.
3
+
4
  This script is a Gradio port of a Streamlit application originally built
5
  to visualize and edit OCR output. It loads a pre‑trained model for
6
  sequence decoding, accepts an input image (and optional segmentation
7
  XML in ALTO format), performs OCR on segmented lines, highlights tokens
8
  with low confidence and offers downloads of both the raw text and per
9
  token scores.
10
+
11
  The heavy lifting functions (model loading, pre‑processing, inference
12
  and highlighting) are adapted directly from the Streamlit version. The
13
  UI has been simplified for Gradio: users upload an image and optional
 
15
  OCR. The results are displayed alongside the overlaid segmentation
16
  boxes and a table of token scores. An editable textbox lets users
17
  modify the predicted text before downloading it.
18
+
19
  To run this app locally, install gradio (`pip install gradio`) and
20
  execute this script with Python:
21
+
22
  python gradio_app.py
23
+
24
  """
25
 
26
  import io
 
93
  @lru_cache(maxsize=1)
94
  def load_model():
95
  """Load the OCR model, tokenizer and feature extractor.
96
+
97
  Returns
98
  -------
99
  model : VisionEncoderDecoderModel
 
122
  #
123
 
124
  def clean_text(text: str) -> str:
125
+ """Normalize and collapse whitespace from a decoded string.
126
+
127
+ Parameters
128
+ ----------
129
+ text : str
130
+ The raw decoded string from the model.
131
+
132
+ Returns
133
+ -------
134
+ str
135
+ The cleaned string with Unicode normalization and whitespace
136
+ removed. All whitespace characters are stripped since the
137
+ predictions are later tokenized at the akshara (syllable) level.
138
+ """
139
  text = unicodedata.normalize("NFC", text)
140
  text = CLEANUP.sub("", text)
141
  return re.sub(r"\s+", "", text)
 
143
 
144
  def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
145
  """Resize the image so that its longest side equals max_side.
146
+
147
  Parameters
148
  ----------
149
  image : PIL.Image
150
  Input image.
151
  max_side : int, optional
152
  Maximum allowed size for the longest side of the image.
153
+
154
  Returns
155
  -------
156
  PIL.Image
 
173
  #
174
  def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
175
  """Parse ALTO or PAGE XML to extract bounding boxes.
176
+
177
  Parameters
178
  ----------
179
  xml_bytes : bytes
 
184
  If provided, image_size=(width, height) allows rescaling
185
  coordinates to match the actual image. ALTO files often store
186
  absolute page sizes that differ from the image dimensions.
187
+
188
  Returns
189
  -------
190
  list of dict
 
332
 
333
  def draw_boxes(img: Image.Image, boxes):
334
  """Overlay semi‑transparent red polygons or rectangles on an image.
335
+
336
  Parameters
337
  ----------
338
  img : PIL.Image
339
  The base image.
340
  boxes : list of dict
341
  Segmentation boxes with either 'points' or 'bbox' keys.
342
+
343
  Returns
344
  -------
345
  PIL.Image
 
369
  # ----------------------------------------------------------------------
370
  # OCR inference per line
371
  #
372
+ def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
 
 
 
 
 
 
 
 
 
373
  """Run the model on a single cropped line and return predictions and scores.
374
+
375
  This helper wraps the model.generate call to obtain per‑token
376
  probabilities and derives a DataFrame summarizing each decoding step.
377
+
378
  Parameters
379
  ----------
380
  image : PIL.Image
 
383
  Identifier used in the output DataFrame.
384
  topk : int, optional
385
  Number of alternative tokens to keep for each decoding position.
386
+
387
  Returns
388
  -------
389
  decoded_text : str
 
393
  columns: line_id, seq_pos, token_id, token, confidence,
394
  rel_prob, entropy, gap12, alt_tokens, alt_probs.
395
  """
396
+ model, tokenizer, feature_extractor, device = load_model()
397
  img = prepare_image(image)
398
  pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
399
  amp_ctx = get_amp_ctx()
 
401
  try:
402
  out = model.generate(
403
  pixel_values,
404
+ max_length=MAX_LEN,
405
  num_beams=5,
406
  do_sample=False,
407
  return_dict_in_generate=True,
 
414
  if "out of memory" in str(e).lower():
415
  out = model.generate(
416
  pixel_values,
417
+ max_length=MAX_LEN,
418
  num_beams=1,
419
  do_sample=False,
420
  return_dict_in_generate=True,
 
530
  return [(t if t is not None else "") for t in (s or "").split("|")]
531
 
532
 
 
533
  def highlight_tokens_with_tooltips(
534
  line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
535
  ) -> str:
536
  """Insert HTML spans around tokens whose chosen metric exceeds threshold.
537
+
538
  The metric column can be "rel_prob" (relative probability) or
539
  "entropy". Tokens with a value strictly greater than red_threshold
540
  will be wrapped in a span with a tooltip listing alternative
541
  predictions and their probabilities.
542
+
543
  Parameters
544
  ----------
545
  line_text : str
 
550
  Values above this threshold will be highlighted.
551
  metric_column : str
552
  Column name in df_tok used for thresholding.
553
+
554
  Returns
555
  -------
556
  str
 
632
  # ----------------------------------------------------------------------
633
  # Main OCR wrapper for Gradio
634
  #
635
+ def run_ocr(
636
+ image: np.ndarray | None,
637
+ xml_file: tuple | None,
638
+ apply_gray: bool,
639
+ apply_bin: bool,
640
+ highlight_metric: str,
641
+ ):
642
+ """Run the OCR pipeline on user inputs and return results for Gradio.
643
 
644
+ Parameters
645
+ ----------
646
+ image : numpy.ndarray or None
647
+ The uploaded image converted to a NumPy array by Gradio. If
648
+ None, the function returns empty results.
649
+ xml_file : tuple or None
650
+ A tuple representing the uploaded XML file as provided by
651
+ gr.File. The first element is the file name and the second is
652
+ bytes. If None, no segmentation is applied and the entire
653
+ image is processed as a single line.
654
+ apply_gray : bool
655
+ Whether to convert the image to grayscale before OCR.
656
+ apply_bin : bool
657
+ Whether to apply binarization (Otsu threshold) before OCR. If
658
+ selected, grayscale conversion is applied first automatically.
659
+ highlight_metric : str
660
+ Which metric to use for highlighting ("Relative Probability" or
661
+ "Entropy").
662
 
663
+ Returns
664
+ -------
665
+ overlay_img : PIL.Image or None
666
+ Image with segmentation boxes drawn. None if no input image.
667
+ predictions_html : str
668
+ HTML formatted predicted text with highlighted tokens.
669
+ df_scores : pandas.DataFrame or None
670
+ DataFrame of per‑token statistics. None if no input image.
671
+ txt_file_path : str or None
672
+ Path to a temporary .txt file containing the plain predicted text.
673
+ csv_file_path : str or None
674
+ Path to a temporary CSV file containing the extended token scores.
675
+ """
676
+ if image is None:
677
+ return None, "", None, None, None
678
+ # Convert the numpy array to a PIL image
679
  pil_img = Image.fromarray(image).convert("RGB")
680
+ # Apply preprocessing as requested
681
  if apply_gray:
682
  pil_img = pil_img.convert("L").convert("RGB")
683
  if apply_bin:
684
  img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
685
  _, bin_img = cv2.threshold(img_cv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
686
  pil_img = Image.fromarray(bin_img).convert("RGB")
687
+ # Parse segmentation boxes if XML provided
688
  boxes = []
689
+ if xml_file is not None and isinstance(xml_file, tuple) and len(xml_file) == 2:
690
+ # xml_file comes as (name, bytes) from Gradio
691
+ _, xml_bytes = xml_file
692
+ try:
 
 
 
 
 
 
 
 
 
 
693
  boxes = parse_boxes_from_xml(xml_bytes, level="line", image_size=pil_img.size)
694
  boxes = sort_boxes_reading_order(boxes)[:MAX_LINES]
695
+ except Exception:
696
+ boxes = []
697
+ # Run OCR for each segmented line or the whole image
698
  dfs = []
699
+ concatenated_parts = []
700
+ line_text_by_id = {}
 
 
 
701
  if boxes:
702
+ pad = 2
703
  for idx, b in enumerate(boxes, 1):
704
+ # Create a tight crop around the line
705
+ if "points" in b:
706
+ pts = b["points"]
707
+ mask = Image.new("L", pil_img.size, 0)
708
+ ImageDraw.Draw(mask).polygon(pts, outline=1, fill=255)
709
+ seg_img = Image.new("RGB", pil_img.size, (255, 255, 255))
710
+ seg_img.paste(pil_img, mask=mask)
711
+ xs = [x for x, y in pts]
712
+ ys = [y for x, y in pts]
713
+ x1 = max(0, int(min(xs) - pad))
714
+ y1 = max(0, int(min(ys) - pad))
715
+ x2 = min(pil_img.width, int(max(xs) + pad))
716
+ y2 = min(pil_img.height, int(max(ys) + pad))
717
+ crop = seg_img.crop((x1, y1, x2, y2))
718
+ else:
719
+ x1, y1, x2, y2 = b["bbox"]
720
+ x1p = max(0, x1 - pad)
721
+ y1p = max(0, y1 - pad)
722
+ x2p = min(pil_img.width, x2 + pad)
723
+ y2p = min(pil_img.height, y2 + pad)
724
+ crop = pil_img.crop((x1p, y1p, x2p, y2p))
725
+ # Run inference on the crop
726
+ seg_text, df_tok = predict_and_score_once(crop, line_id=idx, topk=TOPK)
727
  seg_text = clean_text(seg_text)
728
+ # Choose metric
729
  if highlight_metric == "Relative Probability":
730
+ red_threshold = REL_PROB_TH
731
+ metric_col = "rel_prob"
732
  else:
733
+ red_threshold = 0.10 # heuristic threshold for entropy
734
+ metric_col = "entropy"
735
+ # Highlight uncertain tokens
736
+ seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_tok, red_threshold, metric_col)
737
+ concatenated_parts.append(seg_text_flagged)
738
+ df_tok["line_id"] = idx
739
  dfs.append(df_tok)
740
+ line_text_by_id[idx] = seg_text_flagged
741
+ predicted_html = "<br>".join(concatenated_parts).strip()
742
  df_all = pd.concat(dfs, ignore_index=True)
743
  else:
744
+ # Single pass on the whole image
745
+ seg_text, df_all = predict_and_score_once(pil_img, line_id=1, topk=TOPK)
 
 
746
  seg_text = clean_text(seg_text)
747
  if highlight_metric == "Relative Probability":
748
+ red_threshold = REL_PROB_TH
749
+ metric_col = "rel_prob"
750
  else:
751
+ red_threshold = 0.10
752
+ metric_col = "entropy"
753
+ seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_all, red_threshold, metric_col)
754
+ predicted_html = seg_text_flagged
755
+ line_text_by_id[1] = seg_text_flagged
756
+ # Draw overlay image
757
  overlay_img = draw_boxes(pil_img, boxes) if boxes else pil_img
758
+ # Create downloads
759
+ df_all = df_all.copy()
760
+ # Drop the last empty token per line to tidy up output
761
+ df_all.sort_values(["line_id", "seq_pos"], inplace=True)
762
+ to_drop = []
763
+ for line_id, group in df_all.groupby("line_id"):
764
+ if group.iloc[-1]["token"].strip() == "":
765
+ to_drop.append(group.index[-1])
766
+ df_all = df_all.drop(index=to_drop)
767
+ # Prepare plain text by stripping HTML tags and replacing <br>
768
+ plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n"))
769
+ # Write temporary files
770
+ txt_path = None
771
+ csv_path = None
772
+ try:
773
+ txt_fd = io.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
774
+ txt_fd.write(plain_text)
775
+ txt_fd.flush()
776
+ txt_path = txt_fd.name
777
+ txt_fd.close()
778
+ except Exception:
779
+ txt_path = None
780
+ try:
781
+ csv_fd = io.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8")
782
+ df_all.to_csv(csv_fd, index=False)
783
+ csv_fd.flush()
784
+ csv_path = csv_fd.name
785
+ csv_fd.close()
786
+ except Exception:
787
+ csv_path = None
788
+ return overlay_img, predicted_html, df_all, txt_path, csv_path
789
 
790
 
791
  # ----------------------------------------------------------------------
792
  # Build Gradio Interface
793
  #
794
  def create_gradio_interface():
795
+ """Create and return the Gradio Blocks interface."""
796
  with gr.Blocks(title="Old Nepali HTR") as demo:
797
+ gr.Markdown("""# Old Nepali HTR (Gradio)\n\nUpload a scanned image and (optionally) a segmentation XML file. Choose preprocessing\nsteps and a highlight metric, then click **Run OCR** to extract the text.\nUncertain tokens are highlighted with tooltips showing alternative predictions.\nYou can edit the plain text below and download it or the full token scores.""")
 
 
 
 
 
798
  with gr.Row():
799
  image_input = gr.Image(type="numpy", label="Upload Image")
800
+ xml_input = gr.File(label="Upload segmentation XML (optional)")
 
801
  with gr.Row():
802
  apply_gray_checkbox = gr.Checkbox(label="Convert to Grayscale", value=False)
803
  apply_bin_checkbox = gr.Checkbox(label="Binarize", value=False)
804
+ metric_radio = gr.Radio([
805
+ "Relative Probability",
806
+ "Entropy",
807
+ ], label="Highlight tokens by", value="Relative Probability")
808
  run_btn = gr.Button("Run OCR")
809
+ # Outputs
810
+ overlay_output = gr.Image(label="Detected Regions")
811
+ predictions_output = gr.HTML(label="Predictions (HTML)")
812
+ df_output = gr.DataFrame(label="Token Scores", interactive=False)
813
+ txt_file_output = gr.File(label="Download OCR Prediction (.txt)")
814
+ csv_file_output = gr.File(label="Download Token Scores (.csv)")
815
+ # Editable text
816
+ edited_text = gr.Textbox(
817
+ label="Edit full predicted text", lines=8, interactive=True
818
+ )
819
+ download_edited_btn = gr.Button("Download edited text")
820
 
821
+ # Callback for OCR
822
+ def on_run(image, xml, gray, binarize, metric):
823
+ return run_ocr(image, xml, gray, binarize, metric)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
 
 
825
  run_btn.click(
826
+ fn=on_run,
827
+ inputs=[image_input, xml_input, apply_gray_checkbox, apply_bin_checkbox, metric_radio],
828
+ outputs=[overlay_output, predictions_output, df_output, txt_file_output, csv_file_output],
 
 
 
 
 
829
  )
830
 
831
+ # Populate editable text with plain text from predictions
832
+ def update_edited_text(pred_html):
833
+ plain = re.sub(r"<[^>]*>", "", (pred_html or "").replace("<br>", "\n"))
834
+ return plain
 
 
 
 
835
 
836
+ predictions_output.change(
837
+ fn=update_edited_text,
838
+ inputs=predictions_output,
839
+ outputs=edited_text,
840
+ )
841
 
842
+ # Download edited text by writing to a temporary file
843
+ def download_edited(txt):
844
+ if not txt:
845
+ return None
846
+ try:
847
+ fd = io.NamedTemporaryFile(delete=False, suffix=".txt", mode="w", encoding="utf-8")
848
+ fd.write(txt)
849
+ fd.flush()
850
+ path = fd.name
851
+ fd.close()
852
+ return path
853
+ except Exception:
854
+ return None
855
 
856
+ download_edited_btn.click(
857
+ fn=download_edited,
858
+ inputs=edited_text,
859
+ outputs=txt_file_output,
860
+ )
861
  return demo
862
 
863