Spaces:
Running
on
Zero
Running
on
Zero
| import lightning as L | |
| from diffusers.pipelines import FluxPipeline | |
| import torch | |
| from peft import LoraConfig, get_peft_model_state_dict | |
| import prodigyopt | |
| from ..flux.transformer import tranformer_forward | |
| from ..flux.condition import Condition | |
| from ..flux.pipeline_tools import encode_images, prepare_text_input | |
| class OminiModel(L.LightningModule): | |
| def __init__( | |
| self, | |
| flux_pipe_id: str, | |
| lora_path: str = None, | |
| lora_config: dict = None, | |
| device: str = "cuda", | |
| dtype: torch.dtype = torch.bfloat16, | |
| model_config: dict = {}, | |
| optimizer_config: dict = None, | |
| gradient_checkpointing: bool = False, | |
| ): | |
| # Initialize the LightningModule | |
| super().__init__() | |
| self.model_config = model_config | |
| self.optimizer_config = optimizer_config | |
| # Load the Flux pipeline | |
| self.flux_pipe: FluxPipeline = ( | |
| FluxPipeline.from_pretrained(flux_pipe_id).to(dtype=dtype).to(device) | |
| ) | |
| self.transformer = self.flux_pipe.transformer | |
| self.transformer.gradient_checkpointing = gradient_checkpointing | |
| self.transformer.train() | |
| # Freeze the Flux pipeline | |
| self.flux_pipe.text_encoder.requires_grad_(False).eval() | |
| self.flux_pipe.text_encoder_2.requires_grad_(False).eval() | |
| self.flux_pipe.vae.requires_grad_(False).eval() | |
| # Initialize LoRA layers | |
| self.lora_layers = self.init_lora(lora_path, lora_config) | |
| self.to(device).to(dtype) | |
| def init_lora(self, lora_path: str, lora_config: dict): | |
| assert lora_path or lora_config | |
| if lora_path: | |
| # TODO: Implement this | |
| raise NotImplementedError | |
| else: | |
| self.transformer.add_adapter(LoraConfig(**lora_config)) | |
| # TODO: Check if this is correct (p.requires_grad) | |
| lora_layers = filter( | |
| lambda p: p.requires_grad, self.transformer.parameters() | |
| ) | |
| return list(lora_layers) | |
| def save_lora(self, path: str): | |
| FluxPipeline.save_lora_weights( | |
| save_directory=path, | |
| transformer_lora_layers=get_peft_model_state_dict(self.transformer), | |
| safe_serialization=True, | |
| ) | |
| def configure_optimizers(self): | |
| # Freeze the transformer | |
| self.transformer.requires_grad_(False) | |
| opt_config = self.optimizer_config | |
| # Set the trainable parameters | |
| self.trainable_params = self.lora_layers | |
| # Unfreeze trainable parameters | |
| for p in self.trainable_params: | |
| p.requires_grad_(True) | |
| # Initialize the optimizer | |
| if opt_config["type"] == "AdamW": | |
| optimizer = torch.optim.AdamW(self.trainable_params, **opt_config["params"]) | |
| elif opt_config["type"] == "Prodigy": | |
| optimizer = prodigyopt.Prodigy( | |
| self.trainable_params, | |
| **opt_config["params"], | |
| ) | |
| elif opt_config["type"] == "SGD": | |
| optimizer = torch.optim.SGD(self.trainable_params, **opt_config["params"]) | |
| else: | |
| raise NotImplementedError | |
| return optimizer | |
| def training_step(self, batch, batch_idx): | |
| step_loss = self.step(batch) | |
| self.log_loss = ( | |
| step_loss.item() | |
| if not hasattr(self, "log_loss") | |
| else self.log_loss * 0.95 + step_loss.item() * 0.05 | |
| ) | |
| return step_loss | |
| def step(self, batch): | |
| imgs = batch["image"] | |
| conditions = batch["condition"] | |
| condition_types = batch["condition_type"] | |
| prompts = batch["description"] | |
| position_delta = batch["position_delta"][0] | |
| position_scale = float(batch.get("position_scale", [1.0])[0]) | |
| # Prepare inputs | |
| with torch.no_grad(): | |
| # Prepare image input | |
| x_0, img_ids = encode_images(self.flux_pipe, imgs) | |
| # Prepare text input | |
| prompt_embeds, pooled_prompt_embeds, text_ids = prepare_text_input( | |
| self.flux_pipe, prompts | |
| ) | |
| # Prepare t and x_t | |
| t = torch.sigmoid(torch.randn((imgs.shape[0],), device=self.device)) | |
| x_1 = torch.randn_like(x_0).to(self.device) | |
| t_ = t.unsqueeze(1).unsqueeze(1) | |
| x_t = ((1 - t_) * x_0 + t_ * x_1).to(self.dtype) | |
| # Prepare conditions | |
| condition_latents, condition_ids = encode_images(self.flux_pipe, conditions) | |
| # Add position delta | |
| condition_ids[:, 1] += position_delta[0] | |
| condition_ids[:, 2] += position_delta[1] | |
| if position_scale != 1.0: | |
| scale_bias = (position_scale - 1.0) / 2 | |
| condition_ids[:, 1] *= position_scale | |
| condition_ids[:, 2] *= position_scale | |
| condition_ids[:, 1] += scale_bias | |
| condition_ids[:, 2] += scale_bias | |
| # Prepare condition type | |
| condition_type_ids = torch.tensor( | |
| [ | |
| Condition.get_type_id(condition_type) | |
| for condition_type in condition_types | |
| ] | |
| ).to(self.device) | |
| condition_type_ids = ( | |
| torch.ones_like(condition_ids[:, 0]) * condition_type_ids[0] | |
| ).unsqueeze(1) | |
| # Prepare guidance | |
| guidance = ( | |
| torch.ones_like(t).to(self.device) | |
| if self.transformer.config.guidance_embeds | |
| else None | |
| ) | |
| # Forward pass | |
| transformer_out = tranformer_forward( | |
| self.transformer, | |
| # Model config | |
| model_config=self.model_config, | |
| # Inputs of the condition (new feature) | |
| condition_latents=condition_latents, | |
| condition_ids=condition_ids, | |
| condition_type_ids=condition_type_ids, | |
| # Inputs to the original transformer | |
| hidden_states=x_t, | |
| timestep=t, | |
| guidance=guidance, | |
| pooled_projections=pooled_prompt_embeds, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=img_ids, | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| ) | |
| pred = transformer_out[0] | |
| # Compute loss | |
| loss = torch.nn.functional.mse_loss(pred, (x_1 - x_0), reduction="mean") | |
| self.last_t = t.mean().item() | |
| return loss | |