eeuuia commited on
Commit
e6712fd
·
verified ·
1 Parent(s): 29dad4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -41
app.py CHANGED
@@ -15,8 +15,6 @@ import cv2
15
  import shutil
16
  import glob
17
  from pathlib import Path
18
- from diffusers import AutoModel
19
- from diffusers.hooks import apply_group_offloading
20
 
21
  import warnings
22
  import logging
@@ -37,55 +35,34 @@ dtype = torch.bfloat16
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
 
39
 
40
- # 1. Definir o repositório base
41
- base_model_repo = "Lightricks/LTX-Video"
42
 
43
- # 2. Carregar o Transformer separadamente para aplicar o casting FP8
44
- print("Carregando Transformer para otimização FP8...")
45
- transformer = AutoModel.from_pretrained(
46
- base_model_repo,
47
- subfolder="transformer",
48
- torch_dtype=dtype
49
- )
50
- # Habilita a conversão dinâmica para FP8 (requer hardware compatível para funcionar)
51
- print("Habilitando layerwise casting para FP8...")
52
- transformer.enable_layerwise_casting(
53
- storage_dtype=torch.float8_e4m3fn, compute_dtype=dtype
54
- )
55
-
56
- print("Desativando 'dynamic shifting' para compatibilidade com a pipeline.")
57
- transformer.config.use_dynamic_shifting = False
58
-
59
- # 3. Carregar a pipeline completa, injetando o Transformer já otimizado
60
- print(f"Carregando a arquitetura da pipeline de {base_model_repo}...")
61
- pipeline = LTXConditionPipeline.from_pretrained(
62
- base_model_repo,
63
- transformer=transformer, # Injeta o transformer otimizado
64
- torch_dtype=dtype,
65
  cache_dir=os.getenv("HF_HOME_CACHE"),
66
  token=os.getenv("HF_TOKEN"),
67
  )
68
 
69
- # 4. Carregar o upsampler (seu repositório é separado e está correto)
70
- print("Carregando upsampler...")
 
 
 
 
 
 
 
71
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
72
  "Lightricks/ltxv-spatial-upscaler-0.9.7",
73
  cache_dir=os.getenv("HF_HOME_CACHE"),
74
- vae=pipeline.vae,
75
- torch_dtype=dtype
76
  )
77
 
78
-
79
- # 5. Aplicar o descarregamento de grupos para economizar VRAM
80
- print("Aplicando otimizações de group-offloading para economizar VRAM...")
81
- onload_device = torch.device("cuda")
82
- offload_device = torch.device("cpu")
83
- # O Transformer já tem um método integrado
84
- pipeline.transformer.enable_group_offload(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
85
- # Para os outros componentes, usamos a função auxiliar
86
- apply_group_offloading(pipeline.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
87
- apply_group_offloading(pipeline.vae, onload_device=onload_device, offload_type="leaf_level")
88
-
89
 
90
 
91
  current_dir = Path(__file__).parent
 
15
  import shutil
16
  import glob
17
  from pathlib import Path
 
 
18
 
19
  import warnings
20
  import logging
 
35
  device = "cuda" if torch.cuda.is_available() else "cpu"
36
 
37
 
38
+ single_file_url = "https://huggingface.co/Lightricks/LTX-Video/resolve/main/ltxv-13b-0.9.8-distilled-fp8.safetensors"
 
39
 
40
+ pipeline = LTXConditionPipeline.from_single_file(
41
+ single_file_url,
42
+ offload_state_dict=False,
43
+ dtype=torch.bfloat16, # Use o dtype apropriado. Para FP8, pode ser necessário torch.float8_e4m3fn.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  cache_dir=os.getenv("HF_HOME_CACHE"),
45
  token=os.getenv("HF_TOKEN"),
46
  )
47
 
48
+ # Carregamento das pipelines
49
+ #pipeline = LTXConditionPipeline.from_pretrained(
50
+ # "Lightricks/LTX-Video-0.9.8-13B-distilled",
51
+ # offload_state_dict=False,
52
+ # torch_dtype=torch.bfloat16,
53
+ # cache_dir=os.getenv("HF_HOME_CACHE"),
54
+ # token=os.getenv("HF_TOKEN"),
55
+ #)
56
+
57
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
58
  "Lightricks/ltxv-spatial-upscaler-0.9.7",
59
  cache_dir=os.getenv("HF_HOME_CACHE"),
60
+ vae=pipeline.vae, torch_dtype=dtype
 
61
  )
62
 
63
+ pipeline.to(device)
64
+ pipe_upsample.to(device)
65
+ pipeline.vae.enable_tiling()
 
 
 
 
 
 
 
 
66
 
67
 
68
  current_dir = Path(__file__).parent