File size: 3,915 Bytes
fcaa164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import json

import PIL.Image
from rich import print

import src.llms as llms
from src.presentation import Picture, Presentation
from src.utils import Config, pbasename, pexists, pjoin


class ImageLabler:
    """
    A class to extract images information, including caption, size, and appearance times in a presentation.
    """

    def __init__(self, presentation: Presentation, config: Config):
        """
        Initialize the ImageLabler.

        Args:
            presentation (Presentation): The presentation object.
            config (Config): The configuration object.
        """
        self.presentation = presentation
        self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt
        self.image_stats = {}
        self.stats_file = pjoin(config.RUN_DIR, "image_stats.json")
        self.config = config
        self.collect_images()
        if pexists(self.stats_file):
            image_stats: dict[str, dict] = json.load(open(self.stats_file, "r"))
            for name, stat in image_stats.items():
                if pbasename(name) in self.image_stats:
                    self.image_stats[pbasename(name)] = stat

    def apply_stats(self):
        """
        Apply image captions to the presentation.
        """
        for slide in self.presentation.slides:
            for shape in slide.shape_filter(Picture):
                stats = self.image_stats[pbasename(shape.img_path)]
                shape.caption = stats["caption"]

    def caption_images(self):
        """
        Generate captions for images in the presentation.
        """
        caption_prompt = open("prompts/caption.txt").read()
        for image, stats in self.image_stats.items():
            if "caption" not in stats:
                stats["caption"] = llms.vision_model(
                    caption_prompt, pjoin(self.config.IMAGE_DIR, image)
                )
                print("captioned", image, ": ", stats["caption"])
        json.dump(
            self.image_stats,
            open(self.stats_file, "w"),
            indent=4,
            ensure_ascii=False,
        )
        self.apply_stats()
        return self.image_stats

    def collect_images(self):
        """
        Collect images from the presentation and gather other information.
        """
        for slide_index, slide in enumerate(self.presentation.slides):
            for shape in slide.shape_filter(Picture):
                image_path = pbasename(shape.data[0])
                self.image_stats[image_path] = {
                    "appear_times": 0,
                    "slide_numbers": set(),
                    "relative_area": shape.area / self.slide_area * 100,
                    "size": PIL.Image.open(
                        pjoin(self.config.IMAGE_DIR, image_path)
                    ).size,
                }
                self.image_stats[image_path]["appear_times"] += 1
                self.image_stats[image_path]["slide_numbers"].add(slide_index + 1)
        for image_path, stats in self.image_stats.items():
            stats["slide_numbers"] = sorted(list(stats["slide_numbers"]))
            ranges = self._find_ranges(stats["slide_numbers"])
            top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3]
            top_ranges_str = ", ".join(
                [f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges]
            )
            stats["top_ranges_str"] = top_ranges_str

    def _find_ranges(self, numbers):
        """
        Find consecutive ranges in a list of numbers.
        """
        ranges = []
        start = numbers[0]
        end = numbers[0]
        for num in numbers[1:]:
            if num == end + 1:
                end = num
            else:
                ranges.append((start, end))
                start = num
                end = num
        ranges.append((start, end))
        return ranges