Spaces:
Running
on
Zero
Running
on
Zero
| import abc | |
| import types | |
| import torch | |
| from diffusers.models.transformers.transformer_flux import ( | |
| FluxSingleTransformerBlock, FluxTransformerBlock) | |
| from .flux_transformer_forward import (joint_transformer_forward, | |
| single_transformer_forward) | |
| class FeatureCollector: | |
| def __init__(self, transformer, controller, layer_list=[]): | |
| self.transformer = transformer | |
| self.controller = controller | |
| self.layer_list = layer_list | |
| def register_transformer_control(self): | |
| index = 0 | |
| for joint_transformer in self.transformer.transformer_blocks: | |
| place_in_transformer = f'joint_{index}' | |
| joint_transformer.forward = joint_transformer_forward(joint_transformer, self.controller, place_in_transformer) | |
| index +=1 | |
| for i, single_transformer in enumerate(self.transformer.single_transformer_blocks): | |
| place_in_transformer = f'single_{index}' | |
| single_transformer.forward = single_transformer_forward(single_transformer, self.controller, place_in_transformer) | |
| index +=1 | |
| self.controller.num_layers = index | |
| def restore_orig_transformer(self): | |
| place_in_transformer='' | |
| for joint_transformer in self.transformer.transformer_blocks: | |
| joint_transformer.forward = joint_transformer_forward(joint_transformer, None, place_in_transformer) | |
| for i, single_transformer in enumerate(self.transformer.single_transformer_blocks): | |
| single_transformer.forward = single_transformer_forward(single_transformer, None, place_in_transformer) | |
| class FeatureControl(abc.ABC): | |
| def __init__(self): | |
| self.cur_step = 0 | |
| self.num_layers = -1 | |
| self.cur_layer = 0 | |
| def step_callback(self, x_t): | |
| return x_t | |
| def between_steps(self): | |
| return | |
| def forward(self, attn, place_in_transformer: str): | |
| raise NotImplementedError | |
| def __call__(self, hidden_state, place_in_transformer: str): | |
| hidden_state = self.forward(hidden_state, place_in_transformer) | |
| self.cur_layer = self.cur_layer + 1 | |
| if self.cur_layer == self.num_layers: | |
| self.cur_layer = 0 | |
| self.cur_step = self.cur_step + 1 | |
| self.between_steps() | |
| return hidden_state | |
| def reset(self): | |
| self.cur_step = 0 | |
| self.cur_layer = 0 | |
| class FeatureReplace(FeatureControl): | |
| def __init__( | |
| self, | |
| layer_list=[], | |
| feature_steps=7 | |
| ): | |
| super(FeatureReplace, self).__init__() | |
| self.layer_list = layer_list | |
| self.feature_steps = feature_steps | |
| def forward(self, hidden_states, place_in_transformer): | |
| layer_index = int(place_in_transformer.split('_')[-1]) | |
| if (layer_index not in self.layer_list) or (self.cur_step not in range(0, self.feature_steps)): | |
| return hidden_states | |
| hs_dim = hidden_states.shape[1] | |
| t5_dim = 512 | |
| latent_dim = 4096 | |
| attn_dim = t5_dim + latent_dim | |
| index_all = torch.arange(attn_dim) | |
| t5_index, latent_index = index_all.split([t5_dim, latent_dim]) | |
| if 'single' in place_in_transformer: | |
| mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype) | |
| mask[t5_index] = 0 # Only use image latent | |
| else: | |
| mask = torch.ones(hs_dim).to(device=hidden_states.device, dtype=hidden_states.dtype) | |
| mask = mask[None, :, None] | |
| source_hs = hidden_states[:1] | |
| target_hs = hidden_states[1:] | |
| target_hs = source_hs * mask + target_hs * (1 - mask) | |
| hidden_states[1:] = target_hs | |
| return hidden_states | |