import json import os import shutil from collections import defaultdict from jinja2 import Template import src.llms as llms from src.model_utils import get_cluster, get_image_embedding, images_cosine_similarity from src.presentation import Presentation from src.utils import Config, pexists, pjoin, tenacity class SlideInducter: """ Stage I: Presentation Analysis. This stage is to analyze the presentation: cluster slides into different layouts, and extract content schema for each layout. """ def __init__( self, prs: Presentation, ppt_image_folder: str, template_image_folder: str, config: Config, image_models: list, ): """ Initialize the SlideInducter. Args: prs (Presentation): The presentation object. ppt_image_folder (str): The folder containing PPT images. template_image_folder (str): The folder containing normalized slide images. config (Config): The configuration object. image_models (list): A list of image models. """ self.prs = prs self.config = config self.ppt_image_folder = ppt_image_folder self.template_image_folder = template_image_folder assert ( len(os.listdir(template_image_folder)) == len(prs) == len(os.listdir(ppt_image_folder)) ) self.image_models = image_models self.slide_induction = defaultdict(lambda: defaultdict(list)) model_identifier = llms.get_simple_modelname( [llms.language_model, llms.vision_model] ) self.output_dir = pjoin(config.RUN_DIR, "template_induct", model_identifier) self.split_cache = pjoin(self.output_dir, f"split_cache.json") self.induct_cache = pjoin(self.output_dir, f"induct_cache.json") os.makedirs(self.output_dir, exist_ok=True) def layout_induct(self): """ Perform layout induction for the presentation. """ if pexists(self.induct_cache): return json.load(open(self.induct_cache)) content_slides_index, functional_cluster = self.category_split() for layout_name, cluster in functional_cluster.items(): for slide_idx in cluster: content_type = self.prs.slides[slide_idx - 1].get_content_type() self.slide_induction[layout_name + ":" + content_type]["slides"].append( slide_idx ) for layout_name, cluster in self.slide_induction.items(): cluster["template_id"] = cluster["slides"][-1] functional_keys = list(self.slide_induction.keys()) function_slides_index = set() for layout_name, cluster in self.slide_induction.items(): function_slides_index.update(cluster["slides"]) used_slides_index = function_slides_index.union(content_slides_index) for i in range(len(self.prs.slides)): if i + 1 not in used_slides_index: content_slides_index.add(i + 1) self.layout_split(content_slides_index) if self.config.DEBUG: for layout_name, cluster in self.slide_induction.items(): cluster_dir = pjoin(self.output_dir, "cluster_slides", layout_name) os.makedirs(cluster_dir, exist_ok=True) for slide_idx in cluster["slides"]: shutil.copy( pjoin(self.ppt_image_folder, f"slide_{slide_idx:04d}.jpg"), pjoin(cluster_dir, f"slide_{slide_idx:04d}.jpg"), ) self.slide_induction["functional_keys"] = functional_keys json.dump( self.slide_induction, open(self.induct_cache, "w"), indent=4, ensure_ascii=False, ) return self.slide_induction def category_split(self): """ Split slides into categories based on their functional purpose. """ if pexists(self.split_cache): split = json.load(open(self.split_cache)) return set(split["content_slides_index"]), split["functional_cluster"] category_split_template = Template(open("prompts/category_split.txt").read()) functional_cluster = llms.language_model( category_split_template.render(slides=self.prs.to_text()), return_json=True, ) functional_slides = set(sum(functional_cluster.values(), [])) content_slides_index = set(range(1, len(self.prs) + 1)) - functional_slides json.dump( { "content_slides_index": list(content_slides_index), "functional_cluster": functional_cluster, }, open(self.split_cache, "w"), indent=4, ensure_ascii=False, ) return content_slides_index, functional_cluster def layout_split(self, content_slides_index: set[int]): """ Cluster slides into different layouts. """ embeddings = get_image_embedding(self.template_image_folder, *self.image_models) assert len(embeddings) == len(self.prs) template = Template(open("prompts/ask_category.txt").read()) content_split = defaultdict(list) for slide_idx in content_slides_index: slide = self.prs.slides[slide_idx - 1] content_type = slide.get_content_type() layout_name = slide.slide_layout_name content_split[(layout_name, content_type)].append(slide_idx) for (layout_name, content_type), slides in content_split.items(): sub_embeddings = [ embeddings[f"slide_{slide_idx:04d}.jpg"] for slide_idx in slides ] similarity = images_cosine_similarity(sub_embeddings) for cluster in get_cluster(similarity): slide_indexs = [slides[i] for i in cluster] template_id = max( slide_indexs, key=lambda x: len(self.prs.slides[x - 1].shapes), ) cluster_name = ( llms.vision_model( template.render( existed_layoutnames=list(self.slide_induction.keys()), ), pjoin(self.ppt_image_folder, f"slide_{template_id:04d}.jpg"), ) + ":" + content_type ) self.slide_induction[cluster_name]["template_id"] = template_id self.slide_induction[cluster_name]["slides"] = slide_indexs @tenacity def content_induct(self): """ Perform content schema extraction for the presentation. """ self.slide_induction = self.layout_induct() content_induct_prompt = Template(open("prompts/content_induct.txt").read()) for layout_name, cluster in self.slide_induction.items(): if "template_id" in cluster and "content_schema" not in cluster: schema = llms.language_model( content_induct_prompt.render( slide=self.prs.slides[cluster["template_id"] - 1].to_html( element_id=False, paragraph_id=False ) ), return_json=True, ) for k in list(schema.keys()): if "data" not in schema[k]: raise ValueError(f"Cannot find `data` in {k}\n{schema[k]}") if len(schema[k]["data"]) == 0: print(f"Empty content schema: {schema[k]}") schema.pop(k) assert len(schema) > 0, "No content schema generated" self.slide_induction[layout_name]["content_schema"] = schema json.dump( self.slide_induction, open(self.induct_cache, "w"), indent=4, ensure_ascii=False, ) return self.slide_induction