Spaces:
Running
on
Zero
Running
on
Zero
| import contextlib | |
| import unittest | |
| import torch | |
| from . import first_block_cache | |
| class ApplyFBCacheOnModel: | |
| def patch( | |
| self, | |
| model, | |
| object_to_patch, | |
| residual_diff_threshold, | |
| max_consecutive_cache_hits=-1, | |
| start=0.0, | |
| end=1.0, | |
| ): | |
| if residual_diff_threshold <= 0.0 or max_consecutive_cache_hits == 0: | |
| return (model, ) | |
| # first_block_cache.patch_get_output_data() | |
| using_validation = max_consecutive_cache_hits >= 0 or start > 0 or end < 1 | |
| if using_validation: | |
| model_sampling = model.get_model_object("model_sampling") | |
| start_sigma, end_sigma = (float( | |
| model_sampling.percent_to_sigma(pct)) for pct in (start, end)) | |
| del model_sampling | |
| def validate_use_cache(use_cached): | |
| nonlocal consecutive_cache_hits | |
| use_cached = use_cached and end_sigma <= current_timestep <= start_sigma | |
| use_cached = use_cached and (max_consecutive_cache_hits < 0 | |
| or consecutive_cache_hits | |
| < max_consecutive_cache_hits) | |
| consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0 | |
| return use_cached | |
| else: | |
| validate_use_cache = None | |
| prev_timestep = None | |
| prev_input_state = None | |
| current_timestep = None | |
| consecutive_cache_hits = 0 | |
| def reset_cache_state(): | |
| # Resets the cache state and hits/time tracking variables. | |
| nonlocal prev_input_state, prev_timestep, consecutive_cache_hits | |
| prev_input_state = prev_timestep = None | |
| consecutive_cache_hits = 0 | |
| first_block_cache.set_current_cache_context( | |
| first_block_cache.create_cache_context()) | |
| def ensure_cache_state(model_input: torch.Tensor, timestep: float): | |
| # Validates the current cache state and hits/time tracking variables | |
| # and triggers a reset if necessary. Also updates current_timestep. | |
| nonlocal current_timestep | |
| input_state = (model_input.shape, model_input.dtype, model_input.device) | |
| need_reset = ( | |
| prev_timestep is None or | |
| prev_input_state != input_state or | |
| first_block_cache.get_current_cache_context() is None or | |
| timestep >= prev_timestep | |
| ) | |
| if need_reset: | |
| reset_cache_state() | |
| current_timestep = timestep | |
| def update_cache_state(model_input: torch.Tensor, timestep: float): | |
| # Updates the previous timestep and input state validation variables. | |
| nonlocal prev_timestep, prev_input_state | |
| prev_timestep = timestep | |
| prev_input_state = (model_input.shape, model_input.dtype, model_input.device) | |
| model = model[0].clone() | |
| diffusion_model = model.get_model_object(object_to_patch) | |
| if diffusion_model.__class__.__name__ in ("UNetModel", "Flux"): | |
| if diffusion_model.__class__.__name__ == "UNetModel": | |
| create_patch_function = first_block_cache.create_patch_unet_model__forward | |
| elif diffusion_model.__class__.__name__ == "Flux": | |
| create_patch_function = first_block_cache.create_patch_flux_forward_orig | |
| else: | |
| raise ValueError( | |
| f"Unsupported model {diffusion_model.__class__.__name__}") | |
| patch_forward = create_patch_function( | |
| diffusion_model, | |
| residual_diff_threshold=residual_diff_threshold, | |
| validate_can_use_cache_function=validate_use_cache, | |
| ) | |
| def model_unet_function_wrapper(model_function, kwargs): | |
| try: | |
| input = kwargs["input"] | |
| timestep = kwargs["timestep"] | |
| c = kwargs["c"] | |
| t = timestep[0].item() | |
| ensure_cache_state(input, t) | |
| with patch_forward(): | |
| result = model_function(input, timestep, **c) | |
| update_cache_state(input, t) | |
| return result | |
| except Exception as exc: | |
| reset_cache_state() | |
| raise exc from None | |
| else: | |
| is_non_native_ltxv = False | |
| if diffusion_model.__class__.__name__ == "LTXVTransformer3D": | |
| is_non_native_ltxv = True | |
| diffusion_model = diffusion_model.transformer | |
| double_blocks_name = None | |
| single_blocks_name = None | |
| if hasattr(diffusion_model, "transformer_blocks"): | |
| double_blocks_name = "transformer_blocks" | |
| elif hasattr(diffusion_model, "double_blocks"): | |
| double_blocks_name = "double_blocks" | |
| elif hasattr(diffusion_model, "joint_blocks"): | |
| double_blocks_name = "joint_blocks" | |
| else: | |
| raise ValueError( | |
| f"No double blocks found for {diffusion_model.__class__.__name__}" | |
| ) | |
| if hasattr(diffusion_model, "single_blocks"): | |
| single_blocks_name = "single_blocks" | |
| if is_non_native_ltxv: | |
| original_create_skip_layer_mask = getattr( | |
| diffusion_model, "create_skip_layer_mask", None) | |
| if original_create_skip_layer_mask is not None: | |
| # original_double_blocks = getattr(diffusion_model, | |
| # double_blocks_name) | |
| def new_create_skip_layer_mask(self, *args, **kwargs): | |
| # with unittest.mock.patch.object(self, double_blocks_name, | |
| # original_double_blocks): | |
| # return original_create_skip_layer_mask(*args, **kwargs) | |
| # return original_create_skip_layer_mask(*args, **kwargs) | |
| raise RuntimeError( | |
| "STG is not supported with FBCache yet") | |
| diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__( | |
| diffusion_model) | |
| cached_transformer_blocks = torch.nn.ModuleList([ | |
| first_block_cache.CachedTransformerBlocks( | |
| None if double_blocks_name is None else getattr( | |
| diffusion_model, double_blocks_name), | |
| None if single_blocks_name is None else getattr( | |
| diffusion_model, single_blocks_name), | |
| residual_diff_threshold=residual_diff_threshold, | |
| validate_can_use_cache_function=validate_use_cache, | |
| cat_hidden_states_first=diffusion_model.__class__.__name__ | |
| == "HunyuanVideo", | |
| return_hidden_states_only=diffusion_model.__class__. | |
| __name__ == "LTXVModel" or is_non_native_ltxv, | |
| clone_original_hidden_states=diffusion_model.__class__. | |
| __name__ == "LTXVModel", | |
| return_hidden_states_first=diffusion_model.__class__. | |
| __name__ != "OpenAISignatureMMDITWrapper", | |
| accept_hidden_states_first=diffusion_model.__class__. | |
| __name__ != "OpenAISignatureMMDITWrapper", | |
| ) | |
| ]) | |
| dummy_single_transformer_blocks = torch.nn.ModuleList() | |
| def model_unet_function_wrapper(model_function, kwargs): | |
| try: | |
| input = kwargs["input"] | |
| timestep = kwargs["timestep"] | |
| c = kwargs["c"] | |
| t = timestep[0].item() | |
| ensure_cache_state(input, t) | |
| with unittest.mock.patch.object( | |
| diffusion_model, | |
| double_blocks_name, | |
| cached_transformer_blocks, | |
| ), unittest.mock.patch.object( | |
| diffusion_model, | |
| single_blocks_name, | |
| dummy_single_transformer_blocks, | |
| ) if single_blocks_name is not None else contextlib.nullcontext( | |
| ): | |
| result = model_function(input, timestep, **c) | |
| update_cache_state(input, t) | |
| return result | |
| except Exception as exc: | |
| reset_cache_state() | |
| raise exc from None | |
| model.set_model_unet_function_wrapper(model_unet_function_wrapper) | |
| return (model, ) | |