import hashlib import os from typing import List, Optional, Union import torch from diffusers import FluxModularPipeline, ModularPipelineBlocks from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.modular_pipelines import PipelineState from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) from diffusers.utils import ( USE_PEFT_BACKEND, logger, scale_lora_layers, unscale_lora_layers, ) from safetensors import safe_open from safetensors.torch import save_file from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast class CachedFluxTextEncoderStep(ModularPipelineBlocks): model_name = "flux" def __init__( self, use_cache: bool = True, cache_dir: Optional[str] = None, load_from_disk: bool = True, ) -> None: """Initialize the cached Flux text encoder step. Args: use_cache: Whether to enable caching of prompt embeddings. Defaults to True. cache_dir: Directory to store cache files. If None, uses ~/.cache/flux_prompt_cache. load_from_disk: Whether to load existing cache from disk on initialization. Defaults to True. """ super().__init__() self.cache = {} if use_cache else None if use_cache: self.cache_dir = cache_dir or os.path.join( os.path.expanduser("~"), ".cache", "flux_prompt_cache" ) os.makedirs(self.cache_dir, exist_ok=True) else: self.cache_dir = None # Load existing cache if requested if load_from_disk and use_cache: self.load_cache_from_disk() @property def description(self) -> str: return "Text Encoder step that generate text_embeddings to guide the video generation" @property def expected_components(self): return [ ComponentSpec("text_encoder", CLIPTextModel), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("text_encoder_2", T5EncoderModel), ComponentSpec("tokenizer_2", T5TokenizerFast), ] @property def expected_configs(self): return [] @property def inputs(self) -> List[InputParam]: return [ InputParam("prompt"), InputParam("prompt_2"), InputParam("joint_attention_kwargs"), ] @property def intermediate_outputs(self): return [ OutputParam( "prompt_embeds", type_hint=torch.Tensor, description="text embeddings used to guide the image generation", ), OutputParam( "pooled_prompt_embeds", type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation", ), OutputParam( "text_ids", type_hint=torch.Tensor, description="ids from the text sequence for RoPE", ), ] @staticmethod def check_inputs(block_state): for prompt in [block_state.prompt, block_state.prompt_2]: if prompt is not None and ( not isinstance(prompt, str) and not isinstance(prompt, list) ): raise ValueError( f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}" ) def save_cache_to_disk(self): """Save the current cache to disk as a safetensors file.""" if not self.cache or not self.cache_dir: return cache_file = os.path.join(self.cache_dir, "cache.safetensors") # Prepare tensors dict for safetensors tensors_to_save = {} for key, tensor in self.cache.items(): # Ensure tensor is on CPU before saving cpu_tensor = ( tensor.cpu() if tensor.device != torch.device("cpu") else tensor ) tensors_to_save[key] = cpu_tensor # Save tensors save_file(tensors_to_save, cache_file) logger.info(f"Saved {len(tensors_to_save)} cached embeddings to {cache_file}") def load_cache_from_disk(self): """Load cache from disk using memory-mapped safetensors.""" if not self.cache_dir or self.cache is None: return cache_file = os.path.join(self.cache_dir, "cache.safetensors") if not os.path.exists(cache_file): return try: # Open safetensors file in context manager with safe_open(cache_file, framework="pt", device="cpu") as f: loaded_count = 0 for key in f.keys(): self.cache[key] = f.get_tensor(key) loaded_count += 1 logger.debug( f"Loaded {loaded_count} cached embeddings from {cache_file} (memory-mapped)" ) except Exception as e: logger.warning(f"Failed to load cache from disk: {e}") def clear_cache_from_disk(self): """Clear cached safetensors file from disk.""" if not self.cache_dir: return cache_file = os.path.join(self.cache_dir, "cache.safetensors") if os.path.exists(cache_file): os.remove(cache_file) logger.info(f"Cleared cache file: {cache_file}") # Also clear the in-memory cache if self.cache: self.cache.clear() def get_cache_size(self): """Get the current cache size in MB.""" if not self.cache_dir: return 0 cache_file = os.path.join(self.cache_dir, "cache.safetensors") if os.path.exists(cache_file): return os.path.getsize(cache_file) / (1024 * 1024) # Convert to MB return 0 @staticmethod def _to_cache_key(prompt: str) -> str: """Generate a hash key for a single prompt string.""" return hashlib.sha256(prompt.encode()).hexdigest() @staticmethod def _get_cached_prompt_embeds(prompts, cache_instance, cache_suffix, device=None): """Split prompts into cached and new, returning indices for reconstruction. Args: prompts: List of prompt strings to check against cache. cache_instance: CachedFluxTextEncoderStep instance with cache, or None. cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip"). device: Optional device to move cached tensors to. Returns: tuple: (cached_embeds, prompts_to_encode, prompt_indices) - cached_embeds: List of (idx, embedding) tuples for cached prompts - prompts_to_encode: List of prompts that need encoding - prompt_indices: List of original indices for prompts_to_encode """ cached_embeds = [] prompts_to_encode = [] prompt_indices = [] for idx, prompt in enumerate(prompts): cache_key = CachedFluxTextEncoderStep._to_cache_key(prompt + cache_suffix) if ( cache_instance and cache_instance.cache and cache_key in cache_instance.cache ): cached_tensor = cache_instance.cache[cache_key] # Move tensor to the correct device if specified if device is not None and cached_tensor.device != device: cached_tensor = cached_tensor.to(device) cached_embeds.append((idx, cached_tensor)) else: prompts_to_encode.append(prompt) prompt_indices.append(idx) return cached_embeds, prompts_to_encode, prompt_indices @staticmethod def _cache_prompt_embeds( prompts, prompt_indices, prompt_embeds, cache_instance, cache_suffix ): """Store newly computed embeddings in cache and save to disk. Args: prompts: Original full list of prompts. prompt_indices: Indices of newly encoded prompts in the original list. prompt_embeds: Newly computed embeddings tensor. cache_instance: CachedFluxTextEncoderStep instance with cache, or None. cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip"). """ if not cache_instance or cache_instance.cache is None: return for i, idx in enumerate(prompt_indices): cache_key = CachedFluxTextEncoderStep._to_cache_key( prompts[idx] + cache_suffix ) # Store in memory cache on CPU to save GPU memory tensor_slice = prompt_embeds[i : i + 1] cache_instance.cache[cache_key] = tensor_slice # Save updated cache to disk cache_instance.save_cache_to_disk() @staticmethod def _merge_cached_prompt_embeds( cached_embeds, prompt_indices, prompt_embeds, batch_size ): """Merge cached and newly computed embeddings back into original batch order. Args: cached_embeds: List of (idx, embedding) tuples from cache. prompt_indices: Indices where new embeddings should be placed. prompt_embeds: Newly computed embeddings tensor, or None if all cached. batch_size: Total batch size for output tensor. Returns: torch.Tensor: Combined embeddings tensor in correct batch order. """ all_embeds = [None] * batch_size # Place cached embeddings for idx, embed in cached_embeds: all_embeds[idx] = embed # Place new embeddings if prompt_embeds is not None: for i, idx in enumerate(prompt_indices): all_embeds[idx] = prompt_embeds[i : i + 1] return torch.cat(all_embeds, dim=0) @staticmethod def _get_t5_prompt_embeds( components, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, max_sequence_length: int = 512, device: torch.device = None, cache_instance=None, ): """Encode prompts using T5 text encoder with caching support. Args: components: Pipeline components containing T5 encoder and tokenizer. prompt: Prompt(s) to encode. num_images_per_prompt: Number of images per prompt for duplication. max_sequence_length: Maximum sequence length for tokenization. device: Device to place tensors on. cache_instance: CachedFluxTextEncoderStep instance for caching, or None. Returns: torch.Tensor: T5 prompt embeddings ready for diffusion model. """ dtype = components.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) cached_embeds, prompts_to_encode, prompt_indices = ( CachedFluxTextEncoderStep._get_cached_prompt_embeds( prompt, cache_instance, "_t5", device ) ) if not prompts_to_encode: prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( cached_embeds, prompt_indices, None, batch_size ) _, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) return prompt_embeds if isinstance(components, TextualInversionLoaderMixin): prompts_to_encode = components.maybe_convert_prompt( prompts_to_encode, components.tokenizer_2 ) text_inputs = components.tokenizer_2( prompts_to_encode, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids # Check for truncation untruncated_ids = components.tokenizer_2( prompts_to_encode, padding="longest", return_tensors="pt" ).input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): removed_text = components.tokenizer_2.batch_decode( untruncated_ids[:, max_sequence_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) prompt_embeds = components.text_encoder_2( text_input_ids.to(device), output_hidden_states=False )[0] CachedFluxTextEncoderStep._cache_prompt_embeds( prompt, prompt_indices, prompt_embeds, cache_instance, "_t5" ) prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( cached_embeds, prompt_indices, prompt_embeds, batch_size ) _, seq_len, _ = prompt_embeds.shape # Duplicate for num_images_per_prompt prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view( batch_size * num_images_per_prompt, seq_len, -1 ) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds @staticmethod def _get_clip_prompt_embeds( components, prompt: Union[str, List[str]] = None, num_images_per_prompt: int = 1, device: torch.device = None, cache_instance=None, ): """Encode prompts using CLIP text encoder with caching support. Args: components: Pipeline components containing CLIP encoder and tokenizer. prompt: Prompt(s) to encode. num_images_per_prompt: Number of images per prompt for duplication. device: Device to place tensors on. cache_instance: CachedFluxTextEncoderStep instance for caching, or None. Returns: torch.Tensor: CLIP pooled prompt embeddings ready for diffusion model. """ prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) # Split cached and new prompts cached_embeds, prompts_to_encode, prompt_indices = ( CachedFluxTextEncoderStep._get_cached_prompt_embeds( prompt, cache_instance, "_clip", device ) ) # Early return if all prompts are cached if not prompts_to_encode: prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( cached_embeds, prompt_indices, None, batch_size ) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds if prompts_to_encode: if isinstance(components, TextualInversionLoaderMixin): prompts_to_encode = components.maybe_convert_prompt( prompts_to_encode, components.tokenizer ) text_inputs = components.tokenizer( prompts_to_encode, padding="max_length", max_length=components.tokenizer.model_max_length, truncation=True, return_overflowing_tokens=False, return_length=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids tokenizer_max_length = components.tokenizer.model_max_length untruncated_ids = components.tokenizer( prompts_to_encode, padding="longest", return_tensors="pt" ).input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[ -1 ] and not torch.equal(text_input_ids, untruncated_ids): removed_text = components.tokenizer.batch_decode( untruncated_ids[:, tokenizer_max_length - 1 : -1] ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = components.text_encoder( text_input_ids.to(device), output_hidden_states=False ) # Use pooled output of CLIPTextModel prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to( dtype=components.text_encoder.dtype, device=device ) # Cache the new embeddings CachedFluxTextEncoderStep._cache_prompt_embeds( prompt, prompt_indices, prompt_embeds, cache_instance, "_clip" ) # Combine cached and newly encoded embeddings in correct order prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( cached_embeds, prompt_indices, prompt_embeds if prompts_to_encode else None, batch_size, ) # Duplicate for num_images_per_prompt prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) return prompt_embeds @staticmethod def encode_prompt( components, prompt: Union[str, List[str]] = None, prompt_2: Union[str, List[str]] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, cache_instance: Optional["CachedFluxTextEncoderStep"] = None, ): r""" Encodes the prompt into text encoder hidden states. Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded prompt_2 (`str` or `List[str]`, *optional*): The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is used in all text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): number of images that should be generated per prompt prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. pooled_prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled text embeddings will be generated from `prompt` input argument. lora_scale (`float`, *optional*): A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. """ device = device or components._execution_device # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin): components._lora_scale = lora_scale # dynamically adjust the LoRA scale if components.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(components.text_encoder, lora_scale) if components.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(components.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt if prompt_embeds is None: prompt_2 = prompt_2 or prompt prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 # We only use the pooled prompt output from the CLIPTextModel pooled_prompt_embeds = CachedFluxTextEncoderStep._get_clip_prompt_embeds( components, prompt=prompt, device=device, num_images_per_prompt=num_images_per_prompt, cache_instance=cache_instance, ) prompt_embeds = CachedFluxTextEncoderStep._get_t5_prompt_embeds( components, prompt=prompt_2, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, cache_instance=cache_instance, ) if components.text_encoder is not None: if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(components.text_encoder, lora_scale) if components.text_encoder_2 is not None: if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(components.text_encoder_2, lora_scale) dtype = ( components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16 ) text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids @torch.no_grad() def __call__( self, components: FluxModularPipeline, state: PipelineState ) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) self.check_inputs(block_state) block_state.device = components._execution_device # Encode input prompt block_state.text_encoder_lora_scale = ( block_state.joint_attention_kwargs.get("scale", None) if block_state.joint_attention_kwargs is not None else None ) ( block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids, ) = self.encode_prompt( components, prompt=block_state.prompt, prompt_2=None, prompt_embeds=None, pooled_prompt_embeds=None, device=block_state.device, num_images_per_prompt=1, # TODO: hardcoded for now. max_sequence_length=512, lora_scale=block_state.text_encoder_lora_scale, cache_instance=self if self.cache is not None else None, # Pass self as cache_instance ) # Add outputs self.set_block_state(state, block_state) return components, state