import json import os from copy import deepcopy import numpy as np import torch import torchvision.transforms as T from FlagEmbedding import BGEM3FlagModel from marker.config.parser import ConfigParser from marker.converters.pdf import PdfConverter from marker.output import text_from_rendered from PIL import Image from torchvision.transforms.functional import InterpolationMode from transformers import AutoFeatureExtractor, AutoModel from utils.src.presentation import Presentation, SlidePage from utils.src.utils import is_image_path, pjoin device_count = torch.cuda.device_count() def prs_dedup( presentation: Presentation, model: BGEM3FlagModel, batchsize: int = 32, threshold: float = 0.8, ) -> list[SlidePage]: """ Deduplicate slides in a presentation based on text similarity. Args: presentation (Presentation): The presentation object containing slides. model: The model used for generating text embeddings. batchsize (int): The batch size for processing slides. threshold (float): The similarity threshold for deduplication. Returns: list: A list of removed duplicate slides. """ text_embeddings = get_text_embedding( [i.to_text() for i in presentation.slides], model, batchsize ) pre_embedding = text_embeddings[0] slide_idx = 1 duplicates = [] while slide_idx < len(presentation): cur_embedding = text_embeddings[slide_idx] if torch.cosine_similarity(pre_embedding, cur_embedding, -1) > threshold: duplicates.append(slide_idx - 1) slide_idx += 1 pre_embedding = cur_embedding return [presentation.slides.pop(i) for i in reversed(duplicates)] def get_text_model(device: str = None) -> BGEM3FlagModel: """ Initialize and return a text model. Args: device (str): The device to run the model on. Returns: BGEM3FlagModel: The initialized text model. """ return BGEM3FlagModel( "BAAI/bge-m3", use_fp16=True, device=device, ) def get_image_model(device: str = None): """ Initialize and return an image model and its feature extractor. Args: device (str): The device to run the model on. Returns: tuple: A tuple containing the feature extractor and the image model. """ model_base = "google/vit-base-patch16-224-in21k" return ( AutoFeatureExtractor.from_pretrained( model_base, torch_dtype=torch.float16, device_map=device, ), AutoModel.from_pretrained( model_base, torch_dtype=torch.float16, device_map=device, ).eval(), ) def parse_pdf( pdf_path: str, output_path: str = None, model_lst: list = None, save_file: bool = True, ) -> str: """ Parse a PDF file and extract text and images. Args: pdf_path (str): The path to the PDF file. output_path (str): The directory to save the extracted content. model_lst (list): A list of models for processing the PDF. Returns: str: The full text extracted from the PDF. """ if save_file: os.makedirs(output_path, exist_ok=True) config_parser = ConfigParser( { "output_format": "markdown", } ) converter = PdfConverter( config=config_parser.generate_config_dict(), artifact_dict=model_lst, processor_list=config_parser.get_processors(), renderer=config_parser.get_renderer(), ) rendered = converter(pdf_path) full_text, _, images = text_from_rendered(rendered) if save_file: with open(pjoin(output_path, "source.md"), "w+", encoding="utf-8") as f: f.write(full_text) for filename, image in images.items(): image_filepath = os.path.join(output_path, filename) image.save(image_filepath, "JPEG") with open(pjoin(output_path, "meta.json"), "w+") as f: f.write(json.dumps(rendered.metadata, indent=4)) if not save_file: return full_text, rendered return full_text def get_text_embedding( text: list[str], model: BGEM3FlagModel, batchsize: int = 32 ) -> list[torch.Tensor]: """ Generate text embeddings for a list of text strings. Args: text (list[str]): A list of text strings. model: The model used for generating embeddings. batchsize (int): The batch size for processing text. Returns: list: A list of text embeddings. """ if isinstance(text, str): return torch.tensor(model.encode(text)["dense_vecs"]).to(model.device) result = [] for i in range(0, len(text), batchsize): result.extend( torch.tensor(model.encode(text[i : i + batchsize])["dense_vecs"]).to( model.device ) ) return result def get_image_embedding( image_dir: str, extractor, model, batchsize: int = 16 ) -> dict[str, torch.Tensor]: """ Generate image embeddings for images in a directory. Args: image_dir (str): The directory containing images. extractor: The feature extractor for images. model: The model used for generating embeddings. batchsize (int): The batch size for processing images. Returns: dict: A dictionary mapping image filenames to their embeddings. """ transform = T.Compose( [ T.Resize(int((256 / 224) * extractor.size["height"])), T.CenterCrop(extractor.size["height"]), T.ToTensor(), T.Normalize(mean=extractor.image_mean, std=extractor.image_std), ] ) inputs = [] embeddings = [] images = [i for i in sorted(os.listdir(image_dir)) if is_image_path(i)] for file in images: image = Image.open(pjoin(image_dir, file)).convert("RGB") inputs.append(transform(image)) if len(inputs) % batchsize == 0 or file == images[-1]: batch = {"pixel_values": torch.stack(inputs).to(model.device)} embeddings.extend(model(**batch).last_hidden_state.detach()) inputs.clear() return {image: embedding.flatten() for image, embedding in zip(images, embeddings)} def images_cosine_similarity(embeddings: list[torch.Tensor]) -> torch.Tensor: """ Calculate the cosine similarity matrix for a list of embeddings. Args: embeddings (list[torch.Tensor]): A list of image embeddings. Returns: torch.Tensor: A NxN similarity matrix. """ embeddings = [embedding for embedding in embeddings] sim_matrix = torch.zeros((len(embeddings), len(embeddings))) for i in range(len(embeddings)): for j in range(i + 1, len(embeddings)): sim_matrix[i, j] = sim_matrix[j, i] = torch.cosine_similarity( embeddings[i], embeddings[j], -1 ) return sim_matrix IMAGENET_MEAN = (0.485, 0.456, 0.406) IMAGENET_STD = (0.229, 0.224, 0.225) def average_distance( similarity: torch.Tensor, idx: int, cluster_idx: list[int] ) -> float: """ Calculate the average distance between a point (idx) and a cluster (cluster_idx). Args: similarity (np.ndarray): The similarity matrix. idx (int): The index of the point. cluster_idx (list): The indices of the cluster. Returns: float: The average distance. """ if idx in cluster_idx: return 0 total_similarity = 0 for idx_in_cluster in cluster_idx: total_similarity += similarity[idx, idx_in_cluster] return total_similarity / len(cluster_idx) def get_cluster(similarity: np.ndarray, sim_bound: float = 0.65): """ Cluster points based on similarity. Args: similarity (np.ndarray): The similarity matrix. sim_bound (float): The similarity threshold for clustering. Returns: list: A list of clusters. """ num_points = similarity.shape[0] clusters = [] sim_copy = deepcopy(similarity) added = [False] * num_points while True: max_avg_dist = sim_bound best_cluster = None best_point = None for c in clusters: for point_idx in range(num_points): if added[point_idx]: continue avg_dist = average_distance(sim_copy, point_idx, c) if avg_dist > max_avg_dist: max_avg_dist = avg_dist best_cluster = c best_point = point_idx if best_point is not None: best_cluster.append(best_point) added[best_point] = True similarity[best_point, :] = 0 similarity[:, best_point] = 0 else: if similarity.max() < sim_bound: break i, j = np.unravel_index(np.argmax(similarity), similarity.shape) clusters.append([int(i), int(j)]) added[i] = True added[j] = True similarity[i, :] = 0 similarity[:, i] = 0 similarity[j, :] = 0 similarity[:, j] = 0 return clusters