|
|
|
|
|
import torch |
|
|
import imageio |
|
|
import os |
|
|
import logging |
|
|
from PIL import Image |
|
|
from diffusers import SeedVR2Pipeline |
|
|
from tqdm import tqdm |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class HDSpecialist: |
|
|
""" |
|
|
Implementa o Especialista HD (Δ+) da arquitetura ADUC-SDR. |
|
|
Utiliza o modelo SeedVR2 para realizar a restauração e o aprimoramento |
|
|
de vídeo, atuando como a etapa final de pós-produção. |
|
|
""" |
|
|
def __init__(self): |
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.pipeline = None |
|
|
self.model_id = "ByteDance-Seed/SeedVR2-3B" |
|
|
logger.info(f"Especialista HD inicializado. Aguardando carregamento do modelo {self.model_id}...") |
|
|
|
|
|
def _load_pipeline(self): |
|
|
"""Carrega o pipeline do modelo sob demanda para economizar memória.""" |
|
|
if self.pipeline is None: |
|
|
logger.info("Carregando o pipeline SeedVR2... Isso pode levar alguns minutos.") |
|
|
try: |
|
|
self.pipeline = SeedVR2Pipeline.from_pretrained( |
|
|
self.model_id, |
|
|
torch_dtype=torch.float16, |
|
|
variant="fp16" |
|
|
).to(self.device) |
|
|
logger.info("Pipeline SeedVR2 carregado com sucesso.") |
|
|
except Exception as e: |
|
|
logger.error(f"Falha ao carregar o pipeline SeedVR2: {e}") |
|
|
raise |
|
|
|
|
|
def process_video(self, input_video_path: str, output_video_path: str, prompt: str, strength: float = 0.8, batch_size: int = 8) -> str: |
|
|
""" |
|
|
Aplica o aprimoramento HD a um vídeo. |
|
|
|
|
|
Args: |
|
|
input_video_path (str): Caminho para o vídeo de entrada. |
|
|
output_video_path (str): Caminho para salvar o vídeo aprimorado. |
|
|
prompt (str): Um prompt para guiar a restauração (pode ser o prompt global). |
|
|
strength (float): Força do efeito de restauração. |
|
|
batch_size (int): Número de frames para processar por lote para gerenciar a VRAM. |
|
|
|
|
|
Returns: |
|
|
str: O caminho para o vídeo aprimorado. |
|
|
""" |
|
|
self._load_pipeline() |
|
|
if not os.path.exists(input_video_path): |
|
|
logger.error(f"Vídeo de entrada não encontrado em: {input_video_path}") |
|
|
raise FileNotFoundError(f"Vídeo de entrada não encontrado: {input_video_path}") |
|
|
|
|
|
logger.info(f"Iniciando processo HD para: {input_video_path}") |
|
|
|
|
|
|
|
|
reader = imageio.get_reader(input_video_path) |
|
|
fps = reader.get_meta_data()['fps'] |
|
|
input_frames = [Image.fromarray(frame) for frame in reader] |
|
|
reader.close() |
|
|
|
|
|
logger.info(f"Vídeo lido. Total de {len(input_frames)} frames a {fps} FPS.") |
|
|
|
|
|
|
|
|
processed_frames = [] |
|
|
|
|
|
|
|
|
for i in tqdm(range(0, len(input_frames), batch_size), desc="Aprimorando frames com SeedVR2"): |
|
|
batch = input_frames[i:i + batch_size] |
|
|
|
|
|
|
|
|
enhanced_batch = self.pipeline( |
|
|
prompt=prompt, |
|
|
frames=batch, |
|
|
num_inference_steps=20, |
|
|
guidance_scale=7.5, |
|
|
strength=strength, |
|
|
).frames |
|
|
|
|
|
processed_frames.extend(enhanced_batch) |
|
|
|
|
|
|
|
|
logger.info(f"Salvando vídeo aprimorado em: {output_video_path}") |
|
|
writer = imageio.get_writer(output_video_path, fps=fps, codec='libx264', quality=9) |
|
|
for frame in processed_frames: |
|
|
writer.append_data(np.array(frame)) |
|
|
writer.close() |
|
|
|
|
|
logger.info("Processo HD concluído com sucesso.") |
|
|
return output_video_path |
|
|
|
|
|
|
|
|
hd_specialist_singleton = HDSpecialist() |