Spaces:
Runtime error
Runtime error
upgrade Finetrainers
Browse files
finetrainers/models/cogvideox/base_specification.py
CHANGED
|
@@ -299,7 +299,7 @@ class CogVideoXModelSpecification(ModelSpecification):
|
|
| 299 |
latents = posterior.sample(generator=generator)
|
| 300 |
del posterior
|
| 301 |
|
| 302 |
-
if not self.vae_config
|
| 303 |
latents = latents * self.vae_config.scaling_factor
|
| 304 |
|
| 305 |
if patch_size_t is not None:
|
|
|
|
| 299 |
latents = posterior.sample(generator=generator)
|
| 300 |
del posterior
|
| 301 |
|
| 302 |
+
if not getattr(self.vae_config, "invert_scale_latents", False):
|
| 303 |
latents = latents * self.vae_config.scaling_factor
|
| 304 |
|
| 305 |
if patch_size_t is not None:
|
finetrainers/models/ltx_video/base_specification.py
CHANGED
|
@@ -336,8 +336,8 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
| 336 |
latents = self._pack_latents(latents, patch_size, patch_size_t)
|
| 337 |
noise = self._pack_latents(noise, patch_size, patch_size_t)
|
| 338 |
noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
|
| 339 |
-
|
| 340 |
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
|
|
|
|
| 341 |
|
| 342 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
| 343 |
|
|
@@ -352,7 +352,6 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
| 352 |
vae_spatial_compression_ratio,
|
| 353 |
vae_spatial_compression_ratio,
|
| 354 |
]
|
| 355 |
-
timesteps = (sigmas * 1000.0).long()
|
| 356 |
|
| 357 |
pred = transformer(
|
| 358 |
**latent_model_conditions,
|
|
@@ -444,9 +443,9 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
| 444 |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
| 445 |
) -> torch.Tensor:
|
| 446 |
# Normalize latents across the channel dimension [B, C, F, H, W]
|
| 447 |
-
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device
|
| 448 |
-
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device
|
| 449 |
-
latents = (latents - latents_mean) * scaling_factor / latents_std
|
| 450 |
return latents
|
| 451 |
|
| 452 |
@staticmethod
|
|
|
|
| 336 |
latents = self._pack_latents(latents, patch_size, patch_size_t)
|
| 337 |
noise = self._pack_latents(noise, patch_size, patch_size_t)
|
| 338 |
noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
|
|
|
|
| 339 |
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
|
| 340 |
+
timesteps = (sigmas * 1000.0).long()
|
| 341 |
|
| 342 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
| 343 |
|
|
|
|
| 352 |
vae_spatial_compression_ratio,
|
| 353 |
vae_spatial_compression_ratio,
|
| 354 |
]
|
|
|
|
| 355 |
|
| 356 |
pred = transformer(
|
| 357 |
**latent_model_conditions,
|
|
|
|
| 443 |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
| 444 |
) -> torch.Tensor:
|
| 445 |
# Normalize latents across the channel dimension [B, C, F, H, W]
|
| 446 |
+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
|
| 447 |
+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
|
| 448 |
+
latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents)
|
| 449 |
return latents
|
| 450 |
|
| 451 |
@staticmethod
|
finetrainers/models/wan/base_specification.py
CHANGED
|
@@ -39,7 +39,7 @@ class WanLatentEncodeProcessor(ProcessorMixin):
|
|
| 39 |
def __init__(self, output_names: List[str]):
|
| 40 |
super().__init__()
|
| 41 |
self.output_names = output_names
|
| 42 |
-
assert len(self.output_names) ==
|
| 43 |
|
| 44 |
def forward(
|
| 45 |
self,
|
|
@@ -72,7 +72,10 @@ class WanLatentEncodeProcessor(ProcessorMixin):
|
|
| 72 |
moments = vae._encode(video)
|
| 73 |
latents = moments.to(dtype=dtype)
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
|
| 78 |
class WanModelSpecification(ModelSpecification):
|
|
@@ -108,7 +111,7 @@ class WanModelSpecification(ModelSpecification):
|
|
| 108 |
if condition_model_processors is None:
|
| 109 |
condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
|
| 110 |
if latent_model_processors is None:
|
| 111 |
-
latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
|
| 112 |
|
| 113 |
self.condition_model_processors = condition_model_processors
|
| 114 |
self.latent_model_processors = latent_model_processors
|
|
@@ -266,7 +269,10 @@ class WanModelSpecification(ModelSpecification):
|
|
| 266 |
"image": image,
|
| 267 |
"video": video,
|
| 268 |
"generator": generator,
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
| 270 |
**kwargs,
|
| 271 |
}
|
| 272 |
input_keys = set(conditions.keys())
|
|
@@ -284,20 +290,29 @@ class WanModelSpecification(ModelSpecification):
|
|
| 284 |
compute_posterior: bool = True,
|
| 285 |
**kwargs,
|
| 286 |
) -> Tuple[torch.Tensor, ...]:
|
|
|
|
| 287 |
if compute_posterior:
|
| 288 |
latents = latent_model_conditions.pop("latents")
|
| 289 |
else:
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
latents = posterior.sample(generator=generator)
|
| 292 |
del posterior
|
| 293 |
|
| 294 |
noise = torch.zeros_like(latents).normal_(generator=generator)
|
| 295 |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
|
|
|
|
| 296 |
|
| 297 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
| 298 |
|
| 299 |
-
timesteps = (sigmas.flatten() * 1000.0).long()
|
| 300 |
-
|
| 301 |
pred = transformer(
|
| 302 |
**latent_model_conditions,
|
| 303 |
**condition_model_conditions,
|
|
@@ -367,3 +382,12 @@ class WanModelSpecification(ModelSpecification):
|
|
| 367 |
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
|
| 368 |
if scheduler is not None:
|
| 369 |
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def __init__(self, output_names: List[str]):
|
| 40 |
super().__init__()
|
| 41 |
self.output_names = output_names
|
| 42 |
+
assert len(self.output_names) == 3
|
| 43 |
|
| 44 |
def forward(
|
| 45 |
self,
|
|
|
|
| 72 |
moments = vae._encode(video)
|
| 73 |
latents = moments.to(dtype=dtype)
|
| 74 |
|
| 75 |
+
latents_mean = torch.tensor(vae.config.latents_mean)
|
| 76 |
+
latents_std = 1.0 / torch.tensor(vae.config.latents_std)
|
| 77 |
+
|
| 78 |
+
return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std}
|
| 79 |
|
| 80 |
|
| 81 |
class WanModelSpecification(ModelSpecification):
|
|
|
|
| 111 |
if condition_model_processors is None:
|
| 112 |
condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
|
| 113 |
if latent_model_processors is None:
|
| 114 |
+
latent_model_processors = [WanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])]
|
| 115 |
|
| 116 |
self.condition_model_processors = condition_model_processors
|
| 117 |
self.latent_model_processors = latent_model_processors
|
|
|
|
| 269 |
"image": image,
|
| 270 |
"video": video,
|
| 271 |
"generator": generator,
|
| 272 |
+
# We must force this to False because the latent normalization should be done before
|
| 273 |
+
# the posterior is computed. The VAE does not handle this any more:
|
| 274 |
+
# https://github.com/huggingface/diffusers/pull/10998
|
| 275 |
+
"compute_posterior": False,
|
| 276 |
**kwargs,
|
| 277 |
}
|
| 278 |
input_keys = set(conditions.keys())
|
|
|
|
| 290 |
compute_posterior: bool = True,
|
| 291 |
**kwargs,
|
| 292 |
) -> Tuple[torch.Tensor, ...]:
|
| 293 |
+
compute_posterior = False # See explanation in prepare_latents
|
| 294 |
if compute_posterior:
|
| 295 |
latents = latent_model_conditions.pop("latents")
|
| 296 |
else:
|
| 297 |
+
latents = latent_model_conditions.pop("latents")
|
| 298 |
+
latents_mean = latent_model_conditions.pop("latents_mean")
|
| 299 |
+
latents_std = latent_model_conditions.pop("latents_std")
|
| 300 |
+
|
| 301 |
+
mu, logvar = torch.chunk(latents, 2, dim=1)
|
| 302 |
+
mu = self._normalize_latents(mu, latents_mean, latents_std)
|
| 303 |
+
logvar = self._normalize_latents(logvar, latents_mean, latents_std)
|
| 304 |
+
latents = torch.cat([mu, logvar], dim=1)
|
| 305 |
+
|
| 306 |
+
posterior = DiagonalGaussianDistribution(latents)
|
| 307 |
latents = posterior.sample(generator=generator)
|
| 308 |
del posterior
|
| 309 |
|
| 310 |
noise = torch.zeros_like(latents).normal_(generator=generator)
|
| 311 |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
|
| 312 |
+
timesteps = (sigmas.flatten() * 1000.0).long()
|
| 313 |
|
| 314 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
| 315 |
|
|
|
|
|
|
|
| 316 |
pred = transformer(
|
| 317 |
**latent_model_conditions,
|
| 318 |
**condition_model_conditions,
|
|
|
|
| 382 |
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
|
| 383 |
if scheduler is not None:
|
| 384 |
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def _normalize_latents(
|
| 388 |
+
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
| 389 |
+
) -> torch.Tensor:
|
| 390 |
+
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
|
| 391 |
+
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
|
| 392 |
+
latents = ((latents.float() - latents_mean) * latents_std).to(latents)
|
| 393 |
+
return latents
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
|
@@ -147,8 +147,11 @@ class SFTTrainer:
|
|
| 147 |
|
| 148 |
# Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
|
| 149 |
# parameters to be of the same dtype.
|
| 150 |
-
if
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
def _prepare_for_training(self) -> None:
|
| 154 |
# 1. Apply parallelism
|
|
|
|
| 147 |
|
| 148 |
# Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all
|
| 149 |
# parameters to be of the same dtype.
|
| 150 |
+
if parallel_backend.data_sharding_enabled:
|
| 151 |
+
self.transformer.to(dtype=self.args.transformer_dtype)
|
| 152 |
+
else:
|
| 153 |
+
if self.args.training_type == TrainingType.LORA:
|
| 154 |
+
cast_training_params([self.transformer], dtype=torch.float32)
|
| 155 |
|
| 156 |
def _prepare_for_training(self) -> None:
|
| 157 |
# 1. Apply parallelism
|