|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import shutil |
|
|
from functools import partial |
|
|
from glob import glob |
|
|
from time import sleep |
|
|
from typing import Type |
|
|
|
|
|
os.environ['OPENAI_API_KEY'] = 'Your key here' |
|
|
|
|
|
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')) |
|
|
sys.path.insert(0, root_dir) |
|
|
|
|
|
import func_argparse |
|
|
import torch |
|
|
|
|
|
import src.llms as llms |
|
|
from src.experiment.ablation import ( |
|
|
PPTCrew_wo_Decoupling, |
|
|
PPTCrew_wo_HTML, |
|
|
PPTCrew_wo_LayoutInduction, |
|
|
PPTCrew_wo_SchemaInduction, |
|
|
PPTCrew_wo_Structure, |
|
|
) |
|
|
from src.experiment.preprocess import process_filetype |
|
|
from src.model_utils import get_text_model |
|
|
from src.multimodal import ImageLabler |
|
|
from src.pptgen import PPTCrew |
|
|
from src.presentation import Presentation |
|
|
from src.utils import Config, older_than, pbasename, pexists, pjoin, ppt_to_images |
|
|
|
|
|
|
|
|
EVAL_MODELS = [ |
|
|
(llms.qwen2_5, llms.qwen_vl), |
|
|
(llms.gpt4o, llms.gpt4o), |
|
|
(llms.qwen_vl, llms.qwen_vl), |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AGENT_CLASS = { |
|
|
-1: PPTCrew, |
|
|
0: PPTCrew_wo_LayoutInduction, |
|
|
1: PPTCrew_wo_SchemaInduction, |
|
|
2: PPTCrew_wo_Decoupling, |
|
|
3: PPTCrew_wo_HTML, |
|
|
4: PPTCrew, |
|
|
5: PPTCrew_wo_Structure, |
|
|
6: PPTCrew, |
|
|
} |
|
|
|
|
|
|
|
|
def get_setting(setting_id: int, ablation_id: int): |
|
|
assert ablation_id in AGENT_CLASS, f"ablation_id {ablation_id} not in {AGENT_CLASS}" |
|
|
assert ( |
|
|
ablation_id == -1 or setting_id == 0 |
|
|
), "ablation_id == -1 only when setting_id != 0" |
|
|
language_model, vision_model = EVAL_MODELS[setting_id] |
|
|
agent_class = AGENT_CLASS.get(ablation_id) |
|
|
llms.language_model = language_model |
|
|
llms.vision_model = vision_model |
|
|
if ablation_id == -1: |
|
|
setting_name = "PPTCrew-" + llms.get_simple_modelname( |
|
|
[language_model, vision_model] |
|
|
) |
|
|
elif ablation_id == 6: |
|
|
setting_name = "PPTCrew_retry_5" |
|
|
agent_class = partial(agent_class, retry_times=5) |
|
|
else: |
|
|
setting_name = agent_class.__name__ |
|
|
model_identifier = llms.get_simple_modelname( |
|
|
[llms.language_model, llms.vision_model] |
|
|
) |
|
|
if ablation_id == 4: |
|
|
setting_name = "PPTCrew_with_gpt4o" |
|
|
model_identifier = "gpt-4o+gpt-4o" |
|
|
return agent_class, setting_name, model_identifier |
|
|
|
|
|
|
|
|
def do_generate( |
|
|
genclass: Type[PPTCrew], |
|
|
setting: str, |
|
|
model_identifier: str, |
|
|
debug: bool, |
|
|
ppt_folder: str, |
|
|
thread_id: int, |
|
|
num_slides: int = 12, |
|
|
): |
|
|
app_config = Config(rundir=ppt_folder, debug=debug) |
|
|
text_model = get_text_model(f"cuda:{thread_id % torch.cuda.device_count()}") |
|
|
presentation = Presentation.from_file( |
|
|
pjoin(ppt_folder, "source.pptx"), |
|
|
app_config, |
|
|
) |
|
|
ImageLabler(presentation, app_config).caption_images() |
|
|
induct_cache = pjoin( |
|
|
app_config.RUN_DIR, "template_induct", model_identifier, "induct_cache.json" |
|
|
) |
|
|
if not older_than(induct_cache, wait=True): |
|
|
print(f"induct_cache not found: {induct_cache}") |
|
|
return |
|
|
slide_induction = json.load(open(induct_cache)) |
|
|
try: |
|
|
pptgen: PPTCrew = genclass(text_model).set_reference(presentation, slide_induction) |
|
|
except: |
|
|
print("set_reference failed") |
|
|
pptgen: PPTCrew = genclass(text_model).set_reference(presentation, slide_induction) |
|
|
|
|
|
topic = ppt_folder.split("/")[1] |
|
|
for pdf_folder in glob(f"data/{topic}/pdf/*"): |
|
|
app_config.set_rundir(pjoin(ppt_folder, setting, pbasename(pdf_folder))) |
|
|
if pexists(pjoin(app_config.RUN_DIR, "history")): |
|
|
continue |
|
|
images = json.load( |
|
|
open(pjoin(pdf_folder, "image_caption.json"), "r"), |
|
|
) |
|
|
doc_json = json.load( |
|
|
open(pjoin(pdf_folder, "refined_doc.json"), "r"), |
|
|
) |
|
|
pptgen.generate_pres(app_config, images, num_slides, doc_json) |
|
|
|
|
|
|
|
|
def generate_pres( |
|
|
setting_id: int = 0, |
|
|
setting_name: str = None, |
|
|
ablation_id: int = -1, |
|
|
thread_num: int = 8, |
|
|
debug: bool = False, |
|
|
topic: str = "*", |
|
|
num_slides: int = 12, |
|
|
): |
|
|
agent_class, setting, model_identifier = get_setting(setting_id, ablation_id) |
|
|
setting = setting_name or setting |
|
|
print("generating slides using:", setting) |
|
|
generate = partial( |
|
|
do_generate, |
|
|
agent_class, |
|
|
setting, |
|
|
model_identifier, |
|
|
debug, |
|
|
num_slides=num_slides, |
|
|
) |
|
|
process_filetype("pptx", generate, thread_num, topic) |
|
|
|
|
|
|
|
|
def pptx2images(settings: str = "*"): |
|
|
while True: |
|
|
for folder in glob(f"data/*/pptx/*/{settings}/*/history"): |
|
|
folder = os.path.dirname(folder) |
|
|
pptx = pjoin(folder, "final.pptx") |
|
|
ppt_folder, setting, pdf = folder.rsplit("/", 2) |
|
|
dst = pjoin(ppt_folder, "final_images", setting, pdf) |
|
|
|
|
|
if not pexists(pptx): |
|
|
if pexists(dst): |
|
|
print(f"remove {dst}") |
|
|
shutil.rmtree(dst) |
|
|
continue |
|
|
|
|
|
older_than(pptx) |
|
|
if pexists(dst): |
|
|
continue |
|
|
try: |
|
|
ppt_to_images(pptx, dst) |
|
|
except: |
|
|
print("pptx to images failed") |
|
|
sleep(60) |
|
|
print("keep scanning for new pptx") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
func_argparse.main( |
|
|
generate_pres, |
|
|
pptx2images, |
|
|
) |
|
|
|