|
|
|
|
|
|
|
|
import spaces |
|
|
import subprocess |
|
|
import os |
|
|
import torch |
|
|
import mediapy |
|
|
from einops import rearrange |
|
|
from omegaconf import OmegaConf |
|
|
import datetime |
|
|
from tqdm import tqdm |
|
|
import gc |
|
|
import uuid |
|
|
import mimetypes |
|
|
import torchvision.transforms as T |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
from data.image.transforms.divisible_crop import DivisibleCrop |
|
|
from data.image.transforms.na_resize import NaResize |
|
|
from data.video.transforms.rearrange import Rearrange |
|
|
from projects.video_diffusion_sr.color_fix import wavelet_reconstruction |
|
|
from torchvision.transforms import Compose, Lambda, Normalize |
|
|
from torchvision.io.video import read_video |
|
|
from common.distributed import init_torch |
|
|
from common.distributed.advanced import init_sequence_parallel |
|
|
from projects.video_diffusion_sr.infer import VideoDiffusionInfer |
|
|
from common.config import load_config |
|
|
from common.distributed.ops import sync_data |
|
|
from common.seed import set_seed |
|
|
from common.partition import partition_by_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Diretório de trabalho atual: {os.getcwd()}") |
|
|
if not os.path.exists('./projects'): |
|
|
print("AVISO: O script parece não estar rodando de dentro do diretório /app/SeedVR. Verifique o WORKDIR no Dockerfile.") |
|
|
|
|
|
|
|
|
use_colorfix = os.path.exists("./projects/video_diffusion_sr/color_fix.py") |
|
|
if not use_colorfix: |
|
|
print('Atenção: Correção de cor (color_fix.py) não disponível!') |
|
|
|
|
|
|
|
|
def configure_sequence_parallel(sp_size): |
|
|
if sp_size > 1: |
|
|
init_sequence_parallel(sp_size) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def configure_runner(sp_size): |
|
|
config_path = os.path.join('./configs_3b', 'main.yaml') |
|
|
config = load_config(config_path) |
|
|
runner = VideoDiffusionInfer(config) |
|
|
OmegaConf.set_readonly(runner.config, False) |
|
|
|
|
|
|
|
|
os.environ["MASTER_ADDR"] = "127.0.0.1" |
|
|
os.environ["MASTER_PORT"] = "12355" |
|
|
if "RANK" not in os.environ: |
|
|
os.environ["RANK"] = "0" |
|
|
if "WORLD_SIZE" not in os.environ: |
|
|
os.environ["WORLD_SIZE"] = "1" |
|
|
|
|
|
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) |
|
|
configure_sequence_parallel(sp_size) |
|
|
|
|
|
|
|
|
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth') |
|
|
runner.configure_vae_model(checkpoint_path='./ckpts/ema_vae.pth') |
|
|
|
|
|
if hasattr(runner.vae, "set_memory_limit"): |
|
|
runner.vae.set_memory_limit(**runner.config.vae.memory_limit) |
|
|
return runner |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generation_step(runner, text_embeds_dict, cond_latents): |
|
|
def _move_to_cuda(x): |
|
|
return [i.to(torch.device("cuda")) for i in x] |
|
|
|
|
|
noises = [torch.randn_like(latent) for latent in cond_latents] |
|
|
aug_noises = [torch.randn_like(latent) for latent in cond_latents] |
|
|
noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) |
|
|
noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents))) |
|
|
|
|
|
cond_noise_scale = 0.1 |
|
|
def _add_noise(x, aug_noise): |
|
|
t = (torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale) |
|
|
shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None] |
|
|
t = runner.timestep_transform(t, shape) |
|
|
x = runner.schedule.forward(x, aug_noise, t) |
|
|
return x |
|
|
|
|
|
conditions = [ |
|
|
runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise)) |
|
|
for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) |
|
|
] |
|
|
|
|
|
with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): |
|
|
video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict) |
|
|
|
|
|
samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors] |
|
|
del video_tensors |
|
|
return samples |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generation_loop(video_path, seed, fps_out, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=720, res_w=1280, sp_size=1): |
|
|
|
|
|
if video_path is None: |
|
|
raise gr.Error("Por favor, faça o upload de um arquivo de vídeo ou imagem.") |
|
|
|
|
|
runner = configure_runner(sp_size) |
|
|
|
|
|
def _extract_text_embeds(): |
|
|
text_pos_embeds = torch.load('pos_emb.pt') |
|
|
text_neg_embeds = torch.load('neg_emb.pt') |
|
|
return [{"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}] |
|
|
|
|
|
def cut_videos(videos, sp_size): |
|
|
if videos.size(1) > 121: videos = videos[:, :121] |
|
|
t = videos.size(1) |
|
|
if t <= 4 * sp_size: |
|
|
padding = torch.cat([videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1), dim=1) |
|
|
return torch.cat([videos, padding], dim=1) |
|
|
if (t - 1) % (4 * sp_size) == 0: return videos |
|
|
padding = torch.cat([videos[:, -1].unsqueeze(1)] * (4 * sp_size - ((t - 1) % (4 * sp_size))), dim=1) |
|
|
return torch.cat([videos, padding], dim=1) |
|
|
|
|
|
runner.config.diffusion.cfg.scale = cfg_scale |
|
|
runner.config.diffusion.cfg.rescale = cfg_rescale |
|
|
runner.config.diffusion.timesteps.sampling.steps = sample_steps |
|
|
runner.configure_diffusion() |
|
|
|
|
|
set_seed(seed % (2**32), same_across_ranks=True) |
|
|
os.makedirs('output/', exist_ok=True) |
|
|
|
|
|
original_videos_local = [[os.path.basename(video_path)]] |
|
|
positive_prompts_embeds = _extract_text_embeds() |
|
|
|
|
|
video_transform = Compose([ |
|
|
NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False), |
|
|
Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), |
|
|
DivisibleCrop((16, 16)), |
|
|
Normalize(0.5, 0.5), |
|
|
Rearrange("t c h w -> c t h w"), |
|
|
]) |
|
|
|
|
|
for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): |
|
|
cond_latents = [] |
|
|
media_type, _ = mimetypes.guess_type(video_path) |
|
|
is_image = media_type and media_type.startswith("image") |
|
|
is_video = media_type and media_type.startswith("video") |
|
|
|
|
|
if is_video: |
|
|
video_frames = read_video(video_path, output_format="TCHW")[0] / 255.0 |
|
|
if video_frames.size(0) > 121: video_frames = video_frames[:121] |
|
|
output_filename = str(uuid.uuid4()) + '.mp4' |
|
|
elif is_image: |
|
|
img = Image.open(video_path).convert("RGB") |
|
|
video_frames = T.ToTensor()(img).unsqueeze(0) |
|
|
output_filename = str(uuid.uuid4()) + '.png' |
|
|
else: |
|
|
raise gr.Error("Formato de arquivo não suportado. Use vídeo ou imagem.") |
|
|
|
|
|
output_dir = os.path.join('output', output_filename) |
|
|
cond_latents.append(video_transform(video_frames.to(torch.device("cuda")))) |
|
|
|
|
|
ori_lengths = [v.size(1) for v in cond_latents] |
|
|
input_videos = cond_latents |
|
|
if is_video: cond_latents = [cut_videos(v, sp_size) for v in cond_latents] |
|
|
|
|
|
cond_latents = runner.vae_encode(cond_latents) |
|
|
|
|
|
for i, emb in enumerate(text_embeds["texts_pos"]): text_embeds["texts_pos"][i] = emb.to(torch.device("cuda")) |
|
|
for i, emb in enumerate(text_embeds["texts_neg"]): text_embeds["texts_neg"][i] = emb.to(torch.device("cuda")) |
|
|
|
|
|
samples = generation_step(runner, text_embeds, cond_latents=cond_latents) |
|
|
del cond_latents |
|
|
|
|
|
for path, input_vid, sample, ori_length in zip(videos, input_videos, samples, ori_lengths): |
|
|
if ori_length < sample.shape[0]: sample = sample[:ori_length] |
|
|
input_vid = rearrange(input_vid, "c t h w -> t c h w") |
|
|
if use_colorfix: sample = wavelet_reconstruction(sample.cpu(), input_vid[:sample.size(0)].cpu()) |
|
|
else: sample = sample.cpu() |
|
|
|
|
|
sample = rearrange(sample, "t c h w -> t h w c") |
|
|
sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy() |
|
|
|
|
|
if is_image: |
|
|
mediapy.write_image(output_dir, sample[0]) |
|
|
else: |
|
|
mediapy.write_video(output_dir, sample, fps=fps_out) |
|
|
|
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
if is_image: |
|
|
return output_dir, None, output_dir |
|
|
else: |
|
|
return None, output_dir, output_dir |
|
|
|
|
|
|
|
|
with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo: |
|
|
gr.HTML(...) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
input_file = gr.Video(label="Upload image or video") |
|
|
seed = gr.Number(label="Seeds", value=666) |
|
|
fps = gr.Number(label="fps", value=24) |
|
|
|
|
|
with gr.Row(): |
|
|
output_image = gr.Image(label="Output_Image") |
|
|
output_video = gr.Video(label="Output_Video") |
|
|
download_link = gr.File(label="Download the output") |
|
|
|
|
|
run_button = gr.Button("Run") |
|
|
run_button.click( |
|
|
fn=generation_loop, |
|
|
inputs=[input_file, seed, fps], |
|
|
outputs=[output_image, output_video, download_link] |
|
|
) |
|
|
|
|
|
gr.Examples(...) |
|
|
gr.HTML(...) |
|
|
|
|
|
demo.queue(max_size=10) |
|
|
demo.launch() |