Spaces:
Running
Running
| """ | |
| Gradio application for performing OCR on scanned Old Nepali documents. | |
| This script is a Gradio port of a Streamlit application originally built | |
| to visualize and edit OCR output. It loads a pre‑trained model for | |
| sequence decoding, accepts an input image (and optional segmentation | |
| XML in ALTO format), performs OCR on segmented lines, highlights tokens | |
| with low confidence and offers downloads of both the raw text and per | |
| token scores. | |
| The heavy lifting functions (model loading, pre‑processing, inference | |
| and highlighting) are adapted directly from the Streamlit version. The | |
| UI has been simplified for Gradio: users upload an image and optional | |
| XML file, choose preprocessing steps and a highlight metric, then run | |
| OCR. The results are displayed alongside the overlaid segmentation | |
| boxes and a table of token scores. An editable textbox lets users | |
| modify the predicted text before downloading it. | |
| To run this app locally, install gradio (`pip install gradio`) and | |
| execute this script with Python: | |
| python gradio_app.py | |
| """ | |
| import io | |
| import os | |
| import re | |
| import base64 | |
| import unicodedata | |
| import contextlib | |
| import xml.etree.ElementTree as ET | |
| from collections import defaultdict | |
| from functools import lru_cache | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image, ImageDraw, ImageFont | |
| import cv2 | |
| import torch | |
| from transformers import ( | |
| VisionEncoderDecoderModel, | |
| PreTrainedTokenizerFast, | |
| TrOCRProcessor, | |
| ) | |
| from matplotlib import cm | |
| import gradio as gr | |
| import tempfile | |
| # ---------------------------------------------------------------------- | |
| # Configuration | |
| # | |
| # These constants control various aspects of the OCR pipeline. You can | |
| # adjust them to trade off accuracy, performance or output volume. | |
| # The maximum number of tokens to decode for a single line. If your | |
| # documents typically have longer lines you can increase this value, but | |
| # beware that very long sequences may cause more memory usage. | |
| MAX_LEN: int = 128 | |
| # How many alternative tokens to keep when computing per–token statistics. | |
| TOPK: int = 3 | |
| # If an XML segmentation file is provided, only process the first | |
| # MAX_LINES lines. This prevents huge documents from consuming | |
| # excessive resources. | |
| MAX_LINES: int = 120 | |
| # Images are resized such that the longest side does not exceed this | |
| # number of pixels before passing them to the OCR model. Increasing | |
| # this value may improve accuracy at the cost of speed and memory. | |
| RESIZE_MAX_SIDE: int = 800 | |
| # Threshold used when highlighting tokens by relative probability. A | |
| # ratio of Top2/Top1 greater than this value will cause the token to | |
| # be highlighted in red. | |
| REL_PROB_TH: float = 0.70 | |
| # A regex used to clean up Unicode control characters before text | |
| # normalization. Soft hyphens, zero width spaces and similar marks | |
| # interfere with accurate token matching. | |
| CLEANUP: re.Pattern = re.compile(r"[\u00AD\u200B\u200C\u200D]") | |
| # Default font path for rendering predictions directly on the image. | |
| FONT_PATH: str = os.path.join("NotoSansDevanagari-Regular.ttf") | |
| # ---------------------------------------------------------------------- | |
| # Model loading | |
| # | |
| # Loading the model and associated tokenizer/processor is slow. Use | |
| # functools.lru_cache to ensure this only happens once per process. | |
| def load_model(): | |
| """Load the OCR model, tokenizer and feature extractor. | |
| Returns | |
| ------- | |
| model : VisionEncoderDecoderModel | |
| The loaded model in evaluation mode. | |
| tokenizer : PreTrainedTokenizerFast | |
| Tokenizer corresponding to the decoder part of the model. | |
| feature_extractor : callable | |
| Feature extractor converting PIL images into model inputs. | |
| device : torch.device | |
| The device (CPU or CUDA) used for inference. | |
| """ | |
| model_path = "AnjaliSarawgi/model-oct" | |
| # In an offline environment the HF token is None; if you wish | |
| # to use a private model you can set HF_TOKEN in your environment. | |
| hf_token = os.environ.get("HF_TOKEN") | |
| model = VisionEncoderDecoderModel.from_pretrained(model_path, token=hf_token) | |
| tokenizer = PreTrainedTokenizerFast.from_pretrained(model_path, token=hf_token) | |
| processor = TrOCRProcessor.from_pretrained("microsoft/trocr-large-handwritten", token=None) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device).eval() | |
| return model, tokenizer, processor.feature_extractor, device | |
| # ---------------------------------------------------------------------- | |
| # Utility functions | |
| # | |
| def clean_text(text: str) -> str: | |
| """Normalize and collapse whitespace from a decoded string. | |
| Parameters | |
| ---------- | |
| text : str | |
| The raw decoded string from the model. | |
| Returns | |
| ------- | |
| str | |
| The cleaned string with Unicode normalization and whitespace | |
| removed. All whitespace characters are stripped since the | |
| predictions are later tokenized at the akshara (syllable) level. | |
| """ | |
| text = unicodedata.normalize("NFC", text) | |
| text = CLEANUP.sub("", text) | |
| return re.sub(r"\s+", "", text) | |
| def prepare_image(image: Image.Image, max_side: int = RESIZE_MAX_SIDE) -> Image.Image: | |
| """Resize the image so that its longest side equals max_side. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Input image. | |
| max_side : int, optional | |
| Maximum allowed size for the longest side of the image. | |
| Returns | |
| ------- | |
| PIL.Image | |
| The resized image. | |
| """ | |
| img = image.convert("RGB") | |
| w, h = img.size | |
| if max(w, h) > max_side: | |
| img.thumbnail((max_side, max_side), Image.LANCZOS) | |
| return img | |
| def get_amp_ctx(): | |
| """Return the appropriate context manager for automatic mixed precision.""" | |
| return torch.cuda.amp.autocast if torch.cuda.is_available() else contextlib.nullcontext | |
| # ---------------------------------------------------------------------- | |
| # XML parsing and segmentation | |
| # | |
| def parse_boxes_from_xml(xml_bytes: bytes, level: str = "line", image_size: tuple | None = None): | |
| """Parse ALTO or PAGE XML to extract bounding boxes. | |
| Parameters | |
| ---------- | |
| xml_bytes : bytes | |
| Raw XML bytes. | |
| level : {"block", "line", "word"}, optional | |
| The segmentation level to extract. For OCR we use "line". | |
| image_size : tuple or None | |
| If provided, image_size=(width, height) allows rescaling | |
| coordinates to match the actual image. ALTO files often store | |
| absolute page sizes that differ from the image dimensions. | |
| Returns | |
| ------- | |
| list of dict | |
| Each dict represents a bounding box with keys: | |
| - "bbox": [x1, y1, x2, y2] | |
| - "points": list of (x, y) if polygonal coordinates exist | |
| - "id": line identifier (string) | |
| - "label": the type of element (e.g. TextLine) | |
| """ | |
| def _strip_ns(elem): | |
| for e in elem.iter(): | |
| if isinstance(e.tag, str) and e.tag.startswith("{"): | |
| e.tag = e.tag.split("}", 1)[1] | |
| root = ET.parse(io.BytesIO(xml_bytes)).getroot() | |
| _strip_ns(root) | |
| boxes = [] | |
| # ALTO format handling | |
| if root.tag.lower() == "alto": | |
| tag_map = {"block": "TextBlock", "line": "TextLine", "word": "String"} | |
| tag = tag_map.get(level, "TextLine") | |
| page_el = root.find(".//Page") | |
| page_w = page_h = None | |
| if page_el is not None: | |
| try: | |
| page_w = float(page_el.get("WIDTH") or 0) | |
| page_h = float(page_el.get("HEIGHT") or 0) | |
| except Exception: | |
| page_w = page_h = None | |
| sx = sy = 1.0 | |
| if image_size and page_w and page_h: | |
| img_w, img_h = image_size | |
| sx = (img_w / page_w) if page_w else 1.0 | |
| sy = (img_h / page_h) if page_h else 1.0 | |
| for el in root.findall(f".//{tag}"): | |
| poly = el.find(".//Shape/Polygon") | |
| got_box = False | |
| pts = None | |
| if poly is not None and poly.get("POINTS"): | |
| raw = poly.get("POINTS").strip() | |
| tokens = re.split(r"[ ,]+", raw) | |
| nums = [] | |
| for t in tokens: | |
| try: | |
| nums.append(float(t)) | |
| except Exception: | |
| pass | |
| pts = [] | |
| if len(nums) >= 6 and len(nums) % 2 == 0: | |
| for i in range(0, len(nums), 2): | |
| pts.append((nums[i] * sx, nums[i + 1] * sy)) | |
| if pts: | |
| xs = [p[0] for p in pts] | |
| ys = [p[1] for p in pts] | |
| x1, x2 = int(min(xs)), int(max(xs)) | |
| y1, y2 = int(min(ys)), int(max(ys)) | |
| got_box = (x2 > x1 and y2 > y1) | |
| if not got_box: | |
| try: | |
| hpos = float(el.get("HPOS", 0)) * sx | |
| vpos = float(el.get("VPOS", 0)) * sy | |
| width = float(el.get("WIDTH", 0)) * sx | |
| height = float(el.get("HEIGHT", 0)) * sy | |
| x1, y1 = int(hpos), int(vpos) | |
| x2, y2 = int(hpos + width), int(vpos + height) | |
| except Exception: | |
| continue | |
| if x2 <= x1 or y2 <= y1: | |
| continue | |
| label = tag if tag != "String" else (el.get("CONTENT") or "String") | |
| boxes.append( | |
| { | |
| "label": label, | |
| "bbox": [x1, y1, x2, y2], | |
| "source": "alto", | |
| "id": el.get("ID", ""), | |
| **({"points": pts} if pts else {}), | |
| } | |
| ) | |
| return boxes | |
| # PAGE XML handling | |
| for region in root.findall(".//TextRegion"): | |
| coords = region.find(".//Coords") | |
| pts_attr = coords.get("points") if coords is not None else None | |
| if not pts_attr: | |
| continue | |
| pts = [] | |
| for token in pts_attr.strip().split(): | |
| if "," in token: | |
| xx, yy = token.split(",", 1) | |
| try: | |
| pts.append((float(xx), float(yy))) | |
| except Exception: | |
| pass | |
| if not pts: | |
| continue | |
| xs = [p[0] for p in pts] | |
| ys = [p[1] for p in pts] | |
| x1, x2 = int(min(xs)), int(max(xs)) | |
| y1, y2 = int(min(ys)), int(max(ys)) | |
| if x2 > x1 and y2 > y1: | |
| boxes.append( | |
| { | |
| "label": "TextRegion", | |
| "bbox": [x1, y1, x2, y2], | |
| "source": "page", | |
| "id": region.get("id", ""), | |
| } | |
| ) | |
| if boxes: | |
| return boxes | |
| # Fallback: Pascal VOC | |
| for obj in root.findall(".//object"): | |
| bb = obj.find("bndbox") | |
| if bb is None: | |
| continue | |
| try: | |
| xmin = int(float(bb.findtext("xmin"))) | |
| ymin = int(float(bb.findtext("ymin"))) | |
| xmax = int(float(bb.findtext("xmax"))) | |
| ymax = int(float(bb.findtext("ymax"))) | |
| if xmax > xmin and ymax > ymin: | |
| boxes.append( | |
| { | |
| "label": (obj.findtext("name") or "region").strip(), | |
| "bbox": [xmin, ymin, xmax, ymax], | |
| "source": "voc", | |
| "id": obj.findtext("name") or "", | |
| } | |
| ) | |
| except Exception: | |
| pass | |
| return boxes | |
| def sort_boxes_reading_order(boxes, y_tol: int = 10): | |
| """Sort bounding boxes top‑to‑bottom then left‑to‑right.""" | |
| def key(b): | |
| x1, y1, x2, y2 = b["bbox"] | |
| return (round(y1 / max(1, y_tol)), y1, x1) | |
| return sorted(boxes, key=key) | |
| def draw_boxes(img: Image.Image, boxes): | |
| """Overlay semi‑transparent red polygons or rectangles on an image. | |
| Parameters | |
| ---------- | |
| img : PIL.Image | |
| The base image. | |
| boxes : list of dict | |
| Segmentation boxes with either 'points' or 'bbox' keys. | |
| Returns | |
| ------- | |
| PIL.Image | |
| An image with red overlays marking each box. Boxes are numbered | |
| starting from 1. | |
| """ | |
| base = img.convert("RGBA") | |
| overlay = Image.new("RGBA", base.size, (0, 0, 0, 0)) | |
| draw = ImageDraw.Draw(overlay) | |
| thickness = max(3, min(base.size) // 200) | |
| for i, b in enumerate(boxes, 1): | |
| if "points" in b and b["points"]: | |
| pts = [(int(x), int(y)) for x, y in b["points"]] | |
| draw.polygon(pts, outline=(255, 0, 0, 255), fill=(255, 0, 0, 64)) | |
| xs = [p[0] for p in pts] | |
| ys = [p[1] for p in pts] | |
| x1, y1 = min(xs), min(ys) | |
| else: | |
| x1, y1, x2, y2 = map(int, b["bbox"]) | |
| draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0, 255), width=thickness, fill=(255, 0, 0, 64)) | |
| tag_w, tag_h = 40, 24 | |
| draw.rectangle([x1, y1, x1 + tag_w, y1 + tag_h], fill=(255, 0, 0, 190)) | |
| draw.text((x1 + 6, y1 + 4), str(i), fill=(255, 255, 255, 255)) | |
| return Image.alpha_composite(base, overlay).convert("RGB") | |
| # ---------------------------------------------------------------------- | |
| # OCR inference per line | |
| # | |
| def predict_and_score_once(image: Image.Image, line_id: int = 1, topk: int = TOPK): | |
| """Run the model on a single cropped line and return predictions and scores. | |
| This helper wraps the model.generate call to obtain per‑token | |
| probabilities and derives a DataFrame summarizing each decoding step. | |
| Parameters | |
| ---------- | |
| image : PIL.Image | |
| Cropped segment to process. | |
| line_id : int, optional | |
| Identifier used in the output DataFrame. | |
| topk : int, optional | |
| Number of alternative tokens to keep for each decoding position. | |
| Returns | |
| ------- | |
| decoded_text : str | |
| Cleaned predicted string for the line. | |
| df : pandas.DataFrame | |
| Table with one row per generated token containing the following | |
| columns: line_id, seq_pos, token_id, token, confidence, | |
| rel_prob, entropy, gap12, alt_tokens, alt_probs. | |
| """ | |
| model, tokenizer, feature_extractor, device = load_model() | |
| img = prepare_image(image) | |
| pixel_values = feature_extractor(images=img, return_tensors="pt").pixel_values.to(device) | |
| amp_ctx = get_amp_ctx() | |
| with torch.inference_mode(), amp_ctx(): | |
| try: | |
| out = model.generate( | |
| pixel_values, | |
| max_length=100, | |
| num_beams=1, | |
| do_sample=False, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| use_cache=True, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| except RuntimeError as e: | |
| # In case of GPU OOM, fall back to beam=1 without scores | |
| if "out of memory" in str(e).lower(): | |
| out = model.generate( | |
| pixel_values, | |
| max_length=MAX_LEN, | |
| num_beams=1, | |
| do_sample=False, | |
| return_dict_in_generate=True, | |
| output_scores=False, | |
| use_cache=True, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| else: | |
| raise | |
| seq = out.sequences[0] | |
| decoded_text = clean_text(tokenizer.decode(seq, skip_special_tokens=True)) | |
| tokens_rows = [] | |
| # out.scores[i] gives logits for the i+1 token of seq | |
| for step, (logits, tgt) in enumerate(zip(out.scores, seq[1:]), start=1): | |
| probs = torch.softmax(logits[0].float().cpu(), dim=-1) | |
| tgt_id = int(tgt.item()) | |
| conf = float(probs[tgt_id].item()) | |
| tk_vals, tk_idx = torch.topk(probs, k=min(topk, probs.shape[0])) | |
| tk_idx = tk_idx.tolist() | |
| tk_vals = tk_vals.tolist() | |
| if tgt_id in tk_idx: | |
| j = tk_idx.index(tgt_id) | |
| tk_idx.pop(j) | |
| tk_vals.pop(j) | |
| alt_ids = [tgt_id] + tk_idx[: topk - 1] | |
| alt_ps = [conf] + tk_vals[: topk - 1] | |
| alt_tokens = [tokenizer.decode([i], skip_special_tokens=True) for i in alt_ids] | |
| entropy = float((-probs * (probs.clamp_min(1e-12).log())).sum().item()) | |
| gap12 = float(alt_ps[0] - (alt_ps[1] if len(alt_ps) > 1 else 0.0)) | |
| rel_prob = float((alt_ps[1] / alt_ps[0]) if (len(alt_ps) > 1 and alt_ps[0] > 0) else 0.0) | |
| tokens_rows.append( | |
| { | |
| "line_id": line_id, | |
| "seq_pos": step, | |
| "token_id": tgt_id, | |
| "token": alt_tokens[0], | |
| "confidence": conf, | |
| "rel_prob": rel_prob, | |
| "entropy": entropy, | |
| "gap12": gap12, | |
| "alt_tokens": "|".join(alt_tokens), | |
| "alt_probs": "|".join([f"{p:.6f}" for p in alt_ps]), | |
| } | |
| ) | |
| del probs | |
| df = pd.DataFrame( | |
| tokens_rows, | |
| columns=[ | |
| "line_id", | |
| "seq_pos", | |
| "token_id", | |
| "token", | |
| "confidence", | |
| "rel_prob", | |
| "entropy", | |
| "gap12", | |
| "alt_tokens", | |
| "alt_probs", | |
| ], | |
| ) | |
| return decoded_text, df | |
| # ---------------------------------------------------------------------- | |
| # Text splitting into aksharas (syllable units) for highlighting | |
| # | |
| # The following regex and helper functions split a Devanagari string into | |
| # aksharas. This is necessary to map model tokens back to spans of | |
| # characters when highlighting uncertain predictions. | |
| DEV_CONS = "\u0915-\u0939\u0958-\u095F\u0978-\u097F" # consonants incl. nukta variants range | |
| INDEP_VOW = "\u0904-\u0914" # independent vowels | |
| NUKTA = "\u093C" # nukta | |
| VIRAMA = "\u094D" # halant/virama | |
| MATRAS = "\u093A-\u094C" # dependent vowel signs | |
| BINDUS = "\u0901\u0902\u0903" # chandrabindu, anusvara, visarga | |
| AKSHARA_RE = re.compile( | |
| rf"(?:" | |
| rf"(?:[{DEV_CONS}]{NUKTA}?)(?:{VIRAMA}(?:[{DEV_CONS}]{NUKTA}?))*" # consonant cluster | |
| rf"(?:[{MATRAS}])?" # optional matra | |
| rf"(?:[{BINDUS}])?" # optional bindu/visarga | |
| rf"|" | |
| rf"(?:[{INDEP_VOW}](?:[{BINDUS}])?)" # independent vowel (+bindu) | |
| rf")", | |
| flags=re.UNICODE, | |
| ) | |
| def split_aksharas(s: str): | |
| """Split a string into Devanagari aksharas and return spans.""" | |
| spans = [] | |
| i = 0 | |
| while i < len(s): | |
| m = AKSHARA_RE.match(s, i) | |
| if m and m.end() > i: | |
| spans.append((m.start(), m.end())) | |
| i = m.end() | |
| else: | |
| spans.append((i, i + 1)) | |
| i += 1 | |
| return [s[a:b] for (a, b) in spans], spans | |
| def parse_alt_probs(s: str): | |
| try: | |
| return [float(x) for x in (s or "").split("|") if x != ""] | |
| except Exception: | |
| return [] | |
| def parse_alt_tokens(s: str): | |
| return [(t if t is not None else "") for t in (s or "").split("|")] | |
| def highlight_tokens_with_tooltips( | |
| line_text: str, df_tok: pd.DataFrame, red_threshold: float, metric_column: str | |
| ) -> str: | |
| """Insert HTML spans around tokens whose chosen metric exceeds threshold. | |
| The metric column can be "rel_prob" (relative probability) or | |
| "entropy". Tokens with a value strictly greater than red_threshold | |
| will be wrapped in a span with a tooltip listing alternative | |
| predictions and their probabilities. | |
| Parameters | |
| ---------- | |
| line_text : str | |
| The cleaned line prediction. | |
| df_tok : pandas.DataFrame | |
| DataFrame of token statistics for the corresponding line. | |
| red_threshold : float | |
| Values above this threshold will be highlighted. | |
| metric_column : str | |
| Column name in df_tok used for thresholding. | |
| Returns | |
| ------- | |
| str | |
| An HTML string with <span> elements inserted. | |
| """ | |
| aks, spans = split_aksharas(line_text) | |
| joined = "".join(aks) | |
| used_ranges = [] | |
| insertions = [] | |
| for _, row in df_tok.iterrows(): | |
| token = row.get("token", "").strip() | |
| try: | |
| val = float(row.get(metric_column, 0)) | |
| except Exception: | |
| continue | |
| if val <= red_threshold or not token: | |
| continue | |
| # Try finding the token in the joined akshara sequence | |
| start_char_idx = joined.find(token) | |
| if start_char_idx == -1: | |
| continue | |
| # Locate matching akshara span | |
| ak_start = ak_end = None | |
| cum_len = 0 | |
| for i, ak in enumerate(aks): | |
| next_len = cum_len + len(ak) | |
| if cum_len <= start_char_idx < next_len: | |
| ak_start = i | |
| if cum_len < start_char_idx + len(token) <= next_len: | |
| ak_end = i + 1 | |
| break | |
| cum_len = next_len | |
| if ak_start is None or ak_end is None: | |
| continue | |
| # Avoid overlaps | |
| if any(r[0] < ak_end and ak_start < r[1] for r in used_ranges): | |
| continue | |
| used_ranges.append((ak_start, ak_end)) | |
| # Character positions | |
| char_start = spans[ak_start][0] | |
| char_end = spans[ak_end - 1][1] | |
| # Build tooltip content | |
| alt_toks = row.get("alt_tokens", "").split("|") | |
| alt_probs = row.get("alt_probs", "").split("|") | |
| tooltip_lines = [] | |
| for t, p in zip(alt_toks, alt_probs): | |
| try: | |
| prob = float(p) | |
| except Exception: | |
| prob = 0.0 | |
| tooltip_lines.append(f"{_html_escape(t)}: {prob:.3f}") | |
| tooltip = "\n".join(tooltip_lines) | |
| token_str = _html_escape(line_text[char_start:char_end]) | |
| html_token = f"<span class='ocr-token' data-tooltip='{_html_escape(tooltip)}'>{token_str}</span>" | |
| insertions.append((char_start, char_end, html_token)) | |
| if not insertions: | |
| return _html_escape(line_text) | |
| insertions.sort() | |
| out_parts = [] | |
| last_idx = 0 | |
| for s, e, html_tok in insertions: | |
| out_parts.append(_html_escape(line_text[last_idx:s])) | |
| out_parts.append(html_tok) | |
| last_idx = e | |
| out_parts.append(_html_escape(line_text[last_idx:])) | |
| return "".join(out_parts) | |
| def _html_escape(s: str) -> str: | |
| return ( | |
| s.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| .replace("\"", """) | |
| .replace("'", "'") | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Main OCR wrapper for Gradio | |
| # | |
| def run_ocr( | |
| image: np.ndarray | None, | |
| xml_file: tuple | None, | |
| apply_gray: bool, | |
| apply_bin: bool, | |
| highlight_metric: str, | |
| ): | |
| """Run the OCR pipeline on user inputs and return results for Gradio. | |
| Parameters | |
| ---------- | |
| image : numpy.ndarray or None | |
| The uploaded image converted to a NumPy array by Gradio. If | |
| None, the function returns empty results. | |
| xml_file : tuple or None | |
| A tuple representing the uploaded XML file as provided by | |
| gr.File. The first element is the file name and the second is | |
| bytes. If None, no segmentation is applied and the entire | |
| image is processed as a single line. | |
| apply_gray : bool | |
| Whether to convert the image to grayscale before OCR. | |
| apply_bin : bool | |
| Whether to apply binarization (Otsu threshold) before OCR. If | |
| selected, grayscale conversion is applied first automatically. | |
| highlight_metric : str | |
| Which metric to use for highlighting ("Relative Probability" or | |
| "Entropy"). | |
| Returns | |
| ------- | |
| overlay_img : PIL.Image or None | |
| Image with segmentation boxes drawn. None if no input image. | |
| predictions_html : str | |
| HTML formatted predicted text with highlighted tokens. | |
| df_scores : pandas.DataFrame or None | |
| DataFrame of per‑token statistics. None if no input image. | |
| txt_file_path : str or None | |
| Path to a temporary .txt file containing the plain predicted text. | |
| csv_file_path : str or None | |
| Path to a temporary CSV file containing the extended token scores. | |
| """ | |
| if image is None: | |
| return None, "", None, None, None | |
| # Convert the numpy array to a PIL image | |
| pil_img = Image.fromarray(image).convert("RGB") | |
| # Apply preprocessing as requested | |
| if apply_gray: | |
| pil_img = pil_img.convert("L").convert("RGB") | |
| if apply_bin: | |
| img_cv = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2GRAY) | |
| _, bin_img = cv2.threshold(img_cv, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| pil_img = Image.fromarray(bin_img).convert("RGB") | |
| # Parse segmentation boxes if XML provided | |
| boxes: list = [] | |
| if xml_file: | |
| # Determine the correct way to extract bytes from the uploaded file. | |
| xml_bytes = None | |
| # If gr.File is configured with type="binary", xml_file will be raw bytes. | |
| if isinstance(xml_file, (bytes, bytearray)): | |
| xml_bytes = bytes(xml_file) | |
| # When type="filepath", xml_file would be a str path. | |
| elif isinstance(xml_file, str): | |
| try: | |
| with open(xml_file, "rb") as f: | |
| xml_bytes = f.read() | |
| except Exception: | |
| xml_bytes = None | |
| # If a temporary file object is passed in, read its contents. | |
| elif hasattr(xml_file, "read"): | |
| try: | |
| xml_bytes = xml_file.read() | |
| except Exception: | |
| xml_bytes = None | |
| # If xml_file is a dictionary from Gradio (not expected with type="binary"), | |
| # attempt to extract the data key. | |
| elif isinstance(xml_file, dict) and "data" in xml_file: | |
| xml_bytes = xml_file.get("data") | |
| if xml_bytes: | |
| try: | |
| boxes = parse_boxes_from_xml(xml_bytes, level="line", image_size=pil_img.size) | |
| boxes = sort_boxes_reading_order(boxes)[:MAX_LINES] | |
| except Exception: | |
| boxes = [] | |
| # Run OCR for each segmented line or the whole image | |
| dfs = [] | |
| concatenated_parts = [] | |
| line_text_by_id = {} | |
| if boxes: | |
| pad = 2 | |
| for idx, b in enumerate(boxes, 1): | |
| # Create a tight crop around the line | |
| if "points" in b: | |
| pts = b["points"] | |
| mask = Image.new("L", pil_img.size, 0) | |
| ImageDraw.Draw(mask).polygon(pts, outline=1, fill=255) | |
| seg_img = Image.new("RGB", pil_img.size, (255, 255, 255)) | |
| seg_img.paste(pil_img, mask=mask) | |
| xs = [x for x, y in pts] | |
| ys = [y for x, y in pts] | |
| x1 = max(0, int(min(xs) - pad)) | |
| y1 = max(0, int(min(ys) - pad)) | |
| x2 = min(pil_img.width, int(max(xs) + pad)) | |
| y2 = min(pil_img.height, int(max(ys) + pad)) | |
| crop = seg_img.crop((x1, y1, x2, y2)) | |
| else: | |
| x1, y1, x2, y2 = b["bbox"] | |
| x1p = max(0, x1 - pad) | |
| y1p = max(0, y1 - pad) | |
| x2p = min(pil_img.width, x2 + pad) | |
| y2p = min(pil_img.height, y2 + pad) | |
| crop = pil_img.crop((x1p, y1p, x2p, y2p)) | |
| # Run inference on the crop | |
| seg_text, df_tok = predict_and_score_once(crop, line_id=idx, topk=TOPK) | |
| seg_text = clean_text(seg_text) | |
| # Choose metric | |
| if highlight_metric == "Relative Probability": | |
| red_threshold = REL_PROB_TH | |
| metric_col = "rel_prob" | |
| else: | |
| red_threshold = 0.10 # heuristic threshold for entropy | |
| metric_col = "entropy" | |
| # Highlight uncertain tokens | |
| seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_tok, red_threshold, metric_col) | |
| concatenated_parts.append(seg_text_flagged) | |
| df_tok["line_id"] = idx | |
| dfs.append(df_tok) | |
| line_text_by_id[idx] = seg_text_flagged | |
| predicted_html = "<br>".join(concatenated_parts).strip() | |
| df_all = pd.concat(dfs, ignore_index=True) | |
| else: | |
| # Single pass on the whole image | |
| seg_text, df_all = predict_and_score_once(pil_img, line_id=1, topk=TOPK) | |
| seg_text = clean_text(seg_text) | |
| if highlight_metric == "Relative Probability": | |
| red_threshold = REL_PROB_TH | |
| metric_col = "rel_prob" | |
| else: | |
| red_threshold = 0.10 | |
| metric_col = "entropy" | |
| seg_text_flagged = highlight_tokens_with_tooltips(seg_text, df_all, red_threshold, metric_col) | |
| predicted_html = seg_text_flagged | |
| line_text_by_id[1] = seg_text_flagged | |
| # Draw overlay image | |
| overlay_img = draw_boxes(pil_img, boxes) if boxes else pil_img | |
| # Create downloads | |
| df_all = df_all.copy() | |
| # Drop the last empty token per line to tidy up output | |
| df_all.sort_values(["line_id", "seq_pos"], inplace=True) | |
| to_drop = [] | |
| for line_id, group in df_all.groupby("line_id"): | |
| if group.iloc[-1]["token"].strip() == "": | |
| to_drop.append(group.index[-1]) | |
| df_all = df_all.drop(index=to_drop) | |
| # Prepare plain text by stripping HTML tags and replacing <br> | |
| plain_text = re.sub(r"<[^>]*>", "", predicted_html.replace("<br>", "\n")) | |
| # Write temporary files | |
| # return overlay_img, predicted_html | |
| # Save plain text to a temporary .txt file | |
| txt_dir = tempfile.gettempdir() | |
| txt_path = os.path.join(txt_dir, "predictions.txt") | |
| with open(txt_path, "w", encoding="utf-8") as f: | |
| f.write(plain_text) | |
| return overlay_img, predicted_html, txt_path | |
| # ---------------------------------------------------------------------- | |
| # Build Gradio Interface | |
| # | |
| def create_gradio_interface(): | |
| """Create and return the Gradio Blocks interface.""" | |
| with gr.Blocks(title="Handwritten Text Recognition (Old Nepali)") as demo: | |
| gr.Markdown("""# Handwritten Text Recognition (Old Nepali) \n\nUpload an image and (optionally) a segmentation XML file. Then click **Run OCR** to extract the text.""") | |
| gr.HTML(""" | |
| <style> | |
| #prediction-box { | |
| border: 1px solid #ccc; | |
| padding: 16px; | |
| border-radius: 8px; | |
| background-color: #f9f9f9; | |
| font-size: 18px; | |
| line-height: 1.6; | |
| min-height: 100px; | |
| } | |
| } | |
| </style> | |
| """) | |
| with gr.Row(): | |
| image_input = gr.Image(type="numpy", label="Upload Image") | |
| # When used as an input, gr.File returns either a file path or bytes | |
| # depending on the `type` parameter. By setting type="binary" we | |
| # ensure that the XML content is passed directly as bytes to the | |
| # callback, avoiding the need to reopen a temporary file. | |
| xml_input = gr.File( | |
| label="Upload segmentation XML (optional)", | |
| file_count="single", | |
| type="binary", | |
| file_types=[".xml"], | |
| ) | |
| # with gr.Row(): | |
| # apply_gray_checkbox = gr.Checkbox(label="Convert to Grayscale", value=False) | |
| # apply_bin_checkbox = gr.Checkbox(label="Binarize", value=False) | |
| # metric_radio = gr.Radio([ | |
| # "Relative Probability", | |
| # "Entropy", | |
| # ], label="Highlight tokens by", value="Relative Probability") | |
| run_btn = gr.Button("Run OCR") | |
| # Outputs | |
| # overlay_output = gr.Image(label="Detected Regions") | |
| # # predictions_output = gr.HTML(label="Predictions (HTML)") | |
| # predictions_output = gr.HTML( | |
| # label="Predictions (HTML)", | |
| # elem_id="prediction-box" | |
| # ) | |
| # df_output = gr.DataFrame(label="Token Scores", interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| overlay_output = gr.Image(label="Detected Regions") | |
| with gr.Column(scale=2): | |
| predictions_output = gr.HTML( | |
| label="Predictions (HTML)", | |
| elem_id="prediction-box" | |
| ) | |
| # df_output = gr.DataFrame(label="Token Scores", interactive=False) | |
| # txt_file_output = gr.File(label="Download OCR Prediction (.txt)") | |
| # csv_file_output = gr.File(label="Download Token Scores (.csv)") | |
| # Editable text | |
| edited_text = gr.Textbox( | |
| label="Edit full predicted text", lines=8, interactive=True | |
| ) | |
| # download_edited_btn = gr.Button("Download edited text") | |
| txt_file_output = gr.File(label="Download OCR Prediction (.txt)") | |
| # Callback for OCR | |
| def on_run(image, xml): | |
| return run_ocr(image, xml, False, False, "Relative Probability") | |
| run_btn.click( | |
| fn=on_run, | |
| # inputs=[image_input, xml_input, apply_gray_checkbox, apply_bin_checkbox, metric_radio], | |
| inputs=[image_input, xml_input], | |
| outputs=[overlay_output, predictions_output, txt_file_output], | |
| ) | |
| # Populate editable text with plain text from predictions | |
| def update_edited_text(pred_html): | |
| plain = re.sub(r"<[^>]*>", "", (pred_html or "").replace("<br>", "\n")) | |
| return plain | |
| predictions_output.change( | |
| fn=update_edited_text, | |
| inputs=predictions_output, | |
| outputs=edited_text, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| # Create and launch the Gradio interface | |
| iface = create_gradio_interface() | |
| iface.launch() |