import json import PIL.Image from rich import print import src.llms as llms from src.presentation import Picture, Presentation from src.utils import Config, pbasename, pexists, pjoin class ImageLabler: """ A class to extract images information, including caption, size, and appearance times in a presentation. """ def __init__(self, presentation: Presentation, config: Config): """ Initialize the ImageLabler. Args: presentation (Presentation): The presentation object. config (Config): The configuration object. """ self.presentation = presentation self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt self.image_stats = {} self.stats_file = pjoin(config.RUN_DIR, "image_stats.json") self.config = config self.collect_images() if pexists(self.stats_file): image_stats: dict[str, dict] = json.load(open(self.stats_file, "r")) for name, stat in image_stats.items(): if pbasename(name) in self.image_stats: self.image_stats[pbasename(name)] = stat def apply_stats(self): """ Apply image captions to the presentation. """ for slide in self.presentation.slides: for shape in slide.shape_filter(Picture): stats = self.image_stats[pbasename(shape.img_path)] shape.caption = stats["caption"] def caption_images(self): """ Generate captions for images in the presentation. """ caption_prompt = open("prompts/caption.txt").read() for image, stats in self.image_stats.items(): if "caption" not in stats: stats["caption"] = llms.vision_model( caption_prompt, pjoin(self.config.IMAGE_DIR, image) ) print("captioned", image, ": ", stats["caption"]) json.dump( self.image_stats, open(self.stats_file, "w"), indent=4, ensure_ascii=False, ) self.apply_stats() return self.image_stats def collect_images(self): """ Collect images from the presentation and gather other information. """ for slide_index, slide in enumerate(self.presentation.slides): for shape in slide.shape_filter(Picture): image_path = pbasename(shape.data[0]) self.image_stats[image_path] = { "appear_times": 0, "slide_numbers": set(), "relative_area": shape.area / self.slide_area * 100, "size": PIL.Image.open( pjoin(self.config.IMAGE_DIR, image_path) ).size, } self.image_stats[image_path]["appear_times"] += 1 self.image_stats[image_path]["slide_numbers"].add(slide_index + 1) for image_path, stats in self.image_stats.items(): stats["slide_numbers"] = sorted(list(stats["slide_numbers"])) ranges = self._find_ranges(stats["slide_numbers"]) top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3] top_ranges_str = ", ".join( [f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges] ) stats["top_ranges_str"] = top_ranges_str def _find_ranges(self, numbers): """ Find consecutive ranges in a list of numbers. """ ranges = [] start = numbers[0] end = numbers[0] for num in numbers[1:]: if num == end + 1: end = num else: ranges.append((start, end)) start = num end = num ranges.append((start, end)) return ranges