| import json | |
| import os | |
| from copy import deepcopy | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as T | |
| from FlagEmbedding import BGEM3FlagModel | |
| from marker.config.parser import ConfigParser | |
| from marker.converters.pdf import PdfConverter | |
| from marker.output import text_from_rendered | |
| from PIL import Image | |
| from torchvision.transforms.functional import InterpolationMode | |
| from transformers import AutoFeatureExtractor, AutoModel | |
| from utils.src.presentation import Presentation, SlidePage | |
| from utils.src.utils import is_image_path, pjoin | |
| device_count = torch.cuda.device_count() | |
| def prs_dedup( | |
| presentation: Presentation, | |
| model: BGEM3FlagModel, | |
| batchsize: int = 32, | |
| threshold: float = 0.8, | |
| ) -> list[SlidePage]: | |
| """ | |
| Deduplicate slides in a presentation based on text similarity. | |
| Args: | |
| presentation (Presentation): The presentation object containing slides. | |
| model: The model used for generating text embeddings. | |
| batchsize (int): The batch size for processing slides. | |
| threshold (float): The similarity threshold for deduplication. | |
| Returns: | |
| list: A list of removed duplicate slides. | |
| """ | |
| text_embeddings = get_text_embedding( | |
| [i.to_text() for i in presentation.slides], model, batchsize | |
| ) | |
| pre_embedding = text_embeddings[0] | |
| slide_idx = 1 | |
| duplicates = [] | |
| while slide_idx < len(presentation): | |
| cur_embedding = text_embeddings[slide_idx] | |
| if torch.cosine_similarity(pre_embedding, cur_embedding, -1) > threshold: | |
| duplicates.append(slide_idx - 1) | |
| slide_idx += 1 | |
| pre_embedding = cur_embedding | |
| return [presentation.slides.pop(i) for i in reversed(duplicates)] | |
| def get_text_model(device: str = None) -> BGEM3FlagModel: | |
| """ | |
| Initialize and return a text model. | |
| Args: | |
| device (str): The device to run the model on. | |
| Returns: | |
| BGEM3FlagModel: The initialized text model. | |
| """ | |
| return BGEM3FlagModel( | |
| "BAAI/bge-m3", | |
| use_fp16=True, | |
| device=device, | |
| ) | |
| def get_image_model(device: str = None): | |
| """ | |
| Initialize and return an image model and its feature extractor. | |
| Args: | |
| device (str): The device to run the model on. | |
| Returns: | |
| tuple: A tuple containing the feature extractor and the image model. | |
| """ | |
| model_base = "google/vit-base-patch16-224-in21k" | |
| return ( | |
| AutoFeatureExtractor.from_pretrained( | |
| model_base, | |
| torch_dtype=torch.float16, | |
| device_map=device, | |
| ), | |
| AutoModel.from_pretrained( | |
| model_base, | |
| torch_dtype=torch.float16, | |
| device_map=device, | |
| ).eval(), | |
| ) | |
| def parse_pdf( | |
| pdf_path: str, | |
| output_path: str = None, | |
| model_lst: list = None, | |
| save_file: bool = True, | |
| ) -> str: | |
| """ | |
| Parse a PDF file and extract text and images. | |
| Args: | |
| pdf_path (str): The path to the PDF file. | |
| output_path (str): The directory to save the extracted content. | |
| model_lst (list): A list of models for processing the PDF. | |
| Returns: | |
| str: The full text extracted from the PDF. | |
| """ | |
| if save_file: | |
| os.makedirs(output_path, exist_ok=True) | |
| config_parser = ConfigParser( | |
| { | |
| "output_format": "markdown", | |
| } | |
| ) | |
| converter = PdfConverter( | |
| config=config_parser.generate_config_dict(), | |
| artifact_dict=model_lst, | |
| processor_list=config_parser.get_processors(), | |
| renderer=config_parser.get_renderer(), | |
| ) | |
| rendered = converter(pdf_path) | |
| full_text, _, images = text_from_rendered(rendered) | |
| if save_file: | |
| with open(pjoin(output_path, "source.md"), "w+", encoding="utf-8") as f: | |
| f.write(full_text) | |
| for filename, image in images.items(): | |
| image_filepath = os.path.join(output_path, filename) | |
| image.save(image_filepath, "JPEG") | |
| with open(pjoin(output_path, "meta.json"), "w+") as f: | |
| f.write(json.dumps(rendered.metadata, indent=4)) | |
| if not save_file: | |
| return full_text, rendered | |
| return full_text | |
| def get_text_embedding( | |
| text: list[str], model: BGEM3FlagModel, batchsize: int = 32 | |
| ) -> list[torch.Tensor]: | |
| """ | |
| Generate text embeddings for a list of text strings. | |
| Args: | |
| text (list[str]): A list of text strings. | |
| model: The model used for generating embeddings. | |
| batchsize (int): The batch size for processing text. | |
| Returns: | |
| list: A list of text embeddings. | |
| """ | |
| if isinstance(text, str): | |
| return torch.tensor(model.encode(text)["dense_vecs"]).to(model.device) | |
| result = [] | |
| for i in range(0, len(text), batchsize): | |
| result.extend( | |
| torch.tensor(model.encode(text[i : i + batchsize])["dense_vecs"]).to( | |
| model.device | |
| ) | |
| ) | |
| return result | |
| def get_image_embedding( | |
| image_dir: str, extractor, model, batchsize: int = 16 | |
| ) -> dict[str, torch.Tensor]: | |
| """ | |
| Generate image embeddings for images in a directory. | |
| Args: | |
| image_dir (str): The directory containing images. | |
| extractor: The feature extractor for images. | |
| model: The model used for generating embeddings. | |
| batchsize (int): The batch size for processing images. | |
| Returns: | |
| dict: A dictionary mapping image filenames to their embeddings. | |
| """ | |
| transform = T.Compose( | |
| [ | |
| T.Resize(int((256 / 224) * extractor.size["height"])), | |
| T.CenterCrop(extractor.size["height"]), | |
| T.ToTensor(), | |
| T.Normalize(mean=extractor.image_mean, std=extractor.image_std), | |
| ] | |
| ) | |
| inputs = [] | |
| embeddings = [] | |
| images = [i for i in sorted(os.listdir(image_dir)) if is_image_path(i)] | |
| for file in images: | |
| image = Image.open(pjoin(image_dir, file)).convert("RGB") | |
| inputs.append(transform(image)) | |
| if len(inputs) % batchsize == 0 or file == images[-1]: | |
| batch = {"pixel_values": torch.stack(inputs).to(model.device)} | |
| embeddings.extend(model(**batch).last_hidden_state.detach()) | |
| inputs.clear() | |
| return {image: embedding.flatten() for image, embedding in zip(images, embeddings)} | |
| def images_cosine_similarity(embeddings: list[torch.Tensor]) -> torch.Tensor: | |
| """ | |
| Calculate the cosine similarity matrix for a list of embeddings. | |
| Args: | |
| embeddings (list[torch.Tensor]): A list of image embeddings. | |
| Returns: | |
| torch.Tensor: A NxN similarity matrix. | |
| """ | |
| embeddings = [embedding for embedding in embeddings] | |
| sim_matrix = torch.zeros((len(embeddings), len(embeddings))) | |
| for i in range(len(embeddings)): | |
| for j in range(i + 1, len(embeddings)): | |
| sim_matrix[i, j] = sim_matrix[j, i] = torch.cosine_similarity( | |
| embeddings[i], embeddings[j], -1 | |
| ) | |
| return sim_matrix | |
| IMAGENET_MEAN = (0.485, 0.456, 0.406) | |
| IMAGENET_STD = (0.229, 0.224, 0.225) | |
| def average_distance( | |
| similarity: torch.Tensor, idx: int, cluster_idx: list[int] | |
| ) -> float: | |
| """ | |
| Calculate the average distance between a point (idx) and a cluster (cluster_idx). | |
| Args: | |
| similarity (np.ndarray): The similarity matrix. | |
| idx (int): The index of the point. | |
| cluster_idx (list): The indices of the cluster. | |
| Returns: | |
| float: The average distance. | |
| """ | |
| if idx in cluster_idx: | |
| return 0 | |
| total_similarity = 0 | |
| for idx_in_cluster in cluster_idx: | |
| total_similarity += similarity[idx, idx_in_cluster] | |
| return total_similarity / len(cluster_idx) | |
| def get_cluster(similarity: np.ndarray, sim_bound: float = 0.65): | |
| """ | |
| Cluster points based on similarity. | |
| Args: | |
| similarity (np.ndarray): The similarity matrix. | |
| sim_bound (float): The similarity threshold for clustering. | |
| Returns: | |
| list: A list of clusters. | |
| """ | |
| num_points = similarity.shape[0] | |
| clusters = [] | |
| sim_copy = deepcopy(similarity) | |
| added = [False] * num_points | |
| while True: | |
| max_avg_dist = sim_bound | |
| best_cluster = None | |
| best_point = None | |
| for c in clusters: | |
| for point_idx in range(num_points): | |
| if added[point_idx]: | |
| continue | |
| avg_dist = average_distance(sim_copy, point_idx, c) | |
| if avg_dist > max_avg_dist: | |
| max_avg_dist = avg_dist | |
| best_cluster = c | |
| best_point = point_idx | |
| if best_point is not None: | |
| best_cluster.append(best_point) | |
| added[best_point] = True | |
| similarity[best_point, :] = 0 | |
| similarity[:, best_point] = 0 | |
| else: | |
| if similarity.max() < sim_bound: | |
| break | |
| i, j = np.unravel_index(np.argmax(similarity), similarity.shape) | |
| clusters.append([int(i), int(j)]) | |
| added[i] = True | |
| added[j] = True | |
| similarity[i, :] = 0 | |
| similarity[:, i] = 0 | |
| similarity[j, :] = 0 | |
| similarity[:, j] = 0 | |
| return clusters | |