aducsdr commited on
Commit
4785cd8
·
verified ·
1 Parent(s): 3fe558a

Delete SeedVR/app.py

Browse files
Files changed (1) hide show
  1. SeedVR/app.py +0 -235
SeedVR/app.py DELETED
@@ -1,235 +0,0 @@
1
- # app.py (VERSÃO FINAL E CORRIGIDA)
2
-
3
- import spaces
4
- import subprocess
5
- import os
6
- import torch
7
- import mediapy
8
- from einops import rearrange
9
- from omegaconf import OmegaConf
10
- import datetime
11
- from tqdm import tqdm
12
- import gc
13
- import uuid
14
- import mimetypes
15
- import torchvision.transforms as T
16
- from PIL import Image
17
- from pathlib import Path
18
- import gradio as gr
19
-
20
- # --- Módulos do SeedVR (agora que estão no ambiente, podemos importá-los) ---
21
- from data.image.transforms.divisible_crop import DivisibleCrop
22
- from data.image.transforms.na_resize import NaResize
23
- from data.video.transforms.rearrange import Rearrange
24
- from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
25
- from torchvision.transforms import Compose, Lambda, Normalize
26
- from torchvision.io.video import read_video
27
- from common.distributed import init_torch
28
- from common.distributed.advanced import init_sequence_parallel
29
- from projects.video_diffusion_sr.infer import VideoDiffusionInfer
30
- from common.config import load_config
31
- from common.distributed.ops import sync_data
32
- from common.seed import set_seed
33
- from common.partition import partition_by_size
34
-
35
- # --- CONFIGURAÇÃO DO AMBIENTE (REMOVIDA) ---
36
- # REMOVIDO: A instalação do flash-attn e apex já é feita no Dockerfile.
37
- # REMOVIDO: O download dos checkpoints do modelo já é feito no Dockerfile.
38
- # REMOVIDO: A configuração de torch.distributed é tratada de forma mais simples.
39
-
40
- # Verificação para garantir que estamos no diretório certo
41
- print(f"Diretório de trabalho atual: {os.getcwd()}")
42
- if not os.path.exists('./projects'):
43
- print("AVISO: O script parece não estar rodando de dentro do diretório /app/SeedVR. Verifique o WORKDIR no Dockerfile.")
44
-
45
- # Checa se a correção de cor está disponível
46
- use_colorfix = os.path.exists("./projects/video_diffusion_sr/color_fix.py")
47
- if not use_colorfix:
48
- print('Atenção: Correção de cor (color_fix.py) não disponível!')
49
-
50
-
51
- def configure_sequence_parallel(sp_size):
52
- if sp_size > 1:
53
- init_sequence_parallel(sp_size)
54
-
55
- # O decorador @spaces.GPU garante que a função rode na GPU e gerencia a duração
56
- @spaces.GPU(duration=120)
57
- def configure_runner(sp_size):
58
- config_path = os.path.join('./configs_3b', 'main.yaml')
59
- config = load_config(config_path)
60
- runner = VideoDiffusionInfer(config)
61
- OmegaConf.set_readonly(runner.config, False)
62
-
63
- # Inicializa o torch para um único processo
64
- os.environ["MASTER_ADDR"] = "127.0.0.1"
65
- os.environ["MASTER_PORT"] = "12355"
66
- if "RANK" not in os.environ:
67
- os.environ["RANK"] = "0"
68
- if "WORLD_SIZE" not in os.environ:
69
- os.environ["WORLD_SIZE"] = "1"
70
-
71
- init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
72
- configure_sequence_parallel(sp_size)
73
-
74
- # Os checkpoints estão no diretório ckpts, conforme baixado pelo Dockerfile
75
- runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
76
- runner.configure_vae_model(checkpoint_path='./ckpts/ema_vae.pth')
77
-
78
- if hasattr(runner.vae, "set_memory_limit"):
79
- runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
80
- return runner
81
-
82
- @spaces.GPU(duration=120)
83
- def generation_step(runner, text_embeds_dict, cond_latents):
84
- def _move_to_cuda(x):
85
- return [i.to(torch.device("cuda")) for i in x]
86
-
87
- noises = [torch.randn_like(latent) for latent in cond_latents]
88
- aug_noises = [torch.randn_like(latent) for latent in cond_latents]
89
- noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
90
- noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
91
-
92
- cond_noise_scale = 0.1
93
- def _add_noise(x, aug_noise):
94
- t = (torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale)
95
- shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
96
- t = runner.timestep_transform(t, shape)
97
- x = runner.schedule.forward(x, aug_noise, t)
98
- return x
99
-
100
- conditions = [
101
- runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
102
- for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
103
- ]
104
-
105
- with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
106
- video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
107
-
108
- samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
109
- del video_tensors
110
- return samples
111
-
112
- @spaces.GPU(duration=120)
113
- def generation_loop(video_path, seed, fps_out, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=720, res_w=1280, sp_size=1):
114
- # O Gradio passa o caminho do arquivo temporário
115
- if video_path is None:
116
- raise gr.Error("Por favor, faça o upload de um arquivo de vídeo ou imagem.")
117
-
118
- runner = configure_runner(sp_size)
119
-
120
- def _extract_text_embeds():
121
- text_pos_embeds = torch.load('pos_emb.pt')
122
- text_neg_embeds = torch.load('neg_emb.pt')
123
- return [{"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}]
124
-
125
- def cut_videos(videos, sp_size):
126
- if videos.size(1) > 121: videos = videos[:, :121]
127
- t = videos.size(1)
128
- if t <= 4 * sp_size:
129
- padding = torch.cat([videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1), dim=1)
130
- return torch.cat([videos, padding], dim=1)
131
- if (t - 1) % (4 * sp_size) == 0: return videos
132
- padding = torch.cat([videos[:, -1].unsqueeze(1)] * (4 * sp_size - ((t - 1) % (4 * sp_size))), dim=1)
133
- return torch.cat([videos, padding], dim=1)
134
-
135
- runner.config.diffusion.cfg.scale = cfg_scale
136
- runner.config.diffusion.cfg.rescale = cfg_rescale
137
- runner.config.diffusion.timesteps.sampling.steps = sample_steps
138
- runner.configure_diffusion()
139
-
140
- set_seed(seed % (2**32), same_across_ranks=True)
141
- os.makedirs('output/', exist_ok=True)
142
-
143
- original_videos_local = [[os.path.basename(video_path)]]
144
- positive_prompts_embeds = _extract_text_embeds()
145
-
146
- video_transform = Compose([
147
- NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
148
- Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
149
- DivisibleCrop((16, 16)),
150
- Normalize(0.5, 0.5),
151
- Rearrange("t c h w -> c t h w"),
152
- ])
153
-
154
- for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
155
- cond_latents = []
156
- media_type, _ = mimetypes.guess_type(video_path)
157
- is_image = media_type and media_type.startswith("image")
158
- is_video = media_type and media_type.startswith("video")
159
-
160
- if is_video:
161
- video_frames = read_video(video_path, output_format="TCHW")[0] / 255.0
162
- if video_frames.size(0) > 121: video_frames = video_frames[:121]
163
- output_filename = str(uuid.uuid4()) + '.mp4'
164
- elif is_image:
165
- img = Image.open(video_path).convert("RGB")
166
- video_frames = T.ToTensor()(img).unsqueeze(0)
167
- output_filename = str(uuid.uuid4()) + '.png'
168
- else:
169
- raise gr.Error("Formato de arquivo não suportado. Use vídeo ou imagem.")
170
-
171
- output_dir = os.path.join('output', output_filename)
172
- cond_latents.append(video_transform(video_frames.to(torch.device("cuda"))))
173
-
174
- ori_lengths = [v.size(1) for v in cond_latents]
175
- input_videos = cond_latents
176
- if is_video: cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
177
-
178
- cond_latents = runner.vae_encode(cond_latents)
179
-
180
- for i, emb in enumerate(text_embeds["texts_pos"]): text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
181
- for i, emb in enumerate(text_embeds["texts_neg"]): text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
182
-
183
- samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
184
- del cond_latents
185
-
186
- for path, input_vid, sample, ori_length in zip(videos, input_videos, samples, ori_lengths):
187
- if ori_length < sample.shape[0]: sample = sample[:ori_length]
188
- input_vid = rearrange(input_vid, "c t h w -> t c h w")
189
- if use_colorfix: sample = wavelet_reconstruction(sample.cpu(), input_vid[:sample.size(0)].cpu())
190
- else: sample = sample.cpu()
191
-
192
- sample = rearrange(sample, "t c h w -> t h w c")
193
- sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
194
-
195
- if is_image:
196
- mediapy.write_image(output_dir, sample[0])
197
- else:
198
- mediapy.write_video(output_dir, sample, fps=fps_out)
199
-
200
- gc.collect()
201
- torch.cuda.empty_cache()
202
-
203
- # Retorna os valores para os componentes corretos da UI
204
- if is_image:
205
- return output_dir, None, output_dir
206
- else:
207
- return None, output_dir, output_dir
208
-
209
- # --- Interface Gradio ---
210
- with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
211
- gr.HTML(...) # Mantido como no original
212
-
213
- with gr.Row():
214
- # CORRIGIDO: gr.File para gr.Video, que passa um 'filepath' por padrão
215
- input_file = gr.Video(label="Upload image or video")
216
- seed = gr.Number(label="Seeds", value=666)
217
- fps = gr.Number(label="fps", value=24)
218
-
219
- with gr.Row():
220
- output_image = gr.Image(label="Output_Image")
221
- output_video = gr.Video(label="Output_Video")
222
- download_link = gr.File(label="Download the output")
223
-
224
- run_button = gr.Button("Run")
225
- run_button.click(
226
- fn=generation_loop,
227
- inputs=[input_file, seed, fps],
228
- outputs=[output_image, output_video, download_link]
229
- )
230
-
231
- gr.Examples(...) # Mantido como no original
232
- gr.HTML(...) # Mantido como no original
233
-
234
- demo.queue(max_size=10)
235
- demo.launch()