Spaces:
Running
on
Zero
Running
on
Zero
| import abc | |
| import types | |
| import torch | |
| from diffusers.models.transformers.transformer_flux import ( | |
| FluxSingleTransformerBlock, FluxTransformerBlock) | |
| def joint_transformer_forward(self, controller, place_in_transformer): | |
| def forward( | |
| hidden_states: torch.FloatTensor, | |
| encoder_hidden_states: torch.FloatTensor, | |
| temb: torch.FloatTensor, | |
| image_rotary_emb=None, | |
| joint_attention_kwargs=None | |
| ): | |
| norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) | |
| norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( | |
| encoder_hidden_states, emb=temb | |
| ) | |
| # Attention. | |
| attn_output, context_attn_output = self.attn( | |
| hidden_states=norm_hidden_states, | |
| encoder_hidden_states=norm_encoder_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| # Process attention outputs for the `hidden_states`. | |
| attn_output = gate_msa.unsqueeze(1) * attn_output | |
| hidden_states = hidden_states + attn_output | |
| norm_hidden_states = self.norm2(hidden_states) | |
| norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] | |
| ff_output = self.ff(norm_hidden_states) | |
| ff_output = gate_mlp.unsqueeze(1) * ff_output | |
| if controller is not None: | |
| ff_output = controller(ff_output, place_in_transformer) | |
| hidden_states = hidden_states + ff_output | |
| # Process attention outputs for the `encoder_hidden_states`. | |
| context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output | |
| encoder_hidden_states = encoder_hidden_states + context_attn_output | |
| norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) | |
| norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] | |
| context_ff_output = self.ff_context(norm_encoder_hidden_states) | |
| context_ff_output = c_gate_mlp.unsqueeze(1) * context_ff_output | |
| encoder_hidden_states = encoder_hidden_states + context_ff_output | |
| if encoder_hidden_states.dtype == torch.float16: | |
| encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) | |
| return encoder_hidden_states, hidden_states | |
| return forward | |
| def single_transformer_forward(self, controller, place_in_transformer): | |
| def forward( | |
| hidden_states: torch.FloatTensor, | |
| temb: torch.FloatTensor, | |
| image_rotary_emb=None, | |
| joint_attention_kwargs=None | |
| ): | |
| residual = hidden_states | |
| norm_hidden_states, gate = self.norm(hidden_states, emb=temb) | |
| mlp_input = norm_hidden_states | |
| mlp_hidden_states = self.act_mlp(self.proj_mlp(mlp_input)) | |
| attn_output = self.attn( | |
| hidden_states=norm_hidden_states, | |
| image_rotary_emb=image_rotary_emb, | |
| ) | |
| hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) | |
| gate = gate.unsqueeze(1) | |
| hidden_states = gate * self.proj_out(hidden_states) | |
| # Change here | |
| if controller is not None: | |
| hidden_states = controller(hidden_states, place_in_transformer) | |
| hidden_states = residual + hidden_states | |
| if hidden_states.dtype == torch.float16: | |
| hidden_states = hidden_states.clip(-65504, 65504) | |
| return hidden_states | |
| return forward |