File size: 33,513 Bytes
6c8fcdd
 
47e7fc8
6c8fcdd
 
 
 
 
 
47e7fc8
6c8fcdd
 
 
 
 
 
 
47e7fc8
6c8fcdd
 
47e7fc8
6c8fcdd
47e7fc8
6c8fcdd
185acd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864fa38
185acd0
6c8fcdd
 
 
 
 
 
 
 
 
185acd0
6c8fcdd
 
185acd0
6c8fcdd
 
 
 
185acd0
6c8fcdd
 
 
 
185acd0
6c8fcdd
 
 
 
185acd0
6c8fcdd
 
 
 
185acd0
6c8fcdd
 
185acd0
 
 
6c8fcdd
 
 
 
 
 
185acd0
 
6c8fcdd
47e7fc8
6c8fcdd
 
 
 
 
 
 
 
 
 
 
185acd0
6c8fcdd
 
185acd0
 
 
 
 
 
 
 
 
6c8fcdd
 
 
 
185acd0
47e7fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
185acd0
 
 
 
 
 
6c8fcdd
47e7fc8
6c8fcdd
 
 
 
 
 
47e7fc8
6c8fcdd
 
 
 
 
185acd0
 
 
 
 
 
 
 
6c8fcdd
185acd0
 
6c8fcdd
 
 
185acd0
 
6c8fcdd
47e7fc8
6c8fcdd
 
 
 
 
 
 
 
 
 
47e7fc8
6c8fcdd
 
 
 
 
 
 
 
 
185acd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c8fcdd
185acd0
 
 
 
 
 
 
6c8fcdd
47e7fc8
6c8fcdd
 
 
 
 
 
47e7fc8
6c8fcdd
 
 
 
 
 
185acd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c8fcdd
 
 
47e7fc8
6c8fcdd
47e7fc8
6c8fcdd
 
47e7fc8
6c8fcdd
 
 
 
 
 
 
 
47e7fc8
6c8fcdd
 
 
 
 
 
 
 
 
47e7fc8
185acd0
 
 
 
 
 
 
e6ba862
 
185acd0
 
 
 
 
 
 
 
 
 
 
47e7fc8
185acd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c8fcdd
 
 
 
 
 
 
185acd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c8fcdd
47e7fc8
6c8fcdd
 
 
 
47e7fc8
6c8fcdd
 
 
 
 
 
 
 
 
 
47e7fc8
6c8fcdd
 
 
 
 
185acd0
 
e0d71f5
 
185acd0
e0d71f5
185acd0
 
 
e0d71f5
 
 
 
185acd0
 
 
e0d71f5
185acd0
 
 
 
 
 
 
 
 
 
 
 
e0d71f5
185acd0
 
 
e0d71f5
185acd0
 
e0d71f5
185acd0
 
e0d71f5
185acd0
 
 
 
 
 
 
 
e0d71f5
185acd0
 
 
 
e0d71f5
185acd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47e7fc8
 
 
 
 
 
 
 
c59cd27
47e7fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0d71f5
47e7fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185acd0
47e7fc8
185acd0
 
 
 
 
 
47e7fc8
3671a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47e7fc8
185acd0
47e7fc8
 
185acd0
47e7fc8
185acd0
47e7fc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185acd0
47e7fc8
185acd0
47e7fc8
 
185acd0
47e7fc8
 
 
 
 
 
185acd0
47e7fc8
 
185acd0
 
47e7fc8
 
185acd0
 
47e7fc8
 
185acd0
47e7fc8
 
 
 
 
 
e0d71f5
47e7fc8
 
 
 
 
 
 
 
 
 
 
 
6f2b338
864fa38
 
 
 
 
 
 
 
 
185acd0
 
 
 
 
 
47e7fc8
864fa38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185acd0
 
3671a94
 
 
 
 
 
 
 
 
 
864fa38
 
 
 
 
 
 
185acd0
47e7fc8
864fa38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f2b338
 
47e7fc8
 
 
 
6f2b338
864fa38
47e7fc8
864fa38
 
 
8204e08
185acd0
47e7fc8
864fa38
 
 
8204e08
47e7fc8
 
 
 
185acd0
47e7fc8
 
 
 
 
8204e08
c59cd27
 
0862c1e
 
e0d71f5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
"""
Gradio application for performing OCR on scanned Old Nepali documents.

This script is a Gradio port of a Streamlit application originally built
to visualize and edit OCR output. It loads a pre‑trained model for
sequence decoding, accepts an input image (and optional segmentation
XML in ALTO format), performs OCR on segmented lines, highlights tokens
with low confidence and offers downloads of both the raw text and per
token scores.

The heavy lifting functions (model loading, pre‑processing, inference
and highlighting) are adapted directly from the Streamlit version. The
UI has been simplified for Gradio: users upload an image and optional
XML file, choose preprocessing steps and a highlight metric, then run
OCR.  The results are displayed alongside the overlaid segmentation
boxes and a table of token scores.  An editable textbox lets users
modify the predicted text before downloading it.

To run this app locally, install gradio (`pip install gradio`) and
execute this script with Python:

    python gradio_app.py

"""

import io
import os
import re
import base64
import unicodedata
import contextlib
import xml.etree.ElementTree as ET
from collections import defaultdict
from functools import lru_cache

import numpy as np
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
import cv2
import torch
from transformers import (
    VisionEncoderDecoderModel,
    PreTrainedTokenizerFast,
    TrOCRProcessor,
)
from matplotlib import cm
import gradio as gr
import tempfile

# ----------------------------------------------------------------------
# Configuration
#
# These constants control various aspects of the OCR pipeline. You can
# adjust them to trade off accuracy, performance or output volume.

# The maximum number of tokens to decode for a single line.  If your
# documents typically have longer lines you can increase this value, but
# beware that very long sequences may cause more memory usage.
MAX_LEN: int = 128

# How many alternative tokens to keep when computing per–token statistics.
TOPK: int = 3

# If an XML segmentation file is provided, only process the first
# MAX_LINES lines.  This prevents huge documents from consuming
# excessive resources.
MAX_LINES: int = 120

# Images are resized such that the longest side does not exceed this
# number of pixels before passing them to the OCR model.  Increasing
# this value may improve accuracy at the cost of speed and memory.
RESIZE_MAX_SIDE: int = 800

# Threshold used when highlighting tokens by relative probability.  A
# ratio of Top2/Top1 greater than this value will cause the token to
# be highlighted in red.
REL_PROB_TH: float = 0.70

# A regex used to clean up Unicode control characters before text
# normalization.  Soft hyphens, zero width spaces and similar marks
# interfere with accurate token matching.
CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]")

# Default font path for rendering predictions directly on the image.
FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf")


# ----------------------------------------------------------------------
# Model loading
#
# Loading the model and associated tokenizer/processor is slow.  Use
# functools.lru_cache to ensure this only happens once per process.

@lru_cache(maxsize=1)
def load_model():
    """Load the OCR model, tokenizer and feature extractor.

    Returns
    -------
    model : VisionEncoderDecoderModel
        The loaded model in evaluation mode.
    tokenizer : PreTrainedTokenizerFast
        Tokenizer corresponding to the decoder part of the model.
    feature_extractor : callable
        Feature extractor converting PIL images into model inputs.
    device : torch.device
        The device (CPU or CUDA) used for inference.
    """
    model_path = "AnjaliSarawgi/model-oct"
    # In an offline environment the HF token is None; if you wish
    # to use a private model you can set HF_TOKEN in your environment.
    hf_token = os.environ.get("HF_TOKEN")
    model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token)
    tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token)
    processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten", token=None)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device).eval()
    return model, tokenizer, processor.feature_extractor, device


# ----------------------------------------------------------------------
# Utility functions
#

def clean_text(text: str) -> str:
    """Normalize and collapse whitespace from a decoded string.

    Parameters
    ----------
    text : str
        The raw decoded string from the model.

    Returns
    -------
    str
        The cleaned string with Unicode normalization and whitespace
        removed.  All whitespace characters are stripped since the
        predictions are later tokenized at the akshara (syllable) level.
    """
    text = unicodedata.normalize("NFC", text)
    text = CLEANUP.sub("", text)
    return re.sub(r"\s+", "", text)


def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image:
    """Resize the image so that its longest side equals max_side.

    Parameters
    ----------
    image : PIL.Image
        Input image.
    max_side : int, optional
        Maximum allowed size for the longest side of the image.

    Returns
    -------
    PIL.Image
        The resized image.
    """
    img = image.convert("RGB")
    w, h = img.size
    if max(w, h) > max_side:
        img.thumbnail((max_side, max_side), Image.LANCZOS)
    return img


def get_amp_ctx():
    """Return the appropriate context manager for automatic mixed precision."""
    return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext


# ----------------------------------------------------------------------
# XML parsing and segmentation
#
def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None):
    """Parse ALTO or PAGE XML to extract bounding boxes.

    Parameters
    ----------
    xml_bytes : bytes
        Raw XML bytes.
    level : {"block", "line", "word"}, optional
        The segmentation level to extract.  For OCR we use "line".
    image_size : tuple or None
        If provided, image_size=(width, height) allows rescaling
        coordinates to match the actual image.  ALTO files often store
        absolute page sizes that differ from the image dimensions.

    Returns
    -------
    list of dict
        Each dict represents a bounding box with keys:
        - "bbox": [x1, y1, x2, y2]
        - "points": list of (x, y) if polygonal coordinates exist
        - "id": line identifier (string)
        - "label": the type of element (e.g. TextLine)
    """
    def _strip_ns(elem):
        for e in elem.iter():
            if isinstance(e.tag, str) and e.tag.startswith("{"):
                e.tag = e.tag.split("}", 1)[1]

    root = ET.parse(io.BytesIO(xml_bytes)).getroot()
    _strip_ns(root)
    boxes = []

    # ALTO format handling
    if root.tag.lower() == "alto":
        tag_map = {"block": "TextBlock", "line": "TextLine", "word": "String"}
        tag = tag_map.get(level, "TextLine")
        page_el = root.find(".//Page")
        page_w = page_h = None
        if page_el is not None:
            try:
                page_w = float(page_el.get("WIDTH") or 0)
                page_h = float(page_el.get("HEIGHT") or 0)
            except Exception:
                page_w = page_h = None
        sx = sy = 1.0
        if image_size and page_w and page_h:
            img_w, img_h = image_size
            sx = (img_w / page_w) if page_w else 1.0
            sy = (img_h / page_h) if page_h else 1.0
        for el in root.findall(f".//{tag}"):
            poly = el.find(".//Shape/Polygon")
            got_box = False
            pts = None
            if poly is not None and poly.get("POINTS"):
                raw = poly.get("POINTS").strip()
                tokens = re.split(r"[ ,]+", raw)
                nums = []
                for t in tokens:
                    try:
                        nums.append(float(t))
                    except Exception:
                        pass
                pts = []
                if len(nums) >= 6 and len(nums) % 2 == 0:
                    for i in range(0, len(nums), 2):
                        pts.append((nums[i] * sx, nums[i + 1] * sy))
                if pts:
                    xs = [p[0] for p in pts]
                    ys = [p[1] for p in pts]
                    x1, x2 = int(min(xs)), int(max(xs))
                    y1, y2 = int(min(ys)), int(max(ys))
                    got_box = (x2 > x1 and y2 > y1)
            if not got_box:
                try:
                    hpos = float(el.get("HPOS", 0)) * sx
                    vpos = float(el.get("VPOS", 0)) * sy
                    width = float(el.get("WIDTH", 0)) * sx
                    height = float(el.get("HEIGHT", 0)) * sy
                    x1, y1 = int(hpos), int(vpos)
                    x2, y2 = int(hpos + width), int(vpos + height)
                except Exception:
                    continue
                if x2 <= x1 or y2 <= y1:
                    continue
            label = tag if tag != "String" else (el.get("CONTENT") or "String")
            boxes.append(
                {
                    "label": label,
                    "bbox": [x1, y1, x2, y2],
                    "source": "alto",
                    "id": el.get("ID", ""),
                    **({"points": pts} if pts else {}),
                }
            )
        return boxes

    # PAGE XML handling
    for region in root.findall(".//TextRegion"):
        coords = region.find(".//Coords")
        pts_attr = coords.get("points") if coords is not None else None
        if not pts_attr:
            continue
        pts = []
        for token in pts_attr.strip().split():
            if "," in token:
                xx, yy = token.split(",", 1)
                try:
                    pts.append((float(xx), float(yy)))
                except Exception:
                    pass
        if not pts:
            continue
        xs = [p[0] for p in pts]
        ys = [p[1] for p in pts]
        x1, x2 = int(min(xs)), int(max(xs))
        y1, y2 = int(min(ys)), int(max(ys))
        if x2 > x1 and y2 > y1:
            boxes.append(
                {
                    "label": "TextRegion",
                    "bbox": [x1, y1, x2, y2],
                    "source": "page",
                    "id": region.get("id", ""),
                }
            )
    if boxes:
        return boxes
    # Fallback: Pascal VOC
    for obj in root.findall(".//object"):
        bb = obj.find("bndbox")
        if bb is None:
            continue
        try:
            xmin = int(float(bb.findtext("xmin")))
            ymin = int(float(bb.findtext("ymin")))
            xmax = int(float(bb.findtext("xmax")))
            ymax = int(float(bb.findtext("ymax")))
            if xmax > xmin and ymax > ymin:
                boxes.append(
                    {
                        "label": (obj.findtext("name") or "region").strip(),
                        "bbox": [xmin, ymin, xmax, ymax],
                        "source": "voc",
                        "id": obj.findtext("name") or "",
                    }
                )
        except Exception:
            pass
    return boxes


def sort_boxes_reading_order(boxes, y_tol: int = 10):
    """Sort bounding boxes top‑to‑bottom then left‑to‑right."""
    def key(b):
        x1, y1, x2, y2 = b["bbox"]
        return (round(y1 / max(1, y_tol)), y1, x1)
    return sorted(boxes, key=key)


def draw_boxes(img: Image.Image, boxes):
    """Overlay semi‑transparent red polygons or rectangles on an image.

    Parameters
    ----------
    img : PIL.Image
        The base image.
    boxes : list of dict
        Segmentation boxes with either 'points' or 'bbox' keys.

    Returns
    -------
    PIL.Image
        An image with red overlays marking each box.  Boxes are numbered
        starting from 1.
    """
    base = img.convert("RGBA")
    overlay = Image.new("RGBA", base.size, (0, 0, 0, 0))
    draw = ImageDraw.Draw(overlay)
    thickness = max(3, min(base.size) // 200)
    for i, b in enumerate(boxes, 1):
        if "points" in b and b["points"]:
            pts = [(int(x), int(y)) for x, y in b["points"]]
            draw.polygon(pts, outline=(255, 0, 0, 255), fill=(255, 0, 0, 64))
            xs = [p[0] for p in pts]
            ys = [p[1] for p in pts]
            x1, y1 = min(xs), min(ys)
        else:
            x1, y1, x2, y2 = map(int, b["bbox"])
            draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0, 255), width=thickness, fill=(255, 0, 0, 64))
        tag_w, tag_h = 40, 24
        draw.rectangle([x1, y1, x1 + tag_w, y1 + tag_h], fill=(255, 0, 0, 190))
        draw.text((x1 + 6, y1 + 4), str(i), fill=(255, 255, 255, 255))
    return Image.alpha_composite(base, overlay).convert("RGB")


# ----------------------------------------------------------------------
# OCR inference per line
#
def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK):
    """Run the model on a single cropped line and return predictions and scores.

    This helper wraps the model.generate call to obtain per‑token
    probabilities and derives a DataFrame summarizing each decoding step.

    Parameters
    ----------
    image : PIL.Image
        Cropped segment to process.
    line_id : int, optional
        Identifier used in the output DataFrame.
    topk : int, optional
        Number of alternative tokens to keep for each decoding position.

    Returns
    -------
    decoded_text : str
        Cleaned predicted string for the line.
    df : pandas.DataFrame
        Table with one row per generated token containing the following
        columns: line_id, seq_pos, token_id, token, confidence,
        rel_prob, entropy, gap12, alt_tokens, alt_probs.
    """
    model, tokenizer, feature_extractor, device = load_model()
    img = prepare_image(image)
    pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device)
    amp_ctx = get_amp_ctx()
    with torch.inference_mode(), amp_ctx():
        try:
            out = model.generate(
                pixel_values,
                max_length=100,
                num_beams=1,
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True,
                use_cache=True,
                eos_token_id=tokenizer.eos_token_id,
            )
        except RuntimeError as e:
            # In case of GPU OOM, fall back to beam=1 without scores
            if "out of memory" in str(e).lower():
                out = model.generate(
                    pixel_values,
                    max_length=MAX_LEN,
                    num_beams=1,
                    do_sample=False,
                    return_dict_in_generate=True,
                    output_scores=False,
                    use_cache=True,
                    eos_token_id=tokenizer.eos_token_id,
                )
            else:
                raise

    seq = out.sequences[0]
    decoded_text = clean_text(tokenizer.decode(seq, skip_special_tokens=True))
    tokens_rows = []
    # out.scores[i] gives logits for the i+1 token of seq
    for step, (logits, tgt) in enumerate(zip(out.scores, seq[1:]), start=1):
        probs = torch.softmax(logits[0].float().cpu(), dim=-1)
        tgt_id = int(tgt.item())
        conf = float(probs[tgt_id].item())
        tk_vals, tk_idx = torch.topk(probs, k=min(topk, probs.shape[0]))
        tk_idx = tk_idx.tolist()
        tk_vals = tk_vals.tolist()
        if tgt_id in tk_idx:
            j = tk_idx.index(tgt_id)
            tk_idx.pop(j)
            tk_vals.pop(j)
        alt_ids = [tgt_id] + tk_idx[: topk - 1]
        alt_ps = [conf] + tk_vals[: topk - 1]
        alt_tokens = [tokenizer.decode([i], skip_special_tokens=True) for i in alt_ids]
        entropy = float((-probs * (probs.clamp_min(1e-12).log())).sum().item())
        gap12 = float(alt_ps[0] - (alt_ps[1] if len(alt_ps) > 1 else 0.0))
        rel_prob = float((alt_ps[1] / alt_ps[0]) if (len(alt_ps) > 1 and alt_ps[0] > 0) else 0.0)
        tokens_rows.append(
            {
                "line_id": line_id,
                "seq_pos": step,
                "token_id": tgt_id,
                "token": alt_tokens[0],
                "confidence": conf,
                "rel_prob": rel_prob,
                "entropy": entropy,
                "gap12": gap12,
                "alt_tokens": "|".join(alt_tokens),
                "alt_probs": "|".join([f"{p:.6f}" for p in alt_ps]),
            }
        )
        del probs
    df = pd.DataFrame(
        tokens_rows,
        columns=[
            "line_id",
            "seq_pos",
            "token_id",
            "token",
            "confidence",
            "rel_prob",
            "entropy",
            "gap12",
            "alt_tokens",
            "alt_probs",
        ],
    )
    return decoded_text, df


# ----------------------------------------------------------------------
# Text splitting into aksharas (syllable units) for highlighting
#
# The following regex and helper functions split a Devanagari string into
# aksharas.  This is necessary to map model tokens back to spans of
# characters when highlighting uncertain predictions.

DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F"  # consonants incl. nukta variants range
INDEP_VOW = "\u0904-\u0914"  # independent vowels
NUKTA = "\u093C"  # nukta
VIRAMA = "\u094D"  # halant/virama
MATRAS = "\u093A-\u094C"  # dependent vowel signs
BINDUS = "\u0901\u0902\u0903"  # chandrabindu, anusvara, visarga
AKSHARA_RE = re.compile(
    rf"(?:"
    rf"(?:[{DEV_CONS}]{NUKTA}?)(?:{VIRAMA}(?:[{DEV_CONS}]{NUKTA}?))*"  # consonant cluster
    rf"(?:[{MATRAS}])?"  # optional matra
    rf"(?:[{BINDUS}])?"  # optional bindu/visarga
    rf"|"
    rf"(?:[{INDEP_VOW}](?:[{BINDUS}])?)"  # independent vowel (+bindu)
    rf")",
    flags=re.UNICODE,
)


def split_aksharas(s: str):
    """Split a string into Devanagari aksharas and return spans."""
    spans = []
    i = 0
    while i < len(s):
        m = AKSHARA_RE.match(s, i)
        if m and m.end() > i:
            spans.append((m.start(), m.end()))
            i = m.end()
        else:
            spans.append((i, i + 1))
            i += 1
    return [s[a:b] for (a, b) in spans], spans


def parse_alt_probs(s: str):
    try:
        return [float(x) for x in (s or "").split("|") if x != ""]
    except Exception:
        return []


def parse_alt_tokens(s: str):
    return [(t if t is not None else "") for t in (s or "").split("|")]


def highlight_tokens_with_tooltips(
    line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str
) -> str:
    """Insert HTML spans around tokens whose chosen metric exceeds threshold.

    The metric column can be "rel_prob" (relative probability) or
    "entropy".  Tokens with a value strictly greater than red_threshold
    will be wrapped in a span with a tooltip listing alternative
    predictions and their probabilities.

    Parameters
    ----------
    line_text : str
        The cleaned line prediction.
    df_tok : pandas.DataFrame
        DataFrame of token statistics for the corresponding line.
    red_threshold : float
        Values above this threshold will be highlighted.
    metric_column : str
        Column name in df_tok used for thresholding.

    Returns
    -------
    str
        An HTML string with <span> elements inserted.
    """
    aks, spans = split_aksharas(line_text)
    joined = "".join(aks)
    used_ranges = []
    insertions = []
    for _, row in df_tok.iterrows():
        token = row.get("token", "").strip()
        try:
            val = float(row.get(metric_column, 0))
        except Exception:
            continue
        if val <= red_threshold or not token:
            continue
        # Try finding the token in the joined akshara sequence
        start_char_idx = joined.find(token)
        if start_char_idx == -1:
            continue
        # Locate matching akshara span
        ak_start = ak_end = None
        cum_len = 0
        for i, ak in enumerate(aks):
            next_len = cum_len + len(ak)
            if cum_len <= start_char_idx < next_len:
                ak_start = i
            if cum_len < start_char_idx + len(token) <= next_len:
                ak_end = i + 1
                break
            cum_len = next_len
        if ak_start is None or ak_end is None:
            continue
        # Avoid overlaps
        if any(r[0] < ak_end and ak_start < r[1] for r in used_ranges):
            continue
        used_ranges.append((ak_start, ak_end))
        # Character positions
        char_start = spans[ak_start][0]
        char_end = spans[ak_end - 1][1]
        # Build tooltip content
        alt_toks = row.get("alt_tokens", "").split("|")
        alt_probs = row.get("alt_probs", "").split("|")
        tooltip_lines = []
        for t, p in zip(alt_toks, alt_probs):
            try:
                prob = float(p)
            except Exception:
                prob = 0.0
            tooltip_lines.append(f"{_html_escape(t)}: {prob:.3f}")
        tooltip = "\n".join(tooltip_lines)
        token_str = _html_escape(line_text[char_start:char_end])
        html_token = f"<span class='ocr-token' data-tooltip='{_html_escape(tooltip)}'>{token_str}</span>"
        insertions.append((char_start, char_end, html_token))
    if not insertions:
        return _html_escape(line_text)
    insertions.sort()
    out_parts = []
    last_idx = 0
    for s, e, html_tok in insertions:
        out_parts.append(_html_escape(line_text[last_idx:s]))
        out_parts.append(html_tok)
        last_idx = e
    out_parts.append(_html_escape(line_text[last_idx:]))
    return "".join(out_parts)


def _html_escape(s: str) -> str:
    return (
        s.replace("&", "&amp;")
        .replace("<", "&lt;")
        .replace(">", "&gt;")
        .replace("\"", "&quot;")
        .replace("'", "&#x27;")
    )


# ----------------------------------------------------------------------
# Main OCR wrapper for Gradio
#
def run_ocr(
    image: np.ndarray | None,
    xml_file: tuple | None,
    apply_gray: bool,
    apply_bin: bool,
    highlight_metric: str,
):
    """Run the OCR pipeline on user inputs and return results for Gradio.

    Parameters
    ----------
    image : numpy.ndarray or None
        The uploaded image converted to a NumPy array by Gradio.  If
        None, the function returns empty results.
    xml_file : tuple or None
        A tuple representing the uploaded XML file as provided by
        gr.File.  The first element is the file name and the second is
        bytes.  If None, no segmentation is applied and the entire
        image is processed as a single line.
    apply_gray : bool
        Whether to convert the image to grayscale before OCR.
    apply_bin : bool
        Whether to apply binarization (Otsu threshold) before OCR.  If
        selected, grayscale conversion is applied first automatically.
    highlight_metric : str
        Which metric to use for highlighting ("Relative Probability" or
        "Entropy").

    Returns
    -------
    overlay_img : PIL.Image or None
        Image with segmentation boxes drawn.  None if no input image.
    predictions_html : str
        HTML formatted predicted text with highlighted tokens.
    df_scores : pandas.DataFrame or None
        DataFrame of per‑token statistics.  None if no input image.
    txt_file_path : str or None
        Path to a temporary .txt file containing the plain predicted text.
    csv_file_path : str or None
        Path to a temporary CSV file containing the extended token scores.
    """
    if image is None:
        return None, "", None, None, None
    # Convert the numpy array to a PIL image
    pil_img = Image.fromarray(image).convert("RGB")
    # Apply preprocessing as requested
    if apply_gray:
        pil_img = pil_img.convert("L").convert("RGB")
    if apply_bin:
        img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY)
        _, bin_img = cv2.threshold(img_cv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        pil_img = Image.fromarray(bin_img).convert("RGB")
    # Parse segmentation boxes if XML provided
    boxes: list = []
    if xml_file:
        # Determine the correct way to extract bytes from the uploaded file.
        xml_bytes = None
        # If gr.File is configured with type="binary", xml_file will be raw bytes.
        if isinstance(xml_file, (bytes, bytearray)):
            xml_bytes = bytes(xml_file)
        # When type="filepath", xml_file would be a str path.
        elif isinstance(xml_file, str):
            try:
                with open(xml_file, "rb") as f:
                    xml_bytes = f.read()
            except Exception:
                xml_bytes = None
        # If a temporary file object is passed in, read its contents.
        elif hasattr(xml_file, "read"):
            try:
                xml_bytes = xml_file.read()
            except Exception:
                xml_bytes = None
        # If xml_file is a dictionary from Gradio (not expected with type="binary"),
        # attempt to extract the data key.
        elif isinstance(xml_file, dict) and "data" in xml_file:
            xml_bytes = xml_file.get("data")
        if xml_bytes:
            try:
                boxes = parse_boxes_from_xml(xml_bytes, level="line", image_size=pil_img.size)
                boxes = sort_boxes_reading_order(boxes)[:MAX_LINES]
            except Exception:
                boxes = []
    # Run OCR for each segmented line or the whole image
    dfs = []
    concatenated_parts = []
    line_text_by_id = {}
    if boxes:
        pad = 2
        for idx, b in enumerate(boxes, 1):
            # Create a tight crop around the line
            if "points" in b:
                pts = b["points"]
                mask = Image.new("L", pil_img.size, 0)
                ImageDraw.Draw(mask).polygon(pts, outline=1, fill=255)
                seg_img = Image.new("RGB", pil_img.size, (255, 255, 255))
                seg_img.paste(pil_img, mask=mask)
                xs = [x for x, y in pts]
                ys = [y for x, y in pts]
                x1 = max(0, int(min(xs) - pad))
                y1 = max(0, int(min(ys) - pad))
                x2 = min(pil_img.width, int(max(xs) + pad))
                y2 = min(pil_img.height, int(max(ys) + pad))
                crop = seg_img.crop((x1, y1, x2, y2))
            else:
                x1, y1, x2, y2 = b["bbox"]
                x1p = max(0, x1 - pad)
                y1p = max(0, y1 - pad)
                x2p = min(pil_img.width, x2 + pad)
                y2p = min(pil_img.height, y2 + pad)
                crop = pil_img.crop((x1p, y1p, x2p, y2p))
            # Run inference on the crop
            seg_text, df_tok = predict_and_score_once(crop, line_id=idx, topk=TOPK)
            seg_text = clean_text(seg_text)
            # Choose metric
            if highlight_metric == "Relative Probability":
                red_threshold = REL_PROB_TH
                metric_col = "rel_prob"
            else:
                red_threshold = 0.10  # heuristic threshold for entropy
                metric_col = "entropy"
            # Highlight uncertain tokens
            seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_tok, red_threshold, metric_col)
            concatenated_parts.append(seg_text_flagged)
            df_tok["line_id"] = idx
            dfs.append(df_tok)
            line_text_by_id[idx] = seg_text_flagged
        predicted_html = "<br>".join(concatenated_parts).strip()
        df_all = pd.concat(dfs, ignore_index=True)
    else:
        # Single pass on the whole image
        seg_text, df_all = predict_and_score_once(pil_img, line_id=1, topk=TOPK)
        seg_text = clean_text(seg_text)
        if highlight_metric == "Relative Probability":
            red_threshold = REL_PROB_TH
            metric_col = "rel_prob"
        else:
            red_threshold = 0.10
            metric_col = "entropy"
        seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_all, red_threshold, metric_col)
        predicted_html = seg_text_flagged
        line_text_by_id[1] = seg_text_flagged
    # Draw overlay image
    overlay_img = draw_boxes(pil_img, boxes) if boxes else pil_img
    # Create downloads
    df_all = df_all.copy()
    # Drop the last empty token per line to tidy up output
    df_all.sort_values(["line_id", "seq_pos"], inplace=True)
    to_drop = []
    for line_id, group in df_all.groupby("line_id"):
        if group.iloc[-1]["token"].strip() == "":
            to_drop.append(group.index[-1])
    df_all = df_all.drop(index=to_drop)
    # Prepare plain text by stripping HTML tags and replacing <br>
    plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n"))
    # Write temporary files

    # return overlay_img, predicted_html
    # Save plain text to a temporary .txt file
    txt_dir = tempfile.gettempdir()
    txt_path = os.path.join(txt_dir, "predictions.txt")
    with open(txt_path, "w", encoding="utf-8") as f:
        f.write(plain_text)


    return overlay_img, predicted_html, txt_path


# ----------------------------------------------------------------------
# Build Gradio Interface
#
def create_gradio_interface():
    """Create and return the Gradio Blocks interface."""
    with gr.Blocks(title="Handwritten Text Recognition (Old Nepali)") as demo:
        gr.Markdown("""# Handwritten Text Recognition (Old Nepali) \n\nUpload an image and (optionally) a segmentation XML file.  Then click **Run OCR** to extract the text.""")
        gr.HTML("""
            <style>
            #prediction-box {
                border: 1px solid #ccc;
                padding: 16px;
                border-radius: 8px;
                background-color: #f9f9f9;
                font-size: 18px;
                line-height: 1.6;
                min-height: 100px;
            }
            
            }
            </style>
            """)
        with gr.Row():
            image_input = gr.Image(type="numpy", label="Upload Image")
            # When used as an input, gr.File returns either a file path or bytes
            # depending on the `type` parameter.  By setting type="binary" we
            # ensure that the XML content is passed directly as bytes to the
            # callback, avoiding the need to reopen a temporary file.
            xml_input = gr.File(
                label="Upload segmentation XML (optional)",
                file_count="single",
                type="binary",
                file_types=[".xml"],
            )
        # with gr.Row():
        #     apply_gray_checkbox = gr.Checkbox(label="Convert to Grayscale", value=False)
        #     apply_bin_checkbox = gr.Checkbox(label="Binarize", value=False)
        #     metric_radio = gr.Radio([
        #         "Relative Probability",
        #         "Entropy",
        #     ], label="Highlight tokens by", value="Relative Probability")
        run_btn = gr.Button("Run OCR")
        # Outputs
        # overlay_output = gr.Image(label="Detected Regions")
        # # predictions_output = gr.HTML(label="Predictions (HTML)")
        # predictions_output = gr.HTML(
        #     label="Predictions (HTML)",
        #     elem_id="prediction-box"
        # )
        # df_output = gr.DataFrame(label="Token Scores", interactive=False)
        with gr.Row():
            with gr.Column(scale=2):
                overlay_output = gr.Image(label="Detected Regions")

            with gr.Column(scale=2):
                predictions_output = gr.HTML(
                    label="Predictions (HTML)",
                    elem_id="prediction-box"
                )
                
                # df_output = gr.DataFrame(label="Token Scores", interactive=False)
        
        # txt_file_output = gr.File(label="Download OCR Prediction (.txt)")
        # csv_file_output = gr.File(label="Download Token Scores (.csv)")
        # Editable text
        edited_text = gr.Textbox(
            label="Edit full predicted text", lines=8, interactive=True
        )
        # download_edited_btn = gr.Button("Download edited text")
        txt_file_output = gr.File(label="Download OCR Prediction (.txt)")
        # Callback for OCR
        def on_run(image, xml):
            return run_ocr(image, xml, False, False, "Relative Probability")


        run_btn.click(
            fn=on_run,
            # inputs=[image_input, xml_input, apply_gray_checkbox, apply_bin_checkbox, metric_radio],
            inputs=[image_input, xml_input],
            outputs=[overlay_output, predictions_output, txt_file_output],
        )
        # Populate editable text with plain text from predictions
        def update_edited_text(pred_html):
            plain = re.sub(r"<[^>]*>", "", (pred_html or "").replace("<br>", "\n"))
            return plain

        predictions_output.change(
            fn=update_edited_text,
            inputs=predictions_output,
            outputs=edited_text,
        )


    return demo


if __name__ == "__main__":
    # Create and launch the Gradio interface
    iface = create_gradio_interface()
    iface.launch()