Spaces:
Runtime error
Runtime error
| from .hidream_model import HidreamModel | |
| from .src.pipelines.hidream_image.pipeline_hidream_image_editing import ( | |
| HiDreamImageEditingPipeline, | |
| ) | |
| from .src.schedulers.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
| from toolkit.accelerator import unwrap_model | |
| import torch | |
| from toolkit.prompt_utils import PromptEmbeds | |
| from toolkit.config_modules import GenerateImageConfig | |
| from diffusers.models import HiDreamImageTransformer2DModel | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO | |
| class HidreamE1Model(HidreamModel): | |
| arch = "hidream_e1" | |
| hidream_transformer_class = HiDreamImageTransformer2DModel | |
| hidream_pipeline_class = HiDreamImageEditingPipeline | |
| def get_generation_pipeline(self): | |
| scheduler = FlowUniPCMultistepScheduler( | |
| num_train_timesteps=1000, shift=3.0, use_dynamic_shifting=False | |
| ) | |
| pipeline: HiDreamImageEditingPipeline = HiDreamImageEditingPipeline( | |
| scheduler=scheduler, | |
| vae=self.vae, | |
| text_encoder=self.text_encoder[0], | |
| tokenizer=self.tokenizer[0], | |
| text_encoder_2=self.text_encoder[1], | |
| tokenizer_2=self.tokenizer[1], | |
| text_encoder_3=self.text_encoder[2], | |
| tokenizer_3=self.tokenizer[2], | |
| text_encoder_4=self.text_encoder[3], | |
| tokenizer_4=self.tokenizer[3], | |
| transformer=unwrap_model(self.model), | |
| aggressive_unloading=self.low_vram, | |
| ) | |
| pipeline = pipeline.to(self.device_torch) | |
| return pipeline | |
| def generate_single_image( | |
| self, | |
| pipeline: HiDreamImageEditingPipeline, | |
| gen_config: GenerateImageConfig, | |
| conditional_embeds: PromptEmbeds, | |
| unconditional_embeds: PromptEmbeds, | |
| generator: torch.Generator, | |
| extra: dict, | |
| ): | |
| if gen_config.ctrl_img is None: | |
| raise ValueError( | |
| "Control image is required for Flux Kontext model generation." | |
| ) | |
| else: | |
| control_img = Image.open(gen_config.ctrl_img) | |
| control_img = control_img.convert("RGB") | |
| # resize to width and height | |
| if control_img.size != (gen_config.width, gen_config.height): | |
| control_img = control_img.resize( | |
| (gen_config.width, gen_config.height), Image.BILINEAR | |
| ) | |
| img = pipeline( | |
| prompt_embeds_t5=conditional_embeds.text_embeds[0], | |
| prompt_embeds_llama3=conditional_embeds.text_embeds[1], | |
| pooled_prompt_embeds=conditional_embeds.pooled_embeds, | |
| negative_prompt_embeds_t5=unconditional_embeds.text_embeds[0], | |
| negative_prompt_embeds_llama3=unconditional_embeds.text_embeds[1], | |
| negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, | |
| height=gen_config.height, | |
| width=gen_config.width, | |
| num_inference_steps=gen_config.num_inference_steps, | |
| guidance_scale=gen_config.guidance_scale, | |
| latents=gen_config.latents, | |
| generator=generator, | |
| image=control_img, | |
| **extra, | |
| ).images[0] | |
| return img | |
| def get_prompt_embeds(self, prompt: str) -> PromptEmbeds: | |
| self.text_encoder_to(self.device_torch, dtype=self.torch_dtype) | |
| max_sequence_length = 128 | |
| ( | |
| prompt_embeds_t5, | |
| negative_prompt_embeds_t5, | |
| prompt_embeds_llama3, | |
| negative_prompt_embeds_llama3, | |
| pooled_prompt_embeds, | |
| negative_pooled_prompt_embeds, | |
| ) = self.pipeline.encode_prompt( | |
| prompt=prompt, | |
| prompt_2=prompt, | |
| prompt_3=prompt, | |
| prompt_4=prompt, | |
| device=self.device_torch, | |
| dtype=self.torch_dtype, | |
| num_images_per_prompt=1, | |
| max_sequence_length=max_sequence_length, | |
| do_classifier_free_guidance=False, | |
| ) | |
| prompt_embeds = [prompt_embeds_t5, prompt_embeds_llama3] | |
| pe = PromptEmbeds([prompt_embeds, pooled_prompt_embeds]) | |
| return pe | |
| def condition_noisy_latents( | |
| self, latents: torch.Tensor, batch: "DataLoaderBatchDTO" | |
| ): | |
| with torch.no_grad(): | |
| control_tensor = batch.control_tensor | |
| if control_tensor is not None: | |
| self.vae.to(self.device_torch) | |
| # we are not packed here, so we just need to pass them so we can pack them later | |
| control_tensor = control_tensor * 2 - 1 | |
| control_tensor = control_tensor.to( | |
| self.vae_device_torch, dtype=self.torch_dtype | |
| ) | |
| # if it is not the size of batch.tensor, (bs,ch,h,w) then we need to resize it | |
| if batch.tensor is not None: | |
| target_h, target_w = batch.tensor.shape[2], batch.tensor.shape[3] | |
| else: | |
| # When caching latents, batch.tensor is None. We get the size from the file_items instead. | |
| target_h = batch.file_items[0].crop_height | |
| target_w = batch.file_items[0].crop_width | |
| if ( | |
| control_tensor.shape[2] != target_h | |
| or control_tensor.shape[3] != target_w | |
| ): | |
| control_tensor = F.interpolate( | |
| control_tensor, size=(target_h, target_w), mode="bilinear" | |
| ) | |
| control_latent = self.encode_images(control_tensor).to( | |
| latents.device, latents.dtype | |
| ) | |
| latents = torch.cat((latents, control_latent), dim=1) | |
| return latents.detach() | |
| def get_noise_prediction( | |
| self, | |
| latent_model_input: torch.Tensor, | |
| timestep: torch.Tensor, # 0 to 1000 scale | |
| text_embeddings: PromptEmbeds, | |
| **kwargs, | |
| ): | |
| with torch.no_grad(): | |
| # make sure config is set | |
| self.model.config.force_inference_output = True | |
| has_control = False | |
| lat_size = latent_model_input.shape[-1] | |
| if latent_model_input.shape[1] == 32: | |
| # chunk it and stack it on batch dimension | |
| # dont update batch size for img_its | |
| lat, control = torch.chunk(latent_model_input, 2, dim=1) | |
| latent_model_input = torch.cat([lat, control], dim=-1) | |
| has_control = True | |
| dtype = self.model.dtype | |
| device = self.device_torch | |
| text_embeds = text_embeddings.text_embeds | |
| # run the to for the list | |
| text_embeds = [te.to(device, dtype=dtype) for te in text_embeds] | |
| noise_pred = self.transformer( | |
| hidden_states=latent_model_input, | |
| timesteps=timestep, | |
| encoder_hidden_states_t5=text_embeds[0], | |
| encoder_hidden_states_llama3=text_embeds[1], | |
| pooled_embeds=text_embeddings.pooled_embeds.to(device, dtype=dtype), | |
| return_dict=False, | |
| )[0] | |
| if has_control: | |
| noise_pred = -1.0 * noise_pred[..., :lat_size] | |
| else: | |
| noise_pred = -1.0 * noise_pred | |
| return noise_pred | |