Spaces:
Running
on
Zero
Running
on
Zero
Upload folder using huggingface_hub
Browse files- modules/Model/ModelBase.py +23 -14
- modules/NeuralNetwork/unet.py +48 -34
- modules/Utilities/util.py +24 -13
- modules/cond/cond.py +87 -55
- modules/sample/CFG.py +56 -20
- modules/sample/ksampler_util.py +73 -23
- modules/sample/samplers.py +44 -221
- modules/sample/sampling.py +313 -370
- modules/user/GUI.py +8 -4
- modules/user/pipeline.py +6 -6
modules/Model/ModelBase.py
CHANGED
|
@@ -56,7 +56,9 @@ class BaseModel(torch.nn.Module):
|
|
| 56 |
**unet_config, device=device, operations=operations
|
| 57 |
)
|
| 58 |
self.model_type = model_type
|
| 59 |
-
self.model_sampling = sampling.model_sampling(
|
|
|
|
|
|
|
| 60 |
|
| 61 |
self.adm_channels = unet_config.get("adm_in_channels", None)
|
| 62 |
if self.adm_channels is None:
|
|
@@ -93,26 +95,32 @@ class BaseModel(torch.nn.Module):
|
|
| 93 |
"""
|
| 94 |
sigma = t
|
| 95 |
xc = self.model_sampling.calculate_input(sigma, x)
|
| 96 |
-
if c_concat is not None:
|
| 97 |
-
xc = torch.cat([xc] + [c_concat], dim=1)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
|
|
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
|
|
|
| 105 |
xc = xc.to(dtype)
|
| 106 |
t = self.model_sampling.timestep(t).float()
|
| 107 |
-
context =
|
|
|
|
|
|
|
| 108 |
extra_conds = {}
|
| 109 |
-
for
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
extra_conds[o] = extra
|
| 115 |
|
|
|
|
| 116 |
model_output = self.diffusion_model(
|
| 117 |
xc,
|
| 118 |
t,
|
|
@@ -121,6 +129,7 @@ class BaseModel(torch.nn.Module):
|
|
| 121 |
transformer_options=transformer_options,
|
| 122 |
**extra_conds,
|
| 123 |
).float()
|
|
|
|
| 124 |
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
| 125 |
|
| 126 |
def get_dtype(self) -> torch.dtype:
|
|
|
|
| 56 |
**unet_config, device=device, operations=operations
|
| 57 |
)
|
| 58 |
self.model_type = model_type
|
| 59 |
+
self.model_sampling = sampling.model_sampling(
|
| 60 |
+
model_config, model_type, flux=flux
|
| 61 |
+
)
|
| 62 |
|
| 63 |
self.adm_channels = unet_config.get("adm_in_channels", None)
|
| 64 |
if self.adm_channels is None:
|
|
|
|
| 95 |
"""
|
| 96 |
sigma = t
|
| 97 |
xc = self.model_sampling.calculate_input(sigma, x)
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
# Optimize concatenation operation by avoiding unnecessary list creation
|
| 100 |
+
if c_concat is not None:
|
| 101 |
+
xc = torch.cat((xc, c_concat), dim=1)
|
| 102 |
|
| 103 |
+
# Determine dtype once to avoid repeated calls to get_dtype()
|
| 104 |
+
dtype = (
|
| 105 |
+
self.manual_cast_dtype
|
| 106 |
+
if self.manual_cast_dtype is not None
|
| 107 |
+
else self.get_dtype()
|
| 108 |
+
)
|
| 109 |
|
| 110 |
+
# Batch operations to reduce overhead
|
| 111 |
xc = xc.to(dtype)
|
| 112 |
t = self.model_sampling.timestep(t).float()
|
| 113 |
+
context = c_crossattn.to(dtype) if c_crossattn is not None else None
|
| 114 |
+
|
| 115 |
+
# Process extra conditions more efficiently
|
| 116 |
extra_conds = {}
|
| 117 |
+
for name, value in kwargs.items():
|
| 118 |
+
if hasattr(value, "dtype") and value.dtype not in (torch.int, torch.long):
|
| 119 |
+
extra_conds[name] = value.to(dtype)
|
| 120 |
+
else:
|
| 121 |
+
extra_conds[name] = value
|
|
|
|
| 122 |
|
| 123 |
+
# Run diffusion model and calculate denoised output
|
| 124 |
model_output = self.diffusion_model(
|
| 125 |
xc,
|
| 126 |
t,
|
|
|
|
| 129 |
transformer_options=transformer_options,
|
| 130 |
**extra_conds,
|
| 131 |
).float()
|
| 132 |
+
|
| 133 |
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
| 134 |
|
| 135 |
def get_dtype(self) -> torch.dtype:
|
modules/NeuralNetwork/unet.py
CHANGED
|
@@ -304,7 +304,9 @@ class UNetModel1(nn.Module):
|
|
| 304 |
if num_heads_upsample == -1:
|
| 305 |
num_heads_upsample = num_heads
|
| 306 |
if num_head_channels == -1:
|
| 307 |
-
assert num_heads != -1,
|
|
|
|
|
|
|
| 308 |
|
| 309 |
self.in_channels = in_channels
|
| 310 |
self.model_channels = model_channels
|
|
@@ -684,36 +686,29 @@ class UNetModel1(nn.Module):
|
|
| 684 |
transformer_options: Dict[str, Any] = {},
|
| 685 |
**kwargs: Any,
|
| 686 |
) -> torch.Tensor:
|
| 687 |
-
"""#### Forward pass of the UNet model.
|
| 688 |
-
|
| 689 |
-
#### Args:
|
| 690 |
-
- `x` (torch.Tensor): The input tensor.
|
| 691 |
-
- `timesteps` (Optional[torch.Tensor], optional): The timesteps tensor. Defaults to None.
|
| 692 |
-
- `context` (Optional[torch.Tensor], optional): The context tensor. Defaults to None.
|
| 693 |
-
- `y` (Optional[torch.Tensor], optional): The class labels tensor. Defaults to None.
|
| 694 |
-
- `control` (Optional[torch.Tensor], optional): The control tensor. Defaults to None.
|
| 695 |
-
- `transformer_options` (Dict[str, Any], optional): Options for the transformer. Defaults to {}.
|
| 696 |
-
- `**kwargs` (Any): Additional keyword arguments.
|
| 697 |
-
|
| 698 |
-
#### Returns:
|
| 699 |
-
- `torch.Tensor`: The output tensor.
|
| 700 |
-
"""
|
| 701 |
transformer_options["original_shape"] = list(x.shape)
|
| 702 |
transformer_options["transformer_index"] = 0
|
| 703 |
-
transformer_patches = transformer_options.get("patches", {})
|
| 704 |
|
|
|
|
| 705 |
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
| 706 |
image_only_indicator = kwargs.get("image_only_indicator", None)
|
| 707 |
time_context = kwargs.get("time_context", None)
|
| 708 |
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
).to(
|
|
|
|
|
|
|
| 716 |
emb = self.time_embed(t_emb)
|
|
|
|
|
|
|
|
|
|
| 717 |
h = x
|
| 718 |
for id, module in enumerate(self.input_blocks):
|
| 719 |
transformer_options["block"] = ("input", id)
|
|
@@ -730,6 +725,7 @@ class UNetModel1(nn.Module):
|
|
| 730 |
h = apply_control1(h, control, "input")
|
| 731 |
hs.append(h)
|
| 732 |
|
|
|
|
| 733 |
transformer_options["block"] = ("middle", 0)
|
| 734 |
if self.middle_block is not None:
|
| 735 |
h = ResBlock.forward_timestep_embed1(
|
|
@@ -744,17 +740,19 @@ class UNetModel1(nn.Module):
|
|
| 744 |
)
|
| 745 |
h = apply_control1(h, control, "middle")
|
| 746 |
|
|
|
|
| 747 |
for id, module in enumerate(self.output_blocks):
|
| 748 |
transformer_options["block"] = ("output", id)
|
| 749 |
hsp = hs.pop()
|
| 750 |
hsp = apply_control1(hsp, control, "output")
|
| 751 |
|
|
|
|
| 752 |
h = torch.cat([h, hsp], dim=1)
|
| 753 |
-
del hsp
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
else
|
| 757 |
-
|
| 758 |
h = ResBlock.forward_timestep_embed1(
|
| 759 |
module,
|
| 760 |
h,
|
|
@@ -766,11 +764,15 @@ class UNetModel1(nn.Module):
|
|
| 766 |
num_video_frames=num_video_frames,
|
| 767 |
image_only_indicator=image_only_indicator,
|
| 768 |
)
|
|
|
|
|
|
|
| 769 |
h = h.type(x.dtype)
|
| 770 |
return self.out(h)
|
| 771 |
|
| 772 |
|
| 773 |
-
def detect_unet_config(
|
|
|
|
|
|
|
| 774 |
"""#### Detect the UNet configuration from a state dictionary.
|
| 775 |
|
| 776 |
#### Args:
|
|
@@ -1017,7 +1019,9 @@ def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) ->
|
|
| 1017 |
// model_channels
|
| 1018 |
)
|
| 1019 |
|
| 1020 |
-
out = transformer.calculate_transformer_depth(
|
|
|
|
|
|
|
| 1021 |
if out is not None:
|
| 1022 |
transformer_depth.append(out[0])
|
| 1023 |
if context_dim is None:
|
|
@@ -1076,7 +1080,9 @@ def detect_unet_config(state_dict: Dict[str, torch.Tensor], key_prefix: str) ->
|
|
| 1076 |
return unet_config
|
| 1077 |
|
| 1078 |
|
| 1079 |
-
def model_config_from_unet_config(
|
|
|
|
|
|
|
| 1080 |
"""#### Get the model configuration from a UNet configuration.
|
| 1081 |
|
| 1082 |
#### Args:
|
|
@@ -1096,7 +1102,11 @@ def model_config_from_unet_config(unet_config: Dict[str, Any], state_dict: Optio
|
|
| 1096 |
return None
|
| 1097 |
|
| 1098 |
|
| 1099 |
-
def model_config_from_unet(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1100 |
"""#### Get the model configuration from a UNet state dictionary.
|
| 1101 |
|
| 1102 |
#### Args:
|
|
@@ -1117,7 +1127,11 @@ def model_config_from_unet(state_dict: Dict[str, torch.Tensor], unet_key_prefix:
|
|
| 1117 |
def unet_dtype1(
|
| 1118 |
device: Optional[torch.device] = None,
|
| 1119 |
model_params: int = 0,
|
| 1120 |
-
supported_dtypes: List[torch.dtype] = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1121 |
) -> torch.dtype:
|
| 1122 |
"""#### Get the dtype for the UNet model.
|
| 1123 |
|
|
@@ -1129,4 +1143,4 @@ def unet_dtype1(
|
|
| 1129 |
#### Returns:
|
| 1130 |
- `torch.dtype`: The dtype for the UNet model.
|
| 1131 |
"""
|
| 1132 |
-
return torch.float16
|
|
|
|
| 304 |
if num_heads_upsample == -1:
|
| 305 |
num_heads_upsample = num_heads
|
| 306 |
if num_head_channels == -1:
|
| 307 |
+
assert num_heads != -1, (
|
| 308 |
+
"Either num_heads or num_head_channels has to be set"
|
| 309 |
+
)
|
| 310 |
|
| 311 |
self.in_channels = in_channels
|
| 312 |
self.model_channels = model_channels
|
|
|
|
| 686 |
transformer_options: Dict[str, Any] = {},
|
| 687 |
**kwargs: Any,
|
| 688 |
) -> torch.Tensor:
|
| 689 |
+
"""#### Forward pass of the UNet model with optimized calculations."""
|
| 690 |
+
# Setup transformer options (avoid unused variable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
transformer_options["original_shape"] = list(x.shape)
|
| 692 |
transformer_options["transformer_index"] = 0
|
|
|
|
| 693 |
|
| 694 |
+
# Extract kwargs efficiently
|
| 695 |
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
|
| 696 |
image_only_indicator = kwargs.get("image_only_indicator", None)
|
| 697 |
time_context = kwargs.get("time_context", None)
|
| 698 |
|
| 699 |
+
# Validation
|
| 700 |
+
assert (y is not None) == (self.num_classes is not None), (
|
| 701 |
+
"must specify y if and only if the model is class-conditional"
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
# Time embedding - optimize by computing with target dtype directly
|
| 705 |
+
t_emb = sampling_util.timestep_embedding(timesteps, self.model_channels).to(
|
| 706 |
+
x.dtype
|
| 707 |
+
)
|
| 708 |
emb = self.time_embed(t_emb)
|
| 709 |
+
|
| 710 |
+
# Input blocks processing
|
| 711 |
+
hs = []
|
| 712 |
h = x
|
| 713 |
for id, module in enumerate(self.input_blocks):
|
| 714 |
transformer_options["block"] = ("input", id)
|
|
|
|
| 725 |
h = apply_control1(h, control, "input")
|
| 726 |
hs.append(h)
|
| 727 |
|
| 728 |
+
# Middle block processing
|
| 729 |
transformer_options["block"] = ("middle", 0)
|
| 730 |
if self.middle_block is not None:
|
| 731 |
h = ResBlock.forward_timestep_embed1(
|
|
|
|
| 740 |
)
|
| 741 |
h = apply_control1(h, control, "middle")
|
| 742 |
|
| 743 |
+
# Output blocks processing - optimize memory usage
|
| 744 |
for id, module in enumerate(self.output_blocks):
|
| 745 |
transformer_options["block"] = ("output", id)
|
| 746 |
hsp = hs.pop()
|
| 747 |
hsp = apply_control1(hsp, control, "output")
|
| 748 |
|
| 749 |
+
# Concatenate tensors
|
| 750 |
h = torch.cat([h, hsp], dim=1)
|
| 751 |
+
del hsp # Free memory immediately
|
| 752 |
+
|
| 753 |
+
# Only calculate output shape when needed
|
| 754 |
+
output_shape = hs[-1].shape if hs else None
|
| 755 |
+
|
| 756 |
h = ResBlock.forward_timestep_embed1(
|
| 757 |
module,
|
| 758 |
h,
|
|
|
|
| 764 |
num_video_frames=num_video_frames,
|
| 765 |
image_only_indicator=image_only_indicator,
|
| 766 |
)
|
| 767 |
+
|
| 768 |
+
# Ensure output has correct dtype
|
| 769 |
h = h.type(x.dtype)
|
| 770 |
return self.out(h)
|
| 771 |
|
| 772 |
|
| 773 |
+
def detect_unet_config(
|
| 774 |
+
state_dict: Dict[str, torch.Tensor], key_prefix: str
|
| 775 |
+
) -> Dict[str, Any]:
|
| 776 |
"""#### Detect the UNet configuration from a state dictionary.
|
| 777 |
|
| 778 |
#### Args:
|
|
|
|
| 1019 |
// model_channels
|
| 1020 |
)
|
| 1021 |
|
| 1022 |
+
out = transformer.calculate_transformer_depth(
|
| 1023 |
+
prefix, state_dict_keys, state_dict
|
| 1024 |
+
)
|
| 1025 |
if out is not None:
|
| 1026 |
transformer_depth.append(out[0])
|
| 1027 |
if context_dim is None:
|
|
|
|
| 1080 |
return unet_config
|
| 1081 |
|
| 1082 |
|
| 1083 |
+
def model_config_from_unet_config(
|
| 1084 |
+
unet_config: Dict[str, Any], state_dict: Optional[Dict[str, torch.Tensor]] = None
|
| 1085 |
+
) -> Any:
|
| 1086 |
"""#### Get the model configuration from a UNet configuration.
|
| 1087 |
|
| 1088 |
#### Args:
|
|
|
|
| 1102 |
return None
|
| 1103 |
|
| 1104 |
|
| 1105 |
+
def model_config_from_unet(
|
| 1106 |
+
state_dict: Dict[str, torch.Tensor],
|
| 1107 |
+
unet_key_prefix: str,
|
| 1108 |
+
use_base_if_no_match: bool = False,
|
| 1109 |
+
) -> Any:
|
| 1110 |
"""#### Get the model configuration from a UNet state dictionary.
|
| 1111 |
|
| 1112 |
#### Args:
|
|
|
|
| 1127 |
def unet_dtype1(
|
| 1128 |
device: Optional[torch.device] = None,
|
| 1129 |
model_params: int = 0,
|
| 1130 |
+
supported_dtypes: List[torch.dtype] = [
|
| 1131 |
+
torch.float16,
|
| 1132 |
+
torch.bfloat16,
|
| 1133 |
+
torch.float32,
|
| 1134 |
+
],
|
| 1135 |
) -> torch.dtype:
|
| 1136 |
"""#### Get the dtype for the UNet model.
|
| 1137 |
|
|
|
|
| 1143 |
#### Returns:
|
| 1144 |
- `torch.dtype`: The dtype for the UNet model.
|
| 1145 |
"""
|
| 1146 |
+
return torch.float16
|
modules/Utilities/util.py
CHANGED
|
@@ -4,7 +4,6 @@ import itertools
|
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
import os
|
| 7 |
-
import pickle
|
| 8 |
import safetensors.torch
|
| 9 |
import torch
|
| 10 |
|
|
@@ -120,6 +119,18 @@ def state_dict_prefix_replace(
|
|
| 120 |
return out
|
| 121 |
|
| 122 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
def repeat_to_batch_size(
|
| 124 |
tensor: torch.Tensor, batch_size: int, dim: int = 0
|
| 125 |
) -> torch.Tensor:
|
|
@@ -437,11 +448,11 @@ def tiled_scale_multidim(
|
|
| 437 |
|
| 438 |
def get_upscale(dim: int, val: int) -> int:
|
| 439 |
"""#### Get the upscale value.
|
| 440 |
-
|
| 441 |
#### Args:
|
| 442 |
- `dim` (int): The dimension.
|
| 443 |
- `val` (int): The value.
|
| 444 |
-
|
| 445 |
#### Returns:
|
| 446 |
- `int`: The upscaled value.
|
| 447 |
"""
|
|
@@ -453,11 +464,11 @@ def tiled_scale_multidim(
|
|
| 453 |
|
| 454 |
def get_downscale(dim: int, val: int) -> int:
|
| 455 |
"""#### Get the downscale value.
|
| 456 |
-
|
| 457 |
#### Args:
|
| 458 |
- `dim` (int): The dimension.
|
| 459 |
- `val` (int): The value.
|
| 460 |
-
|
| 461 |
#### Returns:
|
| 462 |
- `int`: The downscaled value.
|
| 463 |
"""
|
|
@@ -469,11 +480,11 @@ def tiled_scale_multidim(
|
|
| 469 |
|
| 470 |
def get_upscale_pos(dim: int, val: int) -> int:
|
| 471 |
"""#### Get the upscaled position.
|
| 472 |
-
|
| 473 |
#### Args:
|
| 474 |
- `dim` (int): The dimension.
|
| 475 |
- `val` (int): The value.
|
| 476 |
-
|
| 477 |
#### Returns:
|
| 478 |
- `int`: The upscaled position.
|
| 479 |
"""
|
|
@@ -485,11 +496,11 @@ def tiled_scale_multidim(
|
|
| 485 |
|
| 486 |
def get_downscale_pos(dim: int, val: int) -> int:
|
| 487 |
"""#### Get the downscaled position.
|
| 488 |
-
|
| 489 |
#### Args:
|
| 490 |
- `dim` (int): The dimension.
|
| 491 |
- `val` (int): The value.
|
| 492 |
-
|
| 493 |
#### Returns:
|
| 494 |
- `int`: The downscaled position.
|
| 495 |
"""
|
|
@@ -508,10 +519,10 @@ def tiled_scale_multidim(
|
|
| 508 |
|
| 509 |
def mult_list_upscale(a: list) -> list:
|
| 510 |
"""#### Multiply a list by the upscale amount.
|
| 511 |
-
|
| 512 |
#### Args:
|
| 513 |
- `a` (list): The list.
|
| 514 |
-
|
| 515 |
#### Returns:
|
| 516 |
- `list`: The multiplied list.
|
| 517 |
"""
|
|
@@ -601,7 +612,7 @@ def tiled_scale(
|
|
| 601 |
pbar: any = None,
|
| 602 |
):
|
| 603 |
"""#### Scale an image using a tiled approach.
|
| 604 |
-
|
| 605 |
#### Args:
|
| 606 |
- `samples` (torch.Tensor): The input samples.
|
| 607 |
- `function` (function): The scaling function.
|
|
@@ -612,7 +623,7 @@ def tiled_scale(
|
|
| 612 |
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
|
| 613 |
- `output_device` (str, optional): The output device. Defaults to "cpu".
|
| 614 |
- `pbar` (any, optional): The progress bar. Defaults to None.
|
| 615 |
-
|
| 616 |
#### Returns:
|
| 617 |
- The scaled image.
|
| 618 |
"""
|
|
|
|
| 4 |
import logging
|
| 5 |
import math
|
| 6 |
import os
|
|
|
|
| 7 |
import safetensors.torch
|
| 8 |
import torch
|
| 9 |
|
|
|
|
| 119 |
return out
|
| 120 |
|
| 121 |
|
| 122 |
+
def lcm_of_list(numbers):
|
| 123 |
+
"""Calculate LCM of a list of numbers more efficiently."""
|
| 124 |
+
if not numbers:
|
| 125 |
+
return 1
|
| 126 |
+
|
| 127 |
+
result = numbers[0]
|
| 128 |
+
for num in numbers[1:]:
|
| 129 |
+
result = torch.lcm(torch.tensor(result), torch.tensor(num)).item()
|
| 130 |
+
|
| 131 |
+
return result
|
| 132 |
+
|
| 133 |
+
|
| 134 |
def repeat_to_batch_size(
|
| 135 |
tensor: torch.Tensor, batch_size: int, dim: int = 0
|
| 136 |
) -> torch.Tensor:
|
|
|
|
| 448 |
|
| 449 |
def get_upscale(dim: int, val: int) -> int:
|
| 450 |
"""#### Get the upscale value.
|
| 451 |
+
|
| 452 |
#### Args:
|
| 453 |
- `dim` (int): The dimension.
|
| 454 |
- `val` (int): The value.
|
| 455 |
+
|
| 456 |
#### Returns:
|
| 457 |
- `int`: The upscaled value.
|
| 458 |
"""
|
|
|
|
| 464 |
|
| 465 |
def get_downscale(dim: int, val: int) -> int:
|
| 466 |
"""#### Get the downscale value.
|
| 467 |
+
|
| 468 |
#### Args:
|
| 469 |
- `dim` (int): The dimension.
|
| 470 |
- `val` (int): The value.
|
| 471 |
+
|
| 472 |
#### Returns:
|
| 473 |
- `int`: The downscaled value.
|
| 474 |
"""
|
|
|
|
| 480 |
|
| 481 |
def get_upscale_pos(dim: int, val: int) -> int:
|
| 482 |
"""#### Get the upscaled position.
|
| 483 |
+
|
| 484 |
#### Args:
|
| 485 |
- `dim` (int): The dimension.
|
| 486 |
- `val` (int): The value.
|
| 487 |
+
|
| 488 |
#### Returns:
|
| 489 |
- `int`: The upscaled position.
|
| 490 |
"""
|
|
|
|
| 496 |
|
| 497 |
def get_downscale_pos(dim: int, val: int) -> int:
|
| 498 |
"""#### Get the downscaled position.
|
| 499 |
+
|
| 500 |
#### Args:
|
| 501 |
- `dim` (int): The dimension.
|
| 502 |
- `val` (int): The value.
|
| 503 |
+
|
| 504 |
#### Returns:
|
| 505 |
- `int`: The downscaled position.
|
| 506 |
"""
|
|
|
|
| 519 |
|
| 520 |
def mult_list_upscale(a: list) -> list:
|
| 521 |
"""#### Multiply a list by the upscale amount.
|
| 522 |
+
|
| 523 |
#### Args:
|
| 524 |
- `a` (list): The list.
|
| 525 |
+
|
| 526 |
#### Returns:
|
| 527 |
- `list`: The multiplied list.
|
| 528 |
"""
|
|
|
|
| 612 |
pbar: any = None,
|
| 613 |
):
|
| 614 |
"""#### Scale an image using a tiled approach.
|
| 615 |
+
|
| 616 |
#### Args:
|
| 617 |
- `samples` (torch.Tensor): The input samples.
|
| 618 |
- `function` (function): The scaling function.
|
|
|
|
| 623 |
- `out_channels` (int, optional): The number of output channels. Defaults to 3.
|
| 624 |
- `output_device` (str, optional): The output device. Defaults to "cpu".
|
| 625 |
- `pbar` (any, optional): The progress bar. Defaults to None.
|
| 626 |
+
|
| 627 |
#### Returns:
|
| 628 |
- The scaled image.
|
| 629 |
"""
|
modules/cond/cond.py
CHANGED
|
@@ -42,13 +42,13 @@ class CONDRegular:
|
|
| 42 |
return self._copy_with(
|
| 43 |
util.repeat_to_batch_size(self.cond, batch_size).to(device)
|
| 44 |
)
|
| 45 |
-
|
| 46 |
def can_concat(self, other: "CONDRegular") -> bool:
|
| 47 |
"""#### Check if conditions can be concatenated.
|
| 48 |
-
|
| 49 |
#### Args:
|
| 50 |
- `other` (CONDRegular): The other condition.
|
| 51 |
-
|
| 52 |
#### Returns:
|
| 53 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
| 54 |
"""
|
|
@@ -58,10 +58,10 @@ class CONDRegular:
|
|
| 58 |
|
| 59 |
def concat(self, others: list) -> torch.Tensor:
|
| 60 |
"""#### Concatenate conditions.
|
| 61 |
-
|
| 62 |
#### Args:
|
| 63 |
- `others` (list): The list of other conditions.
|
| 64 |
-
|
| 65 |
#### Returns:
|
| 66 |
- `torch.Tensor`: The concatenated conditions.
|
| 67 |
"""
|
|
@@ -76,11 +76,11 @@ class CONDCrossAttn(CONDRegular):
|
|
| 76 |
|
| 77 |
def can_concat(self, other: "CONDRegular") -> bool:
|
| 78 |
"""#### Check if conditions can be concatenated.
|
| 79 |
-
|
| 80 |
#### Args:
|
| 81 |
- `other` (CONDRegular): The other condition.
|
| 82 |
-
|
| 83 |
-
#### Returns:
|
| 84 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
| 85 |
"""
|
| 86 |
s1 = self.cond.shape
|
|
@@ -96,31 +96,34 @@ class CONDCrossAttn(CONDRegular):
|
|
| 96 |
): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
| 97 |
return False
|
| 98 |
return True
|
| 99 |
-
|
| 100 |
-
def concat(self, others: list) -> torch.Tensor:
|
| 101 |
-
"""#### Concatenate cross-attention conditions.
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
#### Returns:
|
| 107 |
-
- `torch.Tensor`: The concatenated conditions.
|
| 108 |
-
"""
|
| 109 |
conds = [self.cond]
|
| 110 |
-
|
|
|
|
|
|
|
| 111 |
for x in others:
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
conds.append(c)
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
def convert_cond(cond: list) -> list:
|
|
@@ -277,8 +280,10 @@ def calc_cond_batch(
|
|
| 277 |
out_c += output[o] * mult[o]
|
| 278 |
out_cts += mult[o]
|
| 279 |
|
|
|
|
| 280 |
for i in range(len(out_conds)):
|
| 281 |
-
|
|
|
|
| 282 |
|
| 283 |
return out_conds
|
| 284 |
|
|
@@ -328,48 +333,75 @@ def encode_model_conds(
|
|
| 328 |
conds[t] = x
|
| 329 |
return conds
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
#### Args:
|
| 335 |
-
- `conditions` (list): The list of conditions.
|
| 336 |
-
- `dims` (tuple): The dimensions.
|
| 337 |
-
- `device` (torch.device): The device.
|
| 338 |
-
"""
|
| 339 |
-
# We need to decide on an area outside the sampling loop in order to properly generate opposite areas of equal sizes.
|
| 340 |
-
# While we're doing this, we can also resolve the mask device and scaling for performance reasons
|
| 341 |
for i in range(len(conditions)):
|
| 342 |
c = conditions[i]
|
|
|
|
| 343 |
if "area" in c:
|
| 344 |
area = c["area"]
|
| 345 |
if area[0] == "percentage":
|
| 346 |
-
|
| 347 |
a = area[1:]
|
| 348 |
a_len = len(a) // 2
|
| 349 |
-
area = ()
|
| 350 |
-
for d in range(len(dims)):
|
| 351 |
-
area += (max(1, round(a[d] * dims[d])),)
|
| 352 |
-
for d in range(len(dims)):
|
| 353 |
-
area += (round(a[d + a_len] * dims[d]),)
|
| 354 |
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
if "mask" in c:
|
| 360 |
-
mask = c["mask"]
|
| 361 |
-
mask = mask.to(device=device)
|
| 362 |
modified = c.copy()
|
|
|
|
|
|
|
|
|
|
| 363 |
if len(mask.shape) == len(dims):
|
| 364 |
mask = mask.unsqueeze(0)
|
|
|
|
|
|
|
| 365 |
if mask.shape[1:] != dims:
|
| 366 |
-
mask
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
modified["mask"] = mask
|
| 371 |
conditions[i] = modified
|
| 372 |
|
|
|
|
| 373 |
def process_conds(
|
| 374 |
model: object,
|
| 375 |
noise: torch.Tensor,
|
|
@@ -442,4 +474,4 @@ def process_conds(
|
|
| 442 |
positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]
|
| 443 |
)
|
| 444 |
|
| 445 |
-
return conds
|
|
|
|
| 42 |
return self._copy_with(
|
| 43 |
util.repeat_to_batch_size(self.cond, batch_size).to(device)
|
| 44 |
)
|
| 45 |
+
|
| 46 |
def can_concat(self, other: "CONDRegular") -> bool:
|
| 47 |
"""#### Check if conditions can be concatenated.
|
| 48 |
+
|
| 49 |
#### Args:
|
| 50 |
- `other` (CONDRegular): The other condition.
|
| 51 |
+
|
| 52 |
#### Returns:
|
| 53 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
| 54 |
"""
|
|
|
|
| 58 |
|
| 59 |
def concat(self, others: list) -> torch.Tensor:
|
| 60 |
"""#### Concatenate conditions.
|
| 61 |
+
|
| 62 |
#### Args:
|
| 63 |
- `others` (list): The list of other conditions.
|
| 64 |
+
|
| 65 |
#### Returns:
|
| 66 |
- `torch.Tensor`: The concatenated conditions.
|
| 67 |
"""
|
|
|
|
| 76 |
|
| 77 |
def can_concat(self, other: "CONDRegular") -> bool:
|
| 78 |
"""#### Check if conditions can be concatenated.
|
| 79 |
+
|
| 80 |
#### Args:
|
| 81 |
- `other` (CONDRegular): The other condition.
|
| 82 |
+
|
| 83 |
+
#### Returns:
|
| 84 |
- `bool`: True if conditions can be concatenated, False otherwise.
|
| 85 |
"""
|
| 86 |
s1 = self.cond.shape
|
|
|
|
| 96 |
): # arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
|
| 97 |
return False
|
| 98 |
return True
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
def concat(self, others: list) -> torch.Tensor:
|
| 101 |
+
"""Optimized version of cross-attention condition concatenation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
conds = [self.cond]
|
| 103 |
+
shapes = [self.cond.shape[1]]
|
| 104 |
+
|
| 105 |
+
# Collect all conditions and their shapes
|
| 106 |
for x in others:
|
| 107 |
+
conds.append(x.cond)
|
| 108 |
+
shapes.append(x.cond.shape[1])
|
|
|
|
| 109 |
|
| 110 |
+
# Calculate LCM more efficiently
|
| 111 |
+
crossattn_max_len = util.lcm_of_list(shapes)
|
| 112 |
+
|
| 113 |
+
# Process and concat in one step where possible
|
| 114 |
+
if all(c.shape[1] == shapes[0] for c in conds):
|
| 115 |
+
# All same length, simple concatenation
|
| 116 |
+
return torch.cat(conds)
|
| 117 |
+
else:
|
| 118 |
+
# Process conditions that need repeating
|
| 119 |
+
out = []
|
| 120 |
+
for c in conds:
|
| 121 |
+
if c.shape[1] < crossattn_max_len:
|
| 122 |
+
repeat_factor = crossattn_max_len // c.shape[1]
|
| 123 |
+
# Use repeat instead of individual operations
|
| 124 |
+
c = c.repeat(1, repeat_factor, 1)
|
| 125 |
+
out.append(c)
|
| 126 |
+
return torch.cat(out)
|
| 127 |
|
| 128 |
|
| 129 |
def convert_cond(cond: list) -> list:
|
|
|
|
| 280 |
out_c += output[o] * mult[o]
|
| 281 |
out_cts += mult[o]
|
| 282 |
|
| 283 |
+
# Vectorize the division at the end
|
| 284 |
for i in range(len(out_conds)):
|
| 285 |
+
# Inplace division is already efficient
|
| 286 |
+
out_conds[i].div_(out_counts[i]) # Using .div_ instead of /= for clarity
|
| 287 |
|
| 288 |
return out_conds
|
| 289 |
|
|
|
|
| 333 |
conds[t] = x
|
| 334 |
return conds
|
| 335 |
|
| 336 |
+
|
| 337 |
+
def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
| 338 |
+
"""Optimized version that processes areas and masks more efficiently"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
for i in range(len(conditions)):
|
| 340 |
c = conditions[i]
|
| 341 |
+
# Process area
|
| 342 |
if "area" in c:
|
| 343 |
area = c["area"]
|
| 344 |
if area[0] == "percentage":
|
| 345 |
+
# Vectorized calculation of area dimensions
|
| 346 |
a = area[1:]
|
| 347 |
a_len = len(a) // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
|
| 349 |
+
# Calculate all dimensions at once using tensor operations
|
| 350 |
+
dims_tensor = torch.tensor(dims, device="cpu")
|
| 351 |
+
first_part = torch.tensor(a[:a_len], device="cpu") * dims_tensor
|
| 352 |
+
second_part = torch.tensor(a[a_len:], device="cpu") * dims_tensor
|
| 353 |
+
|
| 354 |
+
# Convert to rounded integers and tuple
|
| 355 |
+
first_part = torch.max(
|
| 356 |
+
torch.ones_like(first_part), torch.round(first_part)
|
| 357 |
+
)
|
| 358 |
+
second_part = torch.round(second_part)
|
| 359 |
|
| 360 |
+
# Create the new area tuple
|
| 361 |
+
new_area = tuple(first_part.int().tolist()) + tuple(
|
| 362 |
+
second_part.int().tolist()
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Create a modified copy with the new area
|
| 366 |
+
modified = c.copy()
|
| 367 |
+
modified["area"] = new_area
|
| 368 |
+
conditions[i] = modified
|
| 369 |
+
|
| 370 |
+
# Process mask
|
| 371 |
if "mask" in c:
|
|
|
|
|
|
|
| 372 |
modified = c.copy()
|
| 373 |
+
mask = c["mask"].to(device=device)
|
| 374 |
+
|
| 375 |
+
# Combine dimension checks and unsqueeze operation
|
| 376 |
if len(mask.shape) == len(dims):
|
| 377 |
mask = mask.unsqueeze(0)
|
| 378 |
+
|
| 379 |
+
# Only interpolate if needed
|
| 380 |
if mask.shape[1:] != dims:
|
| 381 |
+
# Optimize interpolation by ensuring mask is in the right format for the operation
|
| 382 |
+
if len(mask.shape) == 3 and mask.shape[0] == 1:
|
| 383 |
+
# Already in the right format for interpolation
|
| 384 |
+
mask = torch.nn.functional.interpolate(
|
| 385 |
+
mask.unsqueeze(1),
|
| 386 |
+
size=dims,
|
| 387 |
+
mode="bilinear",
|
| 388 |
+
align_corners=False,
|
| 389 |
+
).squeeze(1)
|
| 390 |
+
else:
|
| 391 |
+
# Ensure mask is properly formatted for interpolation
|
| 392 |
+
mask = torch.nn.functional.interpolate(
|
| 393 |
+
mask
|
| 394 |
+
if len(mask.shape) > 3 and mask.shape[1] == 1
|
| 395 |
+
else mask.unsqueeze(1),
|
| 396 |
+
size=dims,
|
| 397 |
+
mode="bilinear",
|
| 398 |
+
align_corners=False,
|
| 399 |
+
).squeeze(1)
|
| 400 |
|
| 401 |
modified["mask"] = mask
|
| 402 |
conditions[i] = modified
|
| 403 |
|
| 404 |
+
|
| 405 |
def process_conds(
|
| 406 |
model: object,
|
| 407 |
noise: torch.Tensor,
|
|
|
|
| 474 |
positive, conds[k], "gligen", lambda cond_cnets, x: cond_cnets[x]
|
| 475 |
)
|
| 476 |
|
| 477 |
+
return conds
|
modules/sample/CFG.py
CHANGED
|
@@ -30,10 +30,15 @@ def cfg_function(
|
|
| 30 |
#### Returns:
|
| 31 |
- `torch.Tensor`: The CFG result.
|
| 32 |
"""
|
|
|
|
| 33 |
if "sampler_cfg_function" in model_options:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
args = {
|
| 35 |
-
"cond":
|
| 36 |
-
"uncond":
|
| 37 |
"cond_scale": cond_scale,
|
| 38 |
"timestep": timestep,
|
| 39 |
"input": x,
|
|
@@ -45,9 +50,18 @@ def cfg_function(
|
|
| 45 |
}
|
| 46 |
cfg_result = x - model_options["sampler_cfg_function"](args)
|
| 47 |
else:
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
args = {
|
| 52 |
"denoised": cfg_result,
|
| 53 |
"cond": cond,
|
|
@@ -59,7 +73,12 @@ def cfg_function(
|
|
| 59 |
"model_options": model_options,
|
| 60 |
"input": x,
|
| 61 |
}
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
return cfg_result
|
| 65 |
|
|
@@ -89,21 +108,29 @@ def sampling_function(
|
|
| 89 |
#### Returns:
|
| 90 |
- `torch.Tensor`: The sampled tensor.
|
| 91 |
"""
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
|
|
|
| 100 |
conds = [condo, uncond_]
|
| 101 |
-
out = cond.calc_cond_batch(model, conds, x, timestep, model_options)
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
args = {
|
| 105 |
"conds": conds,
|
| 106 |
-
"conds_out":
|
| 107 |
"cond_scale": cond_scale,
|
| 108 |
"timestep": timestep,
|
| 109 |
"input": x,
|
|
@@ -111,12 +138,20 @@ def sampling_function(
|
|
| 111 |
"model": model,
|
| 112 |
"model_options": model_options,
|
| 113 |
}
|
| 114 |
-
out = fn(args)
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
return cfg_function(
|
| 117 |
model,
|
| 118 |
-
|
| 119 |
-
|
| 120 |
cond_scale,
|
| 121 |
x,
|
| 122 |
timestep,
|
|
@@ -128,6 +163,7 @@ def sampling_function(
|
|
| 128 |
|
| 129 |
class CFGGuider:
|
| 130 |
"""#### Class for guiding the sampling process with CFG."""
|
|
|
|
| 131 |
def __init__(self, model_patcher, flux=False):
|
| 132 |
"""#### Initialize the CFGGuider.
|
| 133 |
|
|
@@ -315,4 +351,4 @@ class CFGGuider:
|
|
| 315 |
del self.inner_model
|
| 316 |
del self.conds
|
| 317 |
del self.loaded_models
|
| 318 |
-
return output
|
|
|
|
| 30 |
#### Returns:
|
| 31 |
- `torch.Tensor`: The CFG result.
|
| 32 |
"""
|
| 33 |
+
# Check for custom sampler CFG function first
|
| 34 |
if "sampler_cfg_function" in model_options:
|
| 35 |
+
# Precompute differences to avoid redundant operations
|
| 36 |
+
cond_diff = x - cond_pred
|
| 37 |
+
uncond_diff = x - uncond_pred
|
| 38 |
+
|
| 39 |
args = {
|
| 40 |
+
"cond": cond_diff,
|
| 41 |
+
"uncond": uncond_diff,
|
| 42 |
"cond_scale": cond_scale,
|
| 43 |
"timestep": timestep,
|
| 44 |
"input": x,
|
|
|
|
| 50 |
}
|
| 51 |
cfg_result = x - model_options["sampler_cfg_function"](args)
|
| 52 |
else:
|
| 53 |
+
# Standard CFG calculation - optimized to avoid intermediate tensor allocation
|
| 54 |
+
# When cond_scale = 1.0, we can just return cond_pred without computation
|
| 55 |
+
if math.isclose(cond_scale, 1.0):
|
| 56 |
+
cfg_result = cond_pred
|
| 57 |
+
else:
|
| 58 |
+
# Fused operation: uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
| 59 |
+
# Equivalent to: uncond_pred * (1 - cond_scale) + cond_pred * cond_scale
|
| 60 |
+
cfg_result = torch.lerp(uncond_pred, cond_pred, cond_scale)
|
| 61 |
+
|
| 62 |
+
# Apply post-CFG functions if any
|
| 63 |
+
post_cfg_functions = model_options.get("sampler_post_cfg_function", [])
|
| 64 |
+
if post_cfg_functions:
|
| 65 |
args = {
|
| 66 |
"denoised": cfg_result,
|
| 67 |
"cond": cond,
|
|
|
|
| 73 |
"model_options": model_options,
|
| 74 |
"input": x,
|
| 75 |
}
|
| 76 |
+
|
| 77 |
+
# Apply each post-CFG function in sequence
|
| 78 |
+
for fn in post_cfg_functions:
|
| 79 |
+
cfg_result = fn(args)
|
| 80 |
+
# Update the denoised result for the next function
|
| 81 |
+
args["denoised"] = cfg_result
|
| 82 |
|
| 83 |
return cfg_result
|
| 84 |
|
|
|
|
| 108 |
#### Returns:
|
| 109 |
- `torch.Tensor`: The sampled tensor.
|
| 110 |
"""
|
| 111 |
+
# Optimize conditional logic for uncond
|
| 112 |
+
uncond_ = (
|
| 113 |
+
None
|
| 114 |
+
if (
|
| 115 |
+
math.isclose(cond_scale, 1.0)
|
| 116 |
+
and not model_options.get("disable_cfg1_optimization", False)
|
| 117 |
+
)
|
| 118 |
+
else uncond
|
| 119 |
+
)
|
| 120 |
|
| 121 |
+
# Create conditions list once
|
| 122 |
conds = [condo, uncond_]
|
|
|
|
| 123 |
|
| 124 |
+
# Get model predictions for both conditions
|
| 125 |
+
cond_outputs = cond.calc_cond_batch(model, conds, x, timestep, model_options)
|
| 126 |
+
|
| 127 |
+
# Apply pre-CFG functions if any
|
| 128 |
+
pre_cfg_functions = model_options.get("sampler_pre_cfg_function", [])
|
| 129 |
+
if pre_cfg_functions:
|
| 130 |
+
# Create args dictionary once
|
| 131 |
args = {
|
| 132 |
"conds": conds,
|
| 133 |
+
"conds_out": cond_outputs,
|
| 134 |
"cond_scale": cond_scale,
|
| 135 |
"timestep": timestep,
|
| 136 |
"input": x,
|
|
|
|
| 138 |
"model": model,
|
| 139 |
"model_options": model_options,
|
| 140 |
}
|
|
|
|
| 141 |
|
| 142 |
+
# Apply each pre-CFG function
|
| 143 |
+
for fn in pre_cfg_functions:
|
| 144 |
+
cond_outputs = fn(args)
|
| 145 |
+
args["conds_out"] = cond_outputs
|
| 146 |
+
|
| 147 |
+
# Extract conditional and unconditional outputs explicitly for clarity
|
| 148 |
+
cond_pred, uncond_pred = cond_outputs[0], cond_outputs[1]
|
| 149 |
+
|
| 150 |
+
# Apply the CFG function
|
| 151 |
return cfg_function(
|
| 152 |
model,
|
| 153 |
+
cond_pred,
|
| 154 |
+
uncond_pred,
|
| 155 |
cond_scale,
|
| 156 |
x,
|
| 157 |
timestep,
|
|
|
|
| 163 |
|
| 164 |
class CFGGuider:
|
| 165 |
"""#### Class for guiding the sampling process with CFG."""
|
| 166 |
+
|
| 167 |
def __init__(self, model_patcher, flux=False):
|
| 168 |
"""#### Initialize the CFGGuider.
|
| 169 |
|
|
|
|
| 351 |
del self.inner_model
|
| 352 |
del self.conds
|
| 353 |
del self.loaded_models
|
| 354 |
+
return output
|
modules/sample/ksampler_util.py
CHANGED
|
@@ -46,6 +46,7 @@ def pre_run_control(model: torch.nn.Module, conds: list) -> None:
|
|
| 46 |
|
| 47 |
def percent_to_timestep_function(a):
|
| 48 |
return s.percent_to_sigma(a)
|
|
|
|
| 49 |
if "control" in x:
|
| 50 |
x["control"].pre_run(model, percent_to_timestep_function)
|
| 51 |
|
|
@@ -96,9 +97,13 @@ def apply_empty_x_to_equal_area(
|
|
| 96 |
uncond[temp[1]] = n
|
| 97 |
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
"""#### Get the area and multiplier.
|
| 103 |
|
| 104 |
#### Args:
|
|
@@ -109,26 +114,39 @@ def get_area_and_mult(
|
|
| 109 |
#### Returns:
|
| 110 |
- `collections.namedtuple`: The area and multiplier.
|
| 111 |
"""
|
| 112 |
-
|
| 113 |
-
|
| 114 |
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
mult = mask * strength
|
| 118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
conditioning = {}
|
| 120 |
model_conds = conds["model_conds"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
for c in model_conds:
|
| 122 |
conditioning[c] = model_conds[c].process_cond(
|
| 123 |
-
batch_size=
|
| 124 |
)
|
| 125 |
|
|
|
|
| 126 |
control = conds.get("control", None)
|
| 127 |
patches = None
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
)
|
| 131 |
-
return cond_obj(input_x, mult, conditioning, area, control, patches)
|
| 132 |
|
| 133 |
|
| 134 |
def normal_scheduler(
|
|
@@ -158,6 +176,7 @@ def normal_scheduler(
|
|
| 158 |
sigs += [0.0]
|
| 159 |
return torch.FloatTensor(sigs)
|
| 160 |
|
|
|
|
| 161 |
def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
|
| 162 |
"""#### Create a simple scheduler.
|
| 163 |
|
|
@@ -176,21 +195,52 @@ def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.Float
|
|
| 176 |
sigs += [0.0]
|
| 177 |
return torch.FloatTensor(sigs)
|
| 178 |
|
|
|
|
| 179 |
# Implemented based on: https://arxiv.org/abs/2407.12173
|
| 180 |
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
| 185 |
-
sigs = []
|
| 186 |
-
last_t = -1
|
| 187 |
-
for t in ts:
|
| 188 |
-
if t != last_t:
|
| 189 |
-
sigs += [float(model_sampling.sigmas[int(t)])]
|
| 190 |
-
last_t = t
|
| 191 |
-
sigs += [0.0]
|
| 192 |
return torch.FloatTensor(sigs)
|
| 193 |
|
|
|
|
| 194 |
def calculate_sigmas(
|
| 195 |
model_sampling: torch.nn.Module, scheduler_name: str, steps: int
|
| 196 |
) -> torch.Tensor:
|
|
|
|
| 46 |
|
| 47 |
def percent_to_timestep_function(a):
|
| 48 |
return s.percent_to_sigma(a)
|
| 49 |
+
|
| 50 |
if "control" in x:
|
| 51 |
x["control"].pre_run(model, percent_to_timestep_function)
|
| 52 |
|
|
|
|
| 97 |
uncond[temp[1]] = n
|
| 98 |
|
| 99 |
|
| 100 |
+
# Define the namedtuple class once outside the function for reuse
|
| 101 |
+
CondObj = collections.namedtuple(
|
| 102 |
+
"cond_obj", ["input_x", "mult", "conditioning", "area", "control", "patches"]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def get_area_and_mult(conds: dict, x_in: torch.Tensor, timestep_in: int) -> CondObj:
|
| 107 |
"""#### Get the area and multiplier.
|
| 108 |
|
| 109 |
#### Args:
|
|
|
|
| 114 |
#### Returns:
|
| 115 |
- `collections.namedtuple`: The area and multiplier.
|
| 116 |
"""
|
| 117 |
+
# Cache shape information to avoid repeated access
|
| 118 |
+
x_shape = x_in.shape
|
| 119 |
|
| 120 |
+
# Define area dimensions in one operation
|
| 121 |
+
area = (x_shape[2], x_shape[3], 0, 0)
|
|
|
|
| 122 |
|
| 123 |
+
# Extract input region efficiently
|
| 124 |
+
# Since area[2] and area[3] are 0, this is essentially taking the full tensor
|
| 125 |
+
# But we maintain the slice operation for consistency
|
| 126 |
+
input_x = x_in[:, :, : area[0], : area[1]]
|
| 127 |
+
|
| 128 |
+
# Create multiplier tensor directly without intermediate mask creation
|
| 129 |
+
# This avoids an unnecessary tensor allocation and multiplication
|
| 130 |
+
mult = torch.ones_like(input_x) # strength is 1.0, so just create ones directly
|
| 131 |
+
|
| 132 |
+
# Prepare conditioning dictionary with cached device and batch_size
|
| 133 |
conditioning = {}
|
| 134 |
model_conds = conds["model_conds"]
|
| 135 |
+
batch_size = x_shape[0]
|
| 136 |
+
device = x_in.device
|
| 137 |
+
|
| 138 |
+
# Process conditions with cached parameters
|
| 139 |
for c in model_conds:
|
| 140 |
conditioning[c] = model_conds[c].process_cond(
|
| 141 |
+
batch_size=batch_size, device=device, area=area
|
| 142 |
)
|
| 143 |
|
| 144 |
+
# Get control directly without redundant variable assignment
|
| 145 |
control = conds.get("control", None)
|
| 146 |
patches = None
|
| 147 |
+
|
| 148 |
+
# Use the pre-defined namedtuple class instead of creating it every call
|
| 149 |
+
return CondObj(input_x, mult, conditioning, area, control, patches)
|
|
|
|
| 150 |
|
| 151 |
|
| 152 |
def normal_scheduler(
|
|
|
|
| 176 |
sigs += [0.0]
|
| 177 |
return torch.FloatTensor(sigs)
|
| 178 |
|
| 179 |
+
|
| 180 |
def simple_scheduler(model_sampling: torch.nn.Module, steps: int) -> torch.FloatTensor:
|
| 181 |
"""#### Create a simple scheduler.
|
| 182 |
|
|
|
|
| 195 |
sigs += [0.0]
|
| 196 |
return torch.FloatTensor(sigs)
|
| 197 |
|
| 198 |
+
|
| 199 |
# Implemented based on: https://arxiv.org/abs/2407.12173
|
| 200 |
def beta_scheduler(model_sampling, steps, alpha=0.6, beta=0.6):
|
| 201 |
+
"""Creates a beta scheduler for noise levels based on the beta distribution.
|
| 202 |
+
|
| 203 |
+
This optimized implementation efficiently computes sigmas using the beta
|
| 204 |
+
distribution and caches calculations where possible.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
model_sampling: Model sampling module
|
| 208 |
+
steps: Number of steps
|
| 209 |
+
alpha: Alpha parameter for beta distribution
|
| 210 |
+
beta: Beta parameter for beta distribution
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
torch.FloatTensor: Tensor of sigma values for each step
|
| 214 |
+
"""
|
| 215 |
+
# Calculate total timesteps once
|
| 216 |
+
total_timesteps = len(model_sampling.sigmas) - 1
|
| 217 |
+
|
| 218 |
+
# Create a cache dictionary for reused values
|
| 219 |
+
model_sigmas = model_sampling.sigmas
|
| 220 |
+
|
| 221 |
+
# Generate evenly spaced values in [0,1) interval
|
| 222 |
+
ts_normalized = np.linspace(0, 1, steps, endpoint=False)
|
| 223 |
+
|
| 224 |
+
# Apply beta inverse CDF to get sampled time points - vectorized operation
|
| 225 |
+
ts_beta = scipy.stats.beta.ppf(1 - ts_normalized, alpha, beta)
|
| 226 |
+
|
| 227 |
+
# Scale to timestep indices and round to integers
|
| 228 |
+
ts_indices = np.rint(ts_beta * total_timesteps).astype(np.int32)
|
| 229 |
+
|
| 230 |
+
# Use numpy's unique function with return_index to efficiently find unique values
|
| 231 |
+
# while preserving order
|
| 232 |
+
unique_ts, indices = np.unique(ts_indices, return_index=True)
|
| 233 |
+
ordered_unique_ts = unique_ts[np.argsort(indices)]
|
| 234 |
+
|
| 235 |
+
# Map indices to sigma values efficiently
|
| 236 |
+
sigs = [float(model_sigmas[idx]) for idx in ordered_unique_ts]
|
| 237 |
+
|
| 238 |
+
# Add final sigma value of 0.0
|
| 239 |
+
sigs.append(0.0)
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
return torch.FloatTensor(sigs)
|
| 242 |
|
| 243 |
+
|
| 244 |
def calculate_sigmas(
|
| 245 |
model_sampling: torch.nn.Module, scheduler_name: str, steps: int
|
| 246 |
) -> torch.Tensor:
|
modules/sample/samplers.py
CHANGED
|
@@ -142,181 +142,15 @@ def sample_euler(
|
|
| 142 |
return x
|
| 143 |
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
model,
|
| 148 |
-
x,
|
| 149 |
-
sigmas,
|
| 150 |
-
extra_args=None,
|
| 151 |
-
callback=None,
|
| 152 |
-
disable=None,
|
| 153 |
-
eta=1.0,
|
| 154 |
-
s_noise=1.0,
|
| 155 |
-
noise_sampler=None,
|
| 156 |
-
r=1 / 2,
|
| 157 |
-
pipeline=False,
|
| 158 |
-
seed=None,
|
| 159 |
-
):
|
| 160 |
-
# Pre-calculate common values
|
| 161 |
-
device = x.device
|
| 162 |
-
global disable_gui
|
| 163 |
-
disable_gui = pipeline
|
| 164 |
-
|
| 165 |
-
if not disable_gui:
|
| 166 |
-
from modules.AutoEncoders import taesd
|
| 167 |
-
from modules.user import app_instance
|
| 168 |
-
|
| 169 |
-
# Early return check
|
| 170 |
-
if len(sigmas) <= 1:
|
| 171 |
-
return x
|
| 172 |
-
|
| 173 |
-
# Pre-allocate tensors and values
|
| 174 |
-
s_in = torch.ones((x.shape[0],), device=device)
|
| 175 |
-
n_steps = len(sigmas) - 1
|
| 176 |
-
extra_args = {} if extra_args is None else extra_args
|
| 177 |
-
|
| 178 |
-
# Define helper functions
|
| 179 |
-
def sigma_fn(t):
|
| 180 |
-
return (-t).exp()
|
| 181 |
-
|
| 182 |
-
def t_fn(sigma):
|
| 183 |
-
return -sigma.log()
|
| 184 |
-
|
| 185 |
-
# Initialize noise sampler
|
| 186 |
-
if noise_sampler is None:
|
| 187 |
-
noise_sampler = sampling_util.BrownianTreeNoiseSampler(
|
| 188 |
-
x, sigmas[sigmas > 0].min(), sigmas.max(), seed=seed, cpu=True
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
for i in trange(n_steps, disable=disable):
|
| 192 |
-
if (
|
| 193 |
-
not pipeline
|
| 194 |
-
and hasattr(app_instance.app, "interrupt_flag")
|
| 195 |
-
and app_instance.app.interrupt_flag
|
| 196 |
-
):
|
| 197 |
-
return x
|
| 198 |
-
|
| 199 |
-
if not pipeline:
|
| 200 |
-
app_instance.app.progress.set(i / n_steps)
|
| 201 |
-
|
| 202 |
-
# Model inference
|
| 203 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 204 |
-
|
| 205 |
-
if callback is not None:
|
| 206 |
-
callback({"x": x, "i": i, "sigma": sigmas[i], "denoised": denoised})
|
| 207 |
-
|
| 208 |
-
if sigmas[i + 1] == 0:
|
| 209 |
-
# Single fused Euler step
|
| 210 |
-
x = x + util.to_d(x, sigmas[i], denoised) * (sigmas[i + 1] - sigmas[i])
|
| 211 |
-
else:
|
| 212 |
-
# Fused DPM-Solver++ steps
|
| 213 |
-
t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
|
| 214 |
-
s = t + (t_next - t) * r
|
| 215 |
-
|
| 216 |
-
# Step 1 - Combined calculations
|
| 217 |
-
sd, su = sampling_util.get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
|
| 218 |
-
s_ = t_fn(sd)
|
| 219 |
-
x_2 = (
|
| 220 |
-
(sigma_fn(s_) / sigma_fn(t)) * x
|
| 221 |
-
- (t - s_).expm1() * denoised
|
| 222 |
-
+ noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
|
| 226 |
-
|
| 227 |
-
# Step 2 - Combined calculations
|
| 228 |
-
sd, su = sampling_util.get_ancestral_step(
|
| 229 |
-
sigma_fn(t), sigma_fn(t_next), eta
|
| 230 |
-
)
|
| 231 |
-
t_next_ = t_fn(sd)
|
| 232 |
-
|
| 233 |
-
# Final update in single calculation
|
| 234 |
-
x = (
|
| 235 |
-
(sigma_fn(t_next_) / sigma_fn(t)) * x
|
| 236 |
-
- (t - t_next_).expm1()
|
| 237 |
-
* ((1 - 1 / (2 * r)) * denoised + (1 / (2 * r)) * denoised_2)
|
| 238 |
-
+ noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
|
| 239 |
-
)
|
| 240 |
-
|
| 241 |
-
# Preview updates
|
| 242 |
-
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
|
| 243 |
-
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
|
| 244 |
-
|
| 245 |
-
return x
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
@torch.no_grad()
|
| 249 |
-
def sample_dpmpp_2m(
|
| 250 |
-
model,
|
| 251 |
-
x,
|
| 252 |
-
sigmas,
|
| 253 |
-
extra_args=None,
|
| 254 |
-
callback=None,
|
| 255 |
-
disable=None,
|
| 256 |
-
pipeline=False,
|
| 257 |
):
|
| 258 |
-
""
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
if not disable_gui:
|
| 265 |
-
from modules.AutoEncoders import taesd
|
| 266 |
-
from modules.user import app_instance
|
| 267 |
-
|
| 268 |
-
# Pre-allocate tensors and transform sigmas
|
| 269 |
-
s_in = torch.ones((x.shape[0],), device=device)
|
| 270 |
-
t_steps = -torch.log(sigmas) # Fused calculation
|
| 271 |
-
|
| 272 |
-
# Pre-calculate all needed values in one go
|
| 273 |
-
sigma_steps = torch.exp(-t_steps) # Fused calculation
|
| 274 |
-
ratios = sigma_steps[1:] / sigma_steps[:-1]
|
| 275 |
-
h_steps = t_steps[1:] - t_steps[:-1]
|
| 276 |
-
|
| 277 |
-
old_denoised = None
|
| 278 |
-
extra_args = {} if extra_args is None else extra_args
|
| 279 |
-
|
| 280 |
-
for i in trange(len(sigmas) - 1, disable=disable):
|
| 281 |
-
if (
|
| 282 |
-
not pipeline
|
| 283 |
-
and hasattr(app_instance.app, "interrupt_flag")
|
| 284 |
-
and app_instance.app.interrupt_flag
|
| 285 |
-
):
|
| 286 |
-
return x
|
| 287 |
-
|
| 288 |
-
if not pipeline:
|
| 289 |
-
app_instance.app.progress.set(i / (len(sigmas) - 1))
|
| 290 |
-
|
| 291 |
-
# Fused model inference and update calculations
|
| 292 |
-
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
| 293 |
-
|
| 294 |
-
if callback is not None:
|
| 295 |
-
callback(
|
| 296 |
-
{
|
| 297 |
-
"x": x,
|
| 298 |
-
"i": i,
|
| 299 |
-
"sigma": sigmas[i],
|
| 300 |
-
"sigma_hat": sigmas[i],
|
| 301 |
-
"denoised": denoised,
|
| 302 |
-
}
|
| 303 |
-
)
|
| 304 |
-
|
| 305 |
-
# Combined update step
|
| 306 |
-
x = ratios[i] * x - (-h_steps[i]).expm1() * (
|
| 307 |
-
denoised
|
| 308 |
-
if old_denoised is None or sigmas[i + 1] == 0
|
| 309 |
-
else (1 + h_steps[i - 1] / (2 * h_steps[i])) * denoised
|
| 310 |
-
- (h_steps[i - 1] / (2 * h_steps[i])) * old_denoised
|
| 311 |
-
)
|
| 312 |
-
|
| 313 |
-
old_denoised = denoised
|
| 314 |
-
|
| 315 |
-
# Preview updates
|
| 316 |
-
if not pipeline and app_instance.app.previewer_var.get() and i % 5 == 0:
|
| 317 |
-
threading.Thread(target=taesd.taesd_preview, args=(x,)).start()
|
| 318 |
-
|
| 319 |
-
return x
|
| 320 |
|
| 321 |
|
| 322 |
@torch.no_grad()
|
|
@@ -354,17 +188,26 @@ def sample_dpmpp_2m_cfgpp(
|
|
| 354 |
ratios = sigma_steps[1:] / sigma_steps[:-1]
|
| 355 |
h_steps = t_steps[1:] - t_steps[:-1]
|
| 356 |
|
| 357 |
-
# CFG
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
progress = step / n_steps
|
| 361 |
-
return cfg_scale + (cfg_min - cfg_scale) * progress
|
| 362 |
|
| 363 |
old_denoised = None
|
| 364 |
old_uncond_denoised = None
|
| 365 |
extra_args = {} if extra_args is None else extra_args
|
| 366 |
|
| 367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
if (
|
| 369 |
not pipeline
|
| 370 |
and hasattr(app_instance.app, "interrupt_flag")
|
|
@@ -373,20 +216,10 @@ def sample_dpmpp_2m_cfgpp(
|
|
| 373 |
return x
|
| 374 |
|
| 375 |
if not pipeline:
|
| 376 |
-
app_instance.app.progress.set(i /
|
| 377 |
-
|
| 378 |
-
# Get current CFG scale
|
| 379 |
-
current_cfg = get_cfg_scale(i)
|
| 380 |
-
|
| 381 |
-
def post_cfg_function(args):
|
| 382 |
-
nonlocal old_uncond_denoised
|
| 383 |
-
old_uncond_denoised = args["uncond_denoised"]
|
| 384 |
-
return args["denoised"]
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
model_options, post_cfg_function, disable_cfg1_optimization=True
|
| 389 |
-
)
|
| 390 |
|
| 391 |
# Fused model inference and update calculations
|
| 392 |
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
@@ -406,27 +239,29 @@ def sample_dpmpp_2m_cfgpp(
|
|
| 406 |
}
|
| 407 |
)
|
| 408 |
|
| 409 |
-
# CFG++ update step
|
| 410 |
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
| 411 |
-
# First step or last step -
|
| 412 |
-
cfg_denoised = uncond_denoised
|
| 413 |
else:
|
| 414 |
-
#
|
| 415 |
-
x0_coeff = cfg_x0_scale * current_cfg
|
| 416 |
-
s_coeff = cfg_s_scale * current_cfg
|
| 417 |
-
|
| 418 |
-
# Momentum terms
|
| 419 |
h_ratio = h_steps[i - 1] / (2 * h_steps[i])
|
| 420 |
-
|
| 421 |
-
uncond_momentum = (
|
| 422 |
-
1 + h_ratio
|
| 423 |
-
) * uncond_denoised - h_ratio * old_uncond_denoised
|
| 424 |
|
| 425 |
-
#
|
| 426 |
-
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
old_denoised = denoised
|
| 432 |
old_uncond_denoised = uncond_denoised
|
|
@@ -438,17 +273,6 @@ def sample_dpmpp_2m_cfgpp(
|
|
| 438 |
return x
|
| 439 |
|
| 440 |
|
| 441 |
-
def set_model_options_post_cfg_function(
|
| 442 |
-
model_options, post_cfg_function, disable_cfg1_optimization=False
|
| 443 |
-
):
|
| 444 |
-
model_options["sampler_post_cfg_function"] = model_options.get(
|
| 445 |
-
"sampler_post_cfg_function", []
|
| 446 |
-
) + [post_cfg_function]
|
| 447 |
-
if disable_cfg1_optimization:
|
| 448 |
-
model_options["disable_cfg1_optimization"] = True
|
| 449 |
-
return model_options
|
| 450 |
-
|
| 451 |
-
|
| 452 |
@torch.no_grad()
|
| 453 |
def sample_dpmpp_sde_cfgpp(
|
| 454 |
model,
|
|
@@ -572,7 +396,6 @@ def sample_dpmpp_sde_cfgpp(
|
|
| 572 |
else:
|
| 573 |
# CFG++ with momentum
|
| 574 |
x0_coeff = cfg_x0_scale * current_cfg
|
| 575 |
-
s_coeff = cfg_s_scale * current_cfg
|
| 576 |
|
| 577 |
# Calculate momentum terms
|
| 578 |
h_ratio = (t - s_) / (2 * (t - t_next))
|
|
|
|
| 142 |
return x
|
| 143 |
|
| 144 |
|
| 145 |
+
def set_model_options_post_cfg_function(
|
| 146 |
+
model_options, post_cfg_function, disable_cfg1_optimization=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
):
|
| 148 |
+
model_options["sampler_post_cfg_function"] = model_options.get(
|
| 149 |
+
"sampler_post_cfg_function", []
|
| 150 |
+
) + [post_cfg_function]
|
| 151 |
+
if disable_cfg1_optimization:
|
| 152 |
+
model_options["disable_cfg1_optimization"] = True
|
| 153 |
+
return model_options
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
|
| 156 |
@torch.no_grad()
|
|
|
|
| 188 |
ratios = sigma_steps[1:] / sigma_steps[:-1]
|
| 189 |
h_steps = t_steps[1:] - t_steps[:-1]
|
| 190 |
|
| 191 |
+
# Pre-calculate CFG schedule for the entire sampling process
|
| 192 |
+
steps = torch.arange(n_steps, device=device)
|
| 193 |
+
cfg_values = cfg_scale + (cfg_min - cfg_scale) * (steps / n_steps)
|
|
|
|
|
|
|
| 194 |
|
| 195 |
old_denoised = None
|
| 196 |
old_uncond_denoised = None
|
| 197 |
extra_args = {} if extra_args is None else extra_args
|
| 198 |
|
| 199 |
+
# Define post-CFG function once outside the loop
|
| 200 |
+
def post_cfg_function(args):
|
| 201 |
+
nonlocal old_uncond_denoised
|
| 202 |
+
old_uncond_denoised = args["uncond_denoised"]
|
| 203 |
+
return args["denoised"]
|
| 204 |
+
|
| 205 |
+
model_options = extra_args.get("model_options", {}).copy()
|
| 206 |
+
extra_args["model_options"] = set_model_options_post_cfg_function(
|
| 207 |
+
model_options, post_cfg_function, disable_cfg1_optimization=True
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
for i in trange(n_steps, disable=disable):
|
| 211 |
if (
|
| 212 |
not pipeline
|
| 213 |
and hasattr(app_instance.app, "interrupt_flag")
|
|
|
|
| 216 |
return x
|
| 217 |
|
| 218 |
if not pipeline:
|
| 219 |
+
app_instance.app.progress.set(i / n_steps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
# Use pre-calculated CFG scale
|
| 222 |
+
current_cfg = cfg_values[i]
|
|
|
|
|
|
|
| 223 |
|
| 224 |
# Fused model inference and update calculations
|
| 225 |
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
|
|
|
| 239 |
}
|
| 240 |
)
|
| 241 |
|
| 242 |
+
# CFG++ update step using optimized operations
|
| 243 |
if old_uncond_denoised is None or sigmas[i + 1] == 0:
|
| 244 |
+
# First step or last step - use torch.lerp for efficient interpolation
|
| 245 |
+
cfg_denoised = torch.lerp(uncond_denoised, denoised, current_cfg)
|
| 246 |
else:
|
| 247 |
+
# Fused momentum calculations
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
h_ratio = h_steps[i - 1] / (2 * h_steps[i])
|
| 249 |
+
h_ratio_plus_1 = 1 + h_ratio
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
# Use fused multiply-add operations for momentum terms
|
| 252 |
+
momentum = torch.addcmul(denoised * h_ratio_plus_1, old_denoised, -h_ratio)
|
| 253 |
+
uncond_momentum = torch.addcmul(
|
| 254 |
+
uncond_denoised * h_ratio_plus_1, old_uncond_denoised, -h_ratio
|
| 255 |
+
)
|
| 256 |
|
| 257 |
+
# Optimized interpolation for CFG++ update
|
| 258 |
+
cfg_denoised = torch.lerp(
|
| 259 |
+
uncond_momentum, momentum, current_cfg * cfg_x0_scale
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# Apply update with pre-calculated expm1
|
| 263 |
+
h_expm1 = torch.expm1(-h_steps[i])
|
| 264 |
+
x = ratios[i] * x - h_expm1 * cfg_denoised
|
| 265 |
|
| 266 |
old_denoised = denoised
|
| 267 |
old_uncond_denoised = uncond_denoised
|
|
|
|
| 273 |
return x
|
| 274 |
|
| 275 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
@torch.no_grad()
|
| 277 |
def sample_dpmpp_sde_cfgpp(
|
| 278 |
model,
|
|
|
|
| 396 |
else:
|
| 397 |
# CFG++ with momentum
|
| 398 |
x0_coeff = cfg_x0_scale * current_cfg
|
|
|
|
| 399 |
|
| 400 |
# Calculate momentum terms
|
| 401 |
h_ratio = (t - s_) / (2 * (t - t_next))
|
modules/sample/sampling.py
CHANGED
|
@@ -76,6 +76,7 @@ class EPS:
|
|
| 76 |
if max_denoise:
|
| 77 |
noise = noise * torch.sqrt(1.0 + sigma**2.0)
|
| 78 |
else:
|
|
|
|
| 79 |
noise = noise * sigma
|
| 80 |
|
| 81 |
noise += latent_image
|
|
@@ -513,153 +514,22 @@ def ksampler(
|
|
| 513 |
#### Returns:
|
| 514 |
- `KSAMPLER`: The KSAMPLER object.
|
| 515 |
"""
|
| 516 |
-
if sampler_name == "
|
| 517 |
-
|
| 518 |
-
def dpmpp_2m_function(
|
| 519 |
-
model: torch.nn.Module,
|
| 520 |
-
noise: torch.Tensor,
|
| 521 |
-
sigmas: torch.Tensor,
|
| 522 |
-
extra_args: dict,
|
| 523 |
-
callback: callable,
|
| 524 |
-
disable: bool,
|
| 525 |
-
pipeline: bool,
|
| 526 |
-
**extra_options,
|
| 527 |
-
) -> torch.Tensor:
|
| 528 |
-
sigma_min = sigmas[-1]
|
| 529 |
-
if sigma_min == 0:
|
| 530 |
-
sigma_min = sigmas[-2]
|
| 531 |
-
return samplers.sample_dpmpp_2m(
|
| 532 |
-
model,
|
| 533 |
-
noise,
|
| 534 |
-
sigmas,
|
| 535 |
-
extra_args=extra_args,
|
| 536 |
-
callback=callback,
|
| 537 |
-
disable=disable,
|
| 538 |
-
pipeline=pipeline,
|
| 539 |
-
**extra_options,
|
| 540 |
-
)
|
| 541 |
-
|
| 542 |
-
sampler_function = dpmpp_2m_function
|
| 543 |
-
|
| 544 |
-
elif sampler_name == "dpmpp_2m_cfgpp":
|
| 545 |
-
|
| 546 |
-
def dpmpp_2m_dy_function(
|
| 547 |
-
model: torch.nn.Module,
|
| 548 |
-
noise: torch.Tensor,
|
| 549 |
-
sigmas: torch.Tensor,
|
| 550 |
-
extra_args: dict,
|
| 551 |
-
callback: callable,
|
| 552 |
-
disable: bool,
|
| 553 |
-
pipeline: bool,
|
| 554 |
-
**extra_options,
|
| 555 |
-
) -> torch.Tensor:
|
| 556 |
-
sigma_min = sigmas[-1]
|
| 557 |
-
if sigma_min == 0:
|
| 558 |
-
sigma_min = sigmas[-2]
|
| 559 |
-
return samplers.sample_dpmpp_2m_cfgpp(
|
| 560 |
-
model,
|
| 561 |
-
noise,
|
| 562 |
-
sigmas,
|
| 563 |
-
extra_args=extra_args,
|
| 564 |
-
callback=callback,
|
| 565 |
-
disable=disable,
|
| 566 |
-
pipeline=pipeline,
|
| 567 |
-
**extra_options,
|
| 568 |
-
)
|
| 569 |
-
|
| 570 |
-
sampler_function = dpmpp_2m_dy_function
|
| 571 |
-
|
| 572 |
-
elif sampler_name == "dpmpp_sde":
|
| 573 |
-
|
| 574 |
-
def dpmpp_sde_function(
|
| 575 |
-
model: torch.nn.Module,
|
| 576 |
-
noise: torch.Tensor,
|
| 577 |
-
sigmas: torch.Tensor,
|
| 578 |
-
extra_args: dict,
|
| 579 |
-
callback: callable,
|
| 580 |
-
disable: bool,
|
| 581 |
-
pipeline: bool,
|
| 582 |
-
**extra_options,
|
| 583 |
-
) -> torch.Tensor:
|
| 584 |
-
return samplers.sample_dpmpp_sde(
|
| 585 |
-
model,
|
| 586 |
-
noise,
|
| 587 |
-
sigmas,
|
| 588 |
-
extra_args=extra_args,
|
| 589 |
-
callback=callback,
|
| 590 |
-
disable=disable,
|
| 591 |
-
pipeline=pipeline,
|
| 592 |
-
**extra_options,
|
| 593 |
-
)
|
| 594 |
-
|
| 595 |
-
sampler_function = dpmpp_sde_function
|
| 596 |
|
| 597 |
elif sampler_name == "euler_ancestral":
|
| 598 |
-
|
| 599 |
-
def euler_ancestral_function(
|
| 600 |
-
model: torch.nn.Module,
|
| 601 |
-
noise: torch.Tensor,
|
| 602 |
-
sigmas: torch.Tensor,
|
| 603 |
-
extra_args: dict,
|
| 604 |
-
callback: callable,
|
| 605 |
-
disable: bool,
|
| 606 |
-
pipeline: bool,
|
| 607 |
-
) -> torch.Tensor:
|
| 608 |
-
return samplers.sample_euler_ancestral(
|
| 609 |
-
model,
|
| 610 |
-
noise,
|
| 611 |
-
sigmas,
|
| 612 |
-
extra_args=extra_args,
|
| 613 |
-
callback=callback,
|
| 614 |
-
disable=disable,
|
| 615 |
-
pipeline=pipeline,
|
| 616 |
-
**extra_options,
|
| 617 |
-
)
|
| 618 |
-
|
| 619 |
-
sampler_function = euler_ancestral_function
|
| 620 |
|
| 621 |
elif sampler_name == "dpmpp_sde_cfgpp":
|
| 622 |
-
|
| 623 |
-
def dpmpp_sde_dy_function(
|
| 624 |
-
model: torch.nn.Module,
|
| 625 |
-
noise: torch.Tensor,
|
| 626 |
-
sigmas: torch.Tensor,
|
| 627 |
-
extra_args: dict,
|
| 628 |
-
callback: callable,
|
| 629 |
-
disable: bool,
|
| 630 |
-
pipeline: bool,
|
| 631 |
-
**extra_options,
|
| 632 |
-
) -> torch.Tensor:
|
| 633 |
-
return samplers.sample_dpmpp_sde_cfgpp(
|
| 634 |
-
model,
|
| 635 |
-
noise,
|
| 636 |
-
sigmas,
|
| 637 |
-
extra_args=extra_args,
|
| 638 |
-
callback=callback,
|
| 639 |
-
disable=disable,
|
| 640 |
-
pipeline=pipeline,
|
| 641 |
-
**extra_options,
|
| 642 |
-
)
|
| 643 |
-
|
| 644 |
-
sampler_function = dpmpp_sde_dy_function
|
| 645 |
|
| 646 |
elif sampler_name == "euler":
|
|
|
|
| 647 |
|
| 648 |
-
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
model,
|
| 653 |
-
noise,
|
| 654 |
-
sigmas,
|
| 655 |
-
extra_args=extra_args,
|
| 656 |
-
callback=callback,
|
| 657 |
-
disable=disable,
|
| 658 |
-
pipeline=pipeline,
|
| 659 |
-
**extra_options,
|
| 660 |
-
)
|
| 661 |
-
|
| 662 |
-
sampler_function = euler_function
|
| 663 |
|
| 664 |
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
| 665 |
|
|
@@ -734,49 +604,49 @@ def sampler_object(name: str, pipeline: bool = False) -> KSAMPLER:
|
|
| 734 |
return sampler
|
| 735 |
|
| 736 |
|
| 737 |
-
class
|
| 738 |
-
"""
|
| 739 |
|
| 740 |
def __init__(
|
| 741 |
self,
|
| 742 |
-
model: torch.nn.Module,
|
| 743 |
-
steps: int,
|
| 744 |
-
device,
|
| 745 |
sampler: str = None,
|
| 746 |
scheduler: str = None,
|
| 747 |
-
denoise: float =
|
| 748 |
model_options: dict = {},
|
| 749 |
pipeline: bool = False,
|
| 750 |
):
|
| 751 |
-
"""
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
|
| 762 |
"""
|
| 763 |
self.model = model
|
| 764 |
-
self.device =
|
| 765 |
self.scheduler = scheduler
|
| 766 |
-
self.
|
| 767 |
-
self.set_steps(steps, denoise)
|
| 768 |
self.denoise = denoise
|
| 769 |
self.model_options = model_options
|
| 770 |
self.pipeline = pipeline
|
| 771 |
|
|
|
|
|
|
|
|
|
|
| 772 |
def calculate_sigmas(self, steps: int) -> torch.Tensor:
|
| 773 |
-
"""
|
| 774 |
|
| 775 |
-
|
| 776 |
-
|
| 777 |
|
| 778 |
-
|
| 779 |
-
|
| 780 |
"""
|
| 781 |
sigmas = ksampler_util.calculate_sigmas(
|
| 782 |
self.model.get_model_object("model_sampling"), self.scheduler, steps
|
|
@@ -784,11 +654,11 @@ class KSampler1:
|
|
| 784 |
return sigmas
|
| 785 |
|
| 786 |
def set_steps(self, steps: int, denoise: float = None):
|
| 787 |
-
"""
|
| 788 |
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
"""
|
| 793 |
self.steps = steps
|
| 794 |
if denoise is None or denoise > 0.9999:
|
|
@@ -801,7 +671,29 @@ class KSampler1:
|
|
| 801 |
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
| 802 |
self.sigmas = sigmas[-(steps + 1) :]
|
| 803 |
|
| 804 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
self,
|
| 806 |
noise: torch.Tensor,
|
| 807 |
positive: torch.Tensor,
|
|
@@ -816,48 +708,45 @@ class KSampler1:
|
|
| 816 |
callback: callable = None,
|
| 817 |
disable_pbar: bool = False,
|
| 818 |
seed: int = None,
|
| 819 |
-
pipeline: bool = False,
|
| 820 |
flux: bool = False,
|
| 821 |
) -> torch.Tensor:
|
| 822 |
-
"""
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
|
| 826 |
-
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
"""
|
|
|
|
|
|
|
|
|
|
| 843 |
if sigmas is None:
|
| 844 |
sigmas = self.sigmas
|
| 845 |
|
| 846 |
-
|
| 847 |
-
sigmas = sigmas[: last_step + 1]
|
| 848 |
-
if force_full_denoise:
|
| 849 |
-
sigmas[-1] = 0
|
| 850 |
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
|
|
|
| 854 |
else:
|
| 855 |
-
|
| 856 |
-
return latent_image
|
| 857 |
-
else:
|
| 858 |
-
return torch.zeros_like(noise)
|
| 859 |
|
| 860 |
-
|
| 861 |
|
| 862 |
return sample(
|
| 863 |
self.model,
|
|
@@ -866,7 +755,7 @@ class KSampler1:
|
|
| 866 |
negative,
|
| 867 |
cfg,
|
| 868 |
self.device,
|
| 869 |
-
|
| 870 |
sigmas,
|
| 871 |
self.model_options,
|
| 872 |
latent_image=latent_image,
|
|
@@ -874,11 +763,117 @@ class KSampler1:
|
|
| 874 |
callback=callback,
|
| 875 |
disable_pbar=disable_pbar,
|
| 876 |
seed=seed,
|
| 877 |
-
pipeline=pipeline,
|
| 878 |
flux=flux,
|
| 879 |
)
|
| 880 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 881 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 882 |
def sample1(
|
| 883 |
model: torch.nn.Module,
|
| 884 |
noise: torch.Tensor,
|
|
@@ -902,37 +897,37 @@ def sample1(
|
|
| 902 |
pipeline: bool = False,
|
| 903 |
flux: bool = False,
|
| 904 |
) -> torch.Tensor:
|
| 905 |
-
"""
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
-
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
-
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
-
|
| 929 |
-
|
| 930 |
-
|
|
|
|
| 931 |
"""
|
| 932 |
-
sampler =
|
| 933 |
-
model,
|
| 934 |
steps=steps,
|
| 935 |
-
device=model.load_device,
|
| 936 |
sampler=sampler_name,
|
| 937 |
scheduler=scheduler,
|
| 938 |
denoise=denoise,
|
|
@@ -940,7 +935,7 @@ def sample1(
|
|
| 940 |
pipeline=pipeline,
|
| 941 |
)
|
| 942 |
|
| 943 |
-
samples = sampler.
|
| 944 |
noise,
|
| 945 |
positive,
|
| 946 |
negative,
|
|
@@ -954,147 +949,12 @@ def sample1(
|
|
| 954 |
callback=callback,
|
| 955 |
disable_pbar=disable_pbar,
|
| 956 |
seed=seed,
|
| 957 |
-
pipeline=pipeline,
|
| 958 |
flux=flux,
|
| 959 |
)
|
| 960 |
samples = samples.to(Device.intermediate_device())
|
| 961 |
return samples
|
| 962 |
|
| 963 |
|
| 964 |
-
def common_ksampler(
|
| 965 |
-
model: torch.nn.Module,
|
| 966 |
-
seed: int,
|
| 967 |
-
steps: int,
|
| 968 |
-
cfg: float,
|
| 969 |
-
sampler_name: str,
|
| 970 |
-
scheduler: str,
|
| 971 |
-
positive: torch.Tensor,
|
| 972 |
-
negative: torch.Tensor,
|
| 973 |
-
latent: dict,
|
| 974 |
-
denoise: float = 1.0,
|
| 975 |
-
disable_noise: bool = False,
|
| 976 |
-
start_step: int = None,
|
| 977 |
-
last_step: int = None,
|
| 978 |
-
force_full_denoise: bool = False,
|
| 979 |
-
pipeline: bool = False,
|
| 980 |
-
flux: bool = False,
|
| 981 |
-
) -> tuple:
|
| 982 |
-
"""#### Common ksampler function.
|
| 983 |
-
|
| 984 |
-
#### Args:
|
| 985 |
-
- `model` (torch.nn.Module): The model.
|
| 986 |
-
- `seed` (int): The seed value.
|
| 987 |
-
- `steps` (int): The number of steps.
|
| 988 |
-
- `cfg` (float): The CFG value.
|
| 989 |
-
- `sampler_name` (str): The sampler name.
|
| 990 |
-
- `scheduler` (str): The scheduler name.
|
| 991 |
-
- `positive` (torch.Tensor): The positive tensor.
|
| 992 |
-
- `negative` (torch.Tensor): The negative tensor.
|
| 993 |
-
- `latent` (dict): The latent dictionary.
|
| 994 |
-
- `denoise` (float, optional): The denoise factor. Defaults to 1.0.
|
| 995 |
-
- `disable_noise` (bool, optional): Whether to disable noise. Defaults to False.
|
| 996 |
-
- `start_step` (int, optional): The start step. Defaults to None.
|
| 997 |
-
- `last_step` (int, optional): The last step. Defaults to None.
|
| 998 |
-
- `force_full_denoise` (bool, optional): Whether to force full denoise. Defaults to False.
|
| 999 |
-
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
|
| 1000 |
-
|
| 1001 |
-
#### Returns:
|
| 1002 |
-
- `tuple`: The output tuple containing the latent dictionary and samples.
|
| 1003 |
-
"""
|
| 1004 |
-
latent_image = latent["samples"]
|
| 1005 |
-
latent_image = Latent.fix_empty_latent_channels(model, latent_image)
|
| 1006 |
-
|
| 1007 |
-
if disable_noise:
|
| 1008 |
-
noise = torch.zeros(
|
| 1009 |
-
latent_image.size(),
|
| 1010 |
-
dtype=latent_image.dtype,
|
| 1011 |
-
layout=latent_image.layout,
|
| 1012 |
-
device="cpu",
|
| 1013 |
-
)
|
| 1014 |
-
else:
|
| 1015 |
-
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
| 1016 |
-
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
|
| 1017 |
-
|
| 1018 |
-
noise_mask = None
|
| 1019 |
-
if "noise_mask" in latent:
|
| 1020 |
-
noise_mask = latent["noise_mask"]
|
| 1021 |
-
samples = sample1(
|
| 1022 |
-
model,
|
| 1023 |
-
noise,
|
| 1024 |
-
steps,
|
| 1025 |
-
cfg,
|
| 1026 |
-
sampler_name,
|
| 1027 |
-
scheduler,
|
| 1028 |
-
positive,
|
| 1029 |
-
negative,
|
| 1030 |
-
latent_image,
|
| 1031 |
-
denoise=denoise,
|
| 1032 |
-
disable_noise=disable_noise,
|
| 1033 |
-
start_step=start_step,
|
| 1034 |
-
last_step=last_step,
|
| 1035 |
-
force_full_denoise=force_full_denoise,
|
| 1036 |
-
noise_mask=noise_mask,
|
| 1037 |
-
seed=seed,
|
| 1038 |
-
pipeline=pipeline,
|
| 1039 |
-
flux=flux,
|
| 1040 |
-
)
|
| 1041 |
-
out = latent.copy()
|
| 1042 |
-
out["samples"] = samples
|
| 1043 |
-
return (out,)
|
| 1044 |
-
|
| 1045 |
-
|
| 1046 |
-
class KSampler2:
|
| 1047 |
-
"""#### Class for KSampler2."""
|
| 1048 |
-
|
| 1049 |
-
def sample(
|
| 1050 |
-
self,
|
| 1051 |
-
model: torch.nn.Module,
|
| 1052 |
-
seed: int,
|
| 1053 |
-
steps: int,
|
| 1054 |
-
cfg: float,
|
| 1055 |
-
sampler_name: str,
|
| 1056 |
-
scheduler: str,
|
| 1057 |
-
positive: torch.Tensor,
|
| 1058 |
-
negative: torch.Tensor,
|
| 1059 |
-
latent_image: torch.Tensor,
|
| 1060 |
-
denoise: float = 1.0,
|
| 1061 |
-
pipeline: bool = False,
|
| 1062 |
-
flux: bool = False,
|
| 1063 |
-
) -> tuple:
|
| 1064 |
-
"""#### Sample using the KSampler2.
|
| 1065 |
-
|
| 1066 |
-
#### Args:
|
| 1067 |
-
- `model` (torch.nn.Module): The model.
|
| 1068 |
-
- `seed` (int): The seed value.
|
| 1069 |
-
- `steps` (int): The number of steps.
|
| 1070 |
-
- `cfg` (float): The CFG value.
|
| 1071 |
-
- `sampler_name` (str): The sampler name.
|
| 1072 |
-
- `scheduler` (str): The scheduler name.
|
| 1073 |
-
- `positive` (torch.Tensor): The positive tensor.
|
| 1074 |
-
- `negative` (torch.Tensor): The negative tensor.
|
| 1075 |
-
- `latent_image` (torch.Tensor): The latent image tensor.
|
| 1076 |
-
- `denoise` (float, optional): The denoise factor. Defaults to 1.0.
|
| 1077 |
-
- `pipeline` (bool, optional): Whether to use the pipeline. Defaults to False.
|
| 1078 |
-
|
| 1079 |
-
#### Returns:
|
| 1080 |
-
- `tuple`: The output tuple containing the latent dictionary and samples.
|
| 1081 |
-
"""
|
| 1082 |
-
return common_ksampler(
|
| 1083 |
-
model,
|
| 1084 |
-
seed,
|
| 1085 |
-
steps,
|
| 1086 |
-
cfg,
|
| 1087 |
-
sampler_name,
|
| 1088 |
-
scheduler,
|
| 1089 |
-
positive,
|
| 1090 |
-
negative,
|
| 1091 |
-
latent_image,
|
| 1092 |
-
denoise=denoise,
|
| 1093 |
-
pipeline=pipeline,
|
| 1094 |
-
flux=flux,
|
| 1095 |
-
)
|
| 1096 |
-
|
| 1097 |
-
|
| 1098 |
class ModelType(Enum):
|
| 1099 |
"""#### Enum for Model Types."""
|
| 1100 |
|
|
@@ -1187,3 +1047,86 @@ def sample_custom(
|
|
| 1187 |
)
|
| 1188 |
samples = samples.to(Device.intermediate_device())
|
| 1189 |
return samples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
if max_denoise:
|
| 77 |
noise = noise * torch.sqrt(1.0 + sigma**2.0)
|
| 78 |
else:
|
| 79 |
+
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
| 80 |
noise = noise * sigma
|
| 81 |
|
| 82 |
noise += latent_image
|
|
|
|
| 514 |
#### Returns:
|
| 515 |
- `KSAMPLER`: The KSAMPLER object.
|
| 516 |
"""
|
| 517 |
+
if sampler_name == "dpmpp_2m_cfgpp":
|
| 518 |
+
sampler_function = samplers.sample_dpmpp_2m_cfgpp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
|
| 520 |
elif sampler_name == "euler_ancestral":
|
| 521 |
+
sampler_function = samplers.sample_euler_ancestral
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 522 |
|
| 523 |
elif sampler_name == "dpmpp_sde_cfgpp":
|
| 524 |
+
sampler_function = samplers.sample_dpmpp_sde_cfgpp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
|
| 526 |
elif sampler_name == "euler":
|
| 527 |
+
sampler_function = samplers.sample_euler
|
| 528 |
|
| 529 |
+
else:
|
| 530 |
+
# Default fallback
|
| 531 |
+
sampler_function = samplers.sample_euler
|
| 532 |
+
print(f"Warning: Unknown sampler '{sampler_name}', falling back to euler")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 533 |
|
| 534 |
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
| 535 |
|
|
|
|
| 604 |
return sampler
|
| 605 |
|
| 606 |
|
| 607 |
+
class KSampler:
|
| 608 |
+
"""A unified sampler class that replaces both KSampler1 and KSampler2."""
|
| 609 |
|
| 610 |
def __init__(
|
| 611 |
self,
|
| 612 |
+
model: torch.nn.Module = None,
|
| 613 |
+
steps: int = None,
|
|
|
|
| 614 |
sampler: str = None,
|
| 615 |
scheduler: str = None,
|
| 616 |
+
denoise: float = 1.0,
|
| 617 |
model_options: dict = {},
|
| 618 |
pipeline: bool = False,
|
| 619 |
):
|
| 620 |
+
"""Initialize the KSampler class.
|
| 621 |
+
|
| 622 |
+
Args:
|
| 623 |
+
model (torch.nn.Module, optional): The model to use for sampling. Required for direct sampling.
|
| 624 |
+
steps (int, optional): The number of steps. Required for direct sampling.
|
| 625 |
+
sampler (str, optional): The sampler name. Defaults to None.
|
| 626 |
+
scheduler (str, optional): The scheduler name. Defaults to None.
|
| 627 |
+
denoise (float, optional): The denoise factor. Defaults to 1.0.
|
| 628 |
+
model_options (dict, optional): The model options. Defaults to {}.
|
| 629 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
|
|
|
| 630 |
"""
|
| 631 |
self.model = model
|
| 632 |
+
self.device = model.load_device if model is not None else None
|
| 633 |
self.scheduler = scheduler
|
| 634 |
+
self.sampler_name = sampler
|
|
|
|
| 635 |
self.denoise = denoise
|
| 636 |
self.model_options = model_options
|
| 637 |
self.pipeline = pipeline
|
| 638 |
|
| 639 |
+
if model is not None and steps is not None:
|
| 640 |
+
self.set_steps(steps, denoise)
|
| 641 |
+
|
| 642 |
def calculate_sigmas(self, steps: int) -> torch.Tensor:
|
| 643 |
+
"""Calculate the sigmas for the given steps.
|
| 644 |
|
| 645 |
+
Args:
|
| 646 |
+
steps (int): The number of steps.
|
| 647 |
|
| 648 |
+
Returns:
|
| 649 |
+
torch.Tensor: The calculated sigmas.
|
| 650 |
"""
|
| 651 |
sigmas = ksampler_util.calculate_sigmas(
|
| 652 |
self.model.get_model_object("model_sampling"), self.scheduler, steps
|
|
|
|
| 654 |
return sigmas
|
| 655 |
|
| 656 |
def set_steps(self, steps: int, denoise: float = None):
|
| 657 |
+
"""Set the steps and calculate the sigmas.
|
| 658 |
|
| 659 |
+
Args:
|
| 660 |
+
steps (int): The number of steps.
|
| 661 |
+
denoise (float, optional): The denoise factor. Defaults to None.
|
| 662 |
"""
|
| 663 |
self.steps = steps
|
| 664 |
if denoise is None or denoise > 0.9999:
|
|
|
|
| 671 |
sigmas = self.calculate_sigmas(new_steps).to(self.device)
|
| 672 |
self.sigmas = sigmas[-(steps + 1) :]
|
| 673 |
|
| 674 |
+
def _process_sigmas(self, sigmas, start_step, last_step, force_full_denoise):
|
| 675 |
+
"""Process sigmas based on start_step and last_step.
|
| 676 |
+
|
| 677 |
+
Args:
|
| 678 |
+
sigmas (torch.Tensor): The sigmas tensor.
|
| 679 |
+
start_step (int, optional): The start step. Defaults to None.
|
| 680 |
+
last_step (int, optional): The last step. Defaults to None.
|
| 681 |
+
force_full_denoise (bool): Whether to force full denoise.
|
| 682 |
+
|
| 683 |
+
Returns:
|
| 684 |
+
torch.Tensor: The processed sigmas.
|
| 685 |
+
"""
|
| 686 |
+
if last_step is not None and last_step < (len(sigmas) - 1):
|
| 687 |
+
sigmas = sigmas[: last_step + 1]
|
| 688 |
+
if force_full_denoise:
|
| 689 |
+
sigmas[-1] = 0
|
| 690 |
+
|
| 691 |
+
if start_step is not None and start_step < (len(sigmas) - 1):
|
| 692 |
+
sigmas = sigmas[start_step:]
|
| 693 |
+
|
| 694 |
+
return sigmas
|
| 695 |
+
|
| 696 |
+
def direct_sample(
|
| 697 |
self,
|
| 698 |
noise: torch.Tensor,
|
| 699 |
positive: torch.Tensor,
|
|
|
|
| 708 |
callback: callable = None,
|
| 709 |
disable_pbar: bool = False,
|
| 710 |
seed: int = None,
|
|
|
|
| 711 |
flux: bool = False,
|
| 712 |
) -> torch.Tensor:
|
| 713 |
+
"""Sample directly with the initialized model and parameters.
|
| 714 |
+
|
| 715 |
+
Args:
|
| 716 |
+
noise (torch.Tensor): The noise tensor.
|
| 717 |
+
positive (torch.Tensor): The positive tensor.
|
| 718 |
+
negative (torch.Tensor): The negative tensor.
|
| 719 |
+
cfg (float): The CFG value.
|
| 720 |
+
latent_image (torch.Tensor, optional): The latent image tensor. Defaults to None.
|
| 721 |
+
start_step (int, optional): The start step. Defaults to None.
|
| 722 |
+
last_step (int, optional): The last step. Defaults to None.
|
| 723 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
| 724 |
+
denoise_mask (torch.Tensor, optional): The denoise mask tensor. Defaults to None.
|
| 725 |
+
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
|
| 726 |
+
callback (callable, optional): The callback function. Defaults to None.
|
| 727 |
+
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
|
| 728 |
+
seed (int, optional): The seed value. Defaults to None.
|
| 729 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
| 730 |
+
|
| 731 |
+
Returns:
|
| 732 |
+
torch.Tensor: The sampled tensor.
|
| 733 |
"""
|
| 734 |
+
if self.model is None:
|
| 735 |
+
raise ValueError("Model must be provided for direct sampling")
|
| 736 |
+
|
| 737 |
if sigmas is None:
|
| 738 |
sigmas = self.sigmas
|
| 739 |
|
| 740 |
+
sigmas = self._process_sigmas(sigmas, start_step, last_step, force_full_denoise)
|
|
|
|
|
|
|
|
|
|
| 741 |
|
| 742 |
+
# Early return if needed
|
| 743 |
+
if start_step is not None and start_step >= (len(sigmas) - 1):
|
| 744 |
+
if latent_image is not None:
|
| 745 |
+
return latent_image
|
| 746 |
else:
|
| 747 |
+
return torch.zeros_like(noise)
|
|
|
|
|
|
|
|
|
|
| 748 |
|
| 749 |
+
sampler_obj = sampler_object(self.sampler_name, pipeline=self.pipeline)
|
| 750 |
|
| 751 |
return sample(
|
| 752 |
self.model,
|
|
|
|
| 755 |
negative,
|
| 756 |
cfg,
|
| 757 |
self.device,
|
| 758 |
+
sampler_obj,
|
| 759 |
sigmas,
|
| 760 |
self.model_options,
|
| 761 |
latent_image=latent_image,
|
|
|
|
| 763 |
callback=callback,
|
| 764 |
disable_pbar=disable_pbar,
|
| 765 |
seed=seed,
|
| 766 |
+
pipeline=self.pipeline,
|
| 767 |
flux=flux,
|
| 768 |
)
|
| 769 |
|
| 770 |
+
def sample(
|
| 771 |
+
self,
|
| 772 |
+
model: torch.nn.Module = None,
|
| 773 |
+
seed: int = None,
|
| 774 |
+
steps: int = None,
|
| 775 |
+
cfg: float = None,
|
| 776 |
+
sampler_name: str = None,
|
| 777 |
+
scheduler: str = None,
|
| 778 |
+
positive: torch.Tensor = None,
|
| 779 |
+
negative: torch.Tensor = None,
|
| 780 |
+
latent_image: torch.Tensor = None,
|
| 781 |
+
denoise: float = None,
|
| 782 |
+
start_step: int = None,
|
| 783 |
+
last_step: int = None,
|
| 784 |
+
force_full_denoise: bool = False,
|
| 785 |
+
noise_mask: torch.Tensor = None,
|
| 786 |
+
callback: callable = None,
|
| 787 |
+
disable_pbar: bool = False,
|
| 788 |
+
disable_noise: bool = False,
|
| 789 |
+
pipeline: bool = False,
|
| 790 |
+
flux: bool = False,
|
| 791 |
+
) -> tuple:
|
| 792 |
+
"""Unified sampling interface that works both as direct sampling and through the common_ksampler.
|
| 793 |
+
|
| 794 |
+
This method can be used in two ways:
|
| 795 |
+
1. If model is provided, it will create a temporary sampler and use that
|
| 796 |
+
2. If model is None, it will use the pre-initialized model and parameters
|
| 797 |
+
|
| 798 |
+
Args:
|
| 799 |
+
model (torch.nn.Module, optional): The model to use for sampling. If None, uses pre-initialized model.
|
| 800 |
+
seed (int, optional): The seed value.
|
| 801 |
+
steps (int, optional): The number of steps. If None, uses pre-initialized steps.
|
| 802 |
+
cfg (float, optional): The CFG value.
|
| 803 |
+
sampler_name (str, optional): The sampler name. If None, uses pre-initialized sampler.
|
| 804 |
+
scheduler (str, optional): The scheduler name. If None, uses pre-initialized scheduler.
|
| 805 |
+
positive (torch.Tensor, optional): The positive tensor.
|
| 806 |
+
negative (torch.Tensor, optional): The negative tensor.
|
| 807 |
+
latent_image (torch.Tensor, optional): The latent image tensor.
|
| 808 |
+
denoise (float, optional): The denoise factor. If None, uses pre-initialized denoise.
|
| 809 |
+
start_step (int, optional): The start step. Defaults to None.
|
| 810 |
+
last_step (int, optional): The last step. Defaults to None.
|
| 811 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
| 812 |
+
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
|
| 813 |
+
callback (callable, optional): The callback function. Defaults to None.
|
| 814 |
+
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
|
| 815 |
+
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
|
| 816 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
| 817 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
| 818 |
+
|
| 819 |
+
Returns:
|
| 820 |
+
tuple: The output tuple containing either (latent_dict,) or the sampled tensor.
|
| 821 |
+
"""
|
| 822 |
+
# Case 1: Use pre-initialized model for direct sampling
|
| 823 |
+
if model is None:
|
| 824 |
+
if latent_image is None:
|
| 825 |
+
raise ValueError(
|
| 826 |
+
"latent_image must be provided when using pre-initialized model"
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
return (
|
| 830 |
+
self.direct_sample(
|
| 831 |
+
None, # noise will be generated in common_ksampler
|
| 832 |
+
positive,
|
| 833 |
+
negative,
|
| 834 |
+
cfg,
|
| 835 |
+
latent_image,
|
| 836 |
+
start_step,
|
| 837 |
+
last_step,
|
| 838 |
+
force_full_denoise,
|
| 839 |
+
noise_mask,
|
| 840 |
+
None, # sigmas will use pre-calculated ones
|
| 841 |
+
callback,
|
| 842 |
+
disable_pbar,
|
| 843 |
+
seed,
|
| 844 |
+
flux,
|
| 845 |
+
),
|
| 846 |
+
)
|
| 847 |
+
|
| 848 |
+
# Case 2: Use common_ksampler approach with provided model
|
| 849 |
+
else:
|
| 850 |
+
# For backwards compatibility with KSampler2 usage pattern
|
| 851 |
+
if isinstance(latent_image, dict):
|
| 852 |
+
latent = latent_image
|
| 853 |
+
else:
|
| 854 |
+
latent = {"samples": latent_image}
|
| 855 |
|
| 856 |
+
return common_ksampler(
|
| 857 |
+
model,
|
| 858 |
+
seed,
|
| 859 |
+
steps,
|
| 860 |
+
cfg,
|
| 861 |
+
sampler_name or self.sampler_name,
|
| 862 |
+
scheduler or self.scheduler,
|
| 863 |
+
positive,
|
| 864 |
+
negative,
|
| 865 |
+
latent,
|
| 866 |
+
denoise or self.denoise,
|
| 867 |
+
disable_noise,
|
| 868 |
+
start_step,
|
| 869 |
+
last_step,
|
| 870 |
+
force_full_denoise,
|
| 871 |
+
pipeline or self.pipeline,
|
| 872 |
+
flux,
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
# Refactor sample1 to use KSampler directly
|
| 877 |
def sample1(
|
| 878 |
model: torch.nn.Module,
|
| 879 |
noise: torch.Tensor,
|
|
|
|
| 897 |
pipeline: bool = False,
|
| 898 |
flux: bool = False,
|
| 899 |
) -> torch.Tensor:
|
| 900 |
+
"""Sample using the given parameters with the unified KSampler.
|
| 901 |
+
|
| 902 |
+
Args:
|
| 903 |
+
model (torch.nn.Module): The model.
|
| 904 |
+
noise (torch.Tensor): The noise tensor.
|
| 905 |
+
steps (int): The number of steps.
|
| 906 |
+
cfg (float): The CFG value.
|
| 907 |
+
sampler_name (str): The sampler name.
|
| 908 |
+
scheduler (str): The scheduler name.
|
| 909 |
+
positive (torch.Tensor): The positive tensor.
|
| 910 |
+
negative (torch.Tensor): The negative tensor.
|
| 911 |
+
latent_image (torch.Tensor): The latent image tensor.
|
| 912 |
+
denoise (float, optional): The denoise factor. Defaults to 1.0.
|
| 913 |
+
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
|
| 914 |
+
start_step (int, optional): The start step. Defaults to None.
|
| 915 |
+
last_step (int, optional): The last step. Defaults to None.
|
| 916 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
| 917 |
+
noise_mask (torch.Tensor, optional): The noise mask tensor. Defaults to None.
|
| 918 |
+
sigmas (torch.Tensor, optional): The sigmas tensor. Defaults to None.
|
| 919 |
+
callback (callable, optional): The callback function. Defaults to None.
|
| 920 |
+
disable_pbar (bool, optional): Whether to disable the progress bar. Defaults to False.
|
| 921 |
+
seed (int, optional): The seed value. Defaults to None.
|
| 922 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
| 923 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
| 924 |
+
|
| 925 |
+
Returns:
|
| 926 |
+
torch.Tensor: The sampled tensor.
|
| 927 |
"""
|
| 928 |
+
sampler = KSampler(
|
| 929 |
+
model=model,
|
| 930 |
steps=steps,
|
|
|
|
| 931 |
sampler=sampler_name,
|
| 932 |
scheduler=scheduler,
|
| 933 |
denoise=denoise,
|
|
|
|
| 935 |
pipeline=pipeline,
|
| 936 |
)
|
| 937 |
|
| 938 |
+
samples = sampler.direct_sample(
|
| 939 |
noise,
|
| 940 |
positive,
|
| 941 |
negative,
|
|
|
|
| 949 |
callback=callback,
|
| 950 |
disable_pbar=disable_pbar,
|
| 951 |
seed=seed,
|
|
|
|
| 952 |
flux=flux,
|
| 953 |
)
|
| 954 |
samples = samples.to(Device.intermediate_device())
|
| 955 |
return samples
|
| 956 |
|
| 957 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
class ModelType(Enum):
|
| 959 |
"""#### Enum for Model Types."""
|
| 960 |
|
|
|
|
| 1047 |
)
|
| 1048 |
samples = samples.to(Device.intermediate_device())
|
| 1049 |
return samples
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
def common_ksampler(
|
| 1053 |
+
model: torch.nn.Module,
|
| 1054 |
+
seed: int,
|
| 1055 |
+
steps: int,
|
| 1056 |
+
cfg: float,
|
| 1057 |
+
sampler_name: str,
|
| 1058 |
+
scheduler: str,
|
| 1059 |
+
positive: torch.Tensor,
|
| 1060 |
+
negative: torch.Tensor,
|
| 1061 |
+
latent: dict,
|
| 1062 |
+
denoise: float = 1.0,
|
| 1063 |
+
disable_noise: bool = False,
|
| 1064 |
+
start_step: int = None,
|
| 1065 |
+
last_step: int = None,
|
| 1066 |
+
force_full_denoise: bool = False,
|
| 1067 |
+
pipeline: bool = False,
|
| 1068 |
+
flux: bool = False,
|
| 1069 |
+
) -> tuple:
|
| 1070 |
+
"""Common ksampler function.
|
| 1071 |
+
|
| 1072 |
+
Args:
|
| 1073 |
+
model (torch.nn.Module): The model.
|
| 1074 |
+
seed (int): The seed value.
|
| 1075 |
+
steps (int): The number of steps.
|
| 1076 |
+
cfg (float): The CFG value.
|
| 1077 |
+
sampler_name (str): The sampler name.
|
| 1078 |
+
scheduler (str): The scheduler name.
|
| 1079 |
+
positive (torch.Tensor): The positive tensor.
|
| 1080 |
+
negative (torch.Tensor): The negative tensor.
|
| 1081 |
+
latent (dict): The latent dictionary.
|
| 1082 |
+
denoise (float, optional): The denoise factor. Defaults to 1.0.
|
| 1083 |
+
disable_noise (bool, optional): Whether to disable noise. Defaults to False.
|
| 1084 |
+
start_step (int, optional): The start step. Defaults to None.
|
| 1085 |
+
last_step (int, optional): The last step. Defaults to None.
|
| 1086 |
+
force_full_denoise (bool, optional): Whether to force full denoise. Defaults to False.
|
| 1087 |
+
pipeline (bool, optional): Whether to use the pipeline. Defaults to False.
|
| 1088 |
+
flux (bool, optional): Whether to use flux mode. Defaults to False.
|
| 1089 |
+
|
| 1090 |
+
Returns:
|
| 1091 |
+
tuple: The output tuple containing the latent dictionary and samples.
|
| 1092 |
+
"""
|
| 1093 |
+
latent_image = latent["samples"]
|
| 1094 |
+
latent_image = Latent.fix_empty_latent_channels(model, latent_image)
|
| 1095 |
+
|
| 1096 |
+
if disable_noise:
|
| 1097 |
+
noise = torch.zeros(
|
| 1098 |
+
latent_image.size(),
|
| 1099 |
+
dtype=latent_image.dtype,
|
| 1100 |
+
layout=latent_image.layout,
|
| 1101 |
+
device="cpu",
|
| 1102 |
+
)
|
| 1103 |
+
else:
|
| 1104 |
+
batch_inds = latent["batch_index"] if "batch_index" in latent else None
|
| 1105 |
+
noise = ksampler_util.prepare_noise(latent_image, seed, batch_inds)
|
| 1106 |
+
|
| 1107 |
+
noise_mask = None
|
| 1108 |
+
if "noise_mask" in latent:
|
| 1109 |
+
noise_mask = latent["noise_mask"]
|
| 1110 |
+
samples = sample1(
|
| 1111 |
+
model,
|
| 1112 |
+
noise,
|
| 1113 |
+
steps,
|
| 1114 |
+
cfg,
|
| 1115 |
+
sampler_name,
|
| 1116 |
+
scheduler,
|
| 1117 |
+
positive,
|
| 1118 |
+
negative,
|
| 1119 |
+
latent_image,
|
| 1120 |
+
denoise=denoise,
|
| 1121 |
+
disable_noise=disable_noise,
|
| 1122 |
+
start_step=start_step,
|
| 1123 |
+
last_step=last_step,
|
| 1124 |
+
force_full_denoise=force_full_denoise,
|
| 1125 |
+
noise_mask=noise_mask,
|
| 1126 |
+
seed=seed,
|
| 1127 |
+
pipeline=pipeline,
|
| 1128 |
+
flux=flux,
|
| 1129 |
+
)
|
| 1130 |
+
out = latent.copy()
|
| 1131 |
+
out["samples"] = samples
|
| 1132 |
+
return (out,)
|
modules/user/GUI.py
CHANGED
|
@@ -449,7 +449,9 @@ class App(tk.Tk):
|
|
| 449 |
img_tensor = img_tensor.unsqueeze(0)
|
| 450 |
self.interrupt_flag = False
|
| 451 |
self.sampler = (
|
| 452 |
-
"
|
|
|
|
|
|
|
| 453 |
)
|
| 454 |
with torch.inference_mode():
|
| 455 |
(
|
|
@@ -612,7 +614,7 @@ class App(tk.Tk):
|
|
| 612 |
)
|
| 613 |
self.cliptextencode = Clip.CLIPTextEncode()
|
| 614 |
self.emptylatentimage = Latent.EmptyLatentImage()
|
| 615 |
-
self.ksampler_instance = sampling.
|
| 616 |
self.vaedecode = VariationalAE.VAEDecode()
|
| 617 |
self.latent_upscale = upscale.LatentUpscale()
|
| 618 |
self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader()
|
|
@@ -637,7 +639,9 @@ class App(tk.Tk):
|
|
| 637 |
self.generation_threads.append(current_thread)
|
| 638 |
self.interrupt_flag = False
|
| 639 |
self.sampler = (
|
| 640 |
-
"
|
|
|
|
|
|
|
| 641 |
)
|
| 642 |
try:
|
| 643 |
# Disable generate button during generation
|
|
@@ -955,7 +959,7 @@ class App(tk.Tk):
|
|
| 955 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
| 956 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
| 957 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
| 958 |
-
ksampler = sampling.
|
| 959 |
vaedecode = VariationalAE.VAEDecode()
|
| 960 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
| 961 |
unet_name="flux1-dev-Q8_0.gguf"
|
|
|
|
| 449 |
img_tensor = img_tensor.unsqueeze(0)
|
| 450 |
self.interrupt_flag = False
|
| 451 |
self.sampler = (
|
| 452 |
+
"dpmpp_sde_cfgpp"
|
| 453 |
+
if not self.prioritize_speed_var.get()
|
| 454 |
+
else "dpmpp_2m_cfgpp"
|
| 455 |
)
|
| 456 |
with torch.inference_mode():
|
| 457 |
(
|
|
|
|
| 614 |
)
|
| 615 |
self.cliptextencode = Clip.CLIPTextEncode()
|
| 616 |
self.emptylatentimage = Latent.EmptyLatentImage()
|
| 617 |
+
self.ksampler_instance = sampling.KSampler()
|
| 618 |
self.vaedecode = VariationalAE.VAEDecode()
|
| 619 |
self.latent_upscale = upscale.LatentUpscale()
|
| 620 |
self.upscalemodelloader = USDU_upscaler.UpscaleModelLoader()
|
|
|
|
| 639 |
self.generation_threads.append(current_thread)
|
| 640 |
self.interrupt_flag = False
|
| 641 |
self.sampler = (
|
| 642 |
+
"dpmpp_sde_cfgpp"
|
| 643 |
+
if not self.prioritize_speed_var.get()
|
| 644 |
+
else "dpmpp_2m_cfgpp"
|
| 645 |
)
|
| 646 |
try:
|
| 647 |
# Disable generate button during generation
|
|
|
|
| 959 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
| 960 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
| 961 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
| 962 |
+
ksampler = sampling.KSampler()
|
| 963 |
vaedecode = VariationalAE.VAEDecode()
|
| 964 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
| 965 |
unet_name="flux1-dev-Q8_0.gguf"
|
modules/user/pipeline.py
CHANGED
|
@@ -92,7 +92,7 @@ def pipeline(
|
|
| 92 |
hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple()
|
| 93 |
cliptextencode = Clip.CLIPTextEncode()
|
| 94 |
emptylatentimage = Latent.EmptyLatentImage()
|
| 95 |
-
ksampler_instance = sampling.
|
| 96 |
vaedecode = VariationalAE.VAEDecode()
|
| 97 |
saveimage = ImageSaver.SaveImage()
|
| 98 |
latent_upscale = upscale.LatentUpscale()
|
|
@@ -187,7 +187,7 @@ def pipeline(
|
|
| 187 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
| 188 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
| 189 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
| 190 |
-
ksampler = sampling.
|
| 191 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
| 192 |
unet_name="flux1-dev-Q8_0.gguf"
|
| 193 |
)
|
|
@@ -283,10 +283,10 @@ def pipeline(
|
|
| 283 |
)
|
| 284 |
else:
|
| 285 |
applystablefast_158 = loraloader_274
|
| 286 |
-
fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
|
| 287 |
-
applystablefast_158 = fb_cache.patch(
|
| 288 |
-
|
| 289 |
-
)
|
| 290 |
|
| 291 |
ksampler_239 = ksampler_instance.sample(
|
| 292 |
seed=seed,
|
|
|
|
| 92 |
hidiffoptimizer = msw_msa_attention.ApplyMSWMSAAttentionSimple()
|
| 93 |
cliptextencode = Clip.CLIPTextEncode()
|
| 94 |
emptylatentimage = Latent.EmptyLatentImage()
|
| 95 |
+
ksampler_instance = sampling.KSampler()
|
| 96 |
vaedecode = VariationalAE.VAEDecode()
|
| 97 |
saveimage = ImageSaver.SaveImage()
|
| 98 |
latent_upscale = upscale.LatentUpscale()
|
|
|
|
| 187 |
unetloadergguf = Quantizer.UnetLoaderGGUF()
|
| 188 |
cliptextencodeflux = Quantizer.CLIPTextEncodeFlux()
|
| 189 |
conditioningzeroout = Quantizer.ConditioningZeroOut()
|
| 190 |
+
ksampler = sampling.KSampler()
|
| 191 |
unetloadergguf_10 = unetloadergguf.load_unet(
|
| 192 |
unet_name="flux1-dev-Q8_0.gguf"
|
| 193 |
)
|
|
|
|
| 283 |
)
|
| 284 |
else:
|
| 285 |
applystablefast_158 = loraloader_274
|
| 286 |
+
# fb_cache = fbcache_nodes.ApplyFBCacheOnModel()
|
| 287 |
+
# applystablefast_158 = fb_cache.patch(
|
| 288 |
+
# applystablefast_158, "diffusion_model", 0.120
|
| 289 |
+
# )
|
| 290 |
|
| 291 |
ksampler_239 = ksampler_instance.sample(
|
| 292 |
seed=seed,
|