File size: 5,437 Bytes
fcaa164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import json
import os
import sys
import shutil
from functools import partial
from glob import glob
from time import sleep
from typing import Type

os.environ['OPENAI_API_KEY'] = 'Your key here'

root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))
sys.path.insert(0, root_dir)

import func_argparse
import torch

import src.llms as llms
from src.experiment.ablation import (
    PPTCrew_wo_Decoupling,
    PPTCrew_wo_HTML,
    PPTCrew_wo_LayoutInduction,
    PPTCrew_wo_SchemaInduction,
    PPTCrew_wo_Structure,
)
from src.experiment.preprocess import process_filetype
from src.model_utils import get_text_model
from src.multimodal import ImageLabler
from src.pptgen import PPTCrew
from src.presentation import Presentation
from src.utils import Config, older_than, pbasename, pexists, pjoin, ppt_to_images

# language_model vision_model
EVAL_MODELS = [
    (llms.qwen2_5, llms.qwen_vl),
    (llms.gpt4o, llms.gpt4o),
    (llms.qwen_vl, llms.qwen_vl),
]

# ablation
# 0: w/o layout induction
# 1: w/o schema induction
# 2: w/o decoupling
# 3: w/o html
# 4: with gpt4o template
# 5: w/o structure information
# 6: retry 5 times

AGENT_CLASS = {
    -1: PPTCrew,
    0: PPTCrew_wo_LayoutInduction,
    1: PPTCrew_wo_SchemaInduction,
    2: PPTCrew_wo_Decoupling,
    3: PPTCrew_wo_HTML,
    4: PPTCrew,
    5: PPTCrew_wo_Structure,
    6: PPTCrew,
}


def get_setting(setting_id: int, ablation_id: int):
    assert ablation_id in AGENT_CLASS, f"ablation_id {ablation_id} not in {AGENT_CLASS}"
    assert (
        ablation_id == -1 or setting_id == 0
    ), "ablation_id == -1 only when setting_id != 0"
    language_model, vision_model = EVAL_MODELS[setting_id]
    agent_class = AGENT_CLASS.get(ablation_id)
    llms.language_model = language_model
    llms.vision_model = vision_model
    if ablation_id == -1:
        setting_name = "PPTCrew-" + llms.get_simple_modelname(
            [language_model, vision_model]
        )
    elif ablation_id == 6:
        setting_name = "PPTCrew_retry_5"
        agent_class = partial(agent_class, retry_times=5)
    else:
        setting_name = agent_class.__name__
    model_identifier = llms.get_simple_modelname(
        [llms.language_model, llms.vision_model]
    )
    if ablation_id == 4:
        setting_name = "PPTCrew_with_gpt4o"
        model_identifier = "gpt-4o+gpt-4o"
    return agent_class, setting_name, model_identifier


def do_generate(
    genclass: Type[PPTCrew],
    setting: str,
    model_identifier: str,
    debug: bool,
    ppt_folder: str,
    thread_id: int,
    num_slides: int = 12,
):
    app_config = Config(rundir=ppt_folder, debug=debug)
    text_model = get_text_model(f"cuda:{thread_id % torch.cuda.device_count()}")
    presentation = Presentation.from_file(
        pjoin(ppt_folder, "source.pptx"),
        app_config,
    )
    ImageLabler(presentation, app_config).caption_images()
    induct_cache = pjoin(
        app_config.RUN_DIR, "template_induct", model_identifier, "induct_cache.json"
    )
    if not older_than(induct_cache, wait=True):
        print(f"induct_cache not found: {induct_cache}")
        return
    slide_induction = json.load(open(induct_cache))
    try:
        pptgen: PPTCrew = genclass(text_model).set_reference(presentation, slide_induction)
    except:
        print("set_reference failed")
        pptgen: PPTCrew = genclass(text_model).set_reference(presentation, slide_induction)
        
    topic = ppt_folder.split("/")[1]
    for pdf_folder in glob(f"data/{topic}/pdf/*"):
        app_config.set_rundir(pjoin(ppt_folder, setting, pbasename(pdf_folder)))
        if pexists(pjoin(app_config.RUN_DIR, "history")):
            continue
        images = json.load(
            open(pjoin(pdf_folder, "image_caption.json"), "r"),
        )
        doc_json = json.load(
            open(pjoin(pdf_folder, "refined_doc.json"), "r"),
        )
        pptgen.generate_pres(app_config, images, num_slides, doc_json)


def generate_pres(
    setting_id: int = 0,
    setting_name: str = None,
    ablation_id: int = -1,
    thread_num: int = 8,
    debug: bool = False,
    topic: str = "*",
    num_slides: int = 12,
):
    agent_class, setting, model_identifier = get_setting(setting_id, ablation_id)
    setting = setting_name or setting
    print("generating slides using:", setting)
    generate = partial(
        do_generate,
        agent_class,
        setting,
        model_identifier,
        debug,
        num_slides=num_slides,
    )
    process_filetype("pptx", generate, thread_num, topic)


def pptx2images(settings: str = "*"):
    while True:
        for folder in glob(f"data/*/pptx/*/{settings}/*/history"):
            folder = os.path.dirname(folder)
            pptx = pjoin(folder, "final.pptx")
            ppt_folder, setting, pdf = folder.rsplit("/", 2)
            dst = pjoin(ppt_folder, "final_images", setting, pdf)

            if not pexists(pptx):
                if pexists(dst):
                    print(f"remove {dst}")
                    shutil.rmtree(dst)
                continue

            older_than(pptx)
            if pexists(dst):
                continue
            try:
                ppt_to_images(pptx, dst)
            except:
                print("pptx to images failed")
        sleep(60)
        print("keep scanning for new pptx")


if __name__ == "__main__":
    func_argparse.main(
        generate_pres,
        pptx2images,
    )