|
|
import hashlib |
|
|
import os |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import torch |
|
|
from diffusers import FluxModularPipeline, ModularPipelineBlocks |
|
|
from diffusers.loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin |
|
|
from diffusers.modular_pipelines import PipelineState |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
from diffusers.utils import ( |
|
|
USE_PEFT_BACKEND, |
|
|
logger, |
|
|
scale_lora_layers, |
|
|
unscale_lora_layers, |
|
|
) |
|
|
from safetensors import safe_open |
|
|
from safetensors.torch import save_file |
|
|
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast |
|
|
|
|
|
|
|
|
class CachedFluxTextEncoderStep(ModularPipelineBlocks): |
|
|
model_name = "flux" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
use_cache: bool = True, |
|
|
cache_dir: Optional[str] = None, |
|
|
load_from_disk: bool = True, |
|
|
) -> None: |
|
|
"""Initialize the cached Flux text encoder step. |
|
|
|
|
|
Args: |
|
|
use_cache: Whether to enable caching of prompt embeddings. Defaults to True. |
|
|
cache_dir: Directory to store cache files. If None, uses ~/.cache/flux_prompt_cache. |
|
|
load_from_disk: Whether to load existing cache from disk on initialization. Defaults to True. |
|
|
""" |
|
|
super().__init__() |
|
|
self.cache = {} if use_cache else None |
|
|
if use_cache: |
|
|
self.cache_dir = cache_dir or os.path.join( |
|
|
os.path.expanduser("~"), ".cache", "flux_prompt_cache" |
|
|
) |
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
else: |
|
|
self.cache_dir = None |
|
|
|
|
|
|
|
|
if load_from_disk and use_cache: |
|
|
self.load_cache_from_disk() |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Text Encoder step that generate text_embeddings to guide the video generation" |
|
|
|
|
|
@property |
|
|
def expected_components(self): |
|
|
return [ |
|
|
ComponentSpec("text_encoder", CLIPTextModel), |
|
|
ComponentSpec("tokenizer", CLIPTokenizer), |
|
|
ComponentSpec("text_encoder_2", T5EncoderModel), |
|
|
ComponentSpec("tokenizer_2", T5TokenizerFast), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self): |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam("prompt"), |
|
|
InputParam("prompt_2"), |
|
|
InputParam("joint_attention_kwargs"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self): |
|
|
return [ |
|
|
OutputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="text embeddings used to guide the image generation", |
|
|
), |
|
|
OutputParam( |
|
|
"pooled_prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="pooled text embeddings used to guide the image generation", |
|
|
), |
|
|
OutputParam( |
|
|
"text_ids", |
|
|
type_hint=torch.Tensor, |
|
|
description="ids from the text sequence for RoPE", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(block_state): |
|
|
for prompt in [block_state.prompt, block_state.prompt_2]: |
|
|
if prompt is not None and ( |
|
|
not isinstance(prompt, str) and not isinstance(prompt, list) |
|
|
): |
|
|
raise ValueError( |
|
|
f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}" |
|
|
) |
|
|
|
|
|
def save_cache_to_disk(self): |
|
|
"""Save the current cache to disk as a safetensors file.""" |
|
|
if not self.cache or not self.cache_dir: |
|
|
return |
|
|
|
|
|
cache_file = os.path.join(self.cache_dir, "cache.safetensors") |
|
|
|
|
|
|
|
|
tensors_to_save = {} |
|
|
for key, tensor in self.cache.items(): |
|
|
|
|
|
cpu_tensor = ( |
|
|
tensor.cpu() if tensor.device != torch.device("cpu") else tensor |
|
|
) |
|
|
tensors_to_save[key] = cpu_tensor |
|
|
|
|
|
|
|
|
save_file(tensors_to_save, cache_file) |
|
|
logger.info(f"Saved {len(tensors_to_save)} cached embeddings to {cache_file}") |
|
|
|
|
|
def load_cache_from_disk(self): |
|
|
"""Load cache from disk using memory-mapped safetensors.""" |
|
|
if not self.cache_dir or self.cache is None: |
|
|
return |
|
|
|
|
|
cache_file = os.path.join(self.cache_dir, "cache.safetensors") |
|
|
|
|
|
if not os.path.exists(cache_file): |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
with safe_open(cache_file, framework="pt", device="cpu") as f: |
|
|
loaded_count = 0 |
|
|
for key in f.keys(): |
|
|
self.cache[key] = f.get_tensor(key) |
|
|
loaded_count += 1 |
|
|
|
|
|
logger.debug( |
|
|
f"Loaded {loaded_count} cached embeddings from {cache_file} (memory-mapped)" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load cache from disk: {e}") |
|
|
|
|
|
def clear_cache_from_disk(self): |
|
|
"""Clear cached safetensors file from disk.""" |
|
|
if not self.cache_dir: |
|
|
return |
|
|
|
|
|
cache_file = os.path.join(self.cache_dir, "cache.safetensors") |
|
|
if os.path.exists(cache_file): |
|
|
os.remove(cache_file) |
|
|
logger.info(f"Cleared cache file: {cache_file}") |
|
|
|
|
|
|
|
|
if self.cache: |
|
|
self.cache.clear() |
|
|
|
|
|
def get_cache_size(self): |
|
|
"""Get the current cache size in MB.""" |
|
|
if not self.cache_dir: |
|
|
return 0 |
|
|
|
|
|
cache_file = os.path.join(self.cache_dir, "cache.safetensors") |
|
|
if os.path.exists(cache_file): |
|
|
return os.path.getsize(cache_file) / (1024 * 1024) |
|
|
return 0 |
|
|
|
|
|
@staticmethod |
|
|
def _to_cache_key(prompt: str) -> str: |
|
|
"""Generate a hash key for a single prompt string.""" |
|
|
return hashlib.sha256(prompt.encode()).hexdigest() |
|
|
|
|
|
@staticmethod |
|
|
def _get_cached_prompt_embeds(prompts, cache_instance, cache_suffix, device=None): |
|
|
"""Split prompts into cached and new, returning indices for reconstruction. |
|
|
|
|
|
Args: |
|
|
prompts: List of prompt strings to check against cache. |
|
|
cache_instance: CachedFluxTextEncoderStep instance with cache, or None. |
|
|
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip"). |
|
|
device: Optional device to move cached tensors to. |
|
|
|
|
|
Returns: |
|
|
tuple: (cached_embeds, prompts_to_encode, prompt_indices) |
|
|
- cached_embeds: List of (idx, embedding) tuples for cached prompts |
|
|
- prompts_to_encode: List of prompts that need encoding |
|
|
- prompt_indices: List of original indices for prompts_to_encode |
|
|
""" |
|
|
cached_embeds = [] |
|
|
prompts_to_encode = [] |
|
|
prompt_indices = [] |
|
|
|
|
|
for idx, prompt in enumerate(prompts): |
|
|
cache_key = CachedFluxTextEncoderStep._to_cache_key(prompt + cache_suffix) |
|
|
if ( |
|
|
cache_instance |
|
|
and cache_instance.cache |
|
|
and cache_key in cache_instance.cache |
|
|
): |
|
|
cached_tensor = cache_instance.cache[cache_key] |
|
|
|
|
|
if device is not None and cached_tensor.device != device: |
|
|
cached_tensor = cached_tensor.to(device) |
|
|
cached_embeds.append((idx, cached_tensor)) |
|
|
else: |
|
|
prompts_to_encode.append(prompt) |
|
|
prompt_indices.append(idx) |
|
|
|
|
|
return cached_embeds, prompts_to_encode, prompt_indices |
|
|
|
|
|
@staticmethod |
|
|
def _cache_prompt_embeds( |
|
|
prompts, prompt_indices, prompt_embeds, cache_instance, cache_suffix |
|
|
): |
|
|
"""Store newly computed embeddings in cache and save to disk. |
|
|
|
|
|
Args: |
|
|
prompts: Original full list of prompts. |
|
|
prompt_indices: Indices of newly encoded prompts in the original list. |
|
|
prompt_embeds: Newly computed embeddings tensor. |
|
|
cache_instance: CachedFluxTextEncoderStep instance with cache, or None. |
|
|
cache_suffix: Suffix to append to cache keys (e.g., "_t5", "_clip"). |
|
|
""" |
|
|
if not cache_instance or cache_instance.cache is None: |
|
|
return |
|
|
|
|
|
for i, idx in enumerate(prompt_indices): |
|
|
cache_key = CachedFluxTextEncoderStep._to_cache_key( |
|
|
prompts[idx] + cache_suffix |
|
|
) |
|
|
|
|
|
tensor_slice = prompt_embeds[i : i + 1] |
|
|
cache_instance.cache[cache_key] = tensor_slice |
|
|
|
|
|
|
|
|
cache_instance.save_cache_to_disk() |
|
|
|
|
|
@staticmethod |
|
|
def _merge_cached_prompt_embeds( |
|
|
cached_embeds, prompt_indices, prompt_embeds, batch_size |
|
|
): |
|
|
"""Merge cached and newly computed embeddings back into original batch order. |
|
|
|
|
|
Args: |
|
|
cached_embeds: List of (idx, embedding) tuples from cache. |
|
|
prompt_indices: Indices where new embeddings should be placed. |
|
|
prompt_embeds: Newly computed embeddings tensor, or None if all cached. |
|
|
batch_size: Total batch size for output tensor. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: Combined embeddings tensor in correct batch order. |
|
|
""" |
|
|
all_embeds = [None] * batch_size |
|
|
|
|
|
|
|
|
for idx, embed in cached_embeds: |
|
|
all_embeds[idx] = embed |
|
|
|
|
|
|
|
|
if prompt_embeds is not None: |
|
|
for i, idx in enumerate(prompt_indices): |
|
|
all_embeds[idx] = prompt_embeds[i : i + 1] |
|
|
|
|
|
return torch.cat(all_embeds, dim=0) |
|
|
|
|
|
@staticmethod |
|
|
def _get_t5_prompt_embeds( |
|
|
components, |
|
|
prompt: Union[str, List[str]] = None, |
|
|
num_images_per_prompt: int = 1, |
|
|
max_sequence_length: int = 512, |
|
|
device: torch.device = None, |
|
|
cache_instance=None, |
|
|
): |
|
|
"""Encode prompts using T5 text encoder with caching support. |
|
|
|
|
|
Args: |
|
|
components: Pipeline components containing T5 encoder and tokenizer. |
|
|
prompt: Prompt(s) to encode. |
|
|
num_images_per_prompt: Number of images per prompt for duplication. |
|
|
max_sequence_length: Maximum sequence length for tokenization. |
|
|
device: Device to place tensors on. |
|
|
cache_instance: CachedFluxTextEncoderStep instance for caching, or None. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: T5 prompt embeddings ready for diffusion model. |
|
|
""" |
|
|
dtype = components.text_encoder_2.dtype |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
batch_size = len(prompt) |
|
|
|
|
|
cached_embeds, prompts_to_encode, prompt_indices = ( |
|
|
CachedFluxTextEncoderStep._get_cached_prompt_embeds( |
|
|
prompt, cache_instance, "_t5", device |
|
|
) |
|
|
) |
|
|
|
|
|
if not prompts_to_encode: |
|
|
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( |
|
|
cached_embeds, prompt_indices, None, batch_size |
|
|
) |
|
|
_, seq_len, _ = prompt_embeds.shape |
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
|
prompt_embeds = prompt_embeds.view( |
|
|
batch_size * num_images_per_prompt, seq_len, -1 |
|
|
) |
|
|
return prompt_embeds |
|
|
|
|
|
if isinstance(components, TextualInversionLoaderMixin): |
|
|
prompts_to_encode = components.maybe_convert_prompt( |
|
|
prompts_to_encode, components.tokenizer_2 |
|
|
) |
|
|
|
|
|
text_inputs = components.tokenizer_2( |
|
|
prompts_to_encode, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
return_length=False, |
|
|
return_overflowing_tokens=False, |
|
|
return_tensors="pt", |
|
|
) |
|
|
text_input_ids = text_inputs.input_ids |
|
|
|
|
|
|
|
|
untruncated_ids = components.tokenizer_2( |
|
|
prompts_to_encode, padding="longest", return_tensors="pt" |
|
|
).input_ids |
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( |
|
|
text_input_ids, untruncated_ids |
|
|
): |
|
|
removed_text = components.tokenizer_2.batch_decode( |
|
|
untruncated_ids[:, max_sequence_length - 1 : -1] |
|
|
) |
|
|
logger.warning( |
|
|
"The following part of your input was truncated because `max_sequence_length` is set to " |
|
|
f" {max_sequence_length} tokens: {removed_text}" |
|
|
) |
|
|
|
|
|
prompt_embeds = components.text_encoder_2( |
|
|
text_input_ids.to(device), output_hidden_states=False |
|
|
)[0] |
|
|
|
|
|
CachedFluxTextEncoderStep._cache_prompt_embeds( |
|
|
prompt, prompt_indices, prompt_embeds, cache_instance, "_t5" |
|
|
) |
|
|
|
|
|
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( |
|
|
cached_embeds, prompt_indices, prompt_embeds, batch_size |
|
|
) |
|
|
_, seq_len, _ = prompt_embeds.shape |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) |
|
|
prompt_embeds = prompt_embeds.view( |
|
|
batch_size * num_images_per_prompt, seq_len, -1 |
|
|
) |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
@staticmethod |
|
|
def _get_clip_prompt_embeds( |
|
|
components, |
|
|
prompt: Union[str, List[str]] = None, |
|
|
num_images_per_prompt: int = 1, |
|
|
device: torch.device = None, |
|
|
cache_instance=None, |
|
|
): |
|
|
"""Encode prompts using CLIP text encoder with caching support. |
|
|
|
|
|
Args: |
|
|
components: Pipeline components containing CLIP encoder and tokenizer. |
|
|
prompt: Prompt(s) to encode. |
|
|
num_images_per_prompt: Number of images per prompt for duplication. |
|
|
device: Device to place tensors on. |
|
|
cache_instance: CachedFluxTextEncoderStep instance for caching, or None. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: CLIP pooled prompt embeddings ready for diffusion model. |
|
|
""" |
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
batch_size = len(prompt) |
|
|
|
|
|
|
|
|
cached_embeds, prompts_to_encode, prompt_indices = ( |
|
|
CachedFluxTextEncoderStep._get_cached_prompt_embeds( |
|
|
prompt, cache_instance, "_clip", device |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if not prompts_to_encode: |
|
|
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( |
|
|
cached_embeds, prompt_indices, None, batch_size |
|
|
) |
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) |
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) |
|
|
return prompt_embeds |
|
|
|
|
|
if prompts_to_encode: |
|
|
if isinstance(components, TextualInversionLoaderMixin): |
|
|
prompts_to_encode = components.maybe_convert_prompt( |
|
|
prompts_to_encode, components.tokenizer |
|
|
) |
|
|
|
|
|
text_inputs = components.tokenizer( |
|
|
prompts_to_encode, |
|
|
padding="max_length", |
|
|
max_length=components.tokenizer.model_max_length, |
|
|
truncation=True, |
|
|
return_overflowing_tokens=False, |
|
|
return_length=False, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
text_input_ids = text_inputs.input_ids |
|
|
tokenizer_max_length = components.tokenizer.model_max_length |
|
|
untruncated_ids = components.tokenizer( |
|
|
prompts_to_encode, padding="longest", return_tensors="pt" |
|
|
).input_ids |
|
|
|
|
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[ |
|
|
-1 |
|
|
] and not torch.equal(text_input_ids, untruncated_ids): |
|
|
removed_text = components.tokenizer.batch_decode( |
|
|
untruncated_ids[:, tokenizer_max_length - 1 : -1] |
|
|
) |
|
|
logger.warning( |
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to" |
|
|
f" {tokenizer_max_length} tokens: {removed_text}" |
|
|
) |
|
|
|
|
|
prompt_embeds = components.text_encoder( |
|
|
text_input_ids.to(device), output_hidden_states=False |
|
|
) |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds.pooler_output |
|
|
prompt_embeds = prompt_embeds.to( |
|
|
dtype=components.text_encoder.dtype, device=device |
|
|
) |
|
|
|
|
|
|
|
|
CachedFluxTextEncoderStep._cache_prompt_embeds( |
|
|
prompt, prompt_indices, prompt_embeds, cache_instance, "_clip" |
|
|
) |
|
|
|
|
|
|
|
|
prompt_embeds = CachedFluxTextEncoderStep._merge_cached_prompt_embeds( |
|
|
cached_embeds, |
|
|
prompt_indices, |
|
|
prompt_embeds if prompts_to_encode else None, |
|
|
batch_size, |
|
|
) |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) |
|
|
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) |
|
|
|
|
|
return prompt_embeds |
|
|
|
|
|
@staticmethod |
|
|
def encode_prompt( |
|
|
components, |
|
|
prompt: Union[str, List[str]] = None, |
|
|
prompt_2: Union[str, List[str]] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
num_images_per_prompt: int = 1, |
|
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
|
max_sequence_length: int = 512, |
|
|
lora_scale: Optional[float] = None, |
|
|
cache_instance: Optional["CachedFluxTextEncoderStep"] = None, |
|
|
): |
|
|
r""" |
|
|
Encodes the prompt into text encoder hidden states. |
|
|
|
|
|
Args: |
|
|
prompt (`str` or `List[str]`, *optional*): |
|
|
prompt to be encoded |
|
|
prompt_2 (`str` or `List[str]`, *optional*): |
|
|
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is |
|
|
used in all text-encoders |
|
|
device: (`torch.device`): |
|
|
torch device |
|
|
num_images_per_prompt (`int`): |
|
|
number of images that should be generated per prompt |
|
|
prompt_embeds (`torch.FloatTensor`, *optional*): |
|
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not |
|
|
provided, text embeddings will be generated from `prompt` input argument. |
|
|
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): |
|
|
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. |
|
|
If not provided, pooled text embeddings will be generated from `prompt` input argument. |
|
|
lora_scale (`float`, *optional*): |
|
|
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. |
|
|
""" |
|
|
device = device or components._execution_device |
|
|
|
|
|
|
|
|
|
|
|
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin): |
|
|
components._lora_scale = lora_scale |
|
|
|
|
|
|
|
|
if components.text_encoder is not None and USE_PEFT_BACKEND: |
|
|
scale_lora_layers(components.text_encoder, lora_scale) |
|
|
if components.text_encoder_2 is not None and USE_PEFT_BACKEND: |
|
|
scale_lora_layers(components.text_encoder_2, lora_scale) |
|
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
|
|
|
if prompt_embeds is None: |
|
|
prompt_2 = prompt_2 or prompt |
|
|
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 |
|
|
|
|
|
|
|
|
pooled_prompt_embeds = CachedFluxTextEncoderStep._get_clip_prompt_embeds( |
|
|
components, |
|
|
prompt=prompt, |
|
|
device=device, |
|
|
num_images_per_prompt=num_images_per_prompt, |
|
|
cache_instance=cache_instance, |
|
|
) |
|
|
prompt_embeds = CachedFluxTextEncoderStep._get_t5_prompt_embeds( |
|
|
components, |
|
|
prompt=prompt_2, |
|
|
num_images_per_prompt=num_images_per_prompt, |
|
|
max_sequence_length=max_sequence_length, |
|
|
device=device, |
|
|
cache_instance=cache_instance, |
|
|
) |
|
|
|
|
|
if components.text_encoder is not None: |
|
|
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
|
|
unscale_lora_layers(components.text_encoder, lora_scale) |
|
|
|
|
|
if components.text_encoder_2 is not None: |
|
|
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND: |
|
|
|
|
|
unscale_lora_layers(components.text_encoder_2, lora_scale) |
|
|
|
|
|
dtype = ( |
|
|
components.text_encoder.dtype |
|
|
if components.text_encoder is not None |
|
|
else torch.bfloat16 |
|
|
) |
|
|
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) |
|
|
|
|
|
return prompt_embeds, pooled_prompt_embeds, text_ids |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: FluxModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
|
|
|
block_state = self.get_block_state(state) |
|
|
self.check_inputs(block_state) |
|
|
|
|
|
block_state.device = components._execution_device |
|
|
|
|
|
|
|
|
block_state.text_encoder_lora_scale = ( |
|
|
block_state.joint_attention_kwargs.get("scale", None) |
|
|
if block_state.joint_attention_kwargs is not None |
|
|
else None |
|
|
) |
|
|
( |
|
|
block_state.prompt_embeds, |
|
|
block_state.pooled_prompt_embeds, |
|
|
block_state.text_ids, |
|
|
) = self.encode_prompt( |
|
|
components, |
|
|
prompt=block_state.prompt, |
|
|
prompt_2=None, |
|
|
prompt_embeds=None, |
|
|
pooled_prompt_embeds=None, |
|
|
device=block_state.device, |
|
|
num_images_per_prompt=1, |
|
|
max_sequence_length=512, |
|
|
lora_scale=block_state.text_encoder_lora_scale, |
|
|
cache_instance=self |
|
|
if self.cache is not None |
|
|
else None, |
|
|
) |
|
|
|
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|