dn6's picture
dn6 HF Staff
Upload folder using huggingface_hub
56f2217 verified
import hashlib
import os
from typing import List, Optional, Union
import torch
from diffusers import FluxModularPipeline, ModularPipelineBlocks
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.modular_pipelines import PipelineState
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
from diffusers.utils import (
USE_PEFT_BACKEND,
logger,
scale_lora_layers,
unscale_lora_layers,
)
from safetensors import safe_open
from safetensors.torch import save_file
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
class CachedFluxTextEncoderStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(
self,
use_cache: bool = True,
cache_dir: Optional[str] = None,
load_from_disk: bool = True,
) -> None:
"""Initialize the cached Flux text encoder step.
Args:
use_cache: Whether to enable caching of prompt embeddings. Defaults to True.
cache_dir: Directory to store cache files. If None, uses ~/.cache/flux_prompt_cache.
load_from_disk: Whether to load existing cache from disk on initialization. Defaults to True.
"""
super().__init__()
self.cache = {} if use_cache else None
if use_cache:
self.cache_dir = cache_dir or os.path.join(
os.path.expanduser("~"), ".cache", "flux_prompt_cache"
)
os.makedirs(self.cache_dir, exist_ok=True)
else:
self.cache_dir = None
# Load existing cache if requested
if load_from_disk and use_cache:
self.load_cache_from_disk()
@property
def description(self) -> str:
return "Text Encoder step that generate text_embeddings to guide the video generation"
@property
def expected_components(self):
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("text_encoder_2", T5EncoderModel),
ComponentSpec("tokenizer_2", T5TokenizerFast),
]
@property
def expected_configs(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("joint_attention_kwargs"),
]
@property
def intermediate_outputs(self):
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"text_ids",
type_hint=torch.Tensor,
description="ids from the text sequence for RoPE",
),
]
@staticmethod
def check_inputs(block_state):
for prompt in [block_state.prompt, block_state.prompt_2]:
if prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}"
)
def save_cache_to_disk(self):
"""Save the current cache to disk as a safetensors file."""
if not self.cache or not self.cache_dir:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
# Prepare tensors dict for safetensors
tensors_to_save = {}
for key, tensor in self.cache.items():
# Ensure tensor is on CPU before saving
cpu_tensor = (
tensor.cpu() if tensor.device != torch.device("cpu") else tensor
)
tensors_to_save[key] = cpu_tensor
# Save tensors
save_file(tensors_to_save, cache_file)
logger.info(f"Saved {len(tensors_to_save)} cached embeddings to {cache_file}")
def load_cache_from_disk(self):
"""Load cache from disk using memory-mapped safetensors."""
if not self.cache_dir or self.cache is None:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if not os.path.exists(cache_file):
return
try:
# Open safetensors file in context manager
with safe_open(cache_file, framework="pt", device="cpu") as f:
loaded_count = 0
for key in f.keys():
self.cache[key] = f.get_tensor(key)
loaded_count += 1
logger.debug(
f"Loaded {loaded_count} cached embeddings from {cache_file} (memory-mapped)"
)
except Exception as e:
logger.warning(f"Failed to load cache from disk: {e}")
def clear_cache_from_disk(self):
"""Clear cached safetensors file from disk."""
if not self.cache_dir:
return
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if os.path.exists(cache_file):
os.remove(cache_file)
logger.info(f"Cleared cache file: {cache_file}")
# Also clear the in-memory cache
if self.cache:
self.cache.clear()
def get_cache_size(self):
"""Get the current cache size in MB."""
if not self.cache_dir:
return 0
cache_file = os.path.join(self.cache_dir, "cache.safetensors")
if os.path.exists(cache_file):
return os.path.getsize(cache_file) / (1024 * 1024) # Convert to MB
return 0
@staticmethod
def _to_cache_key(prompt: str) -> str:
"""Generate a hash key for a single prompt string."""
return hashlib.sha256(prompt.encode()).hexdigest()
@staticmethod
def _get_cached_prompt_embeds(prompts, cache_instance, cache_suffix, device=None):
"""Split prompts into cached and new, returning indices for reconstruction.
Args:
prompts: List of prompt strings to check against cache.
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
device: Optional device to move cached tensors to.
Returns:
tuple: (cached_embeds, prompts_to_encode, prompt_indices)
- cached_embeds: List of (idx, embedding) tuples for cached prompts
- prompts_to_encode: List of prompts that need encoding
- prompt_indices: List of original indices for prompts_to_encode
"""
cached_embeds = []
prompts_to_encode = []
prompt_indices = []
for idx, prompt in enumerate(prompts):
cache_key = CachedFluxTextEncoderStep._to_cache_key(prompt + cache_suffix)
if (
cache_instance
and cache_instance.cache
and cache_key in cache_instance.cache
):
cached_tensor = cache_instance.cache[cache_key]
# Move tensor to the correct device if specified
if device is not None and cached_tensor.device != device:
cached_tensor = cached_tensor.to(device)
cached_embeds.append((idx, cached_tensor))
else:
prompts_to_encode.append(prompt)
prompt_indices.append(idx)
return cached_embeds, prompts_to_encode, prompt_indices
@staticmethod
def _cache_prompt_embeds(
prompts, prompt_indices, prompt_embeds, cache_instance, cache_suffix
):
"""Store newly computed embeddings in cache and save to disk.
Args:
prompts: Original full list of prompts.
prompt_indices: Indices of newly encoded prompts in the original list.
prompt_embeds: Newly computed embeddings tensor.
cache_instance: CachedFluxTextEncoderStep instance with cache, or None.
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip").
"""
if not cache_instance or cache_instance.cache is None:
return
for i, idx in enumerate(prompt_indices):
cache_key = CachedFluxTextEncoderStep._to_cache_key(
prompts[idx] + cache_suffix
)
# Store in memory cache on CPU to save GPU memory
tensor_slice = prompt_embeds[i : i + 1]
cache_instance.cache[cache_key] = tensor_slice
# Save updated cache to disk
cache_instance.save_cache_to_disk()
@staticmethod
def _merge_cached_prompt_embeds(
cached_embeds, prompt_indices, prompt_embeds, batch_size
):
"""Merge cached and newly computed embeddings back into original batch order.
Args:
cached_embeds: List of (idx, embedding) tuples from cache.
prompt_indices: Indices where new embeddings should be placed.
prompt_embeds: Newly computed embeddings tensor, or None if all cached.
batch_size: Total batch size for output tensor.
Returns:
torch.Tensor: Combined embeddings tensor in correct batch order.
"""
all_embeds = [None] * batch_size
# Place cached embeddings
for idx, embed in cached_embeds:
all_embeds[idx] = embed
# Place new embeddings
if prompt_embeds is not None:
for i, idx in enumerate(prompt_indices):
all_embeds[idx] = prompt_embeds[i : i + 1]
return torch.cat(all_embeds, dim=0)
@staticmethod
def _get_t5_prompt_embeds(
components,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 512,
device: torch.device = None,
cache_instance=None,
):
"""Encode prompts using T5 text encoder with caching support.
Args:
components: Pipeline components containing T5 encoder and tokenizer.
prompt: Prompt(s) to encode.
num_images_per_prompt: Number of images per prompt for duplication.
max_sequence_length: Maximum sequence length for tokenization.
device: Device to place tensors on.
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
Returns:
torch.Tensor: T5 prompt embeddings ready for diffusion model.
"""
dtype = components.text_encoder_2.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
cached_embeds, prompts_to_encode, prompt_indices = (
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
prompt, cache_instance, "_t5", device
)
)
if not prompts_to_encode:
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, None, batch_size
)
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
return prompt_embeds
if isinstance(components, TextualInversionLoaderMixin):
prompts_to_encode = components.maybe_convert_prompt(
prompts_to_encode, components.tokenizer_2
)
text_inputs = components.tokenizer_2(
prompts_to_encode,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_length=False,
return_overflowing_tokens=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
# Check for truncation
untruncated_ids = components.tokenizer_2(
prompts_to_encode, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = components.tokenizer_2.batch_decode(
untruncated_ids[:, max_sequence_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder_2(
text_input_ids.to(device), output_hidden_states=False
)[0]
CachedFluxTextEncoderStep._cache_prompt_embeds(
prompt, prompt_indices, prompt_embeds, cache_instance, "_t5"
)
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, prompt_embeds, batch_size
)
_, seq_len, _ = prompt_embeds.shape
# Duplicate for num_images_per_prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds
@staticmethod
def _get_clip_prompt_embeds(
components,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
device: torch.device = None,
cache_instance=None,
):
"""Encode prompts using CLIP text encoder with caching support.
Args:
components: Pipeline components containing CLIP encoder and tokenizer.
prompt: Prompt(s) to encode.
num_images_per_prompt: Number of images per prompt for duplication.
device: Device to place tensors on.
cache_instance: CachedFluxTextEncoderStep instance for caching, or None.
Returns:
torch.Tensor: CLIP pooled prompt embeddings ready for diffusion model.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
# Split cached and new prompts
cached_embeds, prompts_to_encode, prompt_indices = (
CachedFluxTextEncoderStep._get_cached_prompt_embeds(
prompt, cache_instance, "_clip", device
)
)
# Early return if all prompts are cached
if not prompts_to_encode:
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds, prompt_indices, None, batch_size
)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
if prompts_to_encode:
if isinstance(components, TextualInversionLoaderMixin):
prompts_to_encode = components.maybe_convert_prompt(
prompts_to_encode, components.tokenizer
)
text_inputs = components.tokenizer(
prompts_to_encode,
padding="max_length",
max_length=components.tokenizer.model_max_length,
truncation=True,
return_overflowing_tokens=False,
return_length=False,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
tokenizer_max_length = components.tokenizer.model_max_length
untruncated_ids = components.tokenizer(
prompts_to_encode, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = components.tokenizer.batch_decode(
untruncated_ids[:, tokenizer_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = components.text_encoder(
text_input_ids.to(device), output_hidden_states=False
)
# Use pooled output of CLIPTextModel
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(
dtype=components.text_encoder.dtype, device=device
)
# Cache the new embeddings
CachedFluxTextEncoderStep._cache_prompt_embeds(
prompt, prompt_indices, prompt_embeds, cache_instance, "_clip"
)
# Combine cached and newly encoded embeddings in correct order
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds(
cached_embeds,
prompt_indices,
prompt_embeds if prompts_to_encode else None,
batch_size,
)
# Duplicate for num_images_per_prompt
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds
@staticmethod
def encode_prompt(
components,
prompt: Union[str, List[str]] = None,
prompt_2: Union[str, List[str]] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
cache_instance: Optional["CachedFluxTextEncoderStep"] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# We only use the pooled prompt output from the CLIPTextModel
pooled_prompt_embeds = CachedFluxTextEncoderStep._get_clip_prompt_embeds(
components,
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
cache_instance=cache_instance,
)
prompt_embeds = CachedFluxTextEncoderStep._get_t5_prompt_embeds(
components,
prompt=prompt_2,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
cache_instance=cache_instance,
)
if components.text_encoder is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
dtype = (
components.text_encoder.dtype
if components.text_encoder is not None
else torch.bfloat16
)
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
return prompt_embeds, pooled_prompt_embeds, text_ids
@torch.no_grad()
def __call__(
self, components: FluxModularPipeline, state: PipelineState
) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.joint_attention_kwargs.get("scale", None)
if block_state.joint_attention_kwargs is not None
else None
)
(
block_state.prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.text_ids,
) = self.encode_prompt(
components,
prompt=block_state.prompt,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
num_images_per_prompt=1, # TODO: hardcoded for now.
max_sequence_length=512,
lora_scale=block_state.text_encoder_lora_scale,
cache_instance=self
if self.cache is not None
else None, # Pass self as cache_instance
)
# Add outputs
self.set_block_state(state, block_state)
return components, state