euiia commited on
Commit
5137a03
·
verified ·
1 Parent(s): f0a989a

Create hd_specialist.py

Browse files
Files changed (1) hide show
  1. hd_specialist.py +97 -0
hd_specialist.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # hd_specialist.py
2
+ import torch
3
+ import imageio
4
+ import os
5
+ import logging
6
+ from PIL import Image
7
+ from diffusers import SeedVR2Pipeline
8
+ from tqdm import tqdm
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class HDSpecialist:
13
+ """
14
+ Implementa o Especialista HD (Δ+) da arquitetura ADUC-SDR.
15
+ Utiliza o modelo SeedVR2 para realizar a restauração e o aprimoramento
16
+ de vídeo, atuando como a etapa final de pós-produção.
17
+ """
18
+ def __init__(self):
19
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ self.pipeline = None
21
+ self.model_id = "ByteDance-Seed/SeedVR2-3B"
22
+ logger.info(f"Especialista HD inicializado. Aguardando carregamento do modelo {self.model_id}...")
23
+
24
+ def _load_pipeline(self):
25
+ """Carrega o pipeline do modelo sob demanda para economizar memória."""
26
+ if self.pipeline is None:
27
+ logger.info("Carregando o pipeline SeedVR2... Isso pode levar alguns minutos.")
28
+ try:
29
+ self.pipeline = SeedVR2Pipeline.from_pretrained(
30
+ self.model_id,
31
+ torch_dtype=torch.float16,
32
+ variant="fp16"
33
+ ).to(self.device)
34
+ logger.info("Pipeline SeedVR2 carregado com sucesso.")
35
+ except Exception as e:
36
+ logger.error(f"Falha ao carregar o pipeline SeedVR2: {e}")
37
+ raise
38
+
39
+ def process_video(self, input_video_path: str, output_video_path: str, prompt: str, strength: float = 0.8, batch_size: int = 8) -> str:
40
+ """
41
+ Aplica o aprimoramento HD a um vídeo.
42
+
43
+ Args:
44
+ input_video_path (str): Caminho para o vídeo de entrada.
45
+ output_video_path (str): Caminho para salvar o vídeo aprimorado.
46
+ prompt (str): Um prompt para guiar a restauração (pode ser o prompt global).
47
+ strength (float): Força do efeito de restauração.
48
+ batch_size (int): Número de frames para processar por lote para gerenciar a VRAM.
49
+
50
+ Returns:
51
+ str: O caminho para o vídeo aprimorado.
52
+ """
53
+ self._load_pipeline()
54
+ if not os.path.exists(input_video_path):
55
+ logger.error(f"Vídeo de entrada não encontrado em: {input_video_path}")
56
+ raise FileNotFoundError(f"Vídeo de entrada não encontrado: {input_video_path}")
57
+
58
+ logger.info(f"Iniciando processo HD para: {input_video_path}")
59
+
60
+ # 1. Ler os frames do vídeo de entrada
61
+ reader = imageio.get_reader(input_video_path)
62
+ fps = reader.get_meta_data()['fps']
63
+ input_frames = [Image.fromarray(frame) for frame in reader]
64
+ reader.close()
65
+
66
+ logger.info(f"Vídeo lido. Total de {len(input_frames)} frames a {fps} FPS.")
67
+
68
+ # 2. Processar frames em lotes
69
+ processed_frames = []
70
+
71
+ # Usamos tqdm para ter uma barra de progresso no console
72
+ for i in tqdm(range(0, len(input_frames), batch_size), desc="Aprimorando frames com SeedVR2"):
73
+ batch = input_frames[i:i + batch_size]
74
+
75
+ # O pipeline retorna uma lista de imagens PIL
76
+ enhanced_batch = self.pipeline(
77
+ prompt=prompt,
78
+ frames=batch,
79
+ num_inference_steps=20,
80
+ guidance_scale=7.5,
81
+ strength=strength,
82
+ ).frames
83
+
84
+ processed_frames.extend(enhanced_batch)
85
+
86
+ # 3. Salvar os frames aprimorados em um novo vídeo
87
+ logger.info(f"Salvando vídeo aprimorado em: {output_video_path}")
88
+ writer = imageio.get_writer(output_video_path, fps=fps, codec='libx264', quality=9)
89
+ for frame in processed_frames:
90
+ writer.append_data(np.array(frame))
91
+ writer.close()
92
+
93
+ logger.info("Processo HD concluído com sucesso.")
94
+ return output_video_path
95
+
96
+ # Instância Singleton para ser usada em toda a aplicação
97
+ hd_specialist_singleton = HDSpecialist()