Spaces:
Build error
Build error
| from typing import Tuple, List, Sequence, Optional, Union | |
| from torchvision import transforms | |
| from torch import nn, Tensor | |
| from PIL import Image | |
| from pathlib import Path | |
| from bs4 import BeautifulSoup as bs | |
| import numpy as np | |
| import numpy.typing as npt | |
| from numpy import uint8 | |
| ImageType = npt.NDArray[uint8] | |
| from transformers import AutoModelForObjectDetection | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from matplotlib.patches import Patch | |
| from unitable import UnitablePredictor | |
| from doctrfiles import DoctrWordDetector,DoctrTextRecognizer | |
| from utils import crop_an_Image,cropImageExtraMargin | |
| from utils import denoisingAndSharpening | |
| #based on this notebook:https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Inference_with_Table_Transformer_(TATR)_for_parsing_tables.ipynb | |
| class MaxResize(object): | |
| def __init__(self, max_size=800): | |
| self.max_size = max_size | |
| def __call__(self, image): | |
| width, height = image.size | |
| current_max_size = max(width, height) | |
| scale = self.max_size / current_max_size | |
| resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) | |
| return resized_image | |
| html_table_template = ( | |
| lambda table: f"""<html> | |
| <head> <meta charset="UTF-8"> | |
| <style> | |
| table, th, td {{ | |
| border: 1px solid black; | |
| font-size: 10px; | |
| }} | |
| </style> </head> | |
| <body> | |
| <table frame="hsides" rules="groups" width="100%%"> | |
| {table} | |
| </table> </body> </html>""" | |
| ) | |
| class DetectionAndOcrTable1(): | |
| def __init__(self,englishFlag=True): | |
| self.unitablePredictor = UnitablePredictor() | |
| self.wordDetector = DoctrWordDetector(architecture="db_resnet50", | |
| path_weights="doctrfiles/models/db_resnet50-79bd7d70.pt", | |
| path_config_json ="doctrfiles/models/db_resnet50_config.json") | |
| if englishFlag: | |
| self.textRecognizer = DoctrTextRecognizer(architecture="master", path_weights="./doctrfiles/models/master-fde31e4a.pt", | |
| path_config_json="./doctrfiles/models/master.json") | |
| else: | |
| self.textRecognizer = DoctrTextRecognizer(architecture="parseq", path_weights="./doctrfiles/models/doctr-multilingual-parseq.bin", | |
| path_config_json="./doctrfiles/models/multilingual-parseq-config.json") | |
| def build_table_from_html_and_cell( | |
| structure: List[str], content: List[str] = None | |
| ) -> List[str]: | |
| """Build table from html and cell token list""" | |
| assert structure is not None | |
| html_code = list() | |
| # deal with empty table | |
| if content is None: | |
| content = ["placeholder"] * len(structure) | |
| for tag in structure: | |
| if tag in ("<td>[]</td>", ">[]</td>"): | |
| if len(content) == 0: | |
| continue | |
| cell = content.pop(0) | |
| html_code.append(tag.replace("[]", cell)) | |
| else: | |
| html_code.append(tag) | |
| return html_code | |
| def save_detection(detected_lines_images:List[ImageType], prefix = './res/test1/res_'): | |
| i = 0 | |
| for img in detected_lines_images: | |
| pilimg = Image.fromarray(img) | |
| pilimg.save(prefix+str(i)+'.png') | |
| i=i+1 | |
| # for output bounding box post-processing | |
| def box_cxcywh_to_xyxy(x): | |
| x_c, y_c, w, h = x.unbind(-1) | |
| b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] | |
| return torch.stack(b, dim=1) | |
| def rescale_bboxes(out_bbox, size): | |
| img_w, img_h = size | |
| b = DetectionAndOcrTable1.box_cxcywh_to_xyxy(out_bbox) | |
| b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) | |
| return b | |
| def outputs_to_objects(outputs, img_size, id2label): | |
| m = outputs.logits.softmax(-1).max(-1) | |
| pred_labels = list(m.indices.detach().cpu().numpy())[0] | |
| pred_scores = list(m.values.detach().cpu().numpy())[0] | |
| pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] | |
| pred_bboxes = [elem.tolist() for elem in DetectionAndOcrTable1.rescale_bboxes(pred_bboxes, img_size)] | |
| objects = [] | |
| for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): | |
| class_label = id2label[int(label)] | |
| if not class_label == 'no object': | |
| objects.append({'label': class_label, 'score': float(score), | |
| 'bbox': [float(elem) for elem in bbox]}) | |
| return objects | |
| def fig2img(fig): | |
| """Convert a Matplotlib figure to a PIL Image and return it""" | |
| import io | |
| buf = io.BytesIO() | |
| fig.savefig(buf) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| #For that, the TATR authors employ some padding to make sure the borders of the table are included. | |
| def objects_to_crops(img, tokens, objects, class_thresholds, padding=10): | |
| """ | |
| Process the bounding boxes produced by the table detection model into | |
| cropped table images and cropped tokens. | |
| """ | |
| table_crops = [] | |
| for obj in objects: | |
| # abit unecessary here cause i crop them anywyas | |
| if obj['score'] < class_thresholds[obj['label']]: | |
| continue | |
| cropped_table = {} | |
| bbox = obj['bbox'] | |
| bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding] | |
| cropped_img = img.crop(bbox) | |
| # Add padding to the cropped image | |
| padded_width = cropped_img.width + 40 | |
| padded_height = cropped_img.height +40 | |
| new_img_np = np.full((padded_height, padded_width, 3), fill_value=255, dtype=np.uint8) | |
| y_offset = (padded_height - cropped_img.height) // 2 | |
| x_offset = (padded_width - cropped_img.width) // 2 | |
| new_img_np[y_offset:y_offset + cropped_img.height, x_offset:x_offset+cropped_img.width] = np.array(cropped_img) | |
| padded_img = Image.fromarray(new_img_np,'RGB') | |
| table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5] | |
| for token in table_tokens: | |
| token['bbox'] = [token['bbox'][0]-bbox[0] + padding, | |
| token['bbox'][1]-bbox[1] + padding, | |
| token['bbox'][2]-bbox[0] + padding, | |
| token['bbox'][3]-bbox[1] + padding] | |
| # If table is predicted to be rotated, rotate cropped image and tokens/words: | |
| if obj['label'] == 'table rotated': | |
| padded_img = padded_img.rotate(270, expand=True) | |
| for token in table_tokens: | |
| bbox = token['bbox'] | |
| bbox = [padded_img.size[0]-bbox[3]-1, | |
| bbox[0], | |
| padded_img.size[0]-bbox[1]-1, | |
| bbox[2]] | |
| token['bbox'] = bbox | |
| cropped_table['image'] = padded_img | |
| cropped_table['tokens'] = table_tokens | |
| table_crops.append(cropped_table) | |
| return table_crops | |
| def visualize_detected_tables(img, det_tables, out_path=None): | |
| plt.imshow(img, interpolation="lanczos") | |
| fig = plt.gcf() | |
| fig.set_size_inches(20, 20) | |
| ax = plt.gca() | |
| for det_table in det_tables: | |
| bbox = det_table['bbox'] | |
| if det_table['label'] == 'table': | |
| facecolor = (1, 0, 0.45) | |
| edgecolor = (1, 0, 0.45) | |
| alpha = 0.3 | |
| linewidth = 2 | |
| hatch='//////' | |
| elif det_table['label'] == 'table rotated': | |
| facecolor = (0.95, 0.6, 0.1) | |
| edgecolor = (0.95, 0.6, 0.1) | |
| alpha = 0.3 | |
| linewidth = 2 | |
| hatch='//////' | |
| else: | |
| continue | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
| edgecolor='none',facecolor=facecolor, alpha=0.1) | |
| ax.add_patch(rect) | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=linewidth, | |
| edgecolor=edgecolor,facecolor='none',linestyle='-', alpha=alpha) | |
| ax.add_patch(rect) | |
| rect = patches.Rectangle(bbox[:2], bbox[2]-bbox[0], bbox[3]-bbox[1], linewidth=0, | |
| edgecolor=edgecolor,facecolor='none',linestyle='-', hatch=hatch, alpha=0.2) | |
| ax.add_patch(rect) | |
| plt.xticks([], []) | |
| plt.yticks([], []) | |
| legend_elements = [Patch(facecolor=(1, 0, 0.45), edgecolor=(1, 0, 0.45), | |
| label='Table', hatch='//////', alpha=0.3), | |
| Patch(facecolor=(0.95, 0.6, 0.1), edgecolor=(0.95, 0.6, 0.1), | |
| label='Table (rotated)', hatch='//////', alpha=0.3)] | |
| plt.legend(handles=legend_elements, bbox_to_anchor=(0.5, -0.02), loc='upper center', borderaxespad=0, | |
| fontsize=10, ncol=2) | |
| plt.gcf().set_size_inches(10, 10) | |
| plt.axis('off') | |
| if out_path is not None: | |
| plt.savefig(out_path, bbox_inches='tight', dpi=150) | |
| return fig | |
| def predict(self,image:Image.Image,debugfolder_filename_page_name,denoise=False): | |
| """ | |
| 0. Locate the table using Table detection | |
| 1. Unitable | |
| """ | |
| print("Running table transformer + Unitable Hybrid Model") | |
| # Step 0 : Locate the table using Table detection TODO | |
| #First we load a Table Transformer pre-trained for table detection. We use the "no_timm" version here to load the checkpoint with a Transformers-native backbone. | |
| model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| #Preparing the image for the model | |
| detection_transform = transforms.Compose([ | |
| MaxResize(800), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| pixel_values = detection_transform(image).unsqueeze(0) | |
| pixel_values = pixel_values.to(device) | |
| # Next, we forward the pixel values through the model. | |
| # The model outputs logits of shape (batch_size, num_queries, num_labels + 1). The +1 is for the "no object" class. | |
| with torch.no_grad(): | |
| outputs = model(pixel_values) | |
| # update id2label to include "no object" | |
| id2label = model.config.id2label | |
| id2label[len(model.config.id2label)] = "no object" | |
| #[{'label': 'table', 'score': 0.9999570846557617, 'bbox': [110.24547576904297, 73.31171417236328, 1024.609130859375, 308.7159423828125]}] | |
| objects = DetectionAndOcrTable1.outputs_to_objects(outputs, image.size, id2label) | |
| #Only do these for objects with score greater than 0.8 | |
| objects = [obj for obj in objects if obj['score'] > 0.95] | |
| print("detected object from the table transformers are") | |
| print(objects) | |
| if objects: | |
| #Next, we crop the table out of the image. For that, the TATR authors employ some padding to make sure the borders of the table are included. | |
| tokens = [] | |
| detection_class_thresholds = { | |
| "table": 0.95, #this is a bit double cause we do up there another filtering but didn't want to modify too much from original code | |
| "table rotated": 0.95, | |
| "no object": 10 | |
| } | |
| crop_padding = 10 | |
| tables_crops = DetectionAndOcrTable1.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding) | |
| cropped_tables =[] | |
| for i in range (len(tables_crops)): | |
| cropped_table = tables_crops[i]['image'].convert("RGB") | |
| cropped_table.save(debugfolder_filename_page_name+"cropped_table_"+str(i)+".png") | |
| cropped_tables.append(cropped_table) | |
| # Step 1: Unitable | |
| #This take PIL Images as input | |
| if denoise: | |
| cropped_tables =denoisingAndSharpening(cropped_tables) | |
| pred_htmls, pred_bboxs = self.unitablePredictor.predict(cropped_tables,debugfolder_filename_page_name) | |
| table_codes = [] | |
| for k in range(len(cropped_tables)): | |
| pred_html =pred_htmls[k] | |
| pred_bbox = pred_bboxs[k] | |
| # Some tabless have a lot of words in their header | |
| # So for the headers, give doctr word ddetector doesn't work when the images aren't square | |
| table_header_cells = 0 | |
| header_exists = False | |
| for cell in pred_html: | |
| if cell=='>[]</td>' or cell == '<td>[]</td>': | |
| table_header_cells += 1 | |
| if cell =='</thead>': | |
| header_exists = True | |
| break | |
| if not header_exists: | |
| table_header_cells = 0 | |
| pred_cell = [] | |
| cell_imgs_to_viz = [] | |
| cell_img_num=0 | |
| # Find what one line should be if there is a cell with a single line | |
| one_line_height = 100000 | |
| for i in range(table_header_cells): | |
| box = pred_bbox[i] | |
| xmin, ymin, xmax, ymax = box | |
| current_box_height = abs(ymax-ymin) | |
| if current_box_height<one_line_height: | |
| one_line_height = current_box_height | |
| for box in pred_bbox: | |
| xmin, ymin, xmax, ymax = box | |
| fourbytwo = np.array([ | |
| [xmin, ymin], | |
| [xmax, ymin], | |
| [xmax, ymax], | |
| [xmin, ymax] | |
| ], dtype=np.float32) | |
| current_box_height = abs(ymax-ymin) | |
| # Those are for header cells with more than one line | |
| if table_header_cells > 0 and current_box_height>one_line_height+5: | |
| cell_img= cropImageExtraMargin([fourbytwo],cropped_tables[k],margin=1.4)[0] | |
| table_header_cells -= 1 | |
| #List of 4 x 2 | |
| detection_results = self.wordDetector.predict(cell_img,sort_vertical=True) | |
| input_to_recog = [] | |
| if detection_results == []: | |
| input_to_recog.append(cell_img) | |
| else: | |
| for wordbox in detection_results: | |
| cropped_image= crop_an_Image(wordbox.box,cell_img) | |
| if cropped_image.shape[0] >0 and cropped_image.shape[1]>0: | |
| input_to_recog.append(cropped_image) | |
| else: | |
| print("Empty image") | |
| else: | |
| cell_img = crop_an_Image(fourbytwo,cropped_tables[k]) | |
| if table_header_cells>0: | |
| table_header_cells -= 1 | |
| if cell_img.shape[0] >0 and cell_img.shape[1]>0: | |
| input_to_recog =[cell_img] | |
| cell_imgs_to_viz.append(cell_img) | |
| if input_to_recog != []: | |
| words = self.textRecognizer.predict_for_tables(input_to_recog) | |
| cell_output = " ".join(words) | |
| pred_cell.append(cell_output) | |
| else: | |
| #Don't lose empty cell | |
| pred_cell.append("") | |
| print(pred_cell) | |
| #Step3 : | |
| pred_code = self.build_table_from_html_and_cell(pred_html, pred_cell) | |
| pred_code = "".join(pred_code) | |
| pred_code = html_table_template(pred_code) | |
| soup = bs(pred_code) | |
| #formatted and indented) string representation of the HTML document | |
| table_code = soup.prettify() | |
| print(table_code) | |
| # Append extracted table to table_codes | |
| table_codes.append(table_code) | |
| return table_codes | |