Update managers/prompt_enhancer_manager.py
Browse files
managers/prompt_enhancer_manager.py
CHANGED
|
@@ -1,80 +1,98 @@
|
|
| 1 |
-
#
|
| 2 |
#
|
| 3 |
# Copyright (C) 2025 Carlos Rodrigues dos Santos
|
| 4 |
#
|
| 5 |
-
# Version:
|
| 6 |
#
|
| 7 |
-
# This
|
| 8 |
-
#
|
| 9 |
-
#
|
| 10 |
-
#
|
| 11 |
|
| 12 |
-
import torch
|
| 13 |
import logging
|
| 14 |
-
import yaml
|
| 15 |
from PIL import Image
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
def __init__(self):
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
try:
|
| 29 |
-
|
| 30 |
-
config = yaml.safe_load(f)['specialists']['prompt_enhancer']
|
| 31 |
-
|
| 32 |
-
caption_model_name = config['image_caption_model']
|
| 33 |
-
llm_model_name = config['llm_model']
|
| 34 |
-
|
| 35 |
-
prompt_filename = config.get('prompt_file')
|
| 36 |
-
if not prompt_filename:
|
| 37 |
-
raise ValueError("Config for prompt_enhancer is missing the 'prompt_file' key.")
|
| 38 |
|
| 39 |
-
|
| 40 |
-
if not prompt_path.is_file():
|
| 41 |
-
raise FileNotFoundError(f"Enhancer prompt file not found at: {prompt_path}")
|
| 42 |
-
self.system_prompt = prompt_path.read_text(encoding="utf-8").strip()
|
| 43 |
-
logger.info(f"Loaded system prompt for enhancer from: {prompt_path}")
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
logger.info(f"Loading LLM for Prompt Enhancement: {llm_model_name}...")
|
| 59 |
-
self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
|
| 60 |
-
self.llm_model = AutoModelForCausalLM.from_pretrained(
|
| 61 |
-
llm_model_name, torch_dtype=self.dtype, device_map="auto"
|
| 62 |
-
)
|
| 63 |
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
except Exception as e:
|
| 66 |
-
logger.
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
| 70 |
-
def get_image_caption(self, image: Image.Image) -> str:
|
| 71 |
"""
|
| 72 |
-
|
| 73 |
"""
|
| 74 |
-
|
|
|
|
| 75 |
inputs = self.caption_processor(
|
| 76 |
-
text=task_prompt, images=
|
| 77 |
-
).to(self.device
|
| 78 |
|
| 79 |
generated_ids = self.caption_model.generate(
|
| 80 |
input_ids=inputs["input_ids"],
|
|
@@ -83,42 +101,36 @@ class PromptEnhancerManager:
|
|
| 83 |
num_beams=3,
|
| 84 |
)
|
| 85 |
|
|
|
|
| 86 |
generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
| 87 |
-
|
| 88 |
processed_result = self.caption_processor.post_process_generation(
|
| 89 |
generated_text,
|
| 90 |
task=task_prompt,
|
| 91 |
-
image_size=(
|
| 92 |
)
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
-
return caption
|
| 96 |
-
|
| 97 |
-
@torch.no_grad()
|
| 98 |
-
def get_llm_enhanced_prompt(self, user_content_prompt: str) -> str:
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
-
system prompt, and gets a cinematic response from the LLM.
|
| 102 |
"""
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
{"role": "user", "content": user_content_prompt}
|
| 106 |
-
]
|
| 107 |
-
|
| 108 |
-
input_ids = self.llm_tokenizer.apply_chat_template(
|
| 109 |
-
messages, add_generation_prompt=True, return_tensors="pt"
|
| 110 |
-
).to(self.llm_model.device)
|
| 111 |
-
|
| 112 |
-
outputs = self.llm_model.generate(
|
| 113 |
-
input_ids, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9
|
| 114 |
)
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# --- Singleton Instantiation ---
|
| 120 |
try:
|
| 121 |
-
|
| 122 |
except Exception as e:
|
| 123 |
-
|
|
|
|
| 124 |
raise e
|
|
|
|
| 1 |
+
# engineers/deformes3D_thinker.py
|
| 2 |
#
|
| 3 |
# Copyright (C) 2025 Carlos Rodrigues dos Santos
|
| 4 |
#
|
| 5 |
+
# Version: 4.0.0 (Definitive)
|
| 6 |
#
|
| 7 |
+
# This is the definitive, robust implementation. It directly contains the prompt
|
| 8 |
+
# enhancement logic copied from the LTX pipeline's utils. It accesses the
|
| 9 |
+
# enhancement models loaded by the LTX Manager and performs the captioning
|
| 10 |
+
# and LLM generation steps locally, ensuring full control and compatibility.
|
| 11 |
|
|
|
|
| 12 |
import logging
|
|
|
|
| 13 |
from PIL import Image
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
# Importa o singleton do LTX para ter acesso à sua pipeline e aos modelos nela
|
| 17 |
+
from managers.ltx_manager import ltx_manager_singleton
|
| 18 |
+
|
| 19 |
+
# Importa o prompt de sistema do LTX para garantir consistência
|
| 20 |
+
from ltx_video.utils.prompt_enhance_utils import I2V_CINEMATIC_PROMPT
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
+
class Deformes3DThinker:
|
| 25 |
+
"""
|
| 26 |
+
The tactical specialist that now directly implements the prompt enhancement
|
| 27 |
+
logic, using the models provided by the LTX pipeline.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
def __init__(self):
|
| 31 |
+
# Acessa a pipeline exposta para obter os modelos necessários
|
| 32 |
+
pipeline = ltx_manager_singleton.prompt_enhancement_pipeline
|
| 33 |
+
if not pipeline:
|
| 34 |
+
raise RuntimeError("Deformes3DThinker could not access the LTX pipeline.")
|
| 35 |
+
|
| 36 |
+
# Armazena os modelos e processadores como atributos diretos
|
| 37 |
+
self.caption_model = pipeline.prompt_enhancer_image_caption_model
|
| 38 |
+
self.caption_processor = pipeline.prompt_enhancer_image_caption_processor
|
| 39 |
+
self.llm_model = pipeline.prompt_enhancer_llm_model
|
| 40 |
+
self.llm_tokenizer = pipeline.prompt_enhancer_llm_tokenizer
|
| 41 |
+
|
| 42 |
+
# Verifica se os modelos foram realmente carregados
|
| 43 |
+
if not all([self.caption_model, self.caption_processor, self.llm_model, self.llm_tokenizer]):
|
| 44 |
+
logger.warning("Deformes3DThinker initialized, but one or more enhancement models were not loaded by the LTX pipeline. Fallback will be used.")
|
| 45 |
+
else:
|
| 46 |
+
logger.info("Deformes3DThinker initialized and successfully linked to LTX enhancement models.")
|
| 47 |
+
|
| 48 |
+
@torch.no_grad()
|
| 49 |
+
def get_enhanced_motion_prompt(self, global_prompt: str, story_history: str,
|
| 50 |
+
past_keyframe_path: str, present_keyframe_path: str, future_keyframe_path: str,
|
| 51 |
+
past_scene_desc: str, present_scene_desc: str, future_scene_desc: str) -> str:
|
| 52 |
+
"""
|
| 53 |
+
Generates a refined motion prompt by directly executing the enhancement pipeline logic.
|
| 54 |
+
"""
|
| 55 |
+
# Verifica se os modelos estão disponíveis antes de tentar usá-los
|
| 56 |
+
if not all([self.caption_model, self.caption_processor, self.llm_model, self.llm_tokenizer]):
|
| 57 |
+
logger.warning("Enhancement models not available. Using fallback prompt.")
|
| 58 |
+
return f"A cinematic transition from '{present_scene_desc}' to '{future_scene_desc}'."
|
| 59 |
|
| 60 |
try:
|
| 61 |
+
present_image = Image.open(present_keyframe_path).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
+
# --- INÍCIO DA LÓGICA COPIADA E ADAPTADA DO LTX ---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
# 1. Gerar a caption da imagem de referência (presente)
|
| 66 |
+
image_captions = self._generate_image_captions([present_image])
|
| 67 |
|
| 68 |
+
# 2. Construir o prompt para o LLM
|
| 69 |
+
# Usamos a cena futura como o "prompt do usuário"
|
| 70 |
+
messages = [
|
| 71 |
+
{"role": "system", "content": I2V_CINEMATIC_PROMPT},
|
| 72 |
+
{"role": "user", "content": f"user_prompt: {future_scene_desc}\nimage_caption: {image_captions[0]}"},
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
# 3. Gerar e decodificar o prompt final com o LLM
|
| 76 |
+
enhanced_prompt = self._generate_and_decode_prompts(messages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
+
# --- FIM DA LÓGICA COPIADA E ADAPTADA ---
|
| 79 |
+
|
| 80 |
+
logger.info(f"Deformes3DThinker received enhanced prompt: '{enhanced_prompt}'")
|
| 81 |
+
return enhanced_prompt
|
| 82 |
+
|
| 83 |
except Exception as e:
|
| 84 |
+
logger.error(f"The Film Director (Deformes3D Thinker) failed during enhancement: {e}. Using fallback.", exc_info=True)
|
| 85 |
+
return f"A smooth, continuous cinematic transition from '{present_scene_desc}' to '{future_scene_desc}'."
|
| 86 |
|
| 87 |
+
def _generate_image_captions(self, images: list[Image.Image]) -> list[str]:
|
|
|
|
| 88 |
"""
|
| 89 |
+
Lógica interna para gerar captions, copiada do LTX utils.
|
| 90 |
"""
|
| 91 |
+
# O modelo Florence-2 do LTX não usa um system_prompt aqui, mas um task_prompt
|
| 92 |
+
task_prompt = "<MORE_DETAILED_CAPTION>"
|
| 93 |
inputs = self.caption_processor(
|
| 94 |
+
text=[task_prompt] * len(images), images=images, return_tensors="pt"
|
| 95 |
+
).to(self.caption_model.device)
|
| 96 |
|
| 97 |
generated_ids = self.caption_model.generate(
|
| 98 |
input_ids=inputs["input_ids"],
|
|
|
|
| 101 |
num_beams=3,
|
| 102 |
)
|
| 103 |
|
| 104 |
+
# Usa o post_process_generation para extrair a resposta limpa
|
| 105 |
generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
|
|
|
| 106 |
processed_result = self.caption_processor.post_process_generation(
|
| 107 |
generated_text,
|
| 108 |
task=task_prompt,
|
| 109 |
+
image_size=(images[0].width, images[0].height)
|
| 110 |
)
|
| 111 |
+
return [processed_result[task_prompt]]
|
| 112 |
|
| 113 |
+
def _generate_and_decode_prompts(self, messages: list[dict]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
"""
|
| 115 |
+
Lógica interna para gerar prompt com o LLM, copiada do LTX utils.
|
|
|
|
| 116 |
"""
|
| 117 |
+
text = self.llm_tokenizer.apply_chat_template(
|
| 118 |
+
messages, tokenize=False, add_generation_prompt=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
)
|
| 120 |
+
model_inputs = self.llm_tokenizer([text], return_tensors="pt").to(self.llm_model.device)
|
| 121 |
+
|
| 122 |
+
output_ids = self.llm_model.generate(**model_inputs, max_new_tokens=256)
|
| 123 |
|
| 124 |
+
input_ids_len = model_inputs.input_ids.shape[1]
|
| 125 |
+
decoded_prompts = self.llm_tokenizer.batch_decode(
|
| 126 |
+
output_ids[:, input_ids_len:], skip_special_tokens=True
|
| 127 |
+
)
|
| 128 |
+
return decoded_prompts[0].strip()
|
| 129 |
|
| 130 |
# --- Singleton Instantiation ---
|
| 131 |
try:
|
| 132 |
+
deformes3d_thinker_singleton = Deformes3DThinker()
|
| 133 |
except Exception as e:
|
| 134 |
+
# A falha já terá sido logada dentro do __init__
|
| 135 |
+
deformes3d_thinker_singleton = None
|
| 136 |
raise e
|