Spaces:
Paused
Paused
| from typing import Callable | |
| import math | |
| import torch | |
| from torch import Tensor | |
| from torch.nn.functional import group_norm | |
| from einops import rearrange | |
| import comfy.ldm.modules.attention as attention | |
| from comfy.ldm.modules.diffusionmodules import openaimodel | |
| import comfy.model_management as model_management | |
| import comfy.samplers | |
| import comfy.sample | |
| import comfy.utils | |
| from comfy.controlnet import ControlBase | |
| import comfy.ops | |
| from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows | |
| from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, prepare_mask_ad | |
| from .utils_model import ModelTypeSD, wrap_function_to_inject_xformers_bug_info | |
| from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher | |
| from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule | |
| from .logger import logger | |
| ################################################################################## | |
| ###################################################################### | |
| # Global variable to use to more conveniently hack variable access into samplers | |
| class AnimateDiffHelper_GlobalState: | |
| def __init__(self): | |
| self.motion_models: MotionModelGroup = None | |
| self.params: InjectionParams = None | |
| self.sample_settings: SampleSettings = None | |
| self.reset() | |
| def initialize(self, model): | |
| # this function is to be run in sampling func | |
| if not self.initialized: | |
| self.initialized = True | |
| if self.motion_models is not None: | |
| self.motion_models.initialize_timesteps(model) | |
| if self.params.context_options is not None: | |
| self.params.context_options.initialize_timesteps(model) | |
| if self.sample_settings.custom_cfg is not None: | |
| self.sample_settings.custom_cfg.initialize_timesteps(model) | |
| def reset(self): | |
| self.initialized = False | |
| self.start_step: int = 0 | |
| self.last_step: int = 0 | |
| self.current_step: int = 0 | |
| self.total_steps: int = 0 | |
| if self.motion_models is not None: | |
| del self.motion_models | |
| self.motion_models = None | |
| if self.params is not None: | |
| del self.params | |
| self.params = None | |
| if self.sample_settings is not None: | |
| del self.sample_settings | |
| self.sample_settings = None | |
| def update_with_inject_params(self, params: InjectionParams): | |
| self.params = params | |
| def is_using_sliding_context(self): | |
| return self.params is not None and self.params.is_using_sliding_context() | |
| def create_exposed_params(self): | |
| # This dict will be exposed to be used by other extensions | |
| # DO NOT change any of the key names | |
| # or I will find you π.π | |
| return { | |
| "full_length": self.params.full_length, | |
| "context_length": self.params.context_options.context_length, | |
| "sub_idxs": self.params.sub_idxs, | |
| } | |
| ADGS = AnimateDiffHelper_GlobalState() | |
| ###################################################################### | |
| ################################################################################## | |
| ################################################################################## | |
| #### Code Injection ################################################## | |
| # refer to forward_timestep_embed in comfy/ldm/modules/diffusionmodules/openaimodel.py | |
| def forward_timestep_embed_factory() -> Callable: | |
| def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): | |
| for layer in ts: | |
| if isinstance(layer, openaimodel.VideoResBlock): | |
| x = layer(x, emb, num_video_frames, image_only_indicator) | |
| elif isinstance(layer, openaimodel.TimestepBlock): | |
| x = layer(x, emb) | |
| elif isinstance(layer, VanillaTemporalModule): | |
| x = layer(x, context) | |
| elif isinstance(layer, attention.SpatialVideoTransformer): | |
| x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) | |
| if "transformer_index" in transformer_options: | |
| transformer_options["transformer_index"] += 1 | |
| if "current_index" in transformer_options: # keep this for backward compat, for now | |
| transformer_options["current_index"] += 1 | |
| elif isinstance(layer, attention.SpatialTransformer): | |
| x = layer(x, context, transformer_options) | |
| if "transformer_index" in transformer_options: | |
| transformer_options["transformer_index"] += 1 | |
| if "current_index" in transformer_options: # keep this for backward compat, for now | |
| transformer_options["current_index"] += 1 | |
| elif isinstance(layer, openaimodel.Upsample): | |
| x = layer(x, output_shape=output_shape) | |
| else: | |
| x = layer(x) | |
| return x | |
| return forward_timestep_embed | |
| def unlimited_memory_required(*args, **kwargs): | |
| return 0 | |
| def groupnorm_mm_factory(params: InjectionParams, manual_cast=False): | |
| def groupnorm_mm_forward(self, input: Tensor) -> Tensor: | |
| # axes_factor normalizes batch based on total conds and unconds passed in batch; | |
| # the conds and unconds per batch can change based on VRAM optimizations that may kick in | |
| if not params.is_using_sliding_context(): | |
| batched_conds = input.size(0)//params.full_length | |
| else: | |
| batched_conds = input.size(0)//params.context_options.context_length | |
| input = rearrange(input, "(b f) c h w -> b c f h w", b=batched_conds) | |
| if manual_cast: | |
| weight, bias = comfy.ops.cast_bias_weight(self, input) | |
| else: | |
| weight, bias = self.weight, self.bias | |
| input = group_norm(input, self.num_groups, weight, bias, self.eps) | |
| input = rearrange(input, "b c f h w -> (b f) c h w", b=batched_conds) | |
| return input | |
| return groupnorm_mm_forward | |
| def get_additional_models_factory(orig_get_additional_models: Callable, motion_models: MotionModelGroup): | |
| def get_additional_models_with_motion(*args, **kwargs): | |
| models, inference_memory = orig_get_additional_models(*args, **kwargs) | |
| if motion_models is not None: | |
| for motion_model in motion_models.models: | |
| models.append(motion_model) | |
| # TODO: account for inference memory as well? | |
| return models, inference_memory | |
| return get_additional_models_with_motion | |
| ###################################################################### | |
| ################################################################################## | |
| def apply_params_to_motion_models(motion_models: MotionModelGroup, params: InjectionParams): | |
| params = params.clone() | |
| for context in params.context_options.contexts: | |
| if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT: | |
| context.context_length = params.full_length | |
| # TODO: check (and message) should be different based on use_on_equal_length setting | |
| if params.context_options.context_length: | |
| pass | |
| allow_equal = params.context_options.use_on_equal_length | |
| if params.context_options.context_length: | |
| enough_latents = params.full_length >= params.context_options.context_length if allow_equal else params.full_length > params.context_options.context_length | |
| else: | |
| enough_latents = False | |
| if params.context_options.context_length and enough_latents: | |
| logger.info(f"Sliding context window activated - latents passed in ({params.full_length}) greater than context_length {params.context_options.context_length}.") | |
| else: | |
| logger.info(f"Regular AnimateDiff activated - latents passed in ({params.full_length}) less or equal to context_length {params.context_options.context_length}.") | |
| params.reset_context() | |
| if motion_models is not None: | |
| # if no context_length, treat video length as intended AD frame window | |
| if not params.context_options.context_length: | |
| for motion_model in motion_models.models: | |
| if not motion_model.model.is_length_valid_for_encoding_max_len(params.full_length): | |
| raise ValueError(f"Without a context window, AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames, but received {params.full_length} latents.") | |
| motion_models.set_video_length(params.full_length, params.full_length) | |
| # otherwise, treat context_length as intended AD frame window | |
| else: | |
| for motion_model in motion_models.models: | |
| view_options = params.context_options.view_options | |
| context_length = view_options.context_length if view_options else params.context_options.context_length | |
| if not motion_model.model.is_length_valid_for_encoding_max_len(context_length): | |
| raise ValueError(f"AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames for a context window, but received context length of {params.context_options.context_length}.") | |
| motion_models.set_video_length(params.context_options.context_length, params.full_length) | |
| # inject model | |
| module_str = "modules" if len(motion_models.models) > 1 else "module" | |
| logger.info(f"Using motion {module_str} {motion_models.get_name_string(show_version=True)}.") | |
| return params | |
| class FunctionInjectionHolder: | |
| def __init__(self): | |
| pass | |
| def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionParams): | |
| # Save Original Functions | |
| self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule | |
| self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds | |
| self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames | |
| self.orig_groupnorm_manual_cast_forward = comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights | |
| self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers | |
| self.orig_prepare_mask = comfy.sample.prepare_mask | |
| self.orig_get_additional_models = comfy.sample.get_additional_models | |
| # Inject Functions | |
| openaimodel.forward_timestep_embed = forward_timestep_embed_factory() | |
| if params.unlimited_area_hack: | |
| model.model.memory_required = unlimited_memory_required | |
| if model.motion_models is not None: | |
| # only apply groupnorm hack if not [v3 or ([not Hotshot] and SD1.5 and v2 and apply_v2_properly)] | |
| info: AnimateDiffInfo = model.motion_models[0].model.mm_info | |
| if not (info.mm_version == AnimateDiffVersion.V3 or | |
| (info.mm_format not in [AnimateDiffFormat.HOTSHOTXL] and info.sd_type == ModelTypeSD.SD1_5 and info.mm_version == AnimateDiffVersion.V2 and params.apply_v2_properly)): | |
| torch.nn.GroupNorm.forward = groupnorm_mm_factory(params) | |
| comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) | |
| # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack | |
| try: | |
| if model.load_device.type == "mps": | |
| model.model.memory_required = unlimited_memory_required | |
| except Exception: | |
| pass | |
| del info | |
| comfy.samplers.sampling_function = evolved_sampling_function | |
| comfy.sample.prepare_mask = prepare_mask_ad | |
| comfy.sample.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) | |
| def restore_functions(self, model: ModelPatcherAndInjector): | |
| # Restoration | |
| try: | |
| model.model.memory_required = self.orig_memory_required | |
| openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed | |
| torch.nn.GroupNorm.forward = self.orig_groupnorm_forward | |
| comfy.ops.manual_cast.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_manual_cast_forward | |
| comfy.samplers.sampling_function = self.orig_sampling_function | |
| comfy.sample.prepare_mask = self.orig_prepare_mask | |
| comfy.sample.get_additional_models = self.orig_get_additional_models | |
| except AttributeError: | |
| logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ | |
| "to save original functions before injection, and a more specific error was thrown by ComfyUI.") | |
| def motion_sample_factory(orig_comfy_sample: Callable, is_custom: bool=False) -> Callable: | |
| def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs): | |
| # check if model is intended for injecting | |
| if type(model) != ModelPatcherAndInjector: | |
| return orig_comfy_sample(model, noise, *args, **kwargs) | |
| # otherwise, injection time | |
| latents = None | |
| cached_latents = None | |
| cached_noise = None | |
| function_injections = FunctionInjectionHolder() | |
| try: | |
| if model.sample_settings.custom_cfg is not None: | |
| model = model.sample_settings.custom_cfg.patch_model(model) | |
| # clone params from model | |
| params = model.motion_injection_params.clone() | |
| # get amount of latents passed in, and store in params | |
| latents: Tensor = args[-1] | |
| params.full_length = latents.size(0) | |
| # reset global state | |
| ADGS.reset() | |
| # apply custom noise, if needed | |
| disable_noise = kwargs.get("disable_noise") or False | |
| seed = kwargs["seed"] | |
| # apply params to motion model | |
| params = apply_params_to_motion_models(model.motion_models, params) | |
| # store and inject functions | |
| function_injections.inject_functions(model, params) | |
| # prepare noise_extra_args for noise generation purposes | |
| noise_extra_args = {"disable_noise": disable_noise} | |
| params.set_noise_extra_args(noise_extra_args) | |
| # if noise is not disabled, do noise stuff | |
| if not disable_noise: | |
| noise = model.sample_settings.prepare_noise(seed, latents, noise, extra_args=noise_extra_args, force_create_noise=False) | |
| # callback setup | |
| original_callback = kwargs.get("callback", None) | |
| def ad_callback(step, x0, x, total_steps): | |
| if original_callback is not None: | |
| original_callback(step, x0, x, total_steps) | |
| # update GLOBALSTATE for next iteration | |
| ADGS.current_step = ADGS.start_step + step + 1 | |
| kwargs["callback"] = ad_callback | |
| ADGS.motion_models = model.motion_models | |
| ADGS.sample_settings = model.sample_settings | |
| # apply adapt_denoise_steps | |
| args = list(args) | |
| if model.sample_settings.adapt_denoise_steps and not is_custom: | |
| # only applicable when denoise and steps are provided (from simple KSampler nodes) | |
| denoise = kwargs.get("denoise", None) | |
| steps = args[0] | |
| if denoise is not None and type(steps) == int: | |
| args[0] = max(int(denoise * steps), 1) | |
| iter_opts = IterationOptions() | |
| if model.sample_settings is not None: | |
| iter_opts = model.sample_settings.iteration_opts | |
| iter_opts.initialize(latents) | |
| # cache initial noise and latents, if needed | |
| if iter_opts.cache_init_latents: | |
| cached_latents = latents.clone() | |
| if iter_opts.cache_init_noise: | |
| cached_noise = noise.clone() | |
| # prepare iter opts preprocess kwargs, if needed | |
| iter_kwargs = {} | |
| if iter_opts.need_sampler: | |
| # -5 for sampler_name (not custom) and sampler (custom) | |
| model_management.load_model_gpu(model) | |
| if is_custom: | |
| iter_kwargs[IterationOptions.SAMPLER] = None #args[-5] | |
| else: | |
| iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler( | |
| model.model, steps=999, #steps=args[-7], | |
| device=model.current_device, sampler=args[-5], | |
| scheduler=args[-4], denoise=kwargs.get("denoise", None), | |
| model_options=model.model_options) | |
| for curr_i in range(iter_opts.iterations): | |
| # handle GLOBALSTATE vars and step tally | |
| ADGS.update_with_inject_params(params) | |
| ADGS.start_step = kwargs.get("start_step") or 0 | |
| ADGS.current_step = ADGS.start_step | |
| ADGS.last_step = kwargs.get("last_step") or 0 | |
| if iter_opts.iterations > 1: | |
| logger.info(f"Iteration {curr_i+1}/{iter_opts.iterations}") | |
| # perform any iter_opts preprocessing on latents | |
| latents, noise = iter_opts.preprocess_latents(curr_i=curr_i, model=model, latents=latents, noise=noise, | |
| cached_latents=cached_latents, cached_noise=cached_noise, | |
| seed=seed, | |
| sample_settings=model.sample_settings, noise_extra_args=noise_extra_args, | |
| **iter_kwargs) | |
| args[-1] = latents | |
| if model.motion_models is not None: | |
| model.motion_models.pre_run(model) | |
| if model.sample_settings is not None: | |
| model.sample_settings.pre_run(model) | |
| latents = wrap_function_to_inject_xformers_bug_info(orig_comfy_sample)(model, noise, *args, **kwargs) | |
| return latents | |
| finally: | |
| del latents | |
| del noise | |
| del cached_latents | |
| del cached_noise | |
| # reset global state | |
| ADGS.reset() | |
| # restore injected functions | |
| function_injections.restore_functions(model) | |
| del function_injections | |
| return motion_sample | |
| def evolved_sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options: dict={}, seed=None): | |
| ADGS.initialize(model) | |
| if ADGS.motion_models is not None: | |
| ADGS.motion_models.prepare_current_keyframe(t=timestep) | |
| if ADGS.params.context_options is not None: | |
| ADGS.params.context_options.prepare_current_context(t=timestep) | |
| if ADGS.sample_settings.custom_cfg is not None: | |
| ADGS.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep) | |
| # never use cfg1 optimization if using custom_cfg (since can have timesteps and such) | |
| if ADGS.sample_settings.custom_cfg is None and math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: | |
| uncond_ = None | |
| else: | |
| uncond_ = uncond | |
| # add AD/evolved-sampling params to model_options (transformer_options) | |
| model_options = model_options.copy() | |
| if "tranformer_options" not in model_options: | |
| model_options["tranformer_options"] = {} | |
| model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() | |
| if not ADGS.is_using_sliding_context(): | |
| cond_pred, uncond_pred = comfy.samplers.calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) | |
| else: | |
| cond_pred, uncond_pred = sliding_calc_cond_uncond_batch(model, cond, uncond_, x, timestep, model_options) | |
| if "sampler_cfg_function" in model_options: | |
| args = {"cond": x - cond_pred, "uncond": x - uncond_pred, "cond_scale": cond_scale, "timestep": timestep, "input": x, "sigma": timestep, | |
| "cond_denoised": cond_pred, "uncond_denoised": uncond_pred, "model": model, "model_options": model_options} | |
| cfg_result = x - model_options["sampler_cfg_function"](args) | |
| else: | |
| cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale | |
| for fn in model_options.get("sampler_post_cfg_function", []): | |
| args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred, | |
| "sigma": timestep, "model_options": model_options, "input": x} | |
| cfg_result = fn(args) | |
| return cfg_result | |
| # sliding_calc_cond_uncond_batch inspired by ashen's initial hack for 16-frame sliding context: | |
| # https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master | |
| def sliding_calc_cond_uncond_batch(model, cond, uncond, x_in: Tensor, timestep, model_options): | |
| def prepare_control_objects(control: ControlBase, full_idxs: list[int]): | |
| if control.previous_controlnet is not None: | |
| prepare_control_objects(control.previous_controlnet, full_idxs) | |
| control.sub_idxs = full_idxs | |
| control.full_latent_length = ADGS.params.full_length | |
| control.context_length = ADGS.params.context_options.context_length | |
| def get_resized_cond(cond_in, full_idxs) -> list: | |
| # reuse or resize cond items to match context requirements | |
| resized_cond = [] | |
| # cond object is a list containing a dict - outer list is irrelevant, so just loop through it | |
| for actual_cond in cond_in: | |
| resized_actual_cond = actual_cond.copy() | |
| # now we are in the inner dict - "pooled_output" is a tensor, "control" is a ControlBase object, "model_conds" is dictionary | |
| for key in actual_cond: | |
| try: | |
| cond_item = actual_cond[key] | |
| if isinstance(cond_item, Tensor): | |
| # check that tensor is the expected length - x.size(0) | |
| if cond_item.size(0) == x_in.size(0): | |
| # if so, it's subsetting time - tell controls the expected indeces so they can handle them | |
| actual_cond_item = cond_item[full_idxs] | |
| resized_actual_cond[key] = actual_cond_item | |
| else: | |
| resized_actual_cond[key] = cond_item | |
| # look for control | |
| elif key == "control": | |
| control_item = cond_item | |
| if hasattr(control_item, "sub_idxs"): | |
| prepare_control_objects(control_item, full_idxs) | |
| else: | |
| raise ValueError(f"Control type {type(control_item).__name__} may not support required features for sliding context window; \ | |
| use Control objects from Kosinkadink/ComfyUI-Advanced-ControlNet nodes, or make sure Advanced-ControlNet is updated.") | |
| resized_actual_cond[key] = control_item | |
| del control_item | |
| elif isinstance(cond_item, dict): | |
| new_cond_item = cond_item.copy() | |
| # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) | |
| for cond_key, cond_value in new_cond_item.items(): | |
| if isinstance(cond_value, Tensor): | |
| if cond_value.size(0) == x_in.size(0): | |
| new_cond_item[cond_key] = cond_value[full_idxs] | |
| # if has cond that is a Tensor, check if needs to be subset | |
| elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, Tensor): | |
| if cond_value.cond.size(0) == x_in.size(0): | |
| new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond[full_idxs]) | |
| resized_actual_cond[key] = new_cond_item | |
| else: | |
| resized_actual_cond[key] = cond_item | |
| finally: | |
| del cond_item # just in case to prevent VRAM issues | |
| resized_cond.append(resized_actual_cond) | |
| return resized_cond | |
| # get context windows | |
| ADGS.params.context_options.step = ADGS.current_step | |
| context_windows = get_context_windows(ADGS.params.full_length, ADGS.params.context_options) | |
| # figure out how input is split | |
| batched_conds = x_in.size(0)//ADGS.params.full_length | |
| if ADGS.motion_models is not None: | |
| ADGS.motion_models.set_view_options(ADGS.params.context_options.view_options) | |
| # prepare final cond, uncond, and out_count | |
| cond_final = torch.zeros_like(x_in) | |
| uncond_final = torch.zeros_like(x_in) | |
| out_count_final = torch.zeros((x_in.shape[0], 1, 1, 1), device=x_in.device) | |
| bias_final = [0.0] * x_in.shape[0] | |
| # perform calc_cond_uncond_batch per context window | |
| for ctx_idxs in context_windows: | |
| ADGS.params.sub_idxs = ctx_idxs | |
| if ADGS.motion_models is not None: | |
| ADGS.motion_models.set_sub_idxs(ctx_idxs) | |
| ADGS.motion_models.set_video_length(len(ctx_idxs), ADGS.params.full_length) | |
| # update exposed params | |
| model_options["transformer_options"]["ad_params"]["sub_idxs"] = ctx_idxs | |
| model_options["transformer_options"]["ad_params"]["context_length"] = len(ctx_idxs) | |
| # account for all portions of input frames | |
| full_idxs = [] | |
| for n in range(batched_conds): | |
| for ind in ctx_idxs: | |
| full_idxs.append((ADGS.params.full_length*n)+ind) | |
| # get subsections of x, timestep, cond, uncond, cond_concat | |
| sub_x = x_in[full_idxs] | |
| sub_timestep = timestep[full_idxs] | |
| sub_cond = get_resized_cond(cond, full_idxs) if cond is not None else None | |
| sub_uncond = get_resized_cond(uncond, full_idxs) if uncond is not None else None | |
| sub_cond_out, sub_uncond_out = comfy.samplers.calc_cond_uncond_batch(model, sub_cond, sub_uncond, sub_x, sub_timestep, model_options) | |
| if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: | |
| full_length = ADGS.params.full_length | |
| for pos, idx in enumerate(ctx_idxs): | |
| # bias is the influence of a specific index in relation to the whole context window | |
| bias = 1 - abs(idx - (ctx_idxs[0] + ctx_idxs[-1]) / 2) / ((ctx_idxs[-1] - ctx_idxs[0] + 1e-2) / 2) | |
| bias = max(1e-2, bias) | |
| # take weighted average relative to total bias of current idx | |
| # and account for batched_conds | |
| for n in range(batched_conds): | |
| bias_total = bias_final[(full_length*n)+idx] | |
| prev_weight = (bias_total / (bias_total + bias)) | |
| new_weight = (bias / (bias_total + bias)) | |
| cond_final[(full_length*n)+idx] = cond_final[(full_length*n)+idx] * prev_weight + sub_cond_out[(full_length*n)+pos] * new_weight | |
| uncond_final[(full_length*n)+idx] = uncond_final[(full_length*n)+idx] * prev_weight + sub_uncond_out[(full_length*n)+pos] * new_weight | |
| bias_final[(full_length*n)+idx] = bias_total + bias | |
| else: | |
| # add conds and counts based on weights of fuse method | |
| weights = get_context_weights(len(ctx_idxs), ADGS.params.context_options.fuse_method) * batched_conds | |
| weights_tensor = torch.Tensor(weights).to(device=x_in.device).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) | |
| cond_final[full_idxs] += sub_cond_out * weights_tensor | |
| uncond_final[full_idxs] += sub_uncond_out * weights_tensor | |
| out_count_final[full_idxs] += weights_tensor | |
| if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: | |
| # already normalized, so return as is | |
| del out_count_final | |
| return cond_final, uncond_final | |
| else: | |
| # normalize cond and uncond via division by context usage counts | |
| cond_final /= out_count_final | |
| uncond_final /= out_count_final | |
| del out_count_final | |
| return cond_final, uncond_final | |