import os from typing import TYPE_CHECKING, List, Optional import torch import yaml from toolkit.config_modules import GenerateImageConfig, ModelConfig from toolkit.models.base_model import BaseModel from diffusers import AutoencoderKL from toolkit.basic import flush from toolkit.prompt_utils import PromptEmbeds from toolkit.samplers.custom_flowmatch_sampler import ( CustomFlowMatchEulerDiscreteScheduler, ) from toolkit.accelerator import unwrap_model from optimum.quanto import freeze from toolkit.util.quantize import quantize, get_qtype from .src.pipelines.omnigen2.pipeline_omnigen2 import OmniGen2Pipeline from .src.models.transformers import OmniGen2Transformer2DModel from .src.models.transformers.repo import OmniGen2RotaryPosEmbed from .src.schedulers.scheduling_flow_match_euler_discrete import ( FlowMatchEulerDiscreteScheduler as OmniFlowMatchEuler, ) from PIL import Image from transformers import ( CLIPProcessor, Qwen2_5_VLForConditionalGeneration, ) import torch.nn.functional as F if TYPE_CHECKING: from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO scheduler_config = {"num_train_timesteps": 1000} BASE_MODEL_PATH = "OmniGen2/OmniGen2" class OmniGen2Model(BaseModel): arch = "omnigen2" def __init__( self, device, model_config: ModelConfig, dtype="bf16", custom_pipeline=None, noise_scheduler=None, **kwargs, ): super().__init__( device, model_config, dtype, custom_pipeline, noise_scheduler, **kwargs ) self.is_flow_matching = True self.is_transformer = True self.target_lora_modules = ["OmniGen2Transformer2DModel"] self._control_latent = None # static method to get the noise scheduler @staticmethod def get_train_scheduler(): return CustomFlowMatchEulerDiscreteScheduler(**scheduler_config) def get_bucket_divisibility(self): return 16 def load_model(self): dtype = self.torch_dtype # HiDream-ai/HiDream-I1-Full self.print_and_status_update("Loading OmniGen2 model") # will be updated if we detect a existing checkpoint in training folder model_path = self.model_config.name_or_path extras_path = self.model_config.extras_name_or_path scheduler = OmniGen2Model.get_train_scheduler() self.print_and_status_update("Loading Qwen2.5 VL") processor = CLIPProcessor.from_pretrained( extras_path, subfolder="processor", use_fast=True ) mllm = Qwen2_5_VLForConditionalGeneration.from_pretrained( extras_path, subfolder="mllm", torch_dtype=torch.bfloat16 ) mllm.to(self.device_torch, dtype=dtype) if self.model_config.quantize_te: self.print_and_status_update("Quantizing Qwen2.5 VL model") quantization_type = get_qtype(self.model_config.qtype_te) quantize(mllm, weights=quantization_type) freeze(mllm) if self.low_vram: # unload it for now mllm.to("cpu") flush() self.print_and_status_update("Loading transformer") transformer = OmniGen2Transformer2DModel.from_pretrained( model_path, subfolder="transformer", torch_dtype=torch.bfloat16 ) if not self.low_vram: transformer.to(self.device_torch, dtype=dtype) if self.model_config.quantize: self.print_and_status_update("Quantizing transformer") quantization_type = get_qtype(self.model_config.qtype) quantize(transformer, weights=quantization_type) freeze(transformer) if self.low_vram: # unload it for now transformer.to("cpu") flush() self.print_and_status_update("Loading vae") vae = AutoencoderKL.from_pretrained( extras_path, subfolder="vae", torch_dtype=torch.bfloat16 ).to(self.device_torch, dtype=dtype) flush() self.print_and_status_update("Loading Qwen2.5 VLProcessor") flush() if self.low_vram: self.print_and_status_update("Moving everything to device") # move it all back transformer.to(self.device_torch, dtype=dtype) vae.to(self.device_torch, dtype=dtype) mllm.to(self.device_torch, dtype=dtype) # set to eval mode # transformer.eval() vae.eval() mllm.eval() mllm.requires_grad_(False) pipe: OmniGen2Pipeline = OmniGen2Pipeline( transformer=transformer, vae=vae, scheduler=scheduler, mllm=mllm, processor=processor, ) flush() text_encoder_list = [mllm] tokenizer_list = [processor] flush() # save it to the model class self.vae = vae self.text_encoder = text_encoder_list # list of text encoders self.tokenizer = tokenizer_list # list of tokenizers self.model = pipe.transformer self.pipeline = pipe self.freqs_cis = OmniGen2RotaryPosEmbed.get_freqs_cis( transformer.config.axes_dim_rope, transformer.config.axes_lens, theta=10000, ) self.print_and_status_update("Model Loaded") def get_generation_pipeline(self): scheduler = OmniFlowMatchEuler( dynamic_time_shift=True, num_train_timesteps=1000 ) pipeline: OmniGen2Pipeline = OmniGen2Pipeline( transformer=self.model, vae=self.vae, scheduler=scheduler, mllm=self.text_encoder[0], processor=self.tokenizer[0], ) pipeline = pipeline.to(self.device_torch) return pipeline def generate_single_image( self, pipeline: OmniGen2Pipeline, gen_config: GenerateImageConfig, conditional_embeds: PromptEmbeds, unconditional_embeds: PromptEmbeds, generator: torch.Generator, extra: dict, ): input_images = [] if gen_config.ctrl_img is not None: control_img = Image.open(gen_config.ctrl_img) control_img = control_img.convert("RGB") # resize to width and height if control_img.size != (gen_config.width, gen_config.height): control_img = control_img.resize( (gen_config.width, gen_config.height), Image.BILINEAR ) input_images = [control_img] img = pipeline( prompt_embeds=conditional_embeds.text_embeds, prompt_attention_mask=conditional_embeds.attention_mask, negative_prompt_embeds=unconditional_embeds.text_embeds, negative_prompt_attention_mask=unconditional_embeds.attention_mask, height=gen_config.height, width=gen_config.width, num_inference_steps=gen_config.num_inference_steps, text_guidance_scale=gen_config.guidance_scale, image_guidance_scale=1.0, # reference image guidance scale. Add this for controls latents=gen_config.latents, align_res=False, generator=generator, input_images=input_images, **extra, ).images[0] return img def get_noise_prediction( self, latent_model_input: torch.Tensor, timestep: torch.Tensor, # 0 to 1000 scale text_embeddings: PromptEmbeds, **kwargs, ): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML try: timestep = timestep.expand(latent_model_input.shape[0]).to( latent_model_input.dtype ) except Exception as e: pass timesteps = timestep / 1000 # convert to 0 to 1 scale # timestep for model starts at 0 instead of 1. So we need to reverse them timestep = 1 - timesteps model_pred = self.model( latent_model_input, timestep, text_embeddings.text_embeds, self.freqs_cis, text_embeddings.attention_mask, ref_image_hidden_states=self._control_latent, ) return model_pred def condition_noisy_latents( self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" ): # reset the control latent self._control_latent = None with torch.no_grad(): control_tensor = batch.control_tensor if control_tensor is not None: self.vae.to(self.device_torch) # we are not packed here, so we just need to pass them so we can pack them later control_tensor = control_tensor * 2 - 1 control_tensor = control_tensor.to( self.vae_device_torch, dtype=self.torch_dtype ) # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it # todo, we may not need to do this, check if batch.tensor is not None: target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] else: # When caching latents, batch.tensor is None. We get the size from the file_items instead. target_h = batch.file_items[0].crop_height target_w = batch.file_items[0].crop_width if ( control_tensor.shape[2] != target_h or control_tensor.shape[3] != target_w ): control_tensor = F.interpolate( control_tensor, size=(target_h, target_w), mode="bilinear" ) control_latent = self.encode_images(control_tensor).to( latents.device, latents.dtype ) self._control_latent = [ [x.squeeze(0)] for x in torch.chunk(control_latent, control_latent.shape[0], dim=0) ] return latents.detach() def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [self.pipeline._apply_chat_template(_prompt) for _prompt in prompt] self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) max_sequence_length = 256 prompt_embeds, prompt_attention_mask, _, _ = self.pipeline.encode_prompt( prompt=prompt, do_classifier_free_guidance=False, device=self.device_torch, max_sequence_length=max_sequence_length, ) pe = PromptEmbeds(prompt_embeds) pe.attention_mask = prompt_attention_mask return pe def get_model_has_grad(self): # return from a weight if it has grad return False def get_te_has_grad(self): # assume no one wants to finetune 4 text encoders. return False def save_model(self, output_path, meta, save_dtype): # only save the transformer transformer: OmniGen2Transformer2DModel = unwrap_model(self.model) transformer.save_pretrained( save_directory=os.path.join(output_path, "transformer"), safe_serialization=True, ) meta_path = os.path.join(output_path, "aitk_meta.yaml") with open(meta_path, "w") as f: yaml.dump(meta, f) def get_loss_target(self, *args, **kwargs): noise = kwargs.get("noise") batch = kwargs.get("batch") # return (noise - batch.latents).detach() return (batch.latents - noise).detach() def get_transformer_block_names(self) -> Optional[List[str]]: # omnigen2 had a few blocks for things like noise_refiner, ref_image_refiner, context_refiner, and layers. # lets do all but image refiner until we add it if self.model_config.model_kwargs.get("use_image_refiner", False): return ["noise_refiner", "context_refiner", "ref_image_refiner", "layers"] return ["noise_refiner", "context_refiner", "layers"] def convert_lora_weights_before_save(self, state_dict): # currently starte with transformer. but needs to start with diffusion_model. for comfyui new_sd = {} for key, value in state_dict.items(): new_key = key.replace("transformer.", "diffusion_model.") new_sd[new_key] = value return new_sd def convert_lora_weights_before_load(self, state_dict): # saved as diffusion_model. but needs to be transformer. for ai-toolkit new_sd = {} for key, value in state_dict.items(): new_key = key.replace("diffusion_model.", "transformer.") new_sd[new_key] = value return new_sd def get_base_model_version(self): return "omnigen2"