Aduc-sdr commited on
Commit
4e9876a
·
verified ·
1 Parent(s): b0aeb06

Update managers/prompt_enhancer_manager.py

Browse files
Files changed (1) hide show
  1. managers/prompt_enhancer_manager.py +91 -79
managers/prompt_enhancer_manager.py CHANGED
@@ -1,80 +1,98 @@
1
- # managers/prompt_enhancer_manager.py
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
- # Version: 2.0.1 (Definitive Fix)
6
  #
7
- # This version re-introduces the essential `attn_implementation="eager"` parameter
8
- # during the loading of the Florence-2 model. This is required to solve the
9
- # '_supports_sdpa' AttributeError in our specific environment, while keeping
10
- # the correct inference pipeline from the functional example.
11
 
12
- import torch
13
  import logging
14
- import yaml
15
  from PIL import Image
16
- from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
17
- from pathlib import Path
 
 
 
 
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
- class PromptEnhancerManager:
 
 
 
 
 
22
  def __init__(self):
23
- logger.info("Initializing Prompt Enhancer Manager...")
24
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
25
- self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
26
- self.system_prompt = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  try:
29
- with open("config.yaml", 'r') as f:
30
- config = yaml.safe_load(f)['specialists']['prompt_enhancer']
31
-
32
- caption_model_name = config['image_caption_model']
33
- llm_model_name = config['llm_model']
34
-
35
- prompt_filename = config.get('prompt_file')
36
- if not prompt_filename:
37
- raise ValueError("Config for prompt_enhancer is missing the 'prompt_file' key.")
38
 
39
- prompt_path = Path(prompt_filename)
40
- if not prompt_path.is_file():
41
- raise FileNotFoundError(f"Enhancer prompt file not found at: {prompt_path}")
42
- self.system_prompt = prompt_path.read_text(encoding="utf-8").strip()
43
- logger.info(f"Loaded system prompt for enhancer from: {prompt_path}")
44
 
45
- logger.info(f"Loading Image Caption Model: {caption_model_name}...")
46
- self.caption_processor = AutoProcessor.from_pretrained(caption_model_name, trust_remote_code=True)
47
 
48
- # <--- CORREÇÃO DEFINITIVA AQUI --->
49
- # Adicionando de volta o parâmetro CRÍTICO para compatibilidade
50
- self.caption_model = AutoModelForCausalLM.from_pretrained(
51
- caption_model_name,
52
- torch_dtype=self.dtype,
53
- trust_remote_code=True,
54
- attn_implementation="eager" # Essencial para evitar o erro _supports_sdpa
55
- ).to(self.device)
56
- # <--- FIM DA CORREÇÃO --->
57
-
58
- logger.info(f"Loading LLM for Prompt Enhancement: {llm_model_name}...")
59
- self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
60
- self.llm_model = AutoModelForCausalLM.from_pretrained(
61
- llm_model_name, torch_dtype=self.dtype, device_map="auto"
62
- )
63
 
64
- logger.info("Prompt Enhancer Manager initialized successfully.")
 
 
 
 
65
  except Exception as e:
66
- logger.critical("Failed to initialize PromptEnhancerManager.", exc_info=True)
67
- raise e
68
 
69
- @torch.no_grad()
70
- def get_image_caption(self, image: Image.Image) -> str:
71
  """
72
- Tool 1: Describes a single image using the official Florence-2 inference pipeline.
73
  """
74
- task_prompt = '<MORE_DETAILED_CAPTION>'
 
75
  inputs = self.caption_processor(
76
- text=task_prompt, images=image, return_tensors="pt"
77
- ).to(self.device, self.dtype)
78
 
79
  generated_ids = self.caption_model.generate(
80
  input_ids=inputs["input_ids"],
@@ -83,42 +101,36 @@ class PromptEnhancerManager:
83
  num_beams=3,
84
  )
85
 
 
86
  generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
87
-
88
  processed_result = self.caption_processor.post_process_generation(
89
  generated_text,
90
  task=task_prompt,
91
- image_size=(image.width, image.height)
92
  )
 
93
 
94
- caption = processed_result[task_prompt]
95
- return caption
96
-
97
- @torch.no_grad()
98
- def get_llm_enhanced_prompt(self, user_content_prompt: str) -> str:
99
  """
100
- Tool 2: Takes a pre-formatted user content prompt, combines it with the
101
- system prompt, and gets a cinematic response from the LLM.
102
  """
103
- messages = [
104
- {"role": "system", "content": self.system_prompt},
105
- {"role": "user", "content": user_content_prompt}
106
- ]
107
-
108
- input_ids = self.llm_tokenizer.apply_chat_template(
109
- messages, add_generation_prompt=True, return_tensors="pt"
110
- ).to(self.llm_model.device)
111
-
112
- outputs = self.llm_model.generate(
113
- input_ids, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9
114
  )
115
- response = self.llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
 
 
116
 
117
- return response.strip()
 
 
 
 
118
 
119
  # --- Singleton Instantiation ---
120
  try:
121
- prompt_enhancer_manager_singleton = PromptEnhancerManager()
122
  except Exception as e:
123
- prompt_enhancer_manager_singleton = None
 
124
  raise e
 
1
+ # engineers/deformes3D_thinker.py
2
  #
3
  # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
  #
5
+ # Version: 4.0.0 (Definitive)
6
  #
7
+ # This is the definitive, robust implementation. It directly contains the prompt
8
+ # enhancement logic copied from the LTX pipeline's utils. It accesses the
9
+ # enhancement models loaded by the LTX Manager and performs the captioning
10
+ # and LLM generation steps locally, ensuring full control and compatibility.
11
 
 
12
  import logging
 
13
  from PIL import Image
14
+ import torch
15
+
16
+ # Importa o singleton do LTX para ter acesso à sua pipeline e aos modelos nela
17
+ from managers.ltx_manager import ltx_manager_singleton
18
+
19
+ # Importa o prompt de sistema do LTX para garantir consistência
20
+ from ltx_video.utils.prompt_enhance_utils import I2V_CINEMATIC_PROMPT
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
+ class Deformes3DThinker:
25
+ """
26
+ The tactical specialist that now directly implements the prompt enhancement
27
+ logic, using the models provided by the LTX pipeline.
28
+ """
29
+
30
  def __init__(self):
31
+ # Acessa a pipeline exposta para obter os modelos necessários
32
+ pipeline = ltx_manager_singleton.prompt_enhancement_pipeline
33
+ if not pipeline:
34
+ raise RuntimeError("Deformes3DThinker could not access the LTX pipeline.")
35
+
36
+ # Armazena os modelos e processadores como atributos diretos
37
+ self.caption_model = pipeline.prompt_enhancer_image_caption_model
38
+ self.caption_processor = pipeline.prompt_enhancer_image_caption_processor
39
+ self.llm_model = pipeline.prompt_enhancer_llm_model
40
+ self.llm_tokenizer = pipeline.prompt_enhancer_llm_tokenizer
41
+
42
+ # Verifica se os modelos foram realmente carregados
43
+ if not all([self.caption_model, self.caption_processor, self.llm_model, self.llm_tokenizer]):
44
+ logger.warning("Deformes3DThinker initialized, but one or more enhancement models were not loaded by the LTX pipeline. Fallback will be used.")
45
+ else:
46
+ logger.info("Deformes3DThinker initialized and successfully linked to LTX enhancement models.")
47
+
48
+ @torch.no_grad()
49
+ def get_enhanced_motion_prompt(self, global_prompt: str, story_history: str,
50
+ past_keyframe_path: str, present_keyframe_path: str, future_keyframe_path: str,
51
+ past_scene_desc: str, present_scene_desc: str, future_scene_desc: str) -> str:
52
+ """
53
+ Generates a refined motion prompt by directly executing the enhancement pipeline logic.
54
+ """
55
+ # Verifica se os modelos estão disponíveis antes de tentar usá-los
56
+ if not all([self.caption_model, self.caption_processor, self.llm_model, self.llm_tokenizer]):
57
+ logger.warning("Enhancement models not available. Using fallback prompt.")
58
+ return f"A cinematic transition from '{present_scene_desc}' to '{future_scene_desc}'."
59
 
60
  try:
61
+ present_image = Image.open(present_keyframe_path).convert("RGB")
 
 
 
 
 
 
 
 
62
 
63
+ # --- INÍCIO DA LÓGICA COPIADA E ADAPTADA DO LTX ---
 
 
 
 
64
 
65
+ # 1. Gerar a caption da imagem de referência (presente)
66
+ image_captions = self._generate_image_captions([present_image])
67
 
68
+ # 2. Construir o prompt para o LLM
69
+ # Usamos a cena futura como o "prompt do usuário"
70
+ messages = [
71
+ {"role": "system", "content": I2V_CINEMATIC_PROMPT},
72
+ {"role": "user", "content": f"user_prompt: {future_scene_desc}\nimage_caption: {image_captions[0]}"},
73
+ ]
74
+
75
+ # 3. Gerar e decodificar o prompt final com o LLM
76
+ enhanced_prompt = self._generate_and_decode_prompts(messages)
 
 
 
 
 
 
77
 
78
+ # --- FIM DA LÓGICA COPIADA E ADAPTADA ---
79
+
80
+ logger.info(f"Deformes3DThinker received enhanced prompt: '{enhanced_prompt}'")
81
+ return enhanced_prompt
82
+
83
  except Exception as e:
84
+ logger.error(f"The Film Director (Deformes3D Thinker) failed during enhancement: {e}. Using fallback.", exc_info=True)
85
+ return f"A smooth, continuous cinematic transition from '{present_scene_desc}' to '{future_scene_desc}'."
86
 
87
+ def _generate_image_captions(self, images: list[Image.Image]) -> list[str]:
 
88
  """
89
+ Lógica interna para gerar captions, copiada do LTX utils.
90
  """
91
+ # O modelo Florence-2 do LTX não usa um system_prompt aqui, mas um task_prompt
92
+ task_prompt = "<MORE_DETAILED_CAPTION>"
93
  inputs = self.caption_processor(
94
+ text=[task_prompt] * len(images), images=images, return_tensors="pt"
95
+ ).to(self.caption_model.device)
96
 
97
  generated_ids = self.caption_model.generate(
98
  input_ids=inputs["input_ids"],
 
101
  num_beams=3,
102
  )
103
 
104
+ # Usa o post_process_generation para extrair a resposta limpa
105
  generated_text = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
106
  processed_result = self.caption_processor.post_process_generation(
107
  generated_text,
108
  task=task_prompt,
109
+ image_size=(images[0].width, images[0].height)
110
  )
111
+ return [processed_result[task_prompt]]
112
 
113
+ def _generate_and_decode_prompts(self, messages: list[dict]) -> str:
 
 
 
 
114
  """
115
+ Lógica interna para gerar prompt com o LLM, copiada do LTX utils.
 
116
  """
117
+ text = self.llm_tokenizer.apply_chat_template(
118
+ messages, tokenize=False, add_generation_prompt=True
 
 
 
 
 
 
 
 
 
119
  )
120
+ model_inputs = self.llm_tokenizer([text], return_tensors="pt").to(self.llm_model.device)
121
+
122
+ output_ids = self.llm_model.generate(**model_inputs, max_new_tokens=256)
123
 
124
+ input_ids_len = model_inputs.input_ids.shape[1]
125
+ decoded_prompts = self.llm_tokenizer.batch_decode(
126
+ output_ids[:, input_ids_len:], skip_special_tokens=True
127
+ )
128
+ return decoded_prompts[0].strip()
129
 
130
  # --- Singleton Instantiation ---
131
  try:
132
+ deformes3d_thinker_singleton = Deformes3DThinker()
133
  except Exception as e:
134
+ # A falha já terá sido logada dentro do __init__
135
+ deformes3d_thinker_singleton = None
136
  raise e