Aduc-sdr commited on
Commit
25456f6
·
verified ·
1 Parent(s): 8ecb750

Create prompt_enhancer_manager.py

Browse files
Files changed (1) hide show
  1. managers/prompt_enhancer_manager.py +99 -0
managers/prompt_enhancer_manager.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # managers/prompt_enhancer_manager.py
2
+ #
3
+ # Copyright (C) 2025 Carlos Rodrigues dos Santos
4
+ #
5
+ # Version: 1.0.0
6
+ #
7
+ # This is a dedicated specialist responsible for enhancing prompts. It loads
8
+ # an image captioning model and a powerful LLM to create rich, cinematic
9
+ # motion prompts based on visual and textual context.
10
+
11
+ import torch
12
+ import logging
13
+ import yaml
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, AutoModelForCausalLM, AutoTokenizer
16
+ from pathlib import Path
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # O prompt de sistema que guiará nosso LLM
21
+ ENHANCER_SYSTEM_PROMPT = """You are an expert cinematic director. Your task is to write a single, rich, cinematic motion prompt.
22
+ Analyze the user's goal and the provided image caption. Synthesize them into a flowing, descriptive paragraph under 150 words.
23
+ Focus on the action, character expressions, camera movement, and environment. Start directly with the action.
24
+ The final prompt must be a direct instruction for a video generation AI."""
25
+
26
+ class PromptEnhancerManager:
27
+ def __init__(self):
28
+ logger.info("Initializing Prompt Enhancer Manager...")
29
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
31
+
32
+ try:
33
+ with open("config.yaml", 'r') as f:
34
+ config = yaml.safe_load(f)['specialists']['prompt_enhancer']
35
+
36
+ caption_model_name = config['image_caption_model']
37
+ llm_model_name = config['llm_model']
38
+
39
+ logger.info(f"Loading Image Caption Model: {caption_model_name}...")
40
+ self.caption_processor = AutoProcessor.from_pretrained(caption_model_name, trust_remote_code=True)
41
+ self.caption_model = AutoModelForCausalLM.from_pretrained(
42
+ caption_model_name, torch_dtype=self.dtype, trust_remote_code=True
43
+ ).to(self.device)
44
+
45
+ logger.info(f"Loading LLM for Prompt Enhancement: {llm_model_name}...")
46
+ self.llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)
47
+ self.llm_model = AutoModelForCausalLM.from_pretrained(
48
+ llm_model_name,
49
+ torch_dtype=self.dtype,
50
+ device_map="auto" # Deixa o accelerate gerenciar a distribuição em GPUs
51
+ )
52
+
53
+ logger.info("Prompt Enhancer Manager initialized successfully.")
54
+ except Exception as e:
55
+ logger.critical("Failed to initialize PromptEnhancerManager.", exc_info=True)
56
+ raise e
57
+
58
+ @torch.no_grad()
59
+ def generate_enhanced_prompt(self, image: Image.Image, user_prompt: str) -> str:
60
+ """
61
+ Takes a reference image and a user prompt, and returns an enhanced,
62
+ cinematic prompt generated by the LLM.
63
+ """
64
+ logger.info("Generating enhanced prompt...")
65
+
66
+ # 1. Gerar a caption da imagem
67
+ caption_task_prompt = "<MORE_DETAILED_CAPTION>"
68
+ inputs = self.caption_processor(
69
+ text=caption_task_prompt, images=image, return_tensors="pt"
70
+ ).to(self.device, self.dtype)
71
+
72
+ generated_ids = self.caption_model.generate(**inputs, max_new_tokens=1024)
73
+ generated_texts = self.caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
74
+ image_caption = generated_texts[0].split(":", 1)[-1].strip()
75
+ logger.info(f"Generated Image Caption: '{image_caption}'")
76
+
77
+ # 2. Construir a conversa para o LLM
78
+ messages = [
79
+ {"role": "system", "content": ENHANCER_SYSTEM_PROMPT},
80
+ {"role": "user", "content": f"My Goal: '{user_prompt}'\n\nReference Image Scene: '{image_caption}'"}
81
+ ]
82
+
83
+ input_ids = self.llm_tokenizer.apply_chat_template(
84
+ messages, add_generation_prompt=True, return_tensors="pt"
85
+ ).to(self.llm_model.device)
86
+
87
+ # 3. Gerar a resposta do LLM
88
+ outputs = self.llm_model.generate(input_ids, max_new_tokens=256, do_sample=True, temperature=0.6, top_p=0.9)
89
+ response = self.llm_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
90
+
91
+ logger.info(f"LLM Enhanced Prompt: '{response}'")
92
+ return response.strip()
93
+
94
+ # --- Singleton Instantiation ---
95
+ try:
96
+ prompt_enhancer_manager_singleton = PromptEnhancerManager()
97
+ except Exception as e:
98
+ prompt_enhancer_manager_singleton = None
99
+ raise e