eeuuia commited on
Commit
0701f73
verified
1 Parent(s): 9153ad0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -18
app.py CHANGED
@@ -5,7 +5,7 @@ import tempfile
5
  import os
6
  from torchvision import transforms
7
 
8
- from diffusers import LTXLatentUpsamplePipeline
9
  #from pipeline_ltx_condition_control import LTXConditionPipeline, LTXVideoCondition
10
  from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition
11
  from diffusers.utils import export_to_video, load_video
@@ -45,31 +45,72 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
45
  # token=os.getenv("HF_TOKEN"),
46
  #)
47
 
48
- from huggingface_hub import hf_hub_download
49
- from safetensors.torch import load_file as safe_load
 
 
 
50
 
51
- # Baixa exatamente a variante desejada do repo oficial:
52
- weight_path = hf_hub_download(
53
- repo_id="Lightricks/LTX-Video",
54
- filename="ltxv-13b-0.9.8-distilled.safetensors",
55
- #revision=os.getenv("LTXV_REVISION", "8984fa25007f376c1a299016d0957a37a2f797bb")
56
- )
 
 
 
57
 
58
- pipeline = LTXConditionPipeline.from_pretrained(
59
  "Lightricks/LTX-Video",
60
- #revision=os.getenv("LTXV_REVISION", "8984fa25007f376c1a299016d0957a37a2f797bb"), # exemplo: SHA atual do repo oficial
61
- offload_state_dict=False,
62
- torch_dtype=torch.bfloat16,
63
- cache_dir=os.getenv("HF_HOME_CACHE"),
64
- token=os.getenv("HF_TOKEN"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
 
 
 
 
67
 
68
- # Carrega o state_dict e aplica no transformer j谩 criado pelo model_index:
69
- state = safe_load(weight_path)
70
- pipeline.transformer.load_state_dict(state, strict=True)
 
 
 
 
 
 
71
 
72
 
 
 
 
 
 
 
 
 
 
73
 
74
 
75
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
 
5
  import os
6
  from torchvision import transforms
7
 
8
+ from diffusers import LTXLatentUpsamplePipeline, AutoModel
9
  #from pipeline_ltx_condition_control import LTXConditionPipeline, LTXVideoCondition
10
  from diffusers.pipelines.ltx.pipeline_ltx_condition import LTXConditionPipeline, LTXVideoCondition
11
  from diffusers.utils import export_to_video, load_video
 
45
  # token=os.getenv("HF_TOKEN"),
46
  #)
47
 
48
+ base_repo="Lightricks/LTX-Video"
49
+ checkpoint_path="ltxv-13b-0.9.8-distilled.safetensors"
50
+ upscaler_repo="Lightricks/ltxv-spatial-upscaler-0.9.7"
51
+ CACHE_DIR=os.getenv("HF_HOME_CACHE")
52
+ FPS = 24
53
 
54
+ # 2. Baixar os arquivos do modelo base
55
+ print(f"=== Baixando snapshot do reposit贸rio base: {base_repo} ===")
56
+ ckpt_path_str = hf_hub_download(repo_id=base_repo, filename=checkpoint_path, cache_dir=CACHE_DIR)
57
+ ckpt_path = Path(ckpt_path_str)
58
+ if not ckpt_path.is_file():
59
+ raise FileNotFoundError(f"Main checkpoint file not found: {ckpt_path}")
60
+
61
+ # 3. Carregar cada componente da pipeline explicitamente
62
+ print("=== Carregando componentes da pipeline... ===")
63
 
64
+ vae = AutoModel.from_pretrained(
65
  "Lightricks/LTX-Video",
66
+ subfolder="vae",
67
+ dtype=torch_dtype,
68
+ cache_dir=CACHE_DIR
69
+ )
70
+ text_encoder = AutoModel.from_pretrained(
71
+ "Lightricks/LTX-Video",
72
+ subfolder="text_encoder",
73
+ dtype=torch_dtype,
74
+ cache_dir=CACHE_DIR
75
+ )
76
+ scheduler = AutoModel.from_pretrained(
77
+ "Lightricks/LTX-Video",
78
+ subfolder="scheduler",
79
+ dtype=torch_dtype,
80
+ cache_dir=CACHE_DIR
81
+ )
82
+ tokenizer = AutoModel.from_pretrained(
83
+ "Lightricks/LTX-Video",
84
+ subfolder="tokenizer",
85
+ dtype=torch_dtype,
86
+ cache_dir=CACHE_DIR
87
  )
88
 
89
+ if hasattr(scheduler.config, 'use_dynamic_shifting') and scheduler.config.use_dynamic_shifting:
90
+ print("[Config] Desativando 'use_dynamic_shifting' no scheduler.")
91
+ scheduler.config.use_dynamic_shifting = False
92
+
93
 
94
+ transformer = AutoModel.from_pretrained(
95
+ "Lightricks/LTX-Video",
96
+ subfolder="transformer",
97
+ dtype=torch.bfloat16,
98
+ cache_dir=CACHE_DIR
99
+ )
100
+ transformer.enable_layerwise_casting(
101
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
102
+ )
103
 
104
 
105
+ # 4. Montar a pipeline principal
106
+ print("Montando a LTXConditionPipeline...")
107
+ pipeline = LTXConditionPipeline(
108
+ vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
109
+ scheduler=scheduler, transformer=transformer, cache_dir=CACHE_DIR
110
+ )
111
+ pipeline.to(device)
112
+ pipeline.vae.enable_tiling()
113
+
114
 
115
 
116
  pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(