aducsdr commited on
Commit
9f65008
·
verified ·
1 Parent(s): 70b72ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +221 -41
app.py CHANGED
@@ -1,55 +1,235 @@
1
- # Dockerfile (VERSÃO FINAL COM VÍDEOS DE EXEMPLO)
2
 
3
- # 1. COMEÇAR COM A BASE CORRETA
4
- FROM nvidia/cuda:12.1.1-devel-ubuntu22.04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # 2. INSTALAR DEPENDÊNCIAS DO SISTEMA
7
- ENV DEBIAN_FRONTEND=noninteractive
8
- RUN apt-get update && apt-get install -y --no-install-recommends \
9
- wget \
10
- git \
11
- && apt-get clean \
12
- && rm -rf /var/lib/apt/lists/*
 
 
 
 
 
 
 
13
 
14
- # 3. INSTALAR O MINICONDA
15
- RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \
16
- /bin/bash ~/miniconda.sh -b -p /opt/conda && \
17
- rm ~/miniconda.sh
18
 
19
- # 4. ADICIONAR CONDA AO PATH
20
- ENV PATH /opt/conda/bin:$PATH
 
 
21
 
22
- # 5. ACEITAR OS TERMOS DE SERVIÇO
23
- RUN yes | conda tos accept
 
 
24
 
25
- # 6. ATUALIZAR O CONDA
26
- RUN conda update -n base -c defaults conda
27
 
28
- # 7. CRIAR O AMBIENTE CONDA
29
- COPY environment.yml .
30
- RUN conda env create -f environment.yml && conda clean --all -y
31
 
32
- # 8. INSTALAR FLASH_ATTN DENTRO DO AMBIENTE
33
- RUN conda run -n seedvr pip install "flash_attn==2.5.9.post1" --no-build-isolation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # 9. DEFINIR O AMBIENTE CONDA COMO PADRÃO
36
- SHELL ["conda", "run", "-n", "seedvr", "/bin/bash", "-c"]
 
 
37
 
38
- # 10. PREPARAR O APLICATIVO
39
- WORKDIR /app
40
- RUN git clone https://github.com/bytedance-seed/SeedVR.git
41
- WORKDIR /app/SeedVR
 
 
 
 
 
 
 
 
42
 
43
- # 11. BAIXAR O MODELO DURANTE A CONSTRUÇÃO
44
- RUN huggingface-cli download ByteDance-Seed/SeedVR2-3B --local-dir ckpts --local-dir-use-symlinks False
 
 
45
 
46
- # 12. !!! NOVO !!! BAIXAR OS VÍDEOS DE EXEMPLO
47
- RUN wget -O 01.mp4 https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/23_1_lq.mp4 && \
48
- wget -O 02.mp4 https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/28_1_lq.mp4 && \
49
- wget -O 03.mp4 https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4
50
 
51
- # 13. COPIAR O CÓDIGO DO NOSSO APP
52
- COPY app.py .
 
53
 
54
- # 14. DEFINIR O COMANDO DE EXECUÇÃO
55
- CMD ["python", "app.py"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py (VERSÃO FINAL E CORRIGIDA)
2
 
3
+ import spaces
4
+ import subprocess
5
+ import os
6
+ import torch
7
+ import mediapy
8
+ from einops import rearrange
9
+ from omegaconf import OmegaConf
10
+ import datetime
11
+ from tqdm import tqdm
12
+ import gc
13
+ import uuid
14
+ import mimetypes
15
+ import torchvision.transforms as T
16
+ from PIL import Image
17
+ from pathlib import Path
18
+ import gradio as gr
19
 
20
+ # --- Módulos do SeedVR (agora que estão no ambiente, podemos importá-los) ---
21
+ from data.image.transforms.divisible_crop import DivisibleCrop
22
+ from data.image.transforms.na_resize import NaResize
23
+ from data.video.transforms.rearrange import Rearrange
24
+ from projects.video_diffusion_sr.color_fix import wavelet_reconstruction
25
+ from torchvision.transforms import Compose, Lambda, Normalize
26
+ from torchvision.io.video import read_video
27
+ from common.distributed import init_torch
28
+ from common.distributed.advanced import init_sequence_parallel
29
+ from projects.video_diffusion_sr.infer import VideoDiffusionInfer
30
+ from common.config import load_config
31
+ from common.distributed.ops import sync_data
32
+ from common.seed import set_seed
33
+ from common.partition import partition_by_size
34
 
35
+ # --- CONFIGURAÇÃO DO AMBIENTE (REMOVIDA) ---
36
+ # REMOVIDO: A instalação do flash-attn e apex já é feita no Dockerfile.
37
+ # REMOVIDO: O download dos checkpoints do modelo já é feito no Dockerfile.
38
+ # REMOVIDO: A configuração de torch.distributed é tratada de forma mais simples.
39
 
40
+ # Verificação para garantir que estamos no diretório certo
41
+ print(f"Diretório de trabalho atual: {os.getcwd()}")
42
+ if not os.path.exists('./projects'):
43
+ print("AVISO: O script parece não estar rodando de dentro do diretório /app/SeedVR. Verifique o WORKDIR no Dockerfile.")
44
 
45
+ # Checa se a correção de cor está disponível
46
+ use_colorfix = os.path.exists("./projects/video_diffusion_sr/color_fix.py")
47
+ if not use_colorfix:
48
+ print('Atenção: Correção de cor (color_fix.py) não disponível!')
49
 
 
 
50
 
51
+ def configure_sequence_parallel(sp_size):
52
+ if sp_size > 1:
53
+ init_sequence_parallel(sp_size)
54
 
55
+ # O decorador @spaces.GPU garante que a função rode na GPU e gerencia a duração
56
+ @spaces.GPU(duration=120)
57
+ def configure_runner(sp_size):
58
+ config_path = os.path.join('./configs_3b', 'main.yaml')
59
+ config = load_config(config_path)
60
+ runner = VideoDiffusionInfer(config)
61
+ OmegaConf.set_readonly(runner.config, False)
62
+
63
+ # Inicializa o torch para um único processo
64
+ os.environ["MASTER_ADDR"] = "127.0.0.1"
65
+ os.environ["MASTER_PORT"] = "12355"
66
+ if "RANK" not in os.environ:
67
+ os.environ["RANK"] = "0"
68
+ if "WORLD_SIZE" not in os.environ:
69
+ os.environ["WORLD_SIZE"] = "1"
70
+
71
+ init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
72
+ configure_sequence_parallel(sp_size)
73
+
74
+ # Os checkpoints estão no diretório ckpts, conforme baixado pelo Dockerfile
75
+ runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
76
+ runner.configure_vae_model(checkpoint_path='./ckpts/ema_vae.pth')
77
+
78
+ if hasattr(runner.vae, "set_memory_limit"):
79
+ runner.vae.set_memory_limit(**runner.config.vae.memory_limit)
80
+ return runner
81
 
82
+ @spaces.GPU(duration=120)
83
+ def generation_step(runner, text_embeds_dict, cond_latents):
84
+ def _move_to_cuda(x):
85
+ return [i.to(torch.device("cuda")) for i in x]
86
 
87
+ noises = [torch.randn_like(latent) for latent in cond_latents]
88
+ aug_noises = [torch.randn_like(latent) for latent in cond_latents]
89
+ noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0)
90
+ noises, aug_noises, cond_latents = list(map(_move_to_cuda, (noises, aug_noises, cond_latents)))
91
+
92
+ cond_noise_scale = 0.1
93
+ def _add_noise(x, aug_noise):
94
+ t = (torch.tensor([1000.0], device=torch.device("cuda")) * cond_noise_scale)
95
+ shape = torch.tensor(x.shape[1:], device=torch.device("cuda"))[None]
96
+ t = runner.timestep_transform(t, shape)
97
+ x = runner.schedule.forward(x, aug_noise, t)
98
+ return x
99
 
100
+ conditions = [
101
+ runner.get_condition(noise, task="sr", latent_blur=_add_noise(latent_blur, aug_noise))
102
+ for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents)
103
+ ]
104
 
105
+ with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True):
106
+ video_tensors = runner.inference(noises=noises, conditions=conditions, dit_offload=False, **text_embeds_dict)
 
 
107
 
108
+ samples = [rearrange(video, "c t h w -> t c h w") for video in video_tensors]
109
+ del video_tensors
110
+ return samples
111
 
112
+ @spaces.GPU(duration=120)
113
+ def generation_loop(video_path, seed, fps_out, batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, res_h=720, res_w=1280, sp_size=1):
114
+ # O Gradio passa o caminho do arquivo temporário
115
+ if video_path is None:
116
+ raise gr.Error("Por favor, faça o upload de um arquivo de vídeo ou imagem.")
117
+
118
+ runner = configure_runner(sp_size)
119
+
120
+ def _extract_text_embeds():
121
+ text_pos_embeds = torch.load('pos_emb.pt')
122
+ text_neg_embeds = torch.load('neg_emb.pt')
123
+ return [{"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]}]
124
+
125
+ def cut_videos(videos, sp_size):
126
+ if videos.size(1) > 121: videos = videos[:, :121]
127
+ t = videos.size(1)
128
+ if t <= 4 * sp_size:
129
+ padding = torch.cat([videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1), dim=1)
130
+ return torch.cat([videos, padding], dim=1)
131
+ if (t - 1) % (4 * sp_size) == 0: return videos
132
+ padding = torch.cat([videos[:, -1].unsqueeze(1)] * (4 * sp_size - ((t - 1) % (4 * sp_size))), dim=1)
133
+ return torch.cat([videos, padding], dim=1)
134
+
135
+ runner.config.diffusion.cfg.scale = cfg_scale
136
+ runner.config.diffusion.cfg.rescale = cfg_rescale
137
+ runner.config.diffusion.timesteps.sampling.steps = sample_steps
138
+ runner.configure_diffusion()
139
+
140
+ set_seed(seed % (2**32), same_across_ranks=True)
141
+ os.makedirs('output/', exist_ok=True)
142
+
143
+ original_videos_local = [[os.path.basename(video_path)]]
144
+ positive_prompts_embeds = _extract_text_embeds()
145
+
146
+ video_transform = Compose([
147
+ NaResize(resolution=(res_h * res_w) ** 0.5, mode="area", downsample_only=False),
148
+ Lambda(lambda x: torch.clamp(x, 0.0, 1.0)),
149
+ DivisibleCrop((16, 16)),
150
+ Normalize(0.5, 0.5),
151
+ Rearrange("t c h w -> c t h w"),
152
+ ])
153
+
154
+ for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)):
155
+ cond_latents = []
156
+ media_type, _ = mimetypes.guess_type(video_path)
157
+ is_image = media_type and media_type.startswith("image")
158
+ is_video = media_type and media_type.startswith("video")
159
+
160
+ if is_video:
161
+ video_frames = read_video(video_path, output_format="TCHW")[0] / 255.0
162
+ if video_frames.size(0) > 121: video_frames = video_frames[:121]
163
+ output_filename = str(uuid.uuid4()) + '.mp4'
164
+ elif is_image:
165
+ img = Image.open(video_path).convert("RGB")
166
+ video_frames = T.ToTensor()(img).unsqueeze(0)
167
+ output_filename = str(uuid.uuid4()) + '.png'
168
+ else:
169
+ raise gr.Error("Formato de arquivo não suportado. Use vídeo ou imagem.")
170
+
171
+ output_dir = os.path.join('output', output_filename)
172
+ cond_latents.append(video_transform(video_frames.to(torch.device("cuda"))))
173
+
174
+ ori_lengths = [v.size(1) for v in cond_latents]
175
+ input_videos = cond_latents
176
+ if is_video: cond_latents = [cut_videos(v, sp_size) for v in cond_latents]
177
+
178
+ cond_latents = runner.vae_encode(cond_latents)
179
+
180
+ for i, emb in enumerate(text_embeds["texts_pos"]): text_embeds["texts_pos"][i] = emb.to(torch.device("cuda"))
181
+ for i, emb in enumerate(text_embeds["texts_neg"]): text_embeds["texts_neg"][i] = emb.to(torch.device("cuda"))
182
+
183
+ samples = generation_step(runner, text_embeds, cond_latents=cond_latents)
184
+ del cond_latents
185
+
186
+ for path, input_vid, sample, ori_length in zip(videos, input_videos, samples, ori_lengths):
187
+ if ori_length < sample.shape[0]: sample = sample[:ori_length]
188
+ input_vid = rearrange(input_vid, "c t h w -> t c h w")
189
+ if use_colorfix: sample = wavelet_reconstruction(sample.cpu(), input_vid[:sample.size(0)].cpu())
190
+ else: sample = sample.cpu()
191
+
192
+ sample = rearrange(sample, "t c h w -> t h w c")
193
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round().to(torch.uint8).numpy()
194
+
195
+ if is_image:
196
+ mediapy.write_image(output_dir, sample[0])
197
+ else:
198
+ mediapy.write_video(output_dir, sample, fps=fps_out)
199
+
200
+ gc.collect()
201
+ torch.cuda.empty_cache()
202
+
203
+ # Retorna os valores para os componentes corretos da UI
204
+ if is_image:
205
+ return output_dir, None, output_dir
206
+ else:
207
+ return None, output_dir, output_dir
208
+
209
+ # --- Interface Gradio ---
210
+ with gr.Blocks(title="SeedVR2: One-Step Video Restoration") as demo:
211
+ gr.HTML(...) # Mantido como no original
212
+
213
+ with gr.Row():
214
+ # CORRIGIDO: gr.File para gr.Video, que passa um 'filepath' por padrão
215
+ input_file = gr.Video(label="Upload image or video")
216
+ seed = gr.Number(label="Seeds", value=666)
217
+ fps = gr.Number(label="fps", value=24)
218
+
219
+ with gr.Row():
220
+ output_image = gr.Image(label="Output_Image")
221
+ output_video = gr.Video(label="Output_Video")
222
+ download_link = gr.File(label="Download the output")
223
+
224
+ run_button = gr.Button("Run")
225
+ run_button.click(
226
+ fn=generation_loop,
227
+ inputs=[input_file, seed, fps],
228
+ outputs=[output_image, output_video, download_link]
229
+ )
230
+
231
+ gr.Examples(...) # Mantido como no original
232
+ gr.HTML(...) # Mantido como no original
233
+
234
+ demo.queue(max_size=10)
235
+ demo.launch()