Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import json | |
| import torch | |
| from PIL import Image | |
| from google.cloud import vision | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| # ====== Set Google Credential ====== | |
| os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "eastern-entity-450514-u2-c9949243357a.json" | |
| # ====== Load Vision API and Model ====== | |
| client = vision.ImageAnnotatorClient() | |
| model_name = "bashyaldhiraj2067/100epoch_test_march19" | |
| tokenizer = AutoTokenizer.from_pretrained("nielsr/lilt-xlm-roberta-base") | |
| model = AutoModelForTokenClassification.from_pretrained(model_name) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| # ====== Labels ====== | |
| labels = [ | |
| "B-O", "I-O", | |
| "B-CITIZENSHIP_CERTIFICATE_NO", "I-CITIZENSHIP_CERTIFICATE_NO", | |
| "B-FULL_NAME", "I-FULL_NAME", | |
| "B-GENDER", "I-GENDER", | |
| "B-BIRTH_YEAR", "B-BIRTH_MONTH", "B-BIRTH_DAY", | |
| "I-BIRTH_MONTH", "B-DISTRICT", "B-MUNCIPALITY", "I-MUNCIPALITY", | |
| "B-WARD_NO", "I-WARD_NO", | |
| "B-FATHERS_NAME", "I-FATHERS_NAME", | |
| "B-MOTHERS_NAME", "I-MOTHERS_NAME", | |
| "I-BIRTH_YEAR", "I-DISTRICT", "I-BIRTH_DAY" | |
| ] | |
| # ====== Normalize Bounding Boxes ====== | |
| def normalized_boxes(bbox, width, height): | |
| x_min, y_min, x_max, y_max = bbox | |
| return [ | |
| int((x_min / width) * 1000), | |
| int((y_min / height) * 1000), | |
| int((x_max / width) * 1000), | |
| int((y_max / height) * 1000), | |
| ] | |
| # ====== OCR ====== | |
| def extract_text_and_boxes(image): | |
| width, height = image.size | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| content = img_byte_arr.getvalue() | |
| image_vision = vision.Image(content=content) | |
| response = client.document_text_detection(image=image_vision) | |
| words, box_list = [], [] | |
| for page in response.full_text_annotation.pages: | |
| for block in page.blocks: | |
| for paragraph in block.paragraphs: | |
| for word in paragraph.words: | |
| word_text = ''.join([s.text for s in word.symbols]) | |
| words.append(word_text) | |
| x_min = min([v.x for v in word.bounding_box.vertices]) | |
| y_min = min([v.y for v in word.bounding_box.vertices]) | |
| x_max = max([v.x for v in word.bounding_box.vertices]) | |
| y_max = max([v.y for v in word.bounding_box.vertices]) | |
| box = normalized_boxes([x_min, y_min, x_max, y_max], width, height) | |
| box_list.append(box) | |
| return words, box_list | |
| # ====== Inference & Entity Extraction ====== | |
| def predict_text(image): | |
| words, norm_boxes = extract_text_and_boxes(image) | |
| if not words: | |
| return json.dumps({"error": "No text detected"}, ensure_ascii=False) | |
| encoding = tokenizer(" ".join(words), truncation=True, max_length=512, return_tensors="pt") | |
| encoding = {k: v.to(device).long() if k == "input_ids" else v.to(device) for k, v in encoding.items()} | |
| token_boxes = [] | |
| for word, box in zip(words, norm_boxes): | |
| word_tokens = tokenizer.tokenize(word) | |
| token_boxes.extend([box] * len(word_tokens)) | |
| cls_box = [0, 0, 0, 0] | |
| sep_box = [0, 0, 0, 0] | |
| token_boxes = [cls_box] + token_boxes[:510] + [sep_box] | |
| encoding["bbox"] = torch.tensor([token_boxes[:len(encoding["input_ids"][0])]]).to(device).long() | |
| with torch.no_grad(): | |
| outputs = model(**encoding) | |
| predictions = torch.argmax(outputs.logits, dim=-1).squeeze().tolist() | |
| tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0]) | |
| label_output = [labels[idx] for idx in predictions] | |
| # ====== Parse Entities ====== | |
| # ====== Parse Entities ====== | |
| entity_map = { | |
| "CITIZENSHIP_CERTIFICATE_NO": "", | |
| "FULL_NAME": "", | |
| "GENDER": "", | |
| "DISTRICT": "", | |
| "MUNCIPALITY": "", | |
| "BIRTH_YEAR": "", | |
| "BIRTH_MONTH": "", | |
| "BIRTH_DAY": "", | |
| "WARD_NO": "", | |
| "FATHERS_NAME": "", | |
| "MOTHERS_NAME": "" | |
| } | |
| current_label = None | |
| for token, label in zip(tokens, label_output): | |
| if token in ["[CLS]", "[SEP]"]: | |
| continue | |
| token = token.replace("▁", "").strip() | |
| if not token: | |
| continue | |
| if label.startswith("B-"): | |
| ent_type = label[2:] | |
| if ent_type in entity_map: | |
| entity_map[ent_type] += (" " if entity_map[ent_type] else "") + token | |
| current_label = ent_type | |
| else: | |
| current_label = None | |
| elif label.startswith("I-"): | |
| ent_type = label[2:] | |
| if current_label == ent_type and ent_type in entity_map: | |
| entity_map[ent_type] += " " + token | |
| else: | |
| current_label = None | |
| else: | |
| current_label = None | |
| # Strip extra spaces | |
| for k in entity_map: | |
| entity_map[k] = entity_map[k].strip() | |
| return json.dumps(entity_map, ensure_ascii=False, indent=2) | |
| # ====== Gradio Interface ====== | |
| import gradio as gr | |
| gr.Interface(fn=predict_text, inputs=gr.Image(type="pil"), outputs="text").launch() | |