|
|
from comfy.ldm.modules import attention as comfy_attention |
|
|
import logging |
|
|
import comfy.model_patcher |
|
|
import comfy.utils |
|
|
import comfy.sd |
|
|
import torch |
|
|
import comfy.model_management as mm |
|
|
from comfy.cli_args import args |
|
|
|
|
|
sageattn_modes = [ |
|
|
"disabled", |
|
|
"auto", |
|
|
"auto_speed", |
|
|
"auto_quality", |
|
|
"sageattn_qk_int8_pv_fp16_cuda", |
|
|
"sageattn_qk_int8_pv_fp16_triton", |
|
|
"sageattn_qk_int8_pv_fp8_cuda", |
|
|
"sageattn_qk_int8_pv_fp8_cuda++", |
|
|
] |
|
|
|
|
|
_initialized = False |
|
|
|
|
|
_sage_warned_once = False |
|
|
_sage_generic_warned_once = False |
|
|
_original_functions = {} |
|
|
|
|
|
|
|
|
|
|
|
CURRENT_PV_ACCUM = None |
|
|
|
|
|
|
|
|
_attn_entropy_enabled = False |
|
|
_attn_entropy_last = None |
|
|
_attn_probe_heads_cap = 4 |
|
|
_attn_probe_tokens_cap = 1024 |
|
|
|
|
|
def enable_attention_entropy_capture(enable: bool, max_tokens: int = 1024, max_heads: int = 4): |
|
|
"""Toggle capturing a tiny attention entropy map during optimized_attention. |
|
|
Stores a normalized map per forward pass; consumer may upsample to latent size. |
|
|
""" |
|
|
global _attn_entropy_enabled, _attn_probe_tokens_cap, _attn_probe_heads_cap, _attn_entropy_last |
|
|
_attn_entropy_enabled = bool(enable) |
|
|
_attn_probe_tokens_cap = int(max(128, min(16384, max_tokens))) |
|
|
_attn_probe_heads_cap = int(max(1, min(32, max_heads))) |
|
|
if not _attn_entropy_enabled: |
|
|
_attn_entropy_last = None |
|
|
|
|
|
def get_attention_entropy_map(clear: bool = False): |
|
|
"""Return last captured attention entropy map (B,1,h',w') in [0,1] or None.""" |
|
|
global _attn_entropy_last |
|
|
out = _attn_entropy_last |
|
|
if clear: |
|
|
_attn_entropy_last = None |
|
|
return out |
|
|
|
|
|
|
|
|
_kv_prune_enabled = False |
|
|
_kv_prune_keep = 0.85 |
|
|
_kv_prune_min_tokens = 128 |
|
|
|
|
|
def set_kv_prune(enable: bool, keep: float = 0.85, min_tokens: int = 128): |
|
|
"""Enable lightweight K/V token pruning inside optimized attention. |
|
|
- Applies only to self-attention (len(Q)==len(K)). |
|
|
- Keeps top-`keep` fraction of keys/values by L2 energy of K, averaged over heads. |
|
|
- Skips pruning when an attention mask is provided (shape mismatch risk). |
|
|
""" |
|
|
global _kv_prune_enabled, _kv_prune_keep, _kv_prune_min_tokens |
|
|
_kv_prune_enabled = bool(enable) |
|
|
try: |
|
|
_kv_prune_keep = float(max(0.5, min(1.0, keep))) |
|
|
except Exception: |
|
|
_kv_prune_keep = 0.85 |
|
|
try: |
|
|
_kv_prune_min_tokens = int(max(1, min_tokens)) |
|
|
except Exception: |
|
|
_kv_prune_min_tokens = 128 |
|
|
|
|
|
if not _initialized: |
|
|
_original_functions["orig_attention"] = comfy_attention.optimized_attention |
|
|
_original_functions["original_patch_model"] = comfy.model_patcher.ModelPatcher.patch_model |
|
|
_original_functions["original_load_lora_for_models"] = comfy.sd.load_lora_for_models |
|
|
_initialized = True |
|
|
|
|
|
class MGSagpuBaseLoader: |
|
|
original_linear = None |
|
|
cublas_patched = False |
|
|
|
|
|
@torch.compiler.disable() |
|
|
def _patch_modules(self, patch_cublaslinear, sage_attention): |
|
|
from comfy.ops import disable_weight_init, CastWeightBiasOp, cast_bias_weight |
|
|
|
|
|
if sage_attention != "disabled": |
|
|
print("Patching comfy attention to use sageattn") |
|
|
try: |
|
|
from sageattention import sageattn |
|
|
from sageattention import ( |
|
|
sageattn_qk_int8_pv_fp16_cuda, |
|
|
sageattn_qk_int8_pv_fp16_triton, |
|
|
sageattn_qk_int8_pv_fp8_cuda, |
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90, |
|
|
) |
|
|
except ImportError: |
|
|
from SageAttention import sageattn |
|
|
from SageAttention import ( |
|
|
sageattn_qk_int8_pv_fp16_cuda, |
|
|
sageattn_qk_int8_pv_fp16_triton, |
|
|
sageattn_qk_int8_pv_fp8_cuda, |
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90, |
|
|
) |
|
|
def set_sage_func(sage_attention): |
|
|
|
|
|
def select_auto(quality: bool): |
|
|
def _auto(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): |
|
|
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) if torch.cuda.is_available() else (0, 0) |
|
|
try: |
|
|
if major == 12 and minor == 0: |
|
|
|
|
|
pv = "fp32+fp32" if quality else "fp32+fp16" |
|
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype=pv, tensor_layout=tensor_layout) |
|
|
elif major == 9: |
|
|
|
|
|
pv = "fp32+fp32" if quality else "fp32+fp32" |
|
|
return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype=pv, tensor_layout=tensor_layout) |
|
|
elif major == 8 and minor == 9: |
|
|
pv = "fp32+fp32" if quality else "fp32+fp16" |
|
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype=pv, tensor_layout=tensor_layout) |
|
|
elif major == 8 and minor in (0, 6): |
|
|
|
|
|
|
|
|
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) |
|
|
return _auto |
|
|
if sage_attention == "auto": |
|
|
return select_auto(quality=False) |
|
|
if sage_attention == "auto_speed": |
|
|
return select_auto(quality=False) |
|
|
if sage_attention == "auto_quality": |
|
|
return select_auto(quality=True) |
|
|
elif sage_attention == "sageattn_qk_int8_pv_fp16_cuda": |
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): |
|
|
return sageattn_qk_int8_pv_fp16_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32", tensor_layout=tensor_layout) |
|
|
return func |
|
|
elif sage_attention == "sageattn_qk_int8_pv_fp16_triton": |
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): |
|
|
return sageattn_qk_int8_pv_fp16_triton(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) |
|
|
return func |
|
|
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda": |
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): |
|
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp32", tensor_layout=tensor_layout) |
|
|
return func |
|
|
elif sage_attention == "sageattn_qk_int8_pv_fp8_cuda++": |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if torch.cuda.is_available(): |
|
|
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device()) |
|
|
if not (major == 8 and minor == 9): |
|
|
logging.warning(f"sageattn_qk_int8_pv_fp8_cuda++ requires SM89, but detected SM{major}{minor}. Falling back to auto kernel selection.") |
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): |
|
|
return sageattn(q, k, v, is_causal=is_causal, attn_mask=attn_mask, tensor_layout=tensor_layout) |
|
|
return func |
|
|
except Exception: |
|
|
pass |
|
|
def func(q, k, v, is_causal=False, attn_mask=None, tensor_layout="NHD"): |
|
|
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, is_causal=is_causal, attn_mask=attn_mask, pv_accum_dtype="fp32+fp16", tensor_layout=tensor_layout) |
|
|
return func |
|
|
|
|
|
sage_func = set_sage_func(sage_attention) |
|
|
|
|
|
@torch.compiler.disable() |
|
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, transformer_options=None, **kwargs): |
|
|
if skip_reshape: |
|
|
b, _, _, dim_head = q.shape |
|
|
tensor_layout="HND" |
|
|
else: |
|
|
b, _, dim_head = q.shape |
|
|
dim_head //= heads |
|
|
q, k, v = map( |
|
|
lambda t: t.view(b, -1, heads, dim_head), |
|
|
(q, k, v), |
|
|
) |
|
|
tensor_layout="NHD" |
|
|
if mask is not None: |
|
|
|
|
|
if mask.ndim == 2: |
|
|
mask = mask.unsqueeze(0) |
|
|
|
|
|
if mask.ndim == 3: |
|
|
mask = mask.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
if _kv_prune_enabled and (mask is None): |
|
|
import math |
|
|
if tensor_layout == "NHD": |
|
|
|
|
|
Bn, Nq, Hn, Dh = q.shape |
|
|
Nk = k.shape[1] |
|
|
if Nq == Nk and Nk >= _kv_prune_min_tokens: |
|
|
keep = max(1, int(math.ceil(float(_kv_prune_keep) * Nk))) |
|
|
if keep < Nk: |
|
|
|
|
|
imp = (k.pow(2).sum(dim=-1)).mean(dim=2) |
|
|
top = torch.topk(imp, k=keep, dim=1, largest=True, sorted=False).indices |
|
|
idx = top.unsqueeze(-1).unsqueeze(-1).expand(Bn, keep, Hn, Dh) |
|
|
k = torch.gather(k, dim=1, index=idx) |
|
|
v = torch.gather(v, dim=1, index=idx) |
|
|
else: |
|
|
|
|
|
Bb, Hn, Nq, Dh = q.shape |
|
|
Nk = k.shape[2] |
|
|
if Nq == Nk and Nk >= _kv_prune_min_tokens: |
|
|
keep = max(1, int(math.ceil(float(_kv_prune_keep) * Nk))) |
|
|
if keep < Nk: |
|
|
imp = (k.pow(2).sum(dim=-1)).mean(dim=1) |
|
|
top = torch.topk(imp, k=keep, dim=1, largest=True, sorted=False).indices |
|
|
idx = top.unsqueeze(1).unsqueeze(-1).expand(Bb, Hn, keep, Dh) |
|
|
k = torch.gather(k, dim=2, index=idx) |
|
|
v = torch.gather(v, dim=2, index=idx) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
try: |
|
|
pv_override = None |
|
|
if transformer_options and isinstance(transformer_options, dict): |
|
|
so = transformer_options.get("sageattn") |
|
|
if isinstance(so, dict): |
|
|
pv_override = so.get("pv_accum_dtype", None) |
|
|
if pv_override is None: |
|
|
pv_override = CURRENT_PV_ACCUM |
|
|
|
|
|
if pv_override is not None: |
|
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout, pv_accum_dtype=pv_override) |
|
|
else: |
|
|
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) |
|
|
except Exception as e: |
|
|
global _sage_generic_warned_once |
|
|
if not _sage_generic_warned_once: |
|
|
logging.warning(f"Error running sage attention: {e}. Falling back.") |
|
|
_sage_generic_warned_once = True |
|
|
try: |
|
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) |
|
|
except Exception: |
|
|
|
|
|
if tensor_layout == "NHD": |
|
|
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v)) |
|
|
return comfy_attention.attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape, transformer_options=transformer_options, **kwargs) |
|
|
|
|
|
try: |
|
|
if _attn_entropy_enabled: |
|
|
import torch |
|
|
with torch.inference_mode(): |
|
|
if tensor_layout == "HND": |
|
|
|
|
|
q_probe = q.transpose(1, 2) |
|
|
k_probe = k.transpose(1, 2) |
|
|
else: |
|
|
q_probe = q |
|
|
k_probe = k |
|
|
B_, N_, H_, Dh = q_probe.shape |
|
|
|
|
|
h_cap = min(H_, _attn_probe_heads_cap) |
|
|
step = max(1, N_ // _attn_probe_tokens_cap) |
|
|
q_s = q_probe[:, ::step, :h_cap, :].transpose(1, 2) |
|
|
k_s = k_probe[:, ::step, :h_cap, :].transpose(1, 2) |
|
|
scale = (float(Dh) ** -0.5) |
|
|
|
|
|
logits = torch.matmul(q_s * scale, k_s.transpose(-1, -2)) |
|
|
p = torch.softmax(logits, dim=-1) |
|
|
|
|
|
eps = 1e-9 |
|
|
Hq = -(p * (p.clamp_min(eps).log())).sum(dim=-1) |
|
|
Hq = Hq.mean(dim=1) |
|
|
|
|
|
import math |
|
|
Q = Hq.shape[-1] |
|
|
w = int(math.sqrt(Q)) |
|
|
w = max(1, w) |
|
|
h = max(1, Q // w) |
|
|
if h * w > Q: |
|
|
Hq = Hq[..., : (h * w)] |
|
|
elif h * w < Q: |
|
|
|
|
|
pad = (h * w) - Q |
|
|
if pad > 0: |
|
|
Hq = torch.cat([Hq, Hq[..., -1:].expand(B_, pad)], dim=-1) |
|
|
Hmap = Hq.reshape(B_, 1, h, w) |
|
|
|
|
|
Hmin = Hmap.amin(dim=(2, 3), keepdim=True) |
|
|
Hmax = Hmap.amax(dim=(2, 3), keepdim=True) |
|
|
Hn = (Hmap - Hmin) / (Hmax - Hmin + 1e-6) |
|
|
global _attn_entropy_last |
|
|
_attn_entropy_last = Hn.detach() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if tensor_layout == "HND": |
|
|
if not skip_output_reshape: |
|
|
out = ( |
|
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head) |
|
|
) |
|
|
else: |
|
|
if skip_output_reshape: |
|
|
out = out.transpose(1, 2) |
|
|
else: |
|
|
out = out.reshape(b, -1, heads * dim_head) |
|
|
return out |
|
|
|
|
|
comfy_attention.optimized_attention = attention_sage |
|
|
comfy.ldm.hunyuan_video.model.optimized_attention = attention_sage |
|
|
comfy.ldm.flux.math.optimized_attention = attention_sage |
|
|
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = attention_sage |
|
|
comfy.ldm.cosmos.blocks.optimized_attention = attention_sage |
|
|
comfy.ldm.wan.model.optimized_attention = attention_sage |
|
|
|
|
|
else: |
|
|
print("Restoring initial comfy attention") |
|
|
comfy_attention.optimized_attention = _original_functions.get("orig_attention") |
|
|
comfy.ldm.hunyuan_video.model.optimized_attention = _original_functions.get("orig_attention") |
|
|
comfy.ldm.flux.math.optimized_attention = _original_functions.get("orig_attention") |
|
|
comfy.ldm.genmo.joint_model.asymm_models_joint.optimized_attention = _original_functions.get("orig_attention") |
|
|
comfy.ldm.cosmos.blocks.optimized_attention = _original_functions.get("orig_attention") |
|
|
comfy.ldm.wan.model.optimized_attention = _original_functions.get("orig_attention") |
|
|
|
|
|
if patch_cublaslinear: |
|
|
if not MGSagpuBaseLoader.cublas_patched: |
|
|
MGSagpuBaseLoader.original_linear = disable_weight_init.Linear |
|
|
try: |
|
|
from cublas_ops import CublasLinear |
|
|
except ImportError: |
|
|
raise Exception("Can't import 'torch-cublas-hgemm', install it from here https://github.com/aredden/torch-cublas-hgemm") |
|
|
|
|
|
class PatchedLinear(CublasLinear, CastWeightBiasOp): |
|
|
def reset_parameters(self): |
|
|
pass |
|
|
|
|
|
def forward_comfy_cast_weights(self, input): |
|
|
weight, bias = cast_bias_weight(self, input) |
|
|
return torch.nn.functional.linear(input, weight, bias) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
if self.comfy_cast_weights: |
|
|
return self.forward_comfy_cast_weights(*args, **kwargs) |
|
|
else: |
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
disable_weight_init.Linear = PatchedLinear |
|
|
MGSagpuBaseLoader.cublas_patched = True |
|
|
else: |
|
|
if MGSagpuBaseLoader.cublas_patched: |
|
|
disable_weight_init.Linear = MGSagpuBaseLoader.original_linear |
|
|
MGSagpuBaseLoader.cublas_patched = False |
|
|
|
|
|
from comfy.patcher_extension import CallbacksMP |
|
|
class MGSagpuAttention(MGSagpuBaseLoader): |
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return {"required": { |
|
|
"model": ("MODEL",), |
|
|
"sage_attention": (sageattn_modes, {"default": False, "tooltip": "Global patch comfy attention to use sageattn, once patched to revert back to normal you would need to run this node again with disabled option."}), |
|
|
}} |
|
|
|
|
|
RETURN_TYPES = ("MODEL", ) |
|
|
FUNCTION = "patch" |
|
|
DESCRIPTION = "Node for patching attention mode. This doesn't use the model patching system and thus can't be disabled without running the node again with 'disabled' option." |
|
|
EXPERIMENTAL = False |
|
|
CATEGORY = "MagicNodes" |
|
|
|
|
|
def patch(self, model, sage_attention): |
|
|
model_clone = model.clone() |
|
|
@torch.compiler.disable() |
|
|
def patch_attention_enable(model): |
|
|
self._patch_modules(False, sage_attention) |
|
|
@torch.compiler.disable() |
|
|
def patch_attention_disable(model): |
|
|
self._patch_modules(False, "disabled") |
|
|
|
|
|
model_clone.add_callback(CallbacksMP.ON_PRE_RUN, patch_attention_enable) |
|
|
model_clone.add_callback(CallbacksMP.ON_CLEANUP, patch_attention_disable) |
|
|
|
|
|
return model_clone, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect as _inspect |
|
|
try: |
|
|
from comfy.ldm.modules import attention as _cm_attn |
|
|
except Exception as _e: |
|
|
_cm_attn = None |
|
|
|
|
|
_nag_patch_active = False |
|
|
_nag_params = {"scale": 5.0, "tau": 2.5, "alpha": 0.25} |
|
|
_original_functions.setdefault("orig_crossattn_forward", None) |
|
|
_original_functions.setdefault("orig_crossattn_sig", None) |
|
|
|
|
|
def _call_orig_crossattn(self, x, context=None, **kwargs): |
|
|
|
|
|
f = _original_functions.get("orig_crossattn_forward", None) |
|
|
if f is None: |
|
|
|
|
|
return self.__class__.forward(self, x, context=context, **kwargs) |
|
|
sig = _original_functions.get("orig_crossattn_sig", None) |
|
|
if sig is None: |
|
|
try: |
|
|
sig = _inspect.signature(f) |
|
|
_original_functions["orig_crossattn_sig"] = sig |
|
|
except Exception: |
|
|
sig = None |
|
|
if sig is not None: |
|
|
allowed = set(sig.parameters.keys()) |
|
|
fkwargs = {k: v for k, v in kwargs.items() if k in allowed} |
|
|
else: |
|
|
fkwargs = kwargs |
|
|
try: |
|
|
return f(self, x, context=context, **fkwargs) |
|
|
except TypeError: |
|
|
|
|
|
fkwargs.pop("attn_precision", None) |
|
|
fkwargs.pop("transformer_options", None) |
|
|
try: |
|
|
return f(self, x, context=context, **fkwargs) |
|
|
except Exception: |
|
|
|
|
|
return self.__class__.forward(self, x, context=context, **kwargs) |
|
|
|
|
|
def _kj_crossattn_forward_nag(self, x, context=None, value=None, mask=None, **kwargs): |
|
|
|
|
|
if (not _nag_patch_active) or (_cm_attn is None): |
|
|
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs) |
|
|
try: |
|
|
if context is None or not torch.is_tensor(context): |
|
|
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs) |
|
|
|
|
|
|
|
|
if context.shape[0] < 2: |
|
|
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
x_has_pair = (torch.is_tensor(x) and x.shape[0] == 2) |
|
|
x_u = x[0:1] if x_has_pair else x |
|
|
x_c = x[1:2] if x_has_pair else x |
|
|
|
|
|
c_u, c_c = context[0:1], context[1:2] |
|
|
|
|
|
|
|
|
v = kwargs.get("value", value) |
|
|
if torch.is_tensor(v) and v.shape[0] == 2: |
|
|
v_u, v_c = v[0:1], v[1:2] |
|
|
else: |
|
|
v_u = v_c = v |
|
|
|
|
|
|
|
|
|
|
|
out_u = _call_orig_crossattn(self, x_u, context=c_u, value=v_u, mask=mask, **kwargs) |
|
|
|
|
|
z_pos = _call_orig_crossattn(self, x_c, context=c_c, value=v_c, mask=mask, **kwargs) |
|
|
|
|
|
z_neg = _call_orig_crossattn(self, x_c, context=c_u, value=v_u, mask=mask, **kwargs) |
|
|
|
|
|
|
|
|
phi = float(_nag_params.get("scale", 5.0)) |
|
|
tau = float(_nag_params.get("tau", 2.5)) |
|
|
alpha = float(_nag_params.get("alpha", 0.25)) |
|
|
|
|
|
g = z_pos * phi - z_neg * (phi - 1.0) |
|
|
|
|
|
def _l1_norm(t): |
|
|
return torch.sum(torch.abs(t), dim=-1, keepdim=True).clamp_min(1e-6) |
|
|
s_pos = _l1_norm(z_pos) |
|
|
s_g = _l1_norm(g) |
|
|
scale = (s_pos * tau) / s_g |
|
|
g = torch.where((s_g > s_pos * tau), g * scale, g) |
|
|
|
|
|
z_guided = g * alpha + z_pos * (1.0 - alpha) |
|
|
if x_has_pair: |
|
|
return torch.cat([out_u, z_guided], dim=0) |
|
|
else: |
|
|
return z_guided |
|
|
except Exception as e: |
|
|
|
|
|
return _call_orig_crossattn(self, x, context=context, value=value, mask=mask, **kwargs) |
|
|
|
|
|
def enable_crossattention_nag_patch(enable: bool, nag_scale: float = 5.0, nag_tau: float = 2.5, nag_alpha: float = 0.25): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global _nag_patch_active, _nag_params |
|
|
if _cm_attn is None: |
|
|
return False |
|
|
if enable: |
|
|
_nag_params = {"scale": float(nag_scale), "tau": float(nag_tau), "alpha": float(nag_alpha)} |
|
|
if _original_functions.get("orig_crossattn_forward", None) is None: |
|
|
try: |
|
|
_original_functions["orig_crossattn_forward"] = _cm_attn.CrossAttention.forward |
|
|
try: |
|
|
_original_functions["orig_crossattn_sig"] = _inspect.signature(_cm_attn.CrossAttention.forward) |
|
|
except Exception: |
|
|
_original_functions["orig_crossattn_sig"] = None |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
try: |
|
|
_cm_attn.CrossAttention.forward = _kj_crossattn_forward_nag |
|
|
_nag_patch_active = True |
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
else: |
|
|
|
|
|
if _original_functions.get("orig_crossattn_forward", None) is not None: |
|
|
try: |
|
|
_cm_attn.CrossAttention.forward = _original_functions["orig_crossattn_forward"] |
|
|
except Exception: |
|
|
pass |
|
|
_nag_patch_active = False |
|
|
return True |
|
|
|
|
|
|
|
|
PatchSageAttention = MGSagpuAttention |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|