File size: 3,915 Bytes
fcaa164 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import json
import PIL.Image
from rich import print
import src.llms as llms
from src.presentation import Picture, Presentation
from src.utils import Config, pbasename, pexists, pjoin
class ImageLabler:
"""
A class to extract images information, including caption, size, and appearance times in a presentation.
"""
def __init__(self, presentation: Presentation, config: Config):
"""
Initialize the ImageLabler.
Args:
presentation (Presentation): The presentation object.
config (Config): The configuration object.
"""
self.presentation = presentation
self.slide_area = presentation.slide_width.pt * presentation.slide_height.pt
self.image_stats = {}
self.stats_file = pjoin(config.RUN_DIR, "image_stats.json")
self.config = config
self.collect_images()
if pexists(self.stats_file):
image_stats: dict[str, dict] = json.load(open(self.stats_file, "r"))
for name, stat in image_stats.items():
if pbasename(name) in self.image_stats:
self.image_stats[pbasename(name)] = stat
def apply_stats(self):
"""
Apply image captions to the presentation.
"""
for slide in self.presentation.slides:
for shape in slide.shape_filter(Picture):
stats = self.image_stats[pbasename(shape.img_path)]
shape.caption = stats["caption"]
def caption_images(self):
"""
Generate captions for images in the presentation.
"""
caption_prompt = open("prompts/caption.txt").read()
for image, stats in self.image_stats.items():
if "caption" not in stats:
stats["caption"] = llms.vision_model(
caption_prompt, pjoin(self.config.IMAGE_DIR, image)
)
print("captioned", image, ": ", stats["caption"])
json.dump(
self.image_stats,
open(self.stats_file, "w"),
indent=4,
ensure_ascii=False,
)
self.apply_stats()
return self.image_stats
def collect_images(self):
"""
Collect images from the presentation and gather other information.
"""
for slide_index, slide in enumerate(self.presentation.slides):
for shape in slide.shape_filter(Picture):
image_path = pbasename(shape.data[0])
self.image_stats[image_path] = {
"appear_times": 0,
"slide_numbers": set(),
"relative_area": shape.area / self.slide_area * 100,
"size": PIL.Image.open(
pjoin(self.config.IMAGE_DIR, image_path)
).size,
}
self.image_stats[image_path]["appear_times"] += 1
self.image_stats[image_path]["slide_numbers"].add(slide_index + 1)
for image_path, stats in self.image_stats.items():
stats["slide_numbers"] = sorted(list(stats["slide_numbers"]))
ranges = self._find_ranges(stats["slide_numbers"])
top_ranges = sorted(ranges, key=lambda x: x[1] - x[0], reverse=True)[:3]
top_ranges_str = ", ".join(
[f"{r[0]}-{r[1]}" if r[0] != r[1] else f"{r[0]}" for r in top_ranges]
)
stats["top_ranges_str"] = top_ranges_str
def _find_ranges(self, numbers):
"""
Find consecutive ranges in a list of numbers.
"""
ranges = []
start = numbers[0]
end = numbers[0]
for num in numbers[1:]:
if num == end + 1:
end = num
else:
ranges.append((start, end))
start = num
end = num
ranges.append((start, end))
return ranges
|