diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/common/cache.py b/common/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..89592fe8747a0b68b8553729abe908c6f06a5aa5 --- /dev/null +++ b/common/cache.py @@ -0,0 +1,47 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Callable + + +class Cache: + """Caching reusable args for faster inference""" + + def __init__(self, disable=False, prefix="", cache=None): + self.cache = cache if cache is not None else {} + self.disable = disable + self.prefix = prefix + + def __call__(self, key: str, fn: Callable): + if self.disable: + return fn() + + key = self.prefix + key + try: + result = self.cache[key] + except KeyError: + result = fn() + self.cache[key] = result + return result + + def namespace(self, namespace: str): + return Cache( + disable=self.disable, + prefix=self.prefix + namespace + ".", + cache=self.cache, + ) + + def get(self, key: str): + key = self.prefix + key + return self.cache[key] diff --git a/common/config.py b/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f963e8229b8352ef514422609bcbaf9b8c761b15 --- /dev/null +++ b/common/config.py @@ -0,0 +1,110 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Configuration utility functions +""" + +import importlib +from typing import Any, Callable, List, Union +from omegaconf import DictConfig, ListConfig, OmegaConf + +OmegaConf.register_new_resolver("eval", eval) + + +def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]: + """ + Load a configuration. Will resolve inheritance. + """ + config = OmegaConf.load(path) + if argv is not None: + config_argv = OmegaConf.from_dotlist(argv) + config = OmegaConf.merge(config, config_argv) + config = resolve_recursive(config, resolve_inheritance) + return config + + +def resolve_recursive( + config: Any, + resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]], +) -> Any: + config = resolver(config) + if isinstance(config, DictConfig): + for k in config.keys(): + v = config.get(k) + if isinstance(v, (DictConfig, ListConfig)): + config[k] = resolve_recursive(v, resolver) + if isinstance(config, ListConfig): + for i in range(len(config)): + v = config.get(i) + if isinstance(v, (DictConfig, ListConfig)): + config[i] = resolve_recursive(v, resolver) + return config + + +def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any: + """ + Recursively resolve inheritance if the config contains: + __inherit__: path/to/parent.yaml or a ListConfig of such paths. + """ + if isinstance(config, DictConfig): + inherit = config.pop("__inherit__", None) + + if inherit: + inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit] + + parent_config = None + for parent_path in inherit_list: + assert isinstance(parent_path, str) + parent_config = ( + load_config(parent_path) + if parent_config is None + else OmegaConf.merge(parent_config, load_config(parent_path)) + ) + + if len(config.keys()) > 0: + config = OmegaConf.merge(parent_config, config) + else: + config = parent_config + return config + + +def import_item(path: str, name: str) -> Any: + """ + Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass + """ + return getattr(importlib.import_module(path), name) + + +def create_object(config: DictConfig) -> Any: + """ + Create an object from config. + The config is expected to contains the following: + __object__: + path: path.to.module + name: MyClass + args: as_config | as_params (default to as_config) + """ + item = import_item( + path=config.__object__.path, + name=config.__object__.name, + ) + args = config.__object__.get("args", "as_config") + if args == "as_config": + return item(config) + if args == "as_params": + config = OmegaConf.to_object(config) + config.pop("__object__") + return item(**config) + raise NotImplementedError(f"Unknown args type: {args}") \ No newline at end of file diff --git a/common/decorators.py b/common/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..332a32d7b838cf7f8be902b9ae4895bad5edcd2e --- /dev/null +++ b/common/decorators.py @@ -0,0 +1,147 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Decorators. +""" + +import functools +import threading +import time +from typing import Callable +import torch + +from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank +from common.logger import get_logger + +logger = get_logger(__name__) + + +def log_on_entry(func: Callable) -> Callable: + """ + Functions with this decorator will log the function name at entry. + When using multiple decorators, this must be applied innermost to properly capture the name. + """ + + def log_on_entry_wrapper(*args, **kwargs): + logger.info(f"Entering {func.__name__}") + return func(*args, **kwargs) + + return log_on_entry_wrapper + + +def barrier_on_entry(func: Callable) -> Callable: + """ + Functions with this decorator will start executing when all ranks are ready to enter. + """ + + def barrier_on_entry_wrapper(*args, **kwargs): + barrier_if_distributed() + return func(*args, **kwargs) + + return barrier_on_entry_wrapper + + +def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable: + """ + Helper function for local_rank_zero_only and global_rank_zero_only. + """ + + def conditional_execute_wrapper(*args, **kwargs): + # Only execute if needed. + result = func(*args, **kwargs) if execute else None + # All GPUs must wait. + barrier_if_distributed() + # Return results. + return result + + return conditional_execute_wrapper + + +def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable: + """ + Helper function for some functions with special constraints, + especially functions called by other global_rank_zero_only / local_rank_zero_only ones, + in case they are wrongly invoked in other scenarios. + """ + + def asserted_execute_wrapper(*args, **kwargs): + assert condition, err_msg + result = func(*args, **kwargs) + return result + + return asserted_execute_wrapper + + +def local_rank_zero_only(func: Callable) -> Callable: + """ + Functions with this decorator will only execute on local rank zero. + """ + return _conditional_execute_wrapper_factory(get_local_rank() == 0, func) + + +def global_rank_zero_only(func: Callable) -> Callable: + """ + Functions with this decorator will only execute on global rank zero. + """ + return _conditional_execute_wrapper_factory(get_global_rank() == 0, func) + + +def assert_only_global_rank_zero(func: Callable) -> Callable: + """ + Functions with this decorator are only accessible to processes with global rank zero. + """ + return _asserted_wrapper_factory( + get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0" + ) + + +def assert_only_local_rank_zero(func: Callable) -> Callable: + """ + Functions with this decorator are only accessible to processes with local rank zero. + """ + return _asserted_wrapper_factory( + get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0" + ) + + +def new_thread(func: Callable) -> Callable: + """ + Functions with this decorator will run in a new thread. + The function will return the thread, which can be joined to wait for completion. + """ + + def new_thread_wrapper(*args, **kwargs): + thread = threading.Thread(target=func, args=args, kwargs=kwargs) + thread.start() + return thread + + return new_thread_wrapper + + +def log_runtime(func: Callable) -> Callable: + """ + Functions with this decorator will logging the runtime. + """ + + @functools.wraps(func) + def wrapped(*args, **kwargs): + torch.distributed.barrier() + start = time.perf_counter() + result = func(*args, **kwargs) + torch.distributed.barrier() + logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.") + return result + + return wrapped diff --git a/common/diffusion/__init__.py b/common/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..034e36ef7f9eb0b3ae94280165e622a362e9fc1e --- /dev/null +++ b/common/diffusion/__init__.py @@ -0,0 +1,56 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Diffusion package. +""" + +from .config import ( + create_sampler_from_config, + create_sampling_timesteps_from_config, + create_schedule_from_config, +) +from .samplers.base import Sampler +from .samplers.euler import EulerSampler +from .schedules.base import Schedule +from .schedules.lerp import LinearInterpolationSchedule +from .timesteps.base import SamplingTimesteps, Timesteps +from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps +from .types import PredictionType, SamplingDirection +from .utils import classifier_free_guidance, classifier_free_guidance_dispatcher, expand_dims + +__all__ = [ + # Configs + "create_sampler_from_config", + "create_sampling_timesteps_from_config", + "create_schedule_from_config", + # Schedules + "Schedule", + "DiscreteVariancePreservingSchedule", + "LinearInterpolationSchedule", + # Samplers + "Sampler", + "EulerSampler", + # Timesteps + "Timesteps", + "SamplingTimesteps", + # Types + "PredictionType", + "SamplingDirection", + "UniformTrailingSamplingTimesteps", + # Utils + "classifier_free_guidance", + "classifier_free_guidance_dispatcher", + "expand_dims", +] diff --git a/common/diffusion/config.py b/common/diffusion/config.py new file mode 100644 index 0000000000000000000000000000000000000000..f1d0468d88b5dd5f0d787c75ed3df06742d0a483 --- /dev/null +++ b/common/diffusion/config.py @@ -0,0 +1,74 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Utility functions for creating schedules and samplers from config. +""" + +import torch +from omegaconf import DictConfig + +from .samplers.base import Sampler +from .samplers.euler import EulerSampler +from .schedules.base import Schedule +from .schedules.lerp import LinearInterpolationSchedule +from .timesteps.base import SamplingTimesteps +from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps + + +def create_schedule_from_config( + config: DictConfig, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> Schedule: + """ + Create a schedule from configuration. + """ + if config.type == "lerp": + return LinearInterpolationSchedule(T=config.get("T", 1.0)) + + raise NotImplementedError + + +def create_sampler_from_config( + config: DictConfig, + schedule: Schedule, + timesteps: SamplingTimesteps, +) -> Sampler: + """ + Create a sampler from configuration. + """ + if config.type == "euler": + return EulerSampler( + schedule=schedule, + timesteps=timesteps, + prediction_type=config.prediction_type, + ) + raise NotImplementedError + + +def create_sampling_timesteps_from_config( + config: DictConfig, + schedule: Schedule, + device: torch.device, + dtype: torch.dtype = torch.float32, +) -> SamplingTimesteps: + if config.type == "uniform_trailing": + return UniformTrailingSamplingTimesteps( + T=schedule.T, + steps=config.steps, + shift=config.get("shift", 1.0), + device=device, + ) + raise NotImplementedError \ No newline at end of file diff --git a/common/diffusion/samplers/base.py b/common/diffusion/samplers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8e65f19896b6d5844e769762e76d699b96abc733 --- /dev/null +++ b/common/diffusion/samplers/base.py @@ -0,0 +1,108 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Sampler base class. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable +import torch +from tqdm import tqdm + +from ..schedules.base import Schedule +from ..timesteps.base import SamplingTimesteps +from ..types import PredictionType, SamplingDirection +from ..utils import assert_schedule_timesteps_compatible + + +@dataclass +class SamplerModelArgs: + x_t: torch.Tensor + t: torch.Tensor + i: int + + +class Sampler(ABC): + """ + Samplers are ODE/SDE solvers. + """ + + def __init__( + self, + schedule: Schedule, + timesteps: SamplingTimesteps, + prediction_type: PredictionType, + return_endpoint: bool = True, + ): + assert_schedule_timesteps_compatible( + schedule=schedule, + timesteps=timesteps, + ) + self.schedule = schedule + self.timesteps = timesteps + self.prediction_type = prediction_type + self.return_endpoint = return_endpoint + + @abstractmethod + def sample( + self, + x: torch.Tensor, + f: Callable[[SamplerModelArgs], torch.Tensor], + ) -> torch.Tensor: + """ + Generate a new sample given the the intial sample x and score function f. + """ + + def get_next_timestep( + self, + t: torch.Tensor, + ) -> torch.Tensor: + """ + Get the next sample timestep. + Support multiple different timesteps t in a batch. + If no more steps, return out of bound value -1 or T+1. + """ + T = self.timesteps.T + steps = len(self.timesteps) + curr_idx = self.timesteps.index(t) + next_idx = curr_idx + 1 + bound = -1 if self.timesteps.direction == SamplingDirection.backward else T + 1 + + s = self.timesteps[next_idx.clamp_max(steps - 1)] + s = s.where(next_idx < steps, bound) + return s + + def get_endpoint( + self, + pred: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + """ + Get to the endpoint of the probability flow. + """ + x_0, x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) + return x_0 if self.timesteps.direction == SamplingDirection.backward else x_T + + def get_progress_bar(self): + """ + Get progress bar for sampling. + """ + return tqdm( + iterable=range(len(self.timesteps) - (0 if self.return_endpoint else 1)), + dynamic_ncols=True, + desc=self.__class__.__name__, + ) diff --git a/common/diffusion/samplers/euler.py b/common/diffusion/samplers/euler.py new file mode 100644 index 0000000000000000000000000000000000000000..5994979a43658b7ebb75316cefea737d1c54681b --- /dev/null +++ b/common/diffusion/samplers/euler.py @@ -0,0 +1,89 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + + +""" +Euler ODE solver. +""" + +from typing import Callable +import torch +from einops import rearrange +from torch.nn import functional as F + +from models.dit_v2 import na + +from ..types import PredictionType +from ..utils import expand_dims +from .base import Sampler, SamplerModelArgs + + +class EulerSampler(Sampler): + """ + The Euler method is the simplest ODE solver. + + """ + + def sample( + self, + x: torch.Tensor, + f: Callable[[SamplerModelArgs], torch.Tensor], + ) -> torch.Tensor: + timesteps = self.timesteps.timesteps + progress = self.get_progress_bar() + i = 0 + for t, s in zip(timesteps[:-1], timesteps[1:]): + pred = f(SamplerModelArgs(x, t, i)) + x = self.step_to(pred, x, t, s) + i += 1 + progress.update() + + if self.return_endpoint: + t = timesteps[-1] + pred = f(SamplerModelArgs(x, t, i)) + x = self.get_endpoint(pred, x, t) + progress.update() + return x + + def step( + self, + pred: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + ) -> torch.Tensor: + """ + Step to the next timestep. + """ + return self.step_to(pred, x_t, t, self.get_next_timestep(t)) + + def step_to( + self, + pred: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + s: torch.Tensor, + ) -> torch.Tensor: + """ + Steps from x_t at timestep t to x_s at timestep s. Returns x_s. + """ + t = expand_dims(t, x_t.ndim) + s = expand_dims(s, x_t.ndim) + T = self.schedule.T + # Step from x_t to x_s. + pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) + pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T)) + # Clamp x_s to x_0 and x_T if s is out of bound. + pred_x_s = pred_x_s.where(s >= 0, pred_x_0) + pred_x_s = pred_x_s.where(s <= T, pred_x_T) + return pred_x_s diff --git a/common/diffusion/schedules/base.py b/common/diffusion/schedules/base.py new file mode 100644 index 0000000000000000000000000000000000000000..bcf6c6b6460977c6e2687e225c5c913a928bf812 --- /dev/null +++ b/common/diffusion/schedules/base.py @@ -0,0 +1,131 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Schedule base class. +""" + +from abc import ABC, abstractmethod, abstractproperty +from typing import Tuple, Union +import torch + +from ..types import PredictionType +from ..utils import expand_dims + + +class Schedule(ABC): + """ + Diffusion schedules are uniquely defined by T, A, B: + + x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T] + + Schedules can be continuous or discrete. + """ + + @abstractproperty + def T(self) -> Union[int, float]: + """ + Maximum timestep inclusive. + Schedule is continuous if float, discrete if int. + """ + + @abstractmethod + def A(self, t: torch.Tensor) -> torch.Tensor: + """ + Interpolation coefficient A. + Returns tensor with the same shape as t. + """ + + @abstractmethod + def B(self, t: torch.Tensor) -> torch.Tensor: + """ + Interpolation coefficient B. + Returns tensor with the same shape as t. + """ + + # ---------------------------------------------------- + + def snr(self, t: torch.Tensor) -> torch.Tensor: + """ + Signal to noise ratio. + Returns tensor with the same shape as t. + """ + return (self.A(t) ** 2) / (self.B(t) ** 2) + + def isnr(self, snr: torch.Tensor) -> torch.Tensor: + """ + Inverse signal to noise ratio. + Returns tensor with the same shape as snr. + Subclass may implement. + """ + raise NotImplementedError + + # ---------------------------------------------------- + + def is_continuous(self) -> bool: + """ + Whether the schedule is continuous. + """ + return isinstance(self.T, float) + + def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Diffusion forward function. + """ + t = expand_dims(t, x_0.ndim) + return self.A(t) * x_0 + self.B(t) * x_T + + def convert_from_pred( + self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert from prediction. Return predicted x_0 and x_T. + """ + t = expand_dims(t, x_t.ndim) + A_t = self.A(t) + B_t = self.B(t) + + if pred_type == PredictionType.x_T: + pred_x_T = pred + pred_x_0 = (x_t - B_t * pred_x_T) / A_t + elif pred_type == PredictionType.x_0: + pred_x_0 = pred + pred_x_T = (x_t - A_t * pred_x_0) / B_t + elif pred_type == PredictionType.v_cos: + pred_x_0 = A_t * x_t - B_t * pred + pred_x_T = A_t * pred + B_t * x_t + elif pred_type == PredictionType.v_lerp: + pred_x_0 = (x_t - B_t * pred) / (A_t + B_t) + pred_x_T = (x_t + A_t * pred) / (A_t + B_t) + else: + raise NotImplementedError + + return pred_x_0, pred_x_T + + def convert_to_pred( + self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType + ) -> torch.FloatTensor: + """ + Convert to prediction target given x_0 and x_T. + """ + if pred_type == PredictionType.x_T: + return x_T + if pred_type == PredictionType.x_0: + return x_0 + if pred_type == PredictionType.v_cos: + t = expand_dims(t, x_0.ndim) + return self.A(t) * x_T - self.B(t) * x_0 + if pred_type == PredictionType.v_lerp: + return x_T - x_0 + raise NotImplementedError diff --git a/common/diffusion/schedules/lerp.py b/common/diffusion/schedules/lerp.py new file mode 100644 index 0000000000000000000000000000000000000000..56b42bc17538b3217b2209234fc723ac3f58a746 --- /dev/null +++ b/common/diffusion/schedules/lerp.py @@ -0,0 +1,55 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Linear interpolation schedule (lerp). +""" + +from typing import Union +import torch + +from .base import Schedule + + +class LinearInterpolationSchedule(Schedule): + """ + Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow. + It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3. + + + + x_t = (1 - t) * x_0 + t * x_T + + Can be either continuous or discrete. + """ + + def __init__(self, T: Union[int, float] = 1.0): + self._T = T + + @property + def T(self) -> Union[int, float]: + return self._T + + def A(self, t: torch.Tensor) -> torch.Tensor: + return 1 - (t / self.T) + + def B(self, t: torch.Tensor) -> torch.Tensor: + return t / self.T + + # ---------------------------------------------------- + + def isnr(self, snr: torch.Tensor) -> torch.Tensor: + t = self.T / (1 + snr**0.5) + t = t if self.is_continuous() else t.round().int() + return t diff --git a/common/diffusion/timesteps/base.py b/common/diffusion/timesteps/base.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a598103547694d5ef4dc5db0be1e5be2deb60c --- /dev/null +++ b/common/diffusion/timesteps/base.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from typing import Sequence, Union +import torch + +from ..types import SamplingDirection + + +class Timesteps(ABC): + """ + Timesteps base class. + """ + + def __init__(self, T: Union[int, float]): + assert T > 0 + self._T = T + + @property + def T(self) -> Union[int, float]: + """ + Maximum timestep inclusive. + int if discrete, float if continuous. + """ + return self._T + + def is_continuous(self) -> bool: + """ + Whether the schedule is continuous. + """ + return isinstance(self.T, float) + + +class SamplingTimesteps(Timesteps): + """ + Sampling timesteps. + It defines the discretization of sampling steps. + """ + + def __init__( + self, + T: Union[int, float], + timesteps: torch.Tensor, + direction: SamplingDirection, + ): + assert timesteps.ndim == 1 + super().__init__(T) + self.timesteps = timesteps + self.direction = direction + + def __len__(self) -> int: + """ + Number of sampling steps. + """ + return len(self.timesteps) + + def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: + """ + The timestep at the sampling step. + Returns a scalar tensor if idx is int, + or tensor of the same size if idx is a tensor. + """ + return self.timesteps[idx] + + def index(self, t: torch.Tensor) -> torch.Tensor: + """ + Find index by t. + Return index of the same shape as t. + Index is -1 if t not found in timesteps. + """ + i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) + idx = torch.full_like(t, fill_value=-1, dtype=torch.int) + idx.view(-1)[i] = j.int() + return idx diff --git a/common/diffusion/timesteps/sampling/trailing.py b/common/diffusion/timesteps/sampling/trailing.py new file mode 100644 index 0000000000000000000000000000000000000000..248d986aedaaff8f417c32a42e9d9e3a61012f58 --- /dev/null +++ b/common/diffusion/timesteps/sampling/trailing.py @@ -0,0 +1,49 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import torch + +from ...types import SamplingDirection +from ..base import SamplingTimesteps + + +class UniformTrailingSamplingTimesteps(SamplingTimesteps): + """ + Uniform trailing sampling timesteps. + Defined in (https://arxiv.org/abs/2305.08891) + + Shift is proposed in SD3 for RF schedule. + Defined in (https://arxiv.org/pdf/2403.03206) eq.23 + """ + + def __init__( + self, + T: int, + steps: int, + shift: float = 1.0, + device: torch.device = "cpu", + ): + # Create trailing timesteps. + timesteps = torch.arange(1.0, 0.0, -1.0 / steps, device=device) + + # Shift timesteps. + timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) + + # Scale to T range. + if isinstance(T, float): + timesteps = timesteps * T + else: + timesteps = timesteps.mul(T + 1).sub(1).round().int() + + super().__init__(T=T, timesteps=timesteps, direction=SamplingDirection.backward) diff --git a/common/diffusion/types.py b/common/diffusion/types.py new file mode 100644 index 0000000000000000000000000000000000000000..076295f2be24dadc79da20a5f335b391eb9543bb --- /dev/null +++ b/common/diffusion/types.py @@ -0,0 +1,59 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Type definitions. +""" + +from enum import Enum + + +class PredictionType(str, Enum): + """ + x_0: + Predict data sample. + x_T: + Predict noise sample. + Proposed by DDPM (https://arxiv.org/abs/2006.11239) + Proved problematic by zsnr paper (https://arxiv.org/abs/2305.08891) + v_cos: + Predict velocity dx/dt based on the cosine schedule (A_t * x_T - B_t * x_0). + Proposed by progressive distillation (https://arxiv.org/abs/2202.00512) + v_lerp: + Predict velocity dx/dt based on the lerp schedule (x_T - x_0). + Proposed by rectified flow (https://arxiv.org/abs/2209.03003) + """ + + x_0 = "x_0" + x_T = "x_T" + v_cos = "v_cos" + v_lerp = "v_lerp" + + +class SamplingDirection(str, Enum): + """ + backward: Sample from x_T to x_0 for data generation. + forward: Sample from x_0 to x_T for noise inversion. + """ + + backward = "backward" + forward = "forward" + + @staticmethod + def reverse(direction): + if direction == SamplingDirection.backward: + return SamplingDirection.forward + if direction == SamplingDirection.forward: + return SamplingDirection.backward + raise NotImplementedError diff --git a/common/diffusion/utils.py b/common/diffusion/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..69d4aec34f59b293e2354744a4329008063a30e3 --- /dev/null +++ b/common/diffusion/utils.py @@ -0,0 +1,84 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Utility functions. +""" + +from typing import Callable +import torch + + +def expand_dims(tensor: torch.Tensor, ndim: int): + """ + Expand tensor to target ndim. New dims are added to the right. + For example, if the tensor shape was (8,), target ndim is 4, return (8, 1, 1, 1). + """ + shape = tensor.shape + (1,) * (ndim - tensor.ndim) + return tensor.reshape(shape) + + +def assert_schedule_timesteps_compatible(schedule, timesteps): + """ + Check if schedule and timesteps are compatible. + """ + if schedule.T != timesteps.T: + raise ValueError("Schedule and timesteps must have the same T.") + if schedule.is_continuous() != timesteps.is_continuous(): + raise ValueError("Schedule and timesteps must have the same continuity.") + + +def classifier_free_guidance( + pos: torch.Tensor, + neg: torch.Tensor, + scale: float, + rescale: float = 0.0, +): + """ + Apply classifier-free guidance. + """ + # Classifier-free guidance (https://arxiv.org/abs/2207.12598) + cfg = neg + scale * (pos - neg) + + # Classifier-free guidance rescale (https://arxiv.org/pdf/2305.08891.pdf) + if rescale != 0.0: + pos_std = pos.std(dim=list(range(1, pos.ndim)), keepdim=True) + cfg_std = cfg.std(dim=list(range(1, cfg.ndim)), keepdim=True) + factor = pos_std / cfg_std + factor = rescale * factor + (1 - rescale) + cfg *= factor + + return cfg + + +def classifier_free_guidance_dispatcher( + pos: Callable, + neg: Callable, + scale: float, + rescale: float = 0.0, +): + """ + Optionally execute models depending on classifer-free guidance scale. + """ + # If scale is 1, no need to execute neg model. + if scale == 1.0: + return pos() + + # Otherwise, execute both pos nad neg models and apply cfg. + return classifier_free_guidance( + pos=pos(), + neg=neg(), + scale=scale, + rescale=rescale, + ) diff --git a/common/distributed/__init__.py b/common/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b4f873ae3e5524c88942bb27ec98ac98c3b5b5 --- /dev/null +++ b/common/distributed/__init__.py @@ -0,0 +1,37 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Distributed package. +""" + +from .basic import ( + barrier_if_distributed, + convert_to_ddp, + get_device, + get_global_rank, + get_local_rank, + get_world_size, + init_torch, +) + +__all__ = [ + "barrier_if_distributed", + "convert_to_ddp", + "get_device", + "get_global_rank", + "get_local_rank", + "get_world_size", + "init_torch", +] diff --git a/common/distributed/advanced.py b/common/distributed/advanced.py new file mode 100644 index 0000000000000000000000000000000000000000..f55fe20ab45494c96124b072d628273d49def1fa --- /dev/null +++ b/common/distributed/advanced.py @@ -0,0 +1,208 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Advanced distributed functions for sequence parallel. +""" + +from typing import Optional, List +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.fsdp import ShardingStrategy + +from .basic import get_global_rank, get_world_size + + +_DATA_PARALLEL_GROUP = None +_SEQUENCE_PARALLEL_GROUP = None +_SEQUENCE_PARALLEL_CPU_GROUP = None +_MODEL_SHARD_CPU_INTER_GROUP = None +_MODEL_SHARD_CPU_INTRA_GROUP = None +_MODEL_SHARD_INTER_GROUP = None +_MODEL_SHARD_INTRA_GROUP = None +_SEQUENCE_PARALLEL_GLOBAL_RANKS = None + + +def get_data_parallel_group() -> Optional[dist.ProcessGroup]: + """ + Get data parallel process group. + """ + return _DATA_PARALLEL_GROUP + + +def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]: + """ + Get sequence parallel process group. + """ + return _SEQUENCE_PARALLEL_GROUP + + +def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]: + """ + Get sequence parallel CPU process group. + """ + return _SEQUENCE_PARALLEL_CPU_GROUP + + +def get_data_parallel_rank() -> int: + """ + Get data parallel rank. + """ + group = get_data_parallel_group() + return dist.get_rank(group) if group else get_global_rank() + + +def get_data_parallel_world_size() -> int: + """ + Get data parallel world size. + """ + group = get_data_parallel_group() + return dist.get_world_size(group) if group else get_world_size() + + +def get_sequence_parallel_rank() -> int: + """ + Get sequence parallel rank. + """ + group = get_sequence_parallel_group() + return dist.get_rank(group) if group else 0 + + +def get_sequence_parallel_world_size() -> int: + """ + Get sequence parallel world size. + """ + group = get_sequence_parallel_group() + return dist.get_world_size(group) if group else 1 + + +def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]: + """ + Get the CPU intra process group of model sharding. + """ + return _MODEL_SHARD_CPU_INTRA_GROUP + + +def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]: + """ + Get the CPU inter process group of model sharding. + """ + return _MODEL_SHARD_CPU_INTER_GROUP + + +def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]: + """ + Get the GPU intra process group of model sharding. + """ + return _MODEL_SHARD_INTRA_GROUP + + +def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]: + """ + Get the GPU inter process group of model sharding. + """ + return _MODEL_SHARD_INTER_GROUP + + +def init_sequence_parallel(sequence_parallel_size: int): + """ + Initialize sequence parallel. + """ + global _DATA_PARALLEL_GROUP + global _SEQUENCE_PARALLEL_GROUP + global _SEQUENCE_PARALLEL_CPU_GROUP + global _SEQUENCE_PARALLEL_GLOBAL_RANKS + assert dist.is_initialized() + world_size = dist.get_world_size() + rank = dist.get_rank() + data_parallel_size = world_size // sequence_parallel_size + for i in range(data_parallel_size): + start_rank = i * sequence_parallel_size + end_rank = (i + 1) * sequence_parallel_size + ranks = range(start_rank, end_rank) + group = dist.new_group(ranks) + cpu_group = dist.new_group(ranks, backend="gloo") + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + _SEQUENCE_PARALLEL_CPU_GROUP = cpu_group + _SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks) + + +def init_model_shard_group( + *, + sharding_strategy: ShardingStrategy, + device_mesh: Optional[DeviceMesh] = None, +): + """ + Initialize process group of model sharding. + """ + global _MODEL_SHARD_INTER_GROUP + global _MODEL_SHARD_INTRA_GROUP + global _MODEL_SHARD_CPU_INTER_GROUP + global _MODEL_SHARD_CPU_INTRA_GROUP + assert dist.is_initialized() + world_size = dist.get_world_size() + if device_mesh is not None: + num_shards_per_group = device_mesh.shape[1] + elif sharding_strategy == ShardingStrategy.NO_SHARD: + num_shards_per_group = 1 + elif sharding_strategy in [ + ShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2, + ]: + num_shards_per_group = torch.cuda.device_count() + else: + num_shards_per_group = world_size + num_groups = world_size // num_shards_per_group + device_mesh = (num_groups, num_shards_per_group) + + gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra")) + cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra")) + + _MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter") + _MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra") + _MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter") + _MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra") + +def get_sequence_parallel_global_ranks() -> List[int]: + """ + Get all global ranks of the sequence parallel process group + that the caller rank belongs to. + """ + if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None: + return [dist.get_rank()] + return _SEQUENCE_PARALLEL_GLOBAL_RANKS + + +def get_next_sequence_parallel_rank() -> int: + """ + Get the next global rank of the sequence parallel process group + that the caller rank belongs to. + """ + sp_global_ranks = get_sequence_parallel_global_ranks() + sp_rank = get_sequence_parallel_rank() + sp_size = get_sequence_parallel_world_size() + return sp_global_ranks[(sp_rank + 1) % sp_size] + + +def get_prev_sequence_parallel_rank() -> int: + """ + Get the previous global rank of the sequence parallel process group + that the caller rank belongs to. + """ + sp_global_ranks = get_sequence_parallel_global_ranks() + sp_rank = get_sequence_parallel_rank() + sp_size = get_sequence_parallel_world_size() + return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size] \ No newline at end of file diff --git a/common/distributed/basic.py b/common/distributed/basic.py new file mode 100644 index 0000000000000000000000000000000000000000..f829aec01eba2cc44d7274b6a0430155c6d42af6 --- /dev/null +++ b/common/distributed/basic.py @@ -0,0 +1,84 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Distributed basic functions. +""" + +import os +from datetime import timedelta +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + + +def get_global_rank() -> int: + """ + Get the global rank, the global index of the GPU. + """ + return int(os.environ.get("RANK", "0")) + + +def get_local_rank() -> int: + """ + Get the local rank, the local index of the GPU. + """ + return int(os.environ.get("LOCAL_RANK", "0")) + + +def get_world_size() -> int: + """ + Get the world size, the total amount of GPUs. + """ + return int(os.environ.get("WORLD_SIZE", "1")) + + +def get_device() -> torch.device: + """ + Get current rank device. + """ + return torch.device("cuda", get_local_rank()) + + +def barrier_if_distributed(*args, **kwargs): + """ + Synchronizes all processes if under distributed context. + """ + if dist.is_initialized(): + return dist.barrier(*args, **kwargs) + + +def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)): + """ + Common PyTorch initialization configuration. + """ + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = cudnn_benchmark + torch.cuda.set_device(get_local_rank()) + dist.init_process_group( + backend="nccl", + rank=get_global_rank(), + world_size=get_world_size(), + timeout=timeout, + ) + + +def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel: + return DistributedDataParallel( + module=module, + device_ids=[get_local_rank()], + output_device=get_local_rank(), + **kwargs, + ) diff --git a/common/distributed/meta_init_utils.py b/common/distributed/meta_init_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..794cd0b8162de596064e494c0b8140a04b9c36a0 --- /dev/null +++ b/common/distributed/meta_init_utils.py @@ -0,0 +1,41 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import torch +from rotary_embedding_torch import RotaryEmbedding +from torch import nn +from torch.distributed.fsdp._common_utils import _is_fsdp_flattened + +__all__ = ["meta_non_persistent_buffer_init_fn"] + + +def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module: + """ + Used for materializing `non-persistent tensor buffers` while model resuming. + + Since non-persistent tensor buffers are not saved in state_dict, + when initializing model with meta device, user should materialize those buffers manually. + + Currently, only `rope.dummy` is this special case. + """ + with torch.no_grad(): + for submodule in module.modules(): + if not isinstance(submodule, RotaryEmbedding): + continue + for buffer_name, buffer in submodule.named_buffers(recurse=False): + if buffer.is_meta and "dummy" in buffer_name: + materialized_buffer = torch.zeros_like(buffer, device="cpu") + setattr(submodule, buffer_name, materialized_buffer) + assert not any(b.is_meta for n, b in module.named_buffers()) + return module diff --git a/common/distributed/ops.py b/common/distributed/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..9b2ae02a6f77de3a8a31d217e0e1f2a6b359c3be --- /dev/null +++ b/common/distributed/ops.py @@ -0,0 +1,494 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Distributed ops for supporting sequence parallel. +""" + +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import torch +import torch.distributed as dist +from torch import Tensor + +from common.cache import Cache +from common.distributed.advanced import ( + get_sequence_parallel_group, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, +) + +from .basic import get_device + +_SEQ_DATA_BUF = defaultdict(lambda: [None, None, None]) +_SEQ_DATA_META_SHAPES = defaultdict() +_SEQ_DATA_META_DTYPES = defaultdict() +_SEQ_DATA_ASYNC_COMMS = defaultdict(list) +_SYNC_BUFFER = defaultdict(dict) + + +def single_all_to_all( + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + group: dist.ProcessGroup, + async_op: bool = False, +): + """ + A function to do all-to-all on a tensor + """ + seq_world_size = dist.get_world_size(group) + prev_scatter_dim = scatter_dim + if scatter_dim != 0: + local_input = local_input.transpose(0, scatter_dim) + if gather_dim == 0: + gather_dim = scatter_dim + scatter_dim = 0 + + inp_shape = list(local_input.shape) + inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size + input_t = local_input.reshape( + [seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :] + ).contiguous() + output = torch.empty_like(input_t) + comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op) + if async_op: + # let user's code transpose & reshape + return output, comm, prev_scatter_dim + + # first dim is seq_world_size, so we can split it directly + output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0) + if prev_scatter_dim: + output = output.transpose(0, prev_scatter_dim).contiguous() + return output + + +def _all_to_all( + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + group: dist.ProcessGroup, +): + seq_world_size = dist.get_world_size(group) + input_list = [ + t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim) + ] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_dim).contiguous() + + +class SeqAllToAll(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_input: Tensor, + scatter_dim: int, + gather_dim: int, + async_op: bool, + ) -> Tensor: + ctx.group = group + ctx.scatter_dim = scatter_dim + ctx.gather_dim = gather_dim + ctx.async_op = async_op + if async_op: + output, comm, prev_scatter_dim = single_all_to_all( + local_input, scatter_dim, gather_dim, group, async_op=async_op + ) + ctx.prev_scatter_dim = prev_scatter_dim + return output, comm + + return _all_to_all(local_input, scatter_dim, gather_dim, group) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: + if ctx.async_op: + input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0) + if ctx.prev_scatter_dim: + input_t = input_t.transpose(0, ctx.prev_scatter_dim) + else: + input_t = grad_output[0] + return ( + None, + _all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group), + None, + None, + None, + ) + + +class Slice(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor: + ctx.group = group + ctx.rank = dist.get_rank(group) + seq_world_size = dist.get_world_size(group) + ctx.seq_world_size = seq_world_size + ctx.dim = dim + dim_size = local_input.shape[dim] + return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous() + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]: + dim_size = list(grad_output.size()) + split_size = dim_size[0] + dim_size[0] = dim_size[0] * ctx.seq_world_size + output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device()) + dist._all_gather_base(output, grad_output, group=ctx.group) + return (None, torch.cat(output.split(split_size), dim=ctx.dim), None) + + +class Gather(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + local_input: Tensor, + dim: int, + grad_scale: Optional[bool] = False, + ) -> Tensor: + ctx.group = group + ctx.rank = dist.get_rank(group) + ctx.dim = dim + ctx.grad_scale = grad_scale + seq_world_size = dist.get_world_size(group) + ctx.seq_world_size = seq_world_size + dim_size = list(local_input.size()) + split_size = dim_size[0] + ctx.part_size = dim_size[dim] + dim_size[0] = dim_size[0] * seq_world_size + output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device()) + dist._all_gather_base(output, local_input.contiguous(), group=ctx.group) + return torch.cat(output.split(split_size), dim=dim) + + @staticmethod + def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]: + if ctx.grad_scale: + grad_output = grad_output * ctx.seq_world_size + return ( + None, + grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(), + None, + None, + ) + + +def gather_seq_scatter_heads_qkv( + qkv_tensor: Tensor, + *, + seq_dim: int, + qkv_shape: Optional[Tensor] = None, + cache: Cache = Cache(disable=True), + restore_shape: bool = True, +): + """ + A func to sync splited qkv tensor + qkv_tensor: the tensor we want to do alltoall with. The last dim must + be the projection_idx, which we will split into 3 part. After + spliting, the gather idx will be projecttion_idx + 1 + seq_dim: gather_dim for all2all comm + restore_shape: if True, output will has the same shape length as input + """ + group = get_sequence_parallel_group() + if not group: + return qkv_tensor + world = get_sequence_parallel_world_size() + orig_shape = qkv_tensor.shape + scatter_dim = qkv_tensor.dim() + bef_all2all_shape = list(orig_shape) + qkv_proj_dim = bef_all2all_shape[-1] + bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3] + qkv_tensor = qkv_tensor.view(bef_all2all_shape) + qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False) + if restore_shape: + out_shape = list(orig_shape) + out_shape[seq_dim] *= world + out_shape[-1] = qkv_proj_dim // world + qkv_tensor = qkv_tensor.view(out_shape) + + # remove padding + if qkv_shape is not None: + unpad_dim_size = cache( + "unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item() + ) + if unpad_dim_size % world != 0: + padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size + qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size) + return qkv_tensor + + +def slice_inputs(x: Tensor, dim: int, padding: bool = True): + """ + A func to slice the input sequence in sequence parallel + """ + group = get_sequence_parallel_group() + if group is None: + return x + sp_rank = get_sequence_parallel_rank() + sp_world = get_sequence_parallel_world_size() + dim_size = x.shape[dim] + unit = (dim_size + sp_world - 1) // sp_world + if padding and dim_size % sp_world: + padding_size = sp_world - (dim_size % sp_world) + x = _pad_tensor(x, dim, padding_size) + slc = [slice(None)] * len(x.shape) + slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1)) + return x[slc] + + +def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int): + """ + A func to remove the padding part of the tensor based on its original shape + """ + group = get_sequence_parallel_group() + if group is None: + return x + sp_world = get_sequence_parallel_world_size() + if unpad_dim_size % sp_world == 0: + return x + padding_size = sp_world - (unpad_dim_size % sp_world) + assert (padding_size + unpad_dim_size) % sp_world == 0 + return _unpad_tensor(x, dim=dim, padding_size=padding_size) + + +def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor: + """ + A func to sync attention result with alltoall in sequence parallel + """ + group = get_sequence_parallel_group() + if not group: + return x + dim_size = x.size(seq_dim) + sp_world = get_sequence_parallel_world_size() + if dim_size % sp_world != 0: + padding_size = sp_world - (dim_size % sp_world) + x = _pad_tensor(x, seq_dim, padding_size) + return SeqAllToAll.apply(group, x, seq_dim, head_dim, False) + + +def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor: + """ + A func to sync embedding input with alltoall in sequence parallel + """ + group = get_sequence_parallel_group() + if not group: + return x + return SeqAllToAll.apply(group, x, head_dim, seq_dim, False) + + +def scatter_heads(x: Tensor, dim: int) -> Tensor: + """ + A func to split heads before attention in sequence parallel + """ + group = get_sequence_parallel_group() + if not group: + return x + return Slice.apply(group, x, dim) + + +def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor: + """ + A func to gather heads for the attention result in sequence parallel + """ + group = get_sequence_parallel_group() + if not group: + return x + return Gather.apply(group, x, dim, grad_scale) + + +def gather_outputs( + x: Tensor, + *, + gather_dim: int, + padding_dim: Optional[int] = None, + unpad_shape: Optional[Tensor] = None, + cache: Cache = Cache(disable=True), + scale_grad=True, +): + """ + A func to gather the outputs for the model result in sequence parallel + """ + group = get_sequence_parallel_group() + if not group: + return x + x = Gather.apply(group, x, gather_dim, scale_grad) + if padding_dim is not None: + unpad_dim_size = cache( + "unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item() + ) + x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size) + return x + + +def _pad_tensor(x: Tensor, dim: int, padding_size: int): + shape = list(x.shape) + shape[dim] = padding_size + pad = torch.zeros(shape, dtype=x.dtype, device=x.device) + return torch.cat([x, pad], dim=dim) + + +def _unpad_tensor(x: Tensor, dim: int, padding_size): + slc = [slice(None)] * len(x.shape) + slc[dim] = slice(0, -padding_size) + return x[slc] + + +def _broadcast_data(data, shape, dtype, src, group, async_op): + comms = [] + if isinstance(data, (list, tuple)): + for i, sub_shape in enumerate(shape): + comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op) + elif isinstance(data, dict): + for key, sub_data in data.items(): + comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op) + elif isinstance(data, Tensor): + comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op)) + return comms + + +def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]: + if isinstance(data, (list, tuple)): + return [_traverse(sub_data, op) for sub_data in data] + elif isinstance(data, dict): + return {key: _traverse(sub_data, op) for key, sub_data in data.items()} + elif isinstance(data, Tensor): + return op(data) + else: + return None + + +def _get_shapes(data): + return _traverse(data, op=lambda x: x.shape) + + +def _get_dtypes(data): + return _traverse(data, op=lambda x: x.dtype) + + +def _construct_broadcast_buffer(shapes, dtypes, device): + if isinstance(shapes, torch.Size): + return torch.empty(shapes, dtype=dtypes, device=device) + + if isinstance(shapes, (list, tuple)): + buffer = [] + for i, sub_shape in enumerate(shapes): + buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device)) + elif isinstance(shapes, dict): + buffer = {} + for key, sub_shape in shapes.items(): + buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device) + else: + return None + return buffer + + +class SPDistForward: + """A forward tool to sync different result across sp group + + Args: + module: a function or module to process users input + sp_step: current training step to judge which rank to broadcast its result to all + name: a distinct str to save meta and async comm + comm_shape: if different ranks have different shape, mark this arg to True + device: the device for current rank, can be empty + """ + + def __init__( + self, + name: str, + comm_shape: bool, + device: torch.device = None, + ): + self.name = name + self.comm_shape = comm_shape + if device: + self.device = device + else: + self.device = get_device() + + def __call__(self, inputs) -> Any: + group = get_sequence_parallel_group() + if not group: + yield inputs + else: + device = self.device + sp_world = get_sequence_parallel_world_size() + sp_rank = get_sequence_parallel_rank() + for local_step in range(sp_world): + src_rank = dist.get_global_rank(group, local_step) + is_src = sp_rank == local_step + local_shapes = [] + local_dtypes = [] + if local_step == 0: + local_result = inputs + _SEQ_DATA_BUF[self.name][-1] = local_result + local_shapes = _get_shapes(local_result) + local_dtypes = _get_dtypes(local_result) + if self.comm_shape: + group_shapes_lists = [None] * sp_world + dist.all_gather_object(group_shapes_lists, local_shapes, group=group) + _SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists + else: + _SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world + _SEQ_DATA_META_DTYPES[self.name] = local_dtypes + shapes = _SEQ_DATA_META_SHAPES[self.name][local_step] + dtypes = _SEQ_DATA_META_DTYPES[self.name] + buf_id = local_step % 2 + if local_step == 0: + sync_data = ( + local_result + if is_src + else _construct_broadcast_buffer(shapes, dtypes, device) + ) + _broadcast_data(sync_data, shapes, dtypes, src_rank, group, False) + _SEQ_DATA_BUF[self.name][buf_id] = sync_data + + # wait for async comm ops + if _SEQ_DATA_ASYNC_COMMS[self.name]: + for comm in _SEQ_DATA_ASYNC_COMMS[self.name]: + comm.wait() + # before return the sync result, do async broadcast for next batch + if local_step < sp_world - 1: + next_buf_id = 1 - buf_id + shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1] + src_rank = dist.get_global_rank(group, local_step + 1) + is_src = sp_rank == local_step + 1 + next_sync_data = ( + _SEQ_DATA_BUF[self.name][-1] + if is_src + else _construct_broadcast_buffer(shapes, dtypes, device) + ) + _SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data( + next_sync_data, shapes, dtypes, src_rank, group, True + ) + _SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data + yield _SEQ_DATA_BUF[self.name][buf_id] + + +sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True) + + +def sync_data(data, sp_idx, name="tmp"): + group = get_sequence_parallel_group() + if group is None: + return data + # if sp_idx in _SYNC_BUFFER[name]: + # return _SYNC_BUFFER[name][sp_idx] + sp_rank = get_sequence_parallel_rank() + src_rank = dist.get_global_rank(group, sp_idx) + objects = [data] if sp_rank == sp_idx else [None] + dist.broadcast_object_list(objects, src=src_rank, group=group) + # _SYNC_BUFFER[name] = {sp_idx: objects[0]} + return objects[0] diff --git a/common/logger.py b/common/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..faf795f0aecb2b16471c99802f2240880d701830 --- /dev/null +++ b/common/logger.py @@ -0,0 +1,44 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Logging utility functions. +""" + +import logging +import sys +from typing import Optional + +from common.distributed import get_global_rank, get_local_rank, get_world_size + +_default_handler = logging.StreamHandler(sys.stdout) +_default_handler.setFormatter( + logging.Formatter( + "%(asctime)s " + + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "") + + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "") + + "[%(threadName).12s][%(name)s][%(levelname).5s] " + + "%(message)s" + ) +) + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Get a logger. + """ + logger = logging.getLogger(name) + logger.addHandler(_default_handler) + logger.setLevel(logging.INFO) + return logger diff --git a/common/partition.py b/common/partition.py new file mode 100644 index 0000000000000000000000000000000000000000..648c87fe2a61294c09704b9af3e47f5a8570c215 --- /dev/null +++ b/common/partition.py @@ -0,0 +1,59 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +""" +Partition utility functions. +""" + +from typing import Any, List + + +def partition_by_size(data: List[Any], size: int) -> List[List[Any]]: + """ + Partition a list by size. + When indivisible, the last group contains fewer items than the target size. + + Examples: + - data: [1,2,3,4,5] + - size: 2 + - return: [[1,2], [3,4], [5]] + """ + assert size > 0 + return [data[i : (i + size)] for i in range(0, len(data), size)] + + +def partition_by_groups(data: List[Any], groups: int) -> List[List[Any]]: + """ + Partition a list by groups. + When indivisible, some groups may have more items than others. + + Examples: + - data: [1,2,3,4,5] + - groups: 2 + - return: [[1,3,5], [2,4]] + """ + assert groups > 0 + return [data[i::groups] for i in range(groups)] + + +def shift_list(data: List[Any], n: int) -> List[Any]: + """ + Rotate a list by n elements. + + Examples: + - data: [1,2,3,4,5] + - n: 3 + - return: [4,5,1,2,3] + """ + return data[(n % len(data)) :] + data[: (n % len(data))] diff --git a/common/seed.py b/common/seed.py new file mode 100644 index 0000000000000000000000000000000000000000..52866de72fcf98f4a2ceff51a55986780a8b701a --- /dev/null +++ b/common/seed.py @@ -0,0 +1,30 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import random +from typing import Optional +import numpy as np +import torch + +from common.distributed import get_global_rank + + +def set_seed(seed: Optional[int], same_across_ranks: bool = False): + """Function that sets the seed for pseudo-random number generators.""" + if seed is not None: + seed += get_global_rank() if not same_across_ranks else 0 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + diff --git a/configs_3b/main.yaml b/configs_3b/main.yaml new file mode 100644 index 0000000000000000000000000000000000000000..78579065f27852990354bd565b5375a679a76035 --- /dev/null +++ b/configs_3b/main.yaml @@ -0,0 +1,88 @@ +__object__: + path: projects.video_diffusion_sr.train + name: VideoDiffusionTrainer + +dit: + model: + __object__: + path: models.dit_v2.nadit + name: NaDiT + args: as_params + vid_in_channels: 33 + vid_out_channels: 16 + vid_dim: 2560 + vid_out_norm: fusedrms + txt_in_dim: 5120 + txt_in_norm: fusedln + txt_dim: ${.vid_dim} + emb_dim: ${eval:'6 * ${.vid_dim}'} + heads: 20 + head_dim: 128 # llm-like + expand_ratio: 4 + norm: fusedrms + norm_eps: 1.0e-05 + ada: single + qk_bias: False + qk_norm: fusedrms + patch_size: [ 1,2,2 ] + num_layers: 32 # llm-like + mm_layers: 10 + mlp_type: swiglu + msa_type: None + block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full + window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full + window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full + rope_type: mmrope3d + rope_dim: 128 + compile: False + gradient_checkpoint: True + fsdp: + sharding_strategy: _HYBRID_SHARD_ZERO2 + +ema: + decay: 0.9998 + +vae: + model: + __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml + freeze_encoder: False + # gradient_checkpoint: True + slicing: + split_size: 4 + memory_device: same + memory_limit: + conv_max_mem: 0.5 + norm_max_mem: 0.5 + checkpoint: ./ckpts/ema_vae.pth + scaling_factor: 0.9152 + compile: False + grouping: False + dtype: bfloat16 + +diffusion: + schedule: + type: lerp + T: 1000.0 + sampler: + type: euler + prediction_type: v_lerp + timesteps: + training: + type: logitnormal + loc: 0.0 + scale: 1.0 + sampling: + type: uniform_trailing + steps: 50 + transform: True + loss: + type: v_lerp + cfg: + scale: 7.5 + rescale: 0 + +condition: + i2v: 0.0 + v2v: 0.0 + sr: 1.0 + noise_scale: 0.25 diff --git a/configs_7b/main.yaml b/configs_7b/main.yaml new file mode 100644 index 0000000000000000000000000000000000000000..51c5eaf880788ff941bcce84b2548e3f21646339 --- /dev/null +++ b/configs_7b/main.yaml @@ -0,0 +1,85 @@ +__object__: + path: projects.video_diffusion_sr.train + name: VideoDiffusionTrainer + +dit: + model: + __object__: + path: models.dit.nadit + name: NaDiT + args: as_params + vid_in_channels: 33 + vid_out_channels: 16 + vid_dim: 3072 + txt_in_dim: 5120 + txt_dim: ${.vid_dim} + emb_dim: ${eval:'6 * ${.vid_dim}'} + heads: 24 + head_dim: 128 # llm-like + expand_ratio: 4 + norm: fusedrms + norm_eps: 1e-5 + ada: single + qk_bias: False + qk_rope: True + qk_norm: fusedrms + patch_size: [ 1,2,2 ] + num_layers: 36 # llm-like + shared_mlp: False + shared_qkv: False + mlp_type: normal + block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full + window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full + window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full + compile: False + gradient_checkpoint: True + fsdp: + sharding_strategy: _HYBRID_SHARD_ZERO2 + +ema: + decay: 0.9998 + +vae: + model: + __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml + freeze_encoder: False + # gradient_checkpoint: True + slicing: + split_size: 4 + memory_device: same + memory_limit: + conv_max_mem: 0.5 + norm_max_mem: 0.5 + checkpoint: ./ckpts/ema_vae.pth + scaling_factor: 0.9152 + compile: False + grouping: False + dtype: bfloat16 + +diffusion: + schedule: + type: lerp + T: 1000.0 + sampler: + type: euler + prediction_type: v_lerp + timesteps: + training: + type: logitnormal + loc: 0.0 + scale: 1.0 + sampling: + type: uniform_trailing + steps: 50 + transform: True + loss: + type: v_lerp + cfg: + scale: 7.5 + rescale: 0 + +condition: + i2v: 0.0 + v2v: 0.0 + sr: 1.0 + noise_scale: 0.25 diff --git a/data/image/transforms/area_resize.py b/data/image/transforms/area_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..9f621dae1b0af40f58e090405db1ac7338110980 --- /dev/null +++ b/data/image/transforms/area_resize.py @@ -0,0 +1,135 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import math +import random +from typing import Union +import torch +from PIL import Image +from torchvision.transforms import functional as TVF +from torchvision.transforms.functional import InterpolationMode + + +class AreaResize: + def __init__( + self, + max_area: float, + downsample_only: bool = False, + interpolation: InterpolationMode = InterpolationMode.BICUBIC, + ): + self.max_area = max_area + self.downsample_only = downsample_only + self.interpolation = interpolation + + def __call__(self, image: Union[torch.Tensor, Image.Image]): + + if isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + elif isinstance(image, Image.Image): + width, height = image.size + else: + raise NotImplementedError + + scale = math.sqrt(self.max_area / (height * width)) + + # keep original height and width for small pictures. + scale = 1 if scale >= 1 and self.downsample_only else scale + + resized_height, resized_width = round(height * scale), round(width * scale) + + return TVF.resize( + image, + size=(resized_height, resized_width), + interpolation=self.interpolation, + ) + + +class AreaRandomCrop: + def __init__( + self, + max_area: float, + ): + self.max_area = max_area + + def get_params(self, input_size, output_size): + """Get parameters for ``crop`` for a random crop. + + Args: + img (PIL Image): Image to be cropped. + output_size (tuple): Expected output size of the crop. + + Returns: + tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. + """ + # w, h = _get_image_size(img) + h, w = input_size + th, tw = output_size + if w <= tw and h <= th: + return 0, 0, h, w + + i = random.randint(0, h - th) + j = random.randint(0, w - tw) + return i, j, th, tw + + def __call__(self, image: Union[torch.Tensor, Image.Image]): + if isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + elif isinstance(image, Image.Image): + width, height = image.size + else: + raise NotImplementedError + + resized_height = math.sqrt(self.max_area / (width / height)) + resized_width = (width / height) * resized_height + + # print('>>>>>>>>>>>>>>>>>>>>>') + # print((height, width)) + # print( (resized_height, resized_width)) + + resized_height, resized_width = round(resized_height), round(resized_width) + i, j, h, w = self.get_params((height, width), (resized_height, resized_width)) + image = TVF.crop(image, i, j, h, w) + return image + +class ScaleResize: + def __init__( + self, + scale: float, + ): + self.scale = scale + + def __call__(self, image: Union[torch.Tensor, Image.Image]): + if isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + interpolation_mode = InterpolationMode.BILINEAR + antialias = True if image.ndim == 4 else "warn" + elif isinstance(image, Image.Image): + width, height = image.size + interpolation_mode = InterpolationMode.LANCZOS + antialias = "warn" + else: + raise NotImplementedError + + scale = self.scale + + # keep original height and width for small pictures + + resized_height, resized_width = round(height * scale), round(width * scale) + image = TVF.resize( + image, + size=(resized_height, resized_width), + interpolation=interpolation_mode, + antialias=antialias, + ) + return image diff --git a/data/image/transforms/divisible_crop.py b/data/image/transforms/divisible_crop.py new file mode 100644 index 0000000000000000000000000000000000000000..d1815b03ee1ce99486143aca24b9023ab0b3973c --- /dev/null +++ b/data/image/transforms/divisible_crop.py @@ -0,0 +1,40 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Union +import torch +from PIL import Image +from torchvision.transforms import functional as TVF + + +class DivisibleCrop: + def __init__(self, factor): + if not isinstance(factor, tuple): + factor = (factor, factor) + + self.height_factor, self.width_factor = factor[0], factor[1] + + def __call__(self, image: Union[torch.Tensor, Image.Image]): + if isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + elif isinstance(image, Image.Image): + width, height = image.size + else: + raise NotImplementedError + + cropped_height = height - (height % self.height_factor) + cropped_width = width - (width % self.width_factor) + + image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width)) + return image diff --git a/data/image/transforms/na_resize.py b/data/image/transforms/na_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..d230e25e3ca1710ad6261d8e14541a97732b9a30 --- /dev/null +++ b/data/image/transforms/na_resize.py @@ -0,0 +1,50 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Literal +from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Resize + +from .area_resize import AreaResize +from .side_resize import SideResize + + +def NaResize( + resolution: int, + mode: Literal["area", "side"], + downsample_only: bool, + interpolation: InterpolationMode = InterpolationMode.BICUBIC, +): + if mode == "area": + return AreaResize( + max_area=resolution**2, + downsample_only=downsample_only, + interpolation=interpolation, + ) + if mode == "side": + return SideResize( + size=resolution, + downsample_only=downsample_only, + interpolation=interpolation, + ) + if mode == "square": + return Compose( + [ + Resize( + size=resolution, + interpolation=interpolation, + ), + CenterCrop(resolution), + ] + ) + raise ValueError(f"Unknown resize mode: {mode}") diff --git a/data/image/transforms/side_resize.py b/data/image/transforms/side_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..6e07402b2187a048b99d995d68ead12f790f5724 --- /dev/null +++ b/data/image/transforms/side_resize.py @@ -0,0 +1,54 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Union +import torch +from PIL import Image +from torchvision.transforms import InterpolationMode +from torchvision.transforms import functional as TVF + + +class SideResize: + def __init__( + self, + size: int, + downsample_only: bool = False, + interpolation: InterpolationMode = InterpolationMode.BICUBIC, + ): + self.size = size + self.downsample_only = downsample_only + self.interpolation = interpolation + + def __call__(self, image: Union[torch.Tensor, Image.Image]): + """ + Args: + image (PIL Image or Tensor): Image to be scaled. + + Returns: + PIL Image or Tensor: Rescaled image. + """ + if isinstance(image, torch.Tensor): + height, width = image.shape[-2:] + elif isinstance(image, Image.Image): + width, height = image.size + else: + raise NotImplementedError + + if self.downsample_only and min(width, height) < self.size: + # keep original height and width for small pictures. + size = min(width, height) + else: + size = self.size + + return TVF.resize(image, size, self.interpolation) diff --git a/data/video/transforms/rearrange.py b/data/video/transforms/rearrange.py new file mode 100644 index 0000000000000000000000000000000000000000..895347991d71043742777f103d32b62c80284660 --- /dev/null +++ b/data/video/transforms/rearrange.py @@ -0,0 +1,24 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from einops import rearrange + + +class Rearrange: + def __init__(self, pattern: str, **kwargs): + self.pattern = pattern + self.kwargs = kwargs + + def __call__(self, x): + return rearrange(x, self.pattern, **self.kwargs) diff --git a/models/dit/attention.py b/models/dit/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ac0cadbcd62e7d40700108d2857cb587f794fcee --- /dev/null +++ b/models/dit/attention.py @@ -0,0 +1,46 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import torch +import torch.nn.functional as F + +from flash_attn import flash_attn_varlen_func + +from torch import nn + +class TorchAttention(nn.Module): + def tflops(self, args, kwargs, output) -> float: + assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs" + q = kwargs.get("query") or args[0] + k = kwargs.get("key") or args[1] + b, h, sq, d = q.shape + b, h, sk, d = k.shape + return b * h * (4 * d * (sq / 1e6) * (sk / 1e6)) + + def forward(self, *args, **kwargs): + return F.scaled_dot_product_attention(*args, **kwargs) + + +class FlashAttentionVarlen(nn.Module): + def tflops(self, args, kwargs, output) -> float: + cu_seqlens_q = kwargs["cu_seqlens_q"] + cu_seqlens_k = kwargs["cu_seqlens_k"] + _, h, d = output.shape + seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6 + seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6 + return h * (4 * d * (seqlens_q * seqlens_k).sum()) + + def forward(self, *args, **kwargs): + kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() + return flash_attn_varlen_func(*args, **kwargs) \ No newline at end of file diff --git a/models/dit/blocks/__init__.py b/models/dit/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3195b400a407b871a6c19b67cf25239c5c3f196d --- /dev/null +++ b/models/dit/blocks/__init__.py @@ -0,0 +1,25 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from .mmdit_window_block import MMWindowTransformerBlock + +dit_blocks = { + "mmdit_window": MMWindowTransformerBlock, +} + + +def get_block(block_type: str): + if block_type in dit_blocks: + return dit_blocks[block_type] + raise NotImplementedError(f"{block_type} is not supported") diff --git a/models/dit/blocks/mmdit_window_block.py b/models/dit/blocks/mmdit_window_block.py new file mode 100644 index 0000000000000000000000000000000000000000..eacaa093658f62fb483086215cfb6ac72a2dc9fd --- /dev/null +++ b/models/dit/blocks/mmdit_window_block.py @@ -0,0 +1,233 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Tuple, Union +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _triple + +from common.distributed.ops import ( + gather_heads, + gather_heads_scatter_seq, + gather_seq_scatter_heads_qkv, + scatter_heads, +) + +from ..attention import TorchAttention +from ..mlp import get_mlp +from ..mm import MMArg, MMModule +from ..modulation import ada_layer_type +from ..normalization import norm_layer_type +from ..rope import RotaryEmbedding3d + + +class MMWindowAttention(nn.Module): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_rope: bool, + qk_norm: norm_layer_type, + qk_norm_eps: float, + window: Union[int, Tuple[int, int, int]], + window_method: str, + shared_qkv: bool, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + inner_dim = heads * head_dim + qkv_dim = inner_dim * 3 + + self.window = _triple(window) + self.window_method = window_method + assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + + self.head_dim = head_dim + self.proj_qkv = MMModule(nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_qkv) + self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_qkv) + self.norm_q = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True) + self.norm_k = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True) + self.rope = RotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None + self.attn = TorchAttention() + + def forward( + self, + vid: torch.FloatTensor, # b T H W c + txt: torch.FloatTensor, # b L c + txt_mask: torch.BoolTensor, # b L + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + # Project q, k, v. + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + vid_qkv = gather_seq_scatter_heads_qkv(vid_qkv, seq_dim=2) + _, T, H, W, _ = vid_qkv.shape + _, L, _ = txt.shape + + if self.window_method == "win": + nt, nh, nw = self.window + tt, hh, ww = T // nt, H // nh, W // nw + elif self.window_method == "win_by_size": + tt, hh, ww = self.window + tt, hh, ww = ( + tt if tt > 0 else T, + hh if hh > 0 else H, + ww if ww > 0 else W, + ) + nt, nh, nw = T // tt, H // hh, W // ww + else: + raise NotImplementedError + + vid_qkv = rearrange(vid_qkv, "b T H W (o h d) -> o b h (T H W) d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "b L (o h d) -> o b h L d", o=3, d=self.head_dim) + txt_qkv = scatter_heads(txt_qkv, dim=2) + + vid_q, vid_k, vid_v = vid_qkv.unbind() + txt_q, txt_k, txt_v = txt_qkv.unbind() + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + if self.rope: + vid_q, vid_k = self.rope(vid_q, vid_k, (T, H, W)) + + def vid_window(v): + return rearrange( + v, + "b h (nt tt nh hh nw ww) d -> b h (nt nh nw) (tt hh ww) d", + hh=hh, + ww=ww, + tt=tt, + nh=nh, + nw=nw, + nt=nt, + ) + + def txt_window(t): + return rearrange(t, "b h L d -> b h 1 L d").expand(-1, -1, nt * nh * nw, -1, -1) + + # Process video attention. + vid_msk = F.pad(txt_mask, (tt * hh * ww, 0), value=True) + vid_msk = rearrange(vid_msk, "b l -> b 1 1 1 l").expand(-1, 1, 1, tt * hh * ww, -1) + vid_out = self.attn( + vid_window(vid_q), + torch.cat([vid_window(vid_k), txt_window(txt_k)], dim=-2), + torch.cat([vid_window(vid_v), txt_window(txt_v)], dim=-2), + vid_msk, + ) + vid_out = rearrange( + vid_out, + "b h (nt nh nw) (tt hh ww) d -> b (nt tt) (nh hh) (nw ww) (h d)", + hh=hh, + ww=ww, + tt=tt, + nh=nh, + nw=nw, + ) + vid_out = gather_heads_scatter_seq(vid_out, head_dim=4, seq_dim=2) + + # Process text attention. + txt_msk = F.pad(txt_mask, (T * H * W, 0), value=True) + txt_msk = rearrange(txt_msk, "b l -> b 1 1 l").expand(-1, 1, L, -1) + txt_out = self.attn( + txt_q, + torch.cat([vid_k, txt_k], dim=-2), + torch.cat([vid_v, txt_v], dim=-2), + txt_msk, + ) + txt_out = rearrange(txt_out, "b h L d -> b L (h d)") + txt_out = gather_heads(txt_out, dim=2) + + # Project output. + vid_out, txt_out = self.proj_out(vid_out, txt_out) + return vid_out, txt_out + + +class MMWindowTransformerBlock(nn.Module): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm: norm_layer_type, + norm_eps: float, + ada: ada_layer_type, + qk_bias: bool, + qk_rope: bool, + qk_norm: norm_layer_type, + window: Union[int, Tuple[int, int, int]], + window_method: str, + shared_qkv: bool, + shared_mlp: bool, + mlp_type: str, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False) + self.attn = MMWindowAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + window=window, + window_method=window_method, + shared_qkv=shared_qkv, + ) + self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False) + self.mlp = MMModule( + get_mlp(mlp_type), + dim=dim, + expand_ratio=expand_ratio, + shared_weights=shared_mlp, + ) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"]) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + txt_mask: torch.BoolTensor, + emb: torch.FloatTensor, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="in") + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, txt_mask=txt_mask) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="out") + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="in") + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="out") + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp diff --git a/models/dit/embedding.py b/models/dit/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e972244f5767c9f34e5e77bb180ae720ce88b89c --- /dev/null +++ b/models/dit/embedding.py @@ -0,0 +1,62 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Optional, Union +import torch +from diffusers.models.embeddings import get_timestep_embedding +from torch import nn + + +def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): + return emb1 if emb2 is None else emb1 + emb2 + + +class TimeEmbedding(nn.Module): + def __init__( + self, + sinusoidal_dim: int, + hidden_dim: int, + output_dim: int, + ): + super().__init__() + self.sinusoidal_dim = sinusoidal_dim + self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) + self.proj_hid = nn.Linear(hidden_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.act = nn.SiLU() + + def forward( + self, + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.FloatTensor: + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=device, dtype=dtype) + if timestep.ndim == 0: + timestep = timestep[None] + + emb = get_timestep_embedding( + timesteps=timestep, + embedding_dim=self.sinusoidal_dim, + flip_sin_to_cos=False, + downscale_freq_shift=0, + ) + emb = emb.to(dtype) + emb = self.proj_in(emb) + emb = self.act(emb) + emb = self.proj_hid(emb) + emb = self.act(emb) + emb = self.proj_out(emb) + return emb diff --git a/models/dit/mlp.py b/models/dit/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..2d05cb021f3e3c6ac05c0e7ae1aa8a6d29475b87 --- /dev/null +++ b/models/dit/mlp.py @@ -0,0 +1,62 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Optional +import torch +import torch.nn.functional as F +from torch import nn + + +def get_mlp(mlp_type: Optional[str] = "normal"): + if mlp_type == "normal": + return MLP + elif mlp_type == "swiglu": + return SwiGLUMLP + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + ): + super().__init__() + self.proj_in = nn.Linear(dim, dim * expand_ratio) + self.act = nn.GELU("tanh") + self.proj_out = nn.Linear(dim * expand_ratio, dim) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_in(x) + x = self.act(x) + x = self.proj_out(x) + return x + + +class SwiGLUMLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + multiple_of: int = 256, + ): + super().__init__() + hidden_dim = int(2 * dim * expand_ratio / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) + self.proj_out = nn.Linear(hidden_dim, dim, bias=False) + self.proj_in = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) + return x diff --git a/models/dit/mm.py b/models/dit/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..49be1f5915a61d8ea27f3e3718f35e5c9af662e7 --- /dev/null +++ b/models/dit/mm.py @@ -0,0 +1,67 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple +import torch +from torch import nn + + +@dataclass +class MMArg: + vid: Any + txt: Any + + +def get_args(key: str, args: List[Any]) -> List[Any]: + return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] + + +def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} + + +class MMModule(nn.Module): + def __init__( + self, + module: Callable[..., nn.Module], + *args, + shared_weights: bool = False, + **kwargs, + ): + super().__init__() + self.shared_weights = shared_weights + if self.shared_weights: + assert get_args("vid", args) == get_args("txt", args) + assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + else: + self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + self.txt = module(*get_args("txt", args), **get_kwargs("txt", kwargs)) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + *args, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.vid if not self.shared_weights else self.all + txt_module = self.txt if not self.shared_weights else self.all + vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) + txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) + return vid, txt diff --git a/models/dit/modulation.py b/models/dit/modulation.py new file mode 100644 index 0000000000000000000000000000000000000000..cd3b41f6c457396ac65403d88edc3d5ad3382262 --- /dev/null +++ b/models/dit/modulation.py @@ -0,0 +1,97 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Callable, List, Optional +import torch +from einops import rearrange +from torch import nn + +from common.cache import Cache +from common.distributed.ops import slice_inputs + +# (dim: int, emb_dim: int) +ada_layer_type = Callable[[int, int], nn.Module] + + +def get_ada_layer(ada_layer: str) -> ada_layer_type: + if ada_layer == "single": + return AdaSingle + raise NotImplementedError(f"{ada_layer} is not supported") + + +def expand_dims(x: torch.Tensor, dim: int, ndim: int): + """ + Expand tensor "x" to "ndim" by adding empty dims at "dim". + Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). + """ + shape = x.shape + shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] + return x.reshape(shape) + + +class AdaSingle(nn.Module): + def __init__( + self, + dim: int, + emb_dim: int, + layers: List[str], + ): + assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + super().__init__() + self.dim = dim + self.emb_dim = emb_dim + self.layers = layers + for l in layers: + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)) + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) + + def forward( + self, + hid: torch.FloatTensor, # b ... c + emb: torch.FloatTensor, # b d + layer: str, + mode: str, + cache: Cache = Cache(disable=True), + branch_tag: str = "", + hid_len: Optional[torch.LongTensor] = None, # b + ) -> torch.FloatTensor: + idx = self.layers.index(layer) + emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = expand_dims(emb, 1, hid.ndim + 1) + + if hid_len is not None: + emb = cache( + f"emb_repeat_{idx}_{branch_tag}", + lambda: slice_inputs( + torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), + dim=0, + ), + ) + + shiftA, scaleA, gateA = emb.unbind(-1) + shiftB, scaleB, gateB = ( + getattr(self, f"{layer}_shift"), + getattr(self, f"{layer}_scale"), + getattr(self, f"{layer}_gate"), + ) + + if mode == "in": + return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) + if mode == "out": + return hid.mul_(gateA + gateB) + raise NotImplementedError + + def extra_repr(self) -> str: + return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" \ No newline at end of file diff --git a/models/dit/na.py b/models/dit/na.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbd546c4705b3b9c7c19a9823f9d113a0447616 --- /dev/null +++ b/models/dit/na.py @@ -0,0 +1,241 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from itertools import chain +from typing import Callable, Dict, List, Tuple +import einops +import torch + + +def flatten( + hid: List[torch.FloatTensor], # List of (*** c) +) -> Tuple[ + torch.FloatTensor, # (L c) + torch.LongTensor, # (b n) +]: + assert len(hid) > 0 + shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + hid = torch.cat([x.flatten(0, -2) for x in hid]) + return hid, shape + + +def unflatten( + hid: torch.FloatTensor, # (L c) or (L ... c) + hid_shape: torch.LongTensor, # (b n) +) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) + hid_len = hid_shape.prod(-1) + hid = hid.split(hid_len.tolist()) + hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] + return hid + + +def concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + return torch.cat(list(chain(*zip(vid, txt)))) + + +def concat_idx( + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) + src_idx = torch.argsort(tgt_idx) + return ( + lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), + lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), + ) + + +def unconcat( + all: torch.FloatTensor, # (L ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + torch.FloatTensor, # (VL ... c) + torch.FloatTensor, # (TL ... c) +]: + interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) + all = all.split(interleave_len) + vid = torch.cat(all[0::2]) + txt = torch.cat(all[1::2]) + return vid, txt + + +def repeat_concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: List, # (n) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + txt = [[x] * n for x, n in zip(txt, txt_repeat)] + txt = list(chain(*txt)) + return torch.cat(list(chain(*zip(vid, txt)))) + + +def repeat_concat_idx( + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: torch.LongTensor, # (n) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + txt_repeat_list = txt_repeat.tolist() + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + src_idx = torch.argsort(tgt_idx) + txt_idx_len = len(tgt_idx) - len(vid_idx) + repeat_txt_len = (txt_len * txt_repeat).tolist() + + def unconcat_coalesce(all): + """ + Un-concat vid & txt, and coalesce the repeated txt. + e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] + txt [9 10] + repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] + 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] + split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] + 2. reshape & mean for each sample to coalesce the repeated txt. + """ + vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) + txt_out_coalesced = [] + for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): + txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) + txt_out_coalesced.append(txt) + return vid_out, torch.cat(txt_out_coalesced) + + # Note: Backward of torch.index_select is non-deterministic when existing repeated index, + # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here. + return ( + lambda vid, txt: torch.cat([vid, txt])[tgt_idx], + lambda all: unconcat_coalesce(all), + ) + + +def rearrange( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, int], +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) + + +def rearrange_idx( + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, int], +) -> Tuple[Callable, Callable, torch.LongTensor]: + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + ) + + +def repeat( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, torch.LongTensor], # (b) +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + hid = unflatten(hid, hid_shape) + kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] + return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) + + +def pack( + samples: List[torch.Tensor], # List of (h w c). +) -> Tuple[ + List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)] + List[List[int]], # reversal indices. +]: + batches = {} + indices = {} + for i, sample in enumerate(samples): + shape = sample.shape + batches[shape] = batches.get(shape, []) + indices[shape] = indices.get(shape, []) + batches[shape].append(sample) + indices[shape].append(i) + + batches = list(map(torch.stack, batches.values())) + indices = list(indices.values()) + return batches, indices + + +def unpack( + batches: List[torch.Tensor], + indices: List[List[int]], +) -> List[torch.Tensor]: + samples = [None] * (max(chain(*indices)) + 1) + for batch, index in zip(batches, indices): + for sample, i in zip(batch.unbind(), index): + samples[i] = sample + return samples + + +def window( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid = unflatten(hid, hid_shape) + hid = list(map(window_fn, hid)) + hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid, hid_shape = flatten(list(chain(*hid))) + return hid, hid_shape, hid_windows + + +def window_idx( + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + tgt_windows, + ) diff --git a/models/dit/nablocks/__init__.py b/models/dit/nablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..afa206db157786d9e4cf830bec09bd3a390bd9a8 --- /dev/null +++ b/models/dit/nablocks/__init__.py @@ -0,0 +1,25 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from .mmsr_block import NaMMSRTransformerBlock + +nadit_blocks = { + "mmdit_sr": NaMMSRTransformerBlock, +} + + +def get_nablock(block_type: str): + if block_type in nadit_blocks: + return nadit_blocks[block_type] + raise NotImplementedError(f"{block_type} is not supported") diff --git a/models/dit/nablocks/mmsr_block.py b/models/dit/nablocks/mmsr_block.py new file mode 100644 index 0000000000000000000000000000000000000000..b75652efc070188268bb84b35352b543e1a3746b --- /dev/null +++ b/models/dit/nablocks/mmsr_block.py @@ -0,0 +1,248 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Tuple, Union +import torch +from einops import rearrange +from torch.nn import functional as F + +# from ..cache import Cache +from common.cache import Cache +from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv + +from .. import na +from ..attention import FlashAttentionVarlen +from ..blocks.mmdit_window_block import MMWindowAttention, MMWindowTransformerBlock +from ..mm import MMArg +from ..modulation import ada_layer_type +from ..normalization import norm_layer_type +from ..rope import NaRotaryEmbedding3d +from ..window import get_window_op + + +class NaSwinAttention(MMWindowAttention): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_rope: bool, + qk_norm: norm_layer_type, + qk_norm_eps: float, + window: Union[int, Tuple[int, int, int]], + window_method: str, + shared_qkv: bool, + **kwargs, + ): + super().__init__( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + qk_norm_eps=qk_norm_eps, + window=window, + window_method=window_method, + shared_qkv=shared_qkv, + ) + self.rope = NaRotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None + self.attn = FlashAttentionVarlen() + self.window_op = get_window_op(window_method) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + vid_qkv = gather_seq_scatter_heads_qkv( + vid_qkv, + seq_dim=0, + qkv_shape=vid_shape, + cache=cache.namespace("vid"), + ) + txt_qkv = gather_seq_scatter_heads_qkv( + txt_qkv, + seq_dim=0, + qkv_shape=txt_shape, + cache=cache.namespace("txt"), + ) + + # re-org the input seq for window attn + cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") + + def make_window(x: torch.Tensor): + t, h, w, _ = x.shape + window_slices = self.window_op((t, h, w), self.window) + return [x[st, sh, sw] for (st, sh, sw) in window_slices] + + window_partition, window_reverse, window_shape, window_count = cache_win( + "win_transform", + lambda: na.window_idx(vid_shape, make_window), + ) + vid_qkv_win = window_partition(vid_qkv) + + vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + + vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + concat_win, unconcat_win = cache_win( + "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count) + ) + + # window rope + if self.rope: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + + out = self.attn( + q=concat_win(vid_q, txt_q).bfloat16(), + k=concat_win(vid_k, txt_k).bfloat16(), + v=concat_win(vid_v, txt_v).bfloat16(), + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() + ), + max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), + max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), + ).type_as(vid_q) + + # text pooling + vid_out, txt_out = unconcat_win(out) + + vid_out = rearrange(vid_out, "l h d -> l (h d)") + txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = window_reverse(vid_out) + + vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0) + txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + + return vid_out, txt_out + + +class NaMMSRTransformerBlock(MMWindowTransformerBlock): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm: norm_layer_type, + norm_eps: float, + ada: ada_layer_type, + qk_bias: bool, + qk_rope: bool, + qk_norm: norm_layer_type, + shared_qkv: bool, + shared_mlp: bool, + mlp_type: str, + **kwargs, + ): + super().__init__( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + shared_qkv=shared_qkv, + shared_mlp=shared_mlp, + mlp_type=mlp_type, + **kwargs, + ) + + self.attn = NaSwinAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + shared_qkv=shared_qkv, + **kwargs, + ) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + emb: torch.FloatTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.LongTensor, + torch.LongTensor, + ]: + hid_len = MMArg( + cache("vid_len", lambda: vid_shape.prod(-1)), + cache("txt_len", lambda: txt_shape.prod(-1)), + ) + ada_kwargs = { + "emb": emb, + "hid_len": hid_len, + "cache": cache, + "branch_tag": MMArg("vid", "txt"), + } + + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp, vid_shape, txt_shape diff --git a/models/dit/nadit.py b/models/dit/nadit.py new file mode 100644 index 0000000000000000000000000000000000000000..7e778236db6a70f49a364db6e84bf7539c0b58ac --- /dev/null +++ b/models/dit/nadit.py @@ -0,0 +1,350 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from dataclasses import dataclass +from typing import Optional, Tuple, Union, Callable +import torch +from torch import nn + +from common.cache import Cache +from common.distributed.ops import slice_inputs + +from . import na +from .embedding import TimeEmbedding +from .modulation import get_ada_layer +from .nablocks import get_nablock +from .normalization import get_norm_layer +from .patch import NaPatchIn, NaPatchOut + +# Fake func, no checkpointing is required for inference +def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): + return module(*args, **kwargs) + +@dataclass +class NaDiTOutput: + vid_sample: torch.Tensor + + +class NaDiT(nn.Module): + """ + Native Resolution Diffusion Transformer (NaDiT) + """ + + gradient_checkpointing = False + + def __init__( + self, + vid_in_channels: int, + vid_out_channels: int, + vid_dim: int, + txt_in_dim: Optional[int], + txt_dim: Optional[int], + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm: Optional[str], + norm_eps: float, + ada: str, + qk_bias: bool, + qk_rope: bool, + qk_norm: Optional[str], + patch_size: Union[int, Tuple[int, int, int]], + num_layers: int, + block_type: Union[str, Tuple[str]], + shared_qkv: bool = False, + shared_mlp: bool = False, + mlp_type: str = "normal", + window: Optional[Tuple] = None, + window_method: Optional[Tuple[str]] = None, + temporal_window_size: int = None, + temporal_shifted: bool = False, + **kwargs, + ): + ada = get_ada_layer(ada) + norm = get_norm_layer(norm) + qk_norm = get_norm_layer(qk_norm) + if isinstance(block_type, str): + block_type = [block_type] * num_layers + elif len(block_type) != num_layers: + raise ValueError("The ``block_type`` list should equal to ``num_layers``.") + super().__init__() + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + ) + self.txt_in = ( + nn.Linear(txt_in_dim, txt_dim) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=256, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + if window_method is None or isinstance(window_method, str): + window_method = [window_method] * num_layers + if temporal_window_size is None or isinstance(temporal_window_size, int): + temporal_window_size = [temporal_window_size] * num_layers + if temporal_shifted is None or isinstance(temporal_shifted, bool): + temporal_shifted = [temporal_shifted] * num_layers + + self.blocks = nn.ModuleList( + [ + get_nablock(block_type[i])( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + shared_qkv=shared_qkv, + shared_mlp=shared_mlp, + mlp_type=mlp_type, + window=window[i], + window_method=window_method[i], + temporal_window_size=temporal_window_size[i], + temporal_shifted=temporal_shifted[i], + **kwargs, + ) + for i in range(num_layers) + ] + ) + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + ) + + self.need_txt_repeat = block_type[0] in [ + "mmdit_stwin", + "mmdit_stwin_spatial", + "mmdit_stwin_3d_spatial", + ] + + def set_gradient_checkpointing(self, enable: bool): + self.gradient_checkpointing = enable + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b + disable_cache: bool = True, # for test + ): + # Text input. + if txt_shape.size(-1) == 1 and self.need_txt_repeat: + txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + # slice vid after patching in when using sequence parallelism + txt = slice_inputs(txt, dim=0) + txt = self.txt_in(txt) + + # Video input. + # Sequence parallel slicing is done inside patching class. + vid, vid_shape = self.vid_in(vid, vid_shape) + + # Embedding input. + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + + # Body + cache = Cache(disable=disable_cache) + for i, block in enumerate(self.blocks): + vid, txt, vid_shape, txt_shape = gradient_checkpointing( + enabled=(self.gradient_checkpointing and self.training), + module=block, + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + vid, vid_shape = self.vid_out(vid, vid_shape, cache) + return NaDiTOutput(vid_sample=vid) + + +class NaDiTUpscaler(nn.Module): + """ + Native Resolution Diffusion Transformer (NaDiT) + """ + + gradient_checkpointing = False + + def __init__( + self, + vid_in_channels: int, + vid_out_channels: int, + vid_dim: int, + txt_in_dim: Optional[int], + txt_dim: Optional[int], + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm: Optional[str], + norm_eps: float, + ada: str, + qk_bias: bool, + qk_rope: bool, + qk_norm: Optional[str], + patch_size: Union[int, Tuple[int, int, int]], + num_layers: int, + block_type: Union[str, Tuple[str]], + shared_qkv: bool = False, + shared_mlp: bool = False, + mlp_type: str = "normal", + window: Optional[Tuple] = None, + window_method: Optional[Tuple[str]] = None, + temporal_window_size: int = None, + temporal_shifted: bool = False, + **kwargs, + ): + ada = get_ada_layer(ada) + norm = get_norm_layer(norm) + qk_norm = get_norm_layer(qk_norm) + if isinstance(block_type, str): + block_type = [block_type] * num_layers + elif len(block_type) != num_layers: + raise ValueError("The ``block_type`` list should equal to ``num_layers``.") + super().__init__() + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + ) + self.txt_in = ( + nn.Linear(txt_in_dim, txt_dim) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=256, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + ) + + self.emb_scale = TimeEmbedding( + sinusoidal_dim=256, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + if window_method is None or isinstance(window_method, str): + window_method = [window_method] * num_layers + if temporal_window_size is None or isinstance(temporal_window_size, int): + temporal_window_size = [temporal_window_size] * num_layers + if temporal_shifted is None or isinstance(temporal_shifted, bool): + temporal_shifted = [temporal_shifted] * num_layers + + self.blocks = nn.ModuleList( + [ + get_nablock(block_type[i])( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + shared_qkv=shared_qkv, + shared_mlp=shared_mlp, + mlp_type=mlp_type, + window=window[i], + window_method=window_method[i], + temporal_window_size=temporal_window_size[i], + temporal_shifted=temporal_shifted[i], + **kwargs, + ) + for i in range(num_layers) + ] + ) + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + ) + + self.need_txt_repeat = block_type[0] in [ + "mmdit_stwin", + "mmdit_stwin_spatial", + "mmdit_stwin_3d_spatial", + ] + + def set_gradient_checkpointing(self, enable: bool): + self.gradient_checkpointing = enable + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b + downscale: Union[int, float, torch.IntTensor, torch.FloatTensor], # b + disable_cache: bool = False, # for test + ): + + # Text input. + if txt_shape.size(-1) == 1 and self.need_txt_repeat: + txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + # slice vid after patching in when using sequence parallelism + txt = slice_inputs(txt, dim=0) + txt = self.txt_in(txt) + + # Video input. + # Sequence parallel slicing is done inside patching class. + vid, vid_shape = self.vid_in(vid, vid_shape) + + # Embedding input. + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + emb_scale = self.emb_scale(downscale, device=vid.device, dtype=vid.dtype) + emb = emb + emb_scale + + # Body + cache = Cache(disable=disable_cache) + for i, block in enumerate(self.blocks): + vid, txt, vid_shape, txt_shape = gradient_checkpointing( + enabled=(self.gradient_checkpointing and self.training), + module=block, + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + vid, vid_shape = self.vid_out(vid, vid_shape, cache) + return NaDiTOutput(vid_sample=vid) diff --git a/models/dit/normalization.py b/models/dit/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..98827a9c71f9fd6e461937774d022b68844aee34 --- /dev/null +++ b/models/dit/normalization.py @@ -0,0 +1,63 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Callable, Optional +from diffusers.models.normalization import RMSNorm +from torch import nn + +# (dim: int, eps: float, elementwise_affine: bool) +norm_layer_type = Callable[[int, float, bool], nn.Module] + + +def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type: + + def _norm_layer(dim: int, eps: float, elementwise_affine: bool): + if norm_type is None: + return nn.Identity() + + if norm_type == "layer": + return nn.LayerNorm( + normalized_shape=dim, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + if norm_type == "rms": + return RMSNorm( + dim=dim, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + if norm_type == "fusedln": + from apex.normalization import FusedLayerNorm + + return FusedLayerNorm( + normalized_shape=dim, + elementwise_affine=elementwise_affine, + eps=eps, + ) + + if norm_type == "fusedrms": + from apex.normalization import FusedRMSNorm + + return FusedRMSNorm( + normalized_shape=dim, + elementwise_affine=elementwise_affine, + eps=eps, + ) + + raise NotImplementedError(f"{norm_type} is not supported") + + return _norm_layer diff --git a/models/dit/patch.py b/models/dit/patch.py new file mode 100644 index 0000000000000000000000000000000000000000..d98158e34a94e0447ed82b92fbfa289bf1a2be1d --- /dev/null +++ b/models/dit/patch.py @@ -0,0 +1,112 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Tuple, Union +import torch +from einops import rearrange +from torch import nn +from torch.nn.modules.utils import _triple + +from common.cache import Cache +from common.distributed.ops import gather_outputs, slice_inputs + +from . import na + + +class PatchIn(nn.Module): + def __init__( + self, + in_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(in_channels * t * h * w, dim) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + vid = self.proj(vid) + return vid + + +class PatchOut(nn.Module): + def __init__( + self, + out_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(dim, out_channels * t * h * w) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = self.proj(vid) + vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + return vid + + +class NaPatchIn(PatchIn): + def forward( + self, + vid: torch.Tensor, # l c + vid_shape: torch.LongTensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + if not (t == h == w == 1): + vid, vid_shape = na.rearrange( + vid, vid_shape, "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w + ) + # slice vid after patching in when using sequence parallelism + vid = slice_inputs(vid, dim=0) + vid = self.proj(vid) + return vid, vid_shape + + +class NaPatchOut(PatchOut): + def forward( + self, + vid: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), + ) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, + ]: + t, h, w = self.patch_size + vid = self.proj(vid) + # gather vid before patching out when enabling sequence parallelism + vid = gather_outputs( + vid, + gather_dim=0, + padding_dim=0, + unpad_shape=vid_shape, + cache=cache.namespace("vid"), + ) + if not (t == h == w == 1): + vid, vid_shape = na.rearrange( + vid, vid_shape, "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w + ) + return vid, vid_shape diff --git a/models/dit/rope.py b/models/dit/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..32a4815a1b349001cb86ea6d752fb4f91f6e655e --- /dev/null +++ b/models/dit/rope.py @@ -0,0 +1,101 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from functools import lru_cache +from typing import Tuple +import torch +from einops import rearrange +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +from torch import nn + +from common.cache import Cache + + +class RotaryEmbeddingBase(nn.Module): + def __init__(self, dim: int, rope_dim: int): + super().__init__() + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="pixel", + max_freq=256, + ) + # 1. Set model.requires_grad_(True) after model creation will make + # the `requires_grad=False` for rope freqs no longer hold. + # 2. Even if we don't set requires_grad_(True) explicitly, + # FSDP is not memory efficient when handling fsdp_wrap + # with mixed requires_grad=True/False. + # With above consideration, it is easier just remove the freqs + # out of nn.Parameters when `learned_freq=False` + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + + @lru_cache(maxsize=128) + def get_axial_freqs(self, *dims): + return self.rope.get_axial_freqs(*dims) + + +class RotaryEmbedding3d(RotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + + def forward( + self, + q: torch.FloatTensor, # b h l d + k: torch.FloatTensor, # b h l d + size: Tuple[int, int, int], + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + T, H, W = size + freqs = self.get_axial_freqs(T, H, W) + q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + q = apply_rotary_emb(freqs, q) + k = apply_rotary_emb(freqs, k) + q = rearrange(q, "b h T H W d -> b h (T H W) d") + k = rearrange(k, "b h T H W d -> b h (T H W) d") + return q, k + + +class NaRotaryEmbedding3d(RotaryEmbedding3d): + def forward( + self, + q: torch.FloatTensor, # L h d + k: torch.FloatTensor, # L h d + shape: torch.LongTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) + q = rearrange(q, "L h d -> h L d") + k = rearrange(k, "L h d -> h L d") + q = apply_rotary_emb(freqs, q.float()).to(q.dtype) + k = apply_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "h L d -> L h d") + k = rearrange(k, "h L d -> L h d") + return q, k + + def get_freqs( + self, + shape: torch.LongTensor, + ) -> torch.Tensor: + freq_list = [] + for f, h, w in shape.tolist(): + freqs = self.get_axial_freqs(f, h, w) + freq_list.append(freqs.view(-1, freqs.size(-1))) + return torch.cat(freq_list, dim=0) diff --git a/models/dit/window.py b/models/dit/window.py new file mode 100644 index 0000000000000000000000000000000000000000..b7475921ae283cf76d82bff7521233c133f54bfd --- /dev/null +++ b/models/dit/window.py @@ -0,0 +1,83 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from math import ceil +from typing import Tuple +import math + +def get_window_op(name: str): + if name == "720pwin_by_size_bysize": + return make_720Pwindows_bysize + if name == "720pswin_by_size_bysize": + return make_shifted_720Pwindows_bysize + raise ValueError(f"Unknown windowing method: {name}") + + +# -------------------------------- Windowing -------------------------------- # +def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + return [ + ( + slice(it * wt, min((it + 1) * wt, t)), + slice(ih * wh, min((ih + 1) * wh, h)), + slice(iw * ww, min((iw + 1) * ww, w)), + ) + for iw in range(nw) + if min((iw + 1) * ww, w) > iw * ww + for ih in range(nh) + if min((ih + 1) * wh, h) > ih * wh + for it in range(nt) + if min((it + 1) * wt, t) > it * wt + ] + +def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. + + st, sh, sw = ( # shift size. + 0.5 if wt < t else 0, + 0.5 if wh < h else 0, + 0.5 if ww < w else 0, + ) + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. + nt, nh, nw = ( # number of window. + nt + 1 if st > 0 else 1, + nh + 1 if sh > 0 else 1, + nw + 1 if sw > 0 else 1, + ) + return [ + ( + slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), + slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), + slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), + ) + for iw in range(nw) + if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) + for ih in range(nh) + if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) + for it in range(nt) + if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) + ] diff --git a/models/dit_v2/attention.py b/models/dit_v2/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..9201fe095778db21ebd3384d163b0ccac4b35664 --- /dev/null +++ b/models/dit_v2/attention.py @@ -0,0 +1,46 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import torch +import torch.nn.functional as F + +from flash_attn import flash_attn_varlen_func + +from torch import nn + +class TorchAttention(nn.Module): + def tflops(self, args, kwargs, output) -> float: + assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs" + q = kwargs.get("query") or args[0] + k = kwargs.get("key") or args[1] + b, h, sq, d = q.shape + b, h, sk, d = k.shape + return b * h * (4 * d * (sq / 1e6) * (sk / 1e6)) + + def forward(self, *args, **kwargs): + return F.scaled_dot_product_attention(*args, **kwargs) + + +class FlashAttentionVarlen(nn.Module): + def tflops(self, args, kwargs, output) -> float: + cu_seqlens_q = kwargs["cu_seqlens_q"] + cu_seqlens_k = kwargs["cu_seqlens_k"] + _, h, d = output.shape + seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6 + seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6 + return h * (4 * d * (seqlens_q * seqlens_k).sum()) + + def forward(self, *args, **kwargs): + kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled() + return flash_attn_varlen_func(*args, **kwargs) diff --git a/models/dit_v2/embedding.py b/models/dit_v2/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..e972244f5767c9f34e5e77bb180ae720ce88b89c --- /dev/null +++ b/models/dit_v2/embedding.py @@ -0,0 +1,62 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Optional, Union +import torch +from diffusers.models.embeddings import get_timestep_embedding +from torch import nn + + +def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): + return emb1 if emb2 is None else emb1 + emb2 + + +class TimeEmbedding(nn.Module): + def __init__( + self, + sinusoidal_dim: int, + hidden_dim: int, + output_dim: int, + ): + super().__init__() + self.sinusoidal_dim = sinusoidal_dim + self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim) + self.proj_hid = nn.Linear(hidden_dim, hidden_dim) + self.proj_out = nn.Linear(hidden_dim, output_dim) + self.act = nn.SiLU() + + def forward( + self, + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.FloatTensor: + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=device, dtype=dtype) + if timestep.ndim == 0: + timestep = timestep[None] + + emb = get_timestep_embedding( + timesteps=timestep, + embedding_dim=self.sinusoidal_dim, + flip_sin_to_cos=False, + downscale_freq_shift=0, + ) + emb = emb.to(dtype) + emb = self.proj_in(emb) + emb = self.act(emb) + emb = self.proj_hid(emb) + emb = self.act(emb) + emb = self.proj_out(emb) + return emb diff --git a/models/dit_v2/mlp.py b/models/dit_v2/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..2d05cb021f3e3c6ac05c0e7ae1aa8a6d29475b87 --- /dev/null +++ b/models/dit_v2/mlp.py @@ -0,0 +1,62 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Optional +import torch +import torch.nn.functional as F +from torch import nn + + +def get_mlp(mlp_type: Optional[str] = "normal"): + if mlp_type == "normal": + return MLP + elif mlp_type == "swiglu": + return SwiGLUMLP + + +class MLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + ): + super().__init__() + self.proj_in = nn.Linear(dim, dim * expand_ratio) + self.act = nn.GELU("tanh") + self.proj_out = nn.Linear(dim * expand_ratio, dim) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_in(x) + x = self.act(x) + x = self.proj_out(x) + return x + + +class SwiGLUMLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + multiple_of: int = 256, + ): + super().__init__() + hidden_dim = int(2 * dim * expand_ratio / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False) + self.proj_out = nn.Linear(hidden_dim, dim, bias=False) + self.proj_in = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) + return x diff --git a/models/dit_v2/mm.py b/models/dit_v2/mm.py new file mode 100644 index 0000000000000000000000000000000000000000..344f89a8fa22b9a5473b8d25f208085a630f0c85 --- /dev/null +++ b/models/dit_v2/mm.py @@ -0,0 +1,74 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple +import torch +from torch import nn + + +@dataclass +class MMArg: + vid: Any + txt: Any + + +def get_args(key: str, args: List[Any]) -> List[Any]: + return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] + + +def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} + + +class MMModule(nn.Module): + def __init__( + self, + module: Callable[..., nn.Module], + *args, + shared_weights: bool = False, + vid_only: bool = False, + **kwargs, + ): + super().__init__() + self.shared_weights = shared_weights + self.vid_only = vid_only + if self.shared_weights: + assert get_args("vid", args) == get_args("txt", args) + assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + else: + self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + self.txt = ( + module(*get_args("txt", args), **get_kwargs("txt", kwargs)) + if not vid_only + else None + ) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + *args, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.vid if not self.shared_weights else self.all + vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) + if not self.vid_only: + txt_module = self.txt if not self.shared_weights else self.all + txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) + return vid, txt diff --git a/models/dit_v2/modulation.py b/models/dit_v2/modulation.py new file mode 100644 index 0000000000000000000000000000000000000000..9e14bb005ef2d0a2c7205f593c483e8862a42858 --- /dev/null +++ b/models/dit_v2/modulation.py @@ -0,0 +1,102 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Callable, List, Optional +import torch +from einops import rearrange +from torch import nn + +from common.cache import Cache +from common.distributed.ops import slice_inputs + +# (dim: int, emb_dim: int) +ada_layer_type = Callable[[int, int], nn.Module] + + +def get_ada_layer(ada_layer: str) -> ada_layer_type: + if ada_layer == "single": + return AdaSingle + raise NotImplementedError(f"{ada_layer} is not supported") + + +def expand_dims(x: torch.Tensor, dim: int, ndim: int): + """ + Expand tensor "x" to "ndim" by adding empty dims at "dim". + Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d). + """ + shape = x.shape + shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] + return x.reshape(shape) + + +class AdaSingle(nn.Module): + def __init__( + self, + dim: int, + emb_dim: int, + layers: List[str], + modes: List[str] = ["in", "out"], + ): + assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + super().__init__() + self.dim = dim + self.emb_dim = emb_dim + self.layers = layers + for l in layers: + if "in" in modes: + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5)) + self.register_parameter( + f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1) + ) + if "out" in modes: + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5)) + + def forward( + self, + hid: torch.FloatTensor, # b ... c + emb: torch.FloatTensor, # b d + layer: str, + mode: str, + cache: Cache = Cache(disable=True), + branch_tag: str = "", + hid_len: Optional[torch.LongTensor] = None, # b + ) -> torch.FloatTensor: + idx = self.layers.index(layer) + emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = expand_dims(emb, 1, hid.ndim + 1) + + if hid_len is not None: + emb = cache( + f"emb_repeat_{idx}_{branch_tag}", + lambda: slice_inputs( + torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]), + dim=0, + ), + ) + + shiftA, scaleA, gateA = emb.unbind(-1) + shiftB, scaleB, gateB = ( + getattr(self, f"{layer}_shift", None), + getattr(self, f"{layer}_scale", None), + getattr(self, f"{layer}_gate", None), + ) + + if mode == "in": + return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) + if mode == "out": + return hid.mul_(gateA + gateB) + raise NotImplementedError + + def extra_repr(self) -> str: + return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}" \ No newline at end of file diff --git a/models/dit_v2/na.py b/models/dit_v2/na.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbd546c4705b3b9c7c19a9823f9d113a0447616 --- /dev/null +++ b/models/dit_v2/na.py @@ -0,0 +1,241 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from itertools import chain +from typing import Callable, Dict, List, Tuple +import einops +import torch + + +def flatten( + hid: List[torch.FloatTensor], # List of (*** c) +) -> Tuple[ + torch.FloatTensor, # (L c) + torch.LongTensor, # (b n) +]: + assert len(hid) > 0 + shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + hid = torch.cat([x.flatten(0, -2) for x in hid]) + return hid, shape + + +def unflatten( + hid: torch.FloatTensor, # (L c) or (L ... c) + hid_shape: torch.LongTensor, # (b n) +) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) + hid_len = hid_shape.prod(-1) + hid = hid.split(hid_len.tolist()) + hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] + return hid + + +def concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + return torch.cat(list(chain(*zip(vid, txt)))) + + +def concat_idx( + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) + src_idx = torch.argsort(tgt_idx) + return ( + lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), + lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), + ) + + +def unconcat( + all: torch.FloatTensor, # (L ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + torch.FloatTensor, # (VL ... c) + torch.FloatTensor, # (TL ... c) +]: + interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) + all = all.split(interleave_len) + vid = torch.cat(all[0::2]) + txt = torch.cat(all[1::2]) + return vid, txt + + +def repeat_concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: List, # (n) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + txt = [[x] * n for x, n in zip(txt, txt_repeat)] + txt = list(chain(*txt)) + return torch.cat(list(chain(*zip(vid, txt)))) + + +def repeat_concat_idx( + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: torch.LongTensor, # (n) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + txt_repeat_list = txt_repeat.tolist() + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + src_idx = torch.argsort(tgt_idx) + txt_idx_len = len(tgt_idx) - len(vid_idx) + repeat_txt_len = (txt_len * txt_repeat).tolist() + + def unconcat_coalesce(all): + """ + Un-concat vid & txt, and coalesce the repeated txt. + e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] + txt [9 10] + repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] + 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] + split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] + 2. reshape & mean for each sample to coalesce the repeated txt. + """ + vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) + txt_out_coalesced = [] + for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): + txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) + txt_out_coalesced.append(txt) + return vid_out, torch.cat(txt_out_coalesced) + + # Note: Backward of torch.index_select is non-deterministic when existing repeated index, + # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here. + return ( + lambda vid, txt: torch.cat([vid, txt])[tgt_idx], + lambda all: unconcat_coalesce(all), + ) + + +def rearrange( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, int], +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) + + +def rearrange_idx( + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, int], +) -> Tuple[Callable, Callable, torch.LongTensor]: + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + ) + + +def repeat( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, torch.LongTensor], # (b) +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + hid = unflatten(hid, hid_shape) + kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] + return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) + + +def pack( + samples: List[torch.Tensor], # List of (h w c). +) -> Tuple[ + List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)] + List[List[int]], # reversal indices. +]: + batches = {} + indices = {} + for i, sample in enumerate(samples): + shape = sample.shape + batches[shape] = batches.get(shape, []) + indices[shape] = indices.get(shape, []) + batches[shape].append(sample) + indices[shape].append(i) + + batches = list(map(torch.stack, batches.values())) + indices = list(indices.values()) + return batches, indices + + +def unpack( + batches: List[torch.Tensor], + indices: List[List[int]], +) -> List[torch.Tensor]: + samples = [None] * (max(chain(*indices)) + 1) + for batch, index in zip(batches, indices): + for sample, i in zip(batch.unbind(), index): + samples[i] = sample + return samples + + +def window( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid = unflatten(hid, hid_shape) + hid = list(map(window_fn, hid)) + hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid, hid_shape = flatten(list(chain(*hid))) + return hid, hid_shape, hid_windows + + +def window_idx( + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + tgt_windows, + ) diff --git a/models/dit_v2/nablocks/__init__.py b/models/dit_v2/nablocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1a9da26ef760575192042ea32b01bd9cd1a267d --- /dev/null +++ b/models/dit_v2/nablocks/__init__.py @@ -0,0 +1,26 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from .mmsr_block import NaMMSRTransformerBlock + + +nadit_blocks = { + "mmdit_sr": NaMMSRTransformerBlock, +} + + +def get_nablock(block_type: str): + if block_type in nadit_blocks: + return nadit_blocks[block_type] + raise NotImplementedError(f"{block_type} is not supported") diff --git a/models/dit_v2/nablocks/attention/__init__.py b/models/dit_v2/nablocks/attention/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a7561025245d888d26ade38f25668efb216cd907 --- /dev/null +++ b/models/dit_v2/nablocks/attention/__init__.py @@ -0,0 +1,25 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from .mmattn import NaMMAttention + +attns = { + "mm_full": NaMMAttention, +} + + +def get_attn(attn_type: str): + if attn_type in attns: + return attns[attn_type] + raise NotImplementedError(f"{attn_type} is not supported") diff --git a/models/dit_v2/nablocks/attention/mmattn.py b/models/dit_v2/nablocks/attention/mmattn.py new file mode 100644 index 0000000000000000000000000000000000000000..4fea9cb9c6fa2f82dd1aba46d658a04a19a11305 --- /dev/null +++ b/models/dit_v2/nablocks/attention/mmattn.py @@ -0,0 +1,266 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Optional, Tuple, Union +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _triple + +from common.cache import Cache +from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv + +from ... import na +from ...attention import FlashAttentionVarlen +from ...mm import MMArg, MMModule +from ...normalization import norm_layer_type +from ...rope import get_na_rope +from ...window import get_window_op +from itertools import chain + + +class NaMMAttention(nn.Module): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_norm: norm_layer_type, + qk_norm_eps: float, + rope_type: Optional[str], + rope_dim: int, + shared_weights: bool, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + inner_dim = heads * head_dim + qkv_dim = inner_dim * 3 + self.head_dim = head_dim + self.proj_qkv = MMModule( + nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights + ) + self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_weights) + self.norm_q = MMModule( + qk_norm, + dim=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + ) + self.norm_k = MMModule( + qk_norm, + dim=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + ) + + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) + self.attn = FlashAttentionVarlen() + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + vid_qkv = gather_seq_scatter_heads_qkv( + vid_qkv, + seq_dim=0, + qkv_shape=vid_shape, + cache=cache.namespace("vid"), + ) + txt_qkv = gather_seq_scatter_heads_qkv( + txt_qkv, + seq_dim=0, + qkv_shape=txt_shape, + cache=cache.namespace("txt"), + ) + vid_qkv = rearrange(vid_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + if self.rope: + if self.rope.mm: + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, vid_shape, cache) + + vid_len = cache("vid_len", lambda: vid_shape.prod(-1)) + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + all_len = cache("all_len", lambda: vid_len + txt_len) + + concat, unconcat = cache("mm_pnp", lambda: na.concat_idx(vid_len, txt_len)) + + attn = self.attn( + q=concat(vid_q, txt_q).bfloat16(), + k=concat(vid_k, txt_k).bfloat16(), + v=concat(vid_v, txt_v).bfloat16(), + cu_seqlens_q=cache("mm_seqlens", lambda: F.pad(all_len.cumsum(0), (1, 0)).int()), + cu_seqlens_k=cache("mm_seqlens", lambda: F.pad(all_len.cumsum(0), (1, 0)).int()), + max_seqlen_q=cache("mm_maxlen", lambda: all_len.max().item()), + max_seqlen_k=cache("mm_maxlen", lambda: all_len.max().item()), + ).type_as(vid_q) + + attn = rearrange(attn, "l h d -> l (h d)") + vid_out, txt_out = unconcat(attn) + vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0) + txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + return vid_out, txt_out + + +class NaSwinAttention(NaMMAttention): + def __init__( + self, + *args, + window: Union[int, Tuple[int, int, int]], + window_method: str, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.window = _triple(window) + self.window_method = window_method + assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + + self.window_op = get_window_op(window_method) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + vid_qkv = gather_seq_scatter_heads_qkv( + vid_qkv, + seq_dim=0, + qkv_shape=vid_shape, + cache=cache.namespace("vid"), + ) + txt_qkv = gather_seq_scatter_heads_qkv( + txt_qkv, + seq_dim=0, + qkv_shape=txt_shape, + cache=cache.namespace("txt"), + ) + + # re-org the input seq for window attn + cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") + + def make_window(x: torch.Tensor): + t, h, w, _ = x.shape + window_slices = self.window_op((t, h, w), self.window) + return [x[st, sh, sw] for (st, sh, sw) in window_slices] + + window_partition, window_reverse, window_shape, window_count = cache_win( + "win_transform", + lambda: na.window_idx(vid_shape, make_window), + ) + vid_qkv_win = window_partition(vid_qkv) + + vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + + vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + concat_win, unconcat_win = cache_win( + "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count) + ) + + # window rope + if self.rope: + if self.rope.mm: + # repeat text q and k for window mmrope + _, num_h, _ = txt_q.shape + txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = na.unflatten(txt_q_repeat, txt_shape) + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = list(chain(*txt_q_repeat)) + txt_q_repeat, txt_shape_repeat = na.flatten(txt_q_repeat) + txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + + txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = na.unflatten(txt_k_repeat, txt_shape) + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = list(chain(*txt_k_repeat)) + txt_k_repeat, _ = na.flatten(txt_k_repeat) + txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + + out = self.attn( + q=concat_win(vid_q, txt_q).bfloat16(), + k=concat_win(vid_k, txt_k).bfloat16(), + v=concat_win(vid_v, txt_v).bfloat16(), + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() + ), + max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()), + max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()), + ).type_as(vid_q) + + # text pooling + vid_out, txt_out = unconcat_win(out) + + vid_out = rearrange(vid_out, "l h d -> l (h d)") + txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = window_reverse(vid_out) + + vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0) + txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + + return vid_out, txt_out \ No newline at end of file diff --git a/models/dit_v2/nablocks/mmsr_block.py b/models/dit_v2/nablocks/mmsr_block.py new file mode 100644 index 0000000000000000000000000000000000000000..407c5b3eac3d0e572a148283ac322cf50a77d8a4 --- /dev/null +++ b/models/dit_v2/nablocks/mmsr_block.py @@ -0,0 +1,119 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Tuple +import torch +import torch.nn as nn + +# from ..cache import Cache +from common.cache import Cache + +from .attention.mmattn import NaSwinAttention +from ..mm import MMArg +from ..modulation import ada_layer_type +from ..normalization import norm_layer_type +from ..mm import MMArg, MMModule +from ..mlp import get_mlp + + +class NaMMSRTransformerBlock(nn.Module): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm: norm_layer_type, + norm_eps: float, + ada: ada_layer_type, + qk_bias: bool, + qk_norm: norm_layer_type, + mlp_type: str, + shared_weights: bool, + rope_type: str, + rope_dim: int, + is_last_layer: bool, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights,) + + self.attn = NaSwinAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + rope_type=rope_type, + rope_dim=rope_dim, + shared_weights=shared_weights, + window=kwargs.pop("window", None), + window_method=kwargs.pop("window_method", None), + ) + + self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer) + self.mlp = MMModule( + get_mlp(mlp_type), + dim=dim, + expand_ratio=expand_ratio, + shared_weights=shared_weights, + vid_only=is_last_layer + ) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer) + self.is_last_layer = is_last_layer + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + emb: torch.FloatTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.LongTensor, + torch.LongTensor, + ]: + hid_len = MMArg( + cache("vid_len", lambda: vid_shape.prod(-1)), + cache("txt_len", lambda: txt_shape.prod(-1)), + ) + ada_kwargs = { + "emb": emb, + "hid_len": hid_len, + "cache": cache, + "branch_tag": MMArg("vid", "txt"), + } + + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp, vid_shape, txt_shape diff --git a/models/dit_v2/nadit.py b/models/dit_v2/nadit.py new file mode 100644 index 0000000000000000000000000000000000000000..fe9d7f85fa38e330069d1888cdd996468c719144 --- /dev/null +++ b/models/dit_v2/nadit.py @@ -0,0 +1,246 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union, Callable +import torch +from torch import nn + +from common.cache import Cache +from common.distributed.ops import slice_inputs + +from . import na +from .embedding import TimeEmbedding +from .modulation import get_ada_layer +from .nablocks import get_nablock +from .normalization import get_norm_layer +from .patch import get_na_patch_layers + +# Fake func, no checkpointing is required for inference +def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): + return module(*args, **kwargs) + +@dataclass +class NaDiTOutput: + vid_sample: torch.Tensor + + +class NaDiT(nn.Module): + """ + Native Resolution Diffusion Transformer (NaDiT) + """ + + gradient_checkpointing = False + + def __init__( + self, + vid_in_channels: int, + vid_out_channels: int, + vid_dim: int, + txt_in_dim: Union[int, List[int]], + txt_dim: Optional[int], + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm: Optional[str], + norm_eps: float, + ada: str, + qk_bias: bool, + qk_norm: Optional[str], + patch_size: Union[int, Tuple[int, int, int]], + num_layers: int, + block_type: Union[str, Tuple[str]], + mm_layers: Union[int, Tuple[bool]], + mlp_type: str = "normal", + patch_type: str = "v1", + rope_type: Optional[str] = "rope3d", + rope_dim: Optional[int] = None, + window: Optional[Tuple] = None, + window_method: Optional[Tuple[str]] = None, + msa_type: Optional[Tuple[str]] = None, + mca_type: Optional[Tuple[str]] = None, + txt_in_norm: Optional[str] = None, + txt_in_norm_scale_factor: int = 0.01, + txt_proj_type: Optional[str] = "linear", + vid_out_norm: Optional[str] = None, + **kwargs, + ): + ada = get_ada_layer(ada) + norm = get_norm_layer(norm) + qk_norm = get_norm_layer(qk_norm) + rope_dim = rope_dim if rope_dim is not None else head_dim // 2 + if isinstance(block_type, str): + block_type = [block_type] * num_layers + elif len(block_type) != num_layers: + raise ValueError("The ``block_type`` list should equal to ``num_layers``.") + super().__init__() + NaPatchIn, NaPatchOut = get_na_patch_layers(patch_type) + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + ) + if not isinstance(txt_in_dim, int): + self.txt_in = nn.ModuleList([]) + for in_dim in txt_in_dim: + txt_norm_layer = get_norm_layer(txt_in_norm)(txt_dim, norm_eps, True) + if txt_proj_type == "linear": + txt_proj_layer = nn.Linear(in_dim, txt_dim) + else: + txt_proj_layer = nn.Sequential( + nn.Linear(in_dim, in_dim), nn.GELU("tanh"), nn.Linear(in_dim, txt_dim) + ) + torch.nn.init.constant_(txt_norm_layer.weight, txt_in_norm_scale_factor) + self.txt_in.append( + nn.Sequential( + txt_proj_layer, + txt_norm_layer, + ) + ) + else: + self.txt_in = ( + nn.Linear(txt_in_dim, txt_dim) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=256, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + if window_method is None or isinstance(window_method, str): + window_method = [window_method] * num_layers + + if msa_type is None or isinstance(msa_type, str): + msa_type = [msa_type] * num_layers + if mca_type is None or isinstance(mca_type, str): + mca_type = [mca_type] * num_layers + + self.blocks = nn.ModuleList( + [ + get_nablock(block_type[i])( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_norm=qk_norm, + shared_weights=not ( + (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] + ), + mlp_type=mlp_type, + window=window[i], + window_method=window_method[i], + msa_type=msa_type[i], + mca_type=mca_type[i], + rope_type=rope_type, + rope_dim=rope_dim, + is_last_layer=(i == num_layers - 1), + **kwargs, + ) + for i in range(num_layers) + ] + ) + + self.vid_out_norm = None + if vid_out_norm is not None: + self.vid_out_norm = get_norm_layer(vid_out_norm)( + dim=vid_dim, + eps=norm_eps, + elementwise_affine=True, + ) + self.vid_out_ada = ada( + dim=vid_dim, + emb_dim=emb_dim, + layers=["out"], + modes=["in"], + ) + + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + ) + + def set_gradient_checkpointing(self, enable: bool): + self.gradient_checkpointing = enable + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: Union[torch.FloatTensor, List[torch.FloatTensor]], # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: Union[torch.LongTensor, List[torch.LongTensor]], # b 1 + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b + disable_cache: bool = False, # for test + ): + cache = Cache(disable=disable_cache) + + # slice vid after patching in when using sequence parallelism + if isinstance(txt, list): + assert isinstance(self.txt_in, nn.ModuleList) + txt = [ + na.unflatten(fc(i), s) for fc, i, s in zip(self.txt_in, txt, txt_shape) + ] # B L D + txt, txt_shape = na.flatten([torch.cat(t, dim=0) for t in zip(*txt)]) + txt = slice_inputs(txt, dim=0) + else: + txt = slice_inputs(txt, dim=0) + txt = self.txt_in(txt) + + # Video input. + # Sequence parallel slicing is done inside patching class. + vid, vid_shape = self.vid_in(vid, vid_shape, cache) + + # Embedding input. + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + + # Body + for i, block in enumerate(self.blocks): + vid, txt, vid_shape, txt_shape = gradient_checkpointing( + enabled=(self.gradient_checkpointing and self.training), + module=block, + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + # Video output norm. + if self.vid_out_norm: + vid = self.vid_out_norm(vid) + vid = self.vid_out_ada( + vid, + emb=emb, + layer="out", + mode="in", + hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), + cache=cache, + branch_tag="vid", + ) + + # Video output. + vid, vid_shape = self.vid_out(vid, vid_shape, cache) + return NaDiTOutput(vid_sample=vid) diff --git a/models/dit_v2/normalization.py b/models/dit_v2/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..98827a9c71f9fd6e461937774d022b68844aee34 --- /dev/null +++ b/models/dit_v2/normalization.py @@ -0,0 +1,63 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Callable, Optional +from diffusers.models.normalization import RMSNorm +from torch import nn + +# (dim: int, eps: float, elementwise_affine: bool) +norm_layer_type = Callable[[int, float, bool], nn.Module] + + +def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type: + + def _norm_layer(dim: int, eps: float, elementwise_affine: bool): + if norm_type is None: + return nn.Identity() + + if norm_type == "layer": + return nn.LayerNorm( + normalized_shape=dim, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + if norm_type == "rms": + return RMSNorm( + dim=dim, + eps=eps, + elementwise_affine=elementwise_affine, + ) + + if norm_type == "fusedln": + from apex.normalization import FusedLayerNorm + + return FusedLayerNorm( + normalized_shape=dim, + elementwise_affine=elementwise_affine, + eps=eps, + ) + + if norm_type == "fusedrms": + from apex.normalization import FusedRMSNorm + + return FusedRMSNorm( + normalized_shape=dim, + elementwise_affine=elementwise_affine, + eps=eps, + ) + + raise NotImplementedError(f"{norm_type} is not supported") + + return _norm_layer diff --git a/models/dit_v2/patch/__init__.py b/models/dit_v2/patch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3c9783163f1e671f2d946dfad39ca33b12843d --- /dev/null +++ b/models/dit_v2/patch/__init__.py @@ -0,0 +1,19 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +def get_na_patch_layers(patch_type="v1"): + assert patch_type in ["v1"] + if patch_type == "v1": + from .patch_v1 import NaPatchIn, NaPatchOut + return NaPatchIn, NaPatchOut diff --git a/models/dit_v2/patch/patch_v1.py b/models/dit_v2/patch/patch_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..0231bc0905e70e1fc702fe088fb2d0dac30fcc71 --- /dev/null +++ b/models/dit_v2/patch/patch_v1.py @@ -0,0 +1,127 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Tuple, Union +import torch +from einops import rearrange +from torch import nn +from torch.nn.modules.utils import _triple + +from common.cache import Cache +from common.distributed.ops import gather_outputs, slice_inputs + +from .. import na + + +class PatchIn(nn.Module): + def __init__( + self, + in_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(in_channels * t * h * w, dim) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + if t > 1: + assert vid.size(2) % t == 1 + vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) + vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + vid = self.proj(vid) + return vid + + +class PatchOut(nn.Module): + def __init__( + self, + out_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = nn.Linear(dim, out_channels * t * h * w) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = self.proj(vid) + vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + if t > 1: + vid = vid[:, :, (t - 1) :] + return vid + + +class NaPatchIn(PatchIn): + def forward( + self, + vid: torch.Tensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> torch.Tensor: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) + t, h, w = self.patch_size + if not (t == h == w == 1): + vid = na.unflatten(vid, vid_shape) + for i in range(len(vid)): + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) + vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + vid, vid_shape = na.flatten(vid) + + # slice vid after patching in when using sequence parallelism + vid = slice_inputs(vid, dim=0) + vid = self.proj(vid) + return vid, vid_shape + + +class NaPatchOut(PatchOut): + def forward( + self, + vid: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, + ]: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache.get("vid_shape_before_patchify") + + t, h, w = self.patch_size + vid = self.proj(vid) + # gather vid before patching out when enabling sequence parallelism + vid = gather_outputs( + vid, gather_dim=0, padding_dim=0, unpad_shape=vid_shape, cache=cache.namespace("vid") + ) + if not (t == h == w == 1): + vid = na.unflatten(vid, vid_shape) + for i in range(len(vid)): + vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] + vid, vid_shape = na.flatten(vid) + + return vid, vid_shape diff --git a/models/dit_v2/rope.py b/models/dit_v2/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..ceb5458ba2829417a93124b9e06a86b74a523765 --- /dev/null +++ b/models/dit_v2/rope.py @@ -0,0 +1,150 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from functools import lru_cache +from typing import Optional, Tuple +import torch +from einops import rearrange +from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb +from torch import nn + +from common.cache import Cache + + +class RotaryEmbeddingBase(nn.Module): + def __init__(self, dim: int, rope_dim: int): + super().__init__() + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="pixel", + max_freq=256, + ) + # 1. Set model.requires_grad_(True) after model creation will make + # the `requires_grad=False` for rope freqs no longer hold. + # 2. Even if we don't set requires_grad_(True) explicitly, + # FSDP is not memory efficient when handling fsdp_wrap + # with mixed requires_grad=True/False. + # With above consideration, it is easier just remove the freqs + # out of nn.Parameters when `learned_freq=False` + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + + @lru_cache(maxsize=128) + def get_axial_freqs(self, *dims): + return self.rope.get_axial_freqs(*dims) + + +class RotaryEmbedding3d(RotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + self.mm = False + + def forward( + self, + q: torch.FloatTensor, # b h l d + k: torch.FloatTensor, # b h l d + size: Tuple[int, int, int], + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + T, H, W = size + freqs = self.get_axial_freqs(T, H, W) + q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + q = apply_rotary_emb(freqs, q.float()).to(q.dtype) + k = apply_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "b h T H W d -> b h (T H W) d") + k = rearrange(k, "b h T H W d -> b h (T H W) d") + return q, k + + +class MMRotaryEmbeddingBase(RotaryEmbeddingBase): + def __init__(self, dim: int, rope_dim: int): + super().__init__(dim, rope_dim) + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="lang", + theta=10000, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + self.mm = True + + +class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + + def forward( + self, + vid_q: torch.FloatTensor, # L h d + vid_k: torch.FloatTensor, # L h d + vid_shape: torch.LongTensor, # B 3 + txt_q: torch.FloatTensor, # L h d + txt_k: torch.FloatTensor, # L h d + txt_shape: torch.LongTensor, # B 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_freqs, txt_freqs = cache( + "mmrope_freqs_3d", + lambda: self.get_freqs(vid_shape, txt_shape), + ) + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) + vid_k = apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) + vid_q = rearrange(vid_q, "h L d -> L h d") + vid_k = rearrange(vid_k, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) + txt_k = apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) + txt_q = rearrange(txt_q, "h L d -> L h d") + txt_k = rearrange(txt_k, "h L d -> L h d") + return vid_q, vid_k, txt_q, txt_k + + def get_freqs( + self, + vid_shape: torch.LongTensor, + txt_shape: torch.LongTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + ]: + vid_freqs = self.get_axial_freqs(1024, 128, 128) + txt_freqs = self.get_axial_freqs(1024) + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) + txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) + + +def get_na_rope(rope_type: Optional[str], dim: int): + if rope_type is None: + return None + if rope_type == "mmrope3d": + return NaMMRotaryEmbedding3d(dim=dim) + raise NotImplementedError(f"{rope_type} is not supported.") diff --git a/models/dit_v2/window.py b/models/dit_v2/window.py new file mode 100644 index 0000000000000000000000000000000000000000..b7475921ae283cf76d82bff7521233c133f54bfd --- /dev/null +++ b/models/dit_v2/window.py @@ -0,0 +1,83 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from math import ceil +from typing import Tuple +import math + +def get_window_op(name: str): + if name == "720pwin_by_size_bysize": + return make_720Pwindows_bysize + if name == "720pswin_by_size_bysize": + return make_shifted_720Pwindows_bysize + raise ValueError(f"Unknown windowing method: {name}") + + +# -------------------------------- Windowing -------------------------------- # +def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + return [ + ( + slice(it * wt, min((it + 1) * wt, t)), + slice(ih * wh, min((ih + 1) * wh, h)), + slice(iw * ww, min((iw + 1) * ww, w)), + ) + for iw in range(nw) + if min((iw + 1) * ww, w) > iw * ww + for ih in range(nh) + if min((ih + 1) * wh, h) > ih * wh + for it in range(nt) + if min((it + 1) * wt, t) > it * wt + ] + +def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt((45 * 80) / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, 30) / resized_nt) # window size. + + st, sh, sw = ( # shift size. + 0.5 if wt < t else 0, + 0.5 if wh < h else 0, + 0.5 if ww < w else 0, + ) + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. + nt, nh, nw = ( # number of window. + nt + 1 if st > 0 else 1, + nh + 1 if sh > 0 else 1, + nw + 1 if sw > 0 else 1, + ) + return [ + ( + slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), + slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), + slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), + ) + for iw in range(nw) + if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) + for ih in range(nh) + if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) + for it in range(nt) + if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) + ] diff --git a/models/video_vae_v3/modules/attn_video_vae.py b/models/video_vae_v3/modules/attn_video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..edaf817452af1df8c85746f07d017e8802d989b0 --- /dev/null +++ b/models/video_vae_v3/modules/attn_video_vae.py @@ -0,0 +1,1345 @@ +# Copyright (c) 2023 HuggingFace Team +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache License, Version 2.0 (the "License") +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 +# +# Original file was released under Apache License, Version 2.0 (the "License"), with the full license text +# available at http://www.apache.org/licenses/LICENSE-2.0. +# +# This modified file is released under the same license. + + +from contextlib import nullcontext +from typing import Literal, Optional, Tuple, Union +import diffusers +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.attention_processor import Attention, SpatialNorm +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.models.downsampling import Downsample2D +from diffusers.models.lora import LoRACompatibleConv +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.resnet import ResnetBlock2D +from diffusers.models.unets.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D +from diffusers.models.upsampling import Upsample2D +from diffusers.utils import is_torch_version +from diffusers.utils.accelerate_utils import apply_forward_hook +from einops import rearrange + +from common.distributed.advanced import get_sequence_parallel_world_size +from common.logger import get_logger +from models.video_vae_v3.modules.causal_inflation_lib import ( + InflatedCausalConv3d, + causal_norm_wrapper, + init_causal_conv3d, + remove_head, +) +from models.video_vae_v3.modules.context_parallel_lib import ( + causal_conv_gather_outputs, + causal_conv_slice_inputs, +) +from models.video_vae_v3.modules.global_config import set_norm_limit +from models.video_vae_v3.modules.types import ( + CausalAutoencoderOutput, + CausalDecoderOutput, + CausalEncoderOutput, + MemoryState, + _inflation_mode_t, + _memory_device_t, + _receptive_field_t, +) + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +class Upsample3D(Upsample2D): + """A 3D upsampling layer with an optional convolution.""" + + def __init__( + self, + *args, + inflation_mode: _inflation_mode_t = "tail", + temporal_up: bool = False, + spatial_up: bool = True, + slicing: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + conv = self.conv if self.name == "conv" else self.Conv2d_0 + + assert type(conv) is not nn.ConvTranspose2d + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = init_causal_conv3d( + self.channels, + self.out_channels, + 3, + padding=1, + inflation_mode=inflation_mode, + ) + + self.temporal_up = temporal_up + self.spatial_up = spatial_up + self.temporal_ratio = 2 if temporal_up else 1 + self.spatial_ratio = 2 if spatial_up else 1 + self.slicing = slicing + + assert not self.interpolate + # [Override] MAGViT v2 implementation + if not self.interpolate: + upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio + self.upscale_conv = nn.Conv3d( + self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 + ) + identity = ( + torch.eye(self.channels) + .repeat(upscale_ratio, 1) + .reshape_as(self.upscale_conv.weight) + ) + self.upscale_conv.weight.data.copy_(identity) + nn.init.zeros_(self.upscale_conv.bias) + + if self.name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + output_size: Optional[int] = None, + memory_state: MemoryState = MemoryState.DISABLED, + **kwargs, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + if self.slicing: + split_size = hidden_states.size(2) // 2 + hidden_states = list( + hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) + ) + else: + hidden_states = [hidden_states] + + for i in range(len(hidden_states)): + hidden_states[i] = self.upscale_conv(hidden_states[i]) + hidden_states[i] = rearrange( + hidden_states[i], + "b (x y z c) f h w -> b c (f z) (h x) (w y)", + x=self.spatial_ratio, + y=self.spatial_ratio, + z=self.temporal_ratio, + ) + + # [Overridden] For causal temporal conv + if self.temporal_up and memory_state != MemoryState.ACTIVE: + hidden_states[0] = remove_head(hidden_states[0]) + + if not self.slicing: + hidden_states = hidden_states[0] + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states, memory_state=memory_state) + else: + hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) + + if not self.slicing: + return hidden_states + else: + return torch.cat(hidden_states, dim=2) + + +class Downsample3D(Downsample2D): + """A 3D downsampling layer with an optional convolution.""" + + def __init__( + self, + *args, + inflation_mode: _inflation_mode_t = "tail", + spatial_down: bool = False, + temporal_down: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + conv = self.conv + self.temporal_down = temporal_down + self.spatial_down = spatial_down + + self.temporal_ratio = 2 if temporal_down else 1 + self.spatial_ratio = 2 if spatial_down else 1 + + self.temporal_kernel = 3 if temporal_down else 1 + self.spatial_kernel = 3 if spatial_down else 1 + + if type(conv) in [nn.Conv2d, LoRACompatibleConv]: + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = init_causal_conv3d( + self.channels, + self.out_channels, + kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + padding=( + 1 if self.temporal_down else 0, + self.padding if self.spatial_down else 0, + self.padding if self.spatial_down else 0, + ), + inflation_mode=inflation_mode, + ) + elif type(conv) is nn.AvgPool2d: + assert self.channels == self.out_channels + conv = nn.AvgPool3d( + kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + ) + else: + raise NotImplementedError + + if self.name == "conv": + self.Conv2d_0 = conv + self.conv = conv + else: + self.conv = conv + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state: MemoryState = MemoryState.DISABLED, + **kwargs, + ) -> torch.FloatTensor: + + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv and self.padding == 0 and self.spatial_down: + pad = (0, 1, 0, 1) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states, memory_state=memory_state) + + return hidden_states + + +class ResnetBlock3D(ResnetBlock2D): + def __init__( + self, + *args, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + slicing: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.conv1 = init_causal_conv3d( + self.in_channels, + self.out_channels, + kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), + stride=1, + padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), + inflation_mode=inflation_mode, + ) + + self.conv2 = init_causal_conv3d( + self.out_channels, + self.conv2.out_channels, + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + if self.up: + assert type(self.upsample) is Upsample2D + self.upsample = Upsample3D( + self.in_channels, + use_conv=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + elif self.down: + assert type(self.downsample) is Downsample2D + self.downsample = Downsample3D( + self.in_channels, + use_conv=False, + padding=1, + name="op", + inflation_mode=inflation_mode, + ) + + if self.use_in_shortcut: + self.conv_shortcut = init_causal_conv3d( + self.in_channels, + self.conv_shortcut.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=(self.conv_shortcut.bias is not None), + inflation_mode=inflation_mode, + ) + + def forward( + self, input_tensor, temb, memory_state: MemoryState = MemoryState.DISABLED, **kwargs + ): + hidden_states = input_tensor + + hidden_states = causal_norm_wrapper(self.norm1, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + # upsample_nearest_nhwc fails with large batch sizes. + # see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor, memory_state=memory_state) + hidden_states = self.upsample(hidden_states, memory_state=memory_state) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor, memory_state=memory_state) + hidden_states = self.downsample(hidden_states, memory_state=memory_state) + + hidden_states = self.conv1(hidden_states, memory_state=memory_state) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = causal_norm_wrapper(self.norm2, hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, memory_state=memory_state) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class DownEncoderBlock3D(DownEncoderBlock2D): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_down: bool = True, + spatial_down: bool = True, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + num_layers=num_layers, + resnet_eps=resnet_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_pre_norm=resnet_pre_norm, + output_scale_factor=output_scale_factor, + add_downsample=add_downsample, + downsample_padding=downsample_padding, + ) + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + # [Override] Replace module. + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + temporal_down=temporal_down, + spatial_down=spatial_down, + inflation_mode=inflation_mode, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state: MemoryState = MemoryState.DISABLED, + **kwargs, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UpDecoderBlock3D(UpDecoderBlock2D): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up: bool = True, + spatial_up: bool = True, + slicing: bool = False, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + num_layers=num_layers, + resnet_eps=resnet_eps, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_pre_norm=resnet_pre_norm, + output_scale_factor=output_scale_factor, + add_upsample=add_upsample, + temb_channels=temb_channels, + ) + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + slicing=slicing, + ) + ) + + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_upsample: + # [Override] Replace module & use learnable upsample + self.upsamplers = nn.ModuleList( + [ + Upsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + temporal_up=temporal_up, + spatial_up=spatial_up, + interpolate=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + ] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + memory_state: MemoryState = MemoryState.DISABLED, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. " + f"Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=( + resnet_groups if resnet_time_scale_shift == "default" else None + ), + spatial_norm_dim=( + temb_channels if resnet_time_scale_shift == "spatial" else None + ), + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, memory_state: MemoryState = MemoryState.DISABLED): + video_length, frame_height, frame_width = hidden_states.size()[-3:] + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn(hidden_states, temb=temb) + hidden_states = rearrange( + hidden_states, "(b f) c h w -> b c f h w", f=video_length + ) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state) + + return hidden_states + + +class Encoder3D(nn.Module): + r""" + [Override] override most logics to support extra condition input and causal conv + + The `Encoder` layer of a variational autoencoder that encodes + its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. + See `~diffusers.models.unet_2d_blocks.get_down_block` + for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. + See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + # [Override] add extra_cond_dim, temporal down num + temporal_down_num: int = 2, + extra_cond_dim: int = None, + gradient_checkpoint: bool = False, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_down_num = temporal_down_num + + self.conv_in = init_causal_conv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + self.extra_cond_dim = extra_cond_dim + + self.conv_extra_cond = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + # [Override] to support temporal down block design + is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 + # Note: take the last ones + + assert down_block_type == "DownEncoderBlock3D" + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + # Note: Don't know why set it as 0 + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temporal_down=is_temporal_down_block, + spatial_down=True, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.down_blocks.append(down_block) + + def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + self.conv_extra_cond.append( + zero_module( + nn.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) + ) + if self.extra_cond_dim is not None and self.extra_cond_dim > 0 + else None + ) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = init_causal_conv3d( + block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + def forward( + self, + sample: torch.FloatTensor, + extra_cond=None, + memory_state: MemoryState = MemoryState.DISABLED, + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + sample = self.conv_in(sample, memory_state=memory_state) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, memory_state, use_reentrant=False + ) + if extra_block is not None: + sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample, memory_state=memory_state) + + # sample = torch.utils.checkpoint.checkpoint( + # create_custom_forward(self.mid_block), sample, use_reentrant=False + # ) + + else: + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = down_block(sample, memory_state=memory_state) + if extra_block is not None: + sample = sample + F.interpolate(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state=memory_state) + + return sample + + +class Decoder3D(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that + decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. + See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. + See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + # [Override] add temporal up block + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_up_num = temporal_up_num + + self.conv_in = init_causal_conv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + print(f"slicing_up_num: {slicing_up_num}") + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + is_temporal_up_block = i < self.temporal_up_num + is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num + # Note: Keep symmetric + + assert up_block_type == "UpDecoderBlock3D" + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=norm_type, + temb_channels=temb_channels, + temporal_up=is_temporal_up_block, + slicing=is_slicing_up_block, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = init_causal_conv3d( + block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + # Note: Just copy from Decoder. + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + memory_state: MemoryState = MemoryState.DISABLED, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample, memory_state=memory_state) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + memory_state, + use_reentrant=False, + ) + else: + # middle + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), sample, latent_embeds, memory_state + ) + else: + # middle + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state=memory_state) + + return sample + + +class AutoencoderKL(diffusers.AutoencoderKL): + """ + We simply inherit the model code from diffusers + """ + + def __init__(self, attention: bool = True, *args, **kwargs): + super().__init__(*args, **kwargs) + + # A hacky way to remove attention. + if not attention: + self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + + def load_state_dict(self, state_dict, strict=True): + # Newer version of diffusers changed the model keys, + # causing incompatibility with old checkpoints. + # They provided a method for conversion. We call conversion before loading state_dict. + convert_deprecated_attention_blocks = getattr( + self, "_convert_deprecated_attention_blocks", None + ) + if callable(convert_deprecated_attention_blocks): + convert_deprecated_attention_blocks(state_dict) + return super().load_state_dict(state_dict, strict) + + +class VideoAutoencoderKL(diffusers.AutoencoderKL): + """ + We simply inherit the model code from diffusers + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock3D",), + up_block_types: Tuple[str] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + attention: bool = True, + temporal_scale_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "full", + slicing_sample_min_size: int = 32, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + *args, + **kwargs, + ): + extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None + self.slicing_sample_min_size = slicing_sample_min_size + self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + # [Override] make sure it can be normally initialized + down_block_types=tuple( + [down_block_type.replace("3D", "2D") for down_block_type in down_block_types] + ), + up_block_types=tuple( + [up_block_type.replace("3D", "2D") for up_block_type in up_block_types] + ), + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + latent_channels=latent_channels, + norm_num_groups=norm_num_groups, + sample_size=sample_size, + scaling_factor=scaling_factor, + force_upcast=force_upcast, + *args, + **kwargs, + ) + + # pass init params to Encoder + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + extra_cond_dim=extra_cond_dim, + # [Override] add temporal_down_num parameter + temporal_down_num=temporal_scale_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + # [Override] add temporal_up_num parameter + temporal_up_num=temporal_scale_num, + slicing_up_num=slicing_up_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + self.quant_conv = ( + init_causal_conv3d( + in_channels=2 * latent_channels, + out_channels=2 * latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + init_causal_conv3d( + in_channels=latent_channels, + out_channels=latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_post_quant_conv + else None + ) + + # A hacky way to remove attention. + if not attention: + self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + + @apply_forward_hook + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + h = self.slicing_encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + @apply_forward_hook + def decode( + self, z: torch.Tensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.Tensor]: + decoded = self.slicing_decode(z) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def _encode( + self, x: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED + ) -> torch.Tensor: + _x = x.to(self.device) + _x = causal_conv_slice_inputs(_x, self.slicing_sample_min_size, memory_state=memory_state) + h = self.encoder(_x, memory_state=memory_state) + if self.quant_conv is not None: + output = self.quant_conv(h, memory_state=memory_state) + else: + output = h + output = causal_conv_gather_outputs(output) + return output.to(x.device) + + def _decode( + self, z: torch.Tensor, memory_state: MemoryState = MemoryState.DISABLED + ) -> torch.Tensor: + _z = z.to(self.device) + _z = causal_conv_slice_inputs(_z, self.slicing_latent_min_size, memory_state=memory_state) + if self.post_quant_conv is not None: + _z = self.post_quant_conv(_z, memory_state=memory_state) + output = self.decoder(_z, memory_state=memory_state) + output = causal_conv_gather_outputs(output) + return output.to(z.device) + + def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: + sp_size = get_sequence_parallel_world_size() + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) + encoded_slices = [ + self._encode( + torch.cat((x[:, :, :1], x_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for x_idx in range(1, len(x_slices)): + encoded_slices.append( + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + ) + return torch.cat(encoded_slices, dim=2) + else: + return self._encode(x) + + def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: + sp_size = get_sequence_parallel_world_size() + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + decoded_slices = [ + self._decode( + torch.cat((z[:, :, :1], z_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for z_idx in range(1, len(z_slices)): + decoded_slices.append( + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + ) + return torch.cat(decoded_slices, dim=2) + else: + return self._decode(z) + + def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs + ): + # x: [b c t h w] + if mode == "encode": + h = self.encode(x) + return h.latent_dist + elif mode == "decode": + h = self.decode(x) + return h.sample + else: + h = self.encode(x) + h = self.decode(h.latent_dist.mode()) + return h.sample + + def load_state_dict(self, state_dict, strict=False): + # Newer version of diffusers changed the model keys, + # causing incompatibility with old checkpoints. + # They provided a method for conversion. + # We call conversion before loading state_dict. + convert_deprecated_attention_blocks = getattr( + self, "_convert_deprecated_attention_blocks", None + ) + if callable(convert_deprecated_attention_blocks): + convert_deprecated_attention_blocks(state_dict) + return super().load_state_dict(state_dict, strict) + + +class VideoAutoencoderKLWrapper(VideoAutoencoderKL): + def __init__( + self, + *args, + spatial_downsample_factor: int, + temporal_downsample_factor: int, + freeze_encoder: bool, + **kwargs, + ): + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + self.freeze_encoder = freeze_encoder + super().__init__(*args, **kwargs) + + def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: + with torch.no_grad() if self.freeze_encoder else nullcontext(): + z, p = self.encode(x) + x = self.decode(z).sample + return CausalAutoencoderOutput(x, z, p) + + def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: + if x.ndim == 4: + x = x.unsqueeze(2) + p = super().encode(x).latent_dist + z = p.sample().squeeze(2) + return CausalEncoderOutput(z, p) + + def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: + if z.ndim == 4: + z = z.unsqueeze(2) + x = super().decode(z).sample.squeeze(2) + return CausalDecoderOutput(x) + + def preprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + assert x.ndim == 4 or x.size(2) % 4 == 1 + return x + + def postprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + return x + + def set_causal_slicing( + self, + *, + split_size: Optional[int], + memory_device: _memory_device_t, + ): + assert ( + split_size is None or memory_device is not None + ), "if split_size is set, memory_device must not be None." + if split_size is not None: + self.enable_slicing() + self.slicing_sample_min_size = split_size + self.slicing_latent_min_size = split_size // self.temporal_downsample_factor + else: + self.disable_slicing() + for module in self.modules(): + if isinstance(module, InflatedCausalConv3d): + module.set_memory_device(memory_device) + + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): + set_norm_limit(norm_max_mem) + for m in self.modules(): + if isinstance(m, InflatedCausalConv3d): + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) diff --git a/models/video_vae_v3/modules/causal_inflation_lib.py b/models/video_vae_v3/modules/causal_inflation_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd3cbe2512b119d76729c4103325ac22e0b12fe --- /dev/null +++ b/models/video_vae_v3/modules/causal_inflation_lib.py @@ -0,0 +1,460 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import math +from contextlib import contextmanager +from typing import List, Optional, Union +import torch +import torch.distributed as dist +import torch.nn.functional as F +from diffusers.models.normalization import RMSNorm +from einops import rearrange +from torch import Tensor, nn +from torch.nn import Conv3d + +from common.distributed.advanced import ( + get_next_sequence_parallel_rank, + get_prev_sequence_parallel_rank, + get_sequence_parallel_group, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, +) +from common.logger import get_logger +from models.video_vae_v3.modules.context_parallel_lib import cache_send_recv, get_cache_size +from models.video_vae_v3.modules.global_config import get_norm_limit +from models.video_vae_v3.modules.types import MemoryState, _inflation_mode_t, _memory_device_t + +logger = get_logger(__name__) + + +@contextmanager +def ignore_padding(model): + orig_padding = model.padding + model.padding = (0, 0, 0) + try: + yield + finally: + model.padding = orig_padding + + +class InflatedCausalConv3d(Conv3d): + def __init__( + self, + *args, + inflation_mode: _inflation_mode_t, + memory_device: _memory_device_t = "same", + **kwargs, + ): + self.inflation_mode = inflation_mode + self.memory = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.memory_device = memory_device + self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. + self.memory_limit = float("inf") + + def set_memory_limit(self, value: float): + self.memory_limit = value + + def set_memory_device(self, memory_device: _memory_device_t): + self.memory_device = memory_device + + def memory_limit_conv( + self, + x, + *, + split_dim=3, + padding=(0, 0, 0, 0, 0, 0), + prev_cache=None, + ): + # Compatible with no limit. + if math.isinf(self.memory_limit): + if prev_cache is not None: + x = torch.cat([prev_cache, x], dim=split_dim - 1) + return super().forward(x) + + # Compute tensor shape after concat & padding. + shape = torch.tensor(x.size()) + if prev_cache is not None: + shape[split_dim - 1] += prev_cache.size(split_dim - 1) + shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) + memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + logger.debug( + f"x:{(shape, x.dtype)} {memory_occupy:.3f}GiB " + f"prev_cache:{prev_cache.shape if prev_cache is not None else None}" + ) + if memory_occupy < self.memory_limit or split_dim == x.ndim: + if prev_cache is not None: + x = torch.cat([prev_cache, x], dim=split_dim - 1) + x = F.pad(x, padding, value=0.0) + with ignore_padding(self): + return super().forward(x) + + logger.debug( + f"Exceed memory limit {memory_occupy} > {self.memory_limit}, split dim {split_dim}" + ) + + # Split input (& prev_cache). + num_splits = math.ceil(memory_occupy / self.memory_limit) + size_per_split = x.size(split_dim) // num_splits + split_sizes = [size_per_split] * (num_splits - 1) + split_sizes += [x.size(split_dim) - sum(split_sizes)] + + x = list(x.split(split_sizes, dim=split_dim)) + logger.debug(f"Conv inputs: {[inp.size() for inp in x]} {x[0].dtype}") + if prev_cache is not None: + prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) + + # Loop Fwd. + cache = None + for idx in range(len(x)): + # Concat prev cache from last dim + if prev_cache is not None: + x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) + + # Get padding pattern. + lpad_dim = (x[idx].ndim - split_dim - 1) * 2 + rpad_dim = lpad_dim + 1 + padding = list(padding) + padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 + padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 + pad_len = padding[lpad_dim] + padding[rpad_dim] + padding = tuple(padding) + + # Prepare cache for next slice (this dim). + next_cache = None + cache_len = cache.size(split_dim) if cache is not None else 0 + next_catch_size = get_cache_size( + conv_module=self, + input_len=x[idx].size(split_dim) + cache_len, + pad_len=pad_len, + dim=split_dim - 2, + ) + if next_catch_size != 0: + assert next_catch_size <= x[idx].size(split_dim) + next_cache = ( + x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + ) + + # Recursive. + x[idx] = self.memory_limit_conv( + x[idx], + split_dim=split_dim + 1, + padding=padding, + prev_cache=cache, + ) + + # Update cache. + cache = next_cache + + logger.debug(f"Conv outputs, concat(dim={split_dim}): {[d.size() for d in x]}") + return torch.cat(x, split_dim) + + def forward( + self, + input: Union[Tensor, List[Tensor]], + memory_state: MemoryState = MemoryState.UNSET, + ) -> Tensor: + assert memory_state != MemoryState.UNSET + if memory_state != MemoryState.ACTIVE: + self.memory = None + if ( + math.isinf(self.memory_limit) + and torch.is_tensor(input) + and get_sequence_parallel_group() is None + ): + return self.basic_forward(input, memory_state) + return self.slicing_forward(input, memory_state) + + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + mem_size = self.stride[0] - self.kernel_size[0] + if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.memory, times=-1) + else: + input = extend_head(input, times=self.temporal_padding * 2) + memory = ( + input[:, :, mem_size:].detach() + if (mem_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if ( + memory_state != MemoryState.DISABLED + and not self.training + and (self.memory_device is not None) + ): + self.memory = memory + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + return super().forward(input) + + def slicing_forward( + self, + input: Union[Tensor, List[Tensor]], + memory_state: MemoryState = MemoryState.UNSET, + ) -> Tensor: + squeeze_out = False + if torch.is_tensor(input): + input = [input] + squeeze_out = True + + cache_size = self.kernel_size[0] - self.stride[0] + cache = cache_send_recv( + input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 + ) + + # For slice=4 and sp=2, and 17 frames in total + # sp0 sp1 + # slice 0: [`0 0` 0 1 2 {3 4}] [{3 4} 5 6 (7 8)] extend=`0 0` cache={3 4} memory=(7 8) + # slice 1: [(7 8) 9 10 {11 12}] [{11 12} 13 14 15 16] + sp_rank = get_sequence_parallel_rank() + sp_size = get_sequence_parallel_world_size() + sp_group = get_sequence_parallel_group() + send_dst = get_next_sequence_parallel_rank() + recv_src = get_prev_sequence_parallel_rank() + if ( + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + and not self.training + and (self.memory_device is not None) + and sp_rank in [0, sp_size - 1] + and cache_size != 0 + ): + if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: + input[0] = torch.cat([cache, input[0]], dim=2) + cache = None + assert cache_size <= input[-1].size(2) + if sp_size == 1: + self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + else: + if sp_rank == sp_size - 1: + dist.send( + input[-1][:, :, -cache_size:].detach().contiguous(), + send_dst, + group=sp_group, + ) + if sp_rank == 0: + shape = list(input[0].size()) + shape[2] = cache_size + self.memory = torch.empty( + *shape, device=input[0].device, dtype=input[0].dtype + ).contiguous() + dist.recv(self.memory, recv_src, group=sp_group) + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + + padding = tuple(x for x in reversed(self.padding) for _ in range(2)) + for i in range(len(input)): + # Prepare cache for next input slice. + next_cache = None + cache_size = 0 + if i < len(input) - 1: + cache_len = cache.size(2) if cache is not None else 0 + cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) + if cache_size != 0: + if cache_size > input[i].size(2) and cache is not None: + input[i] = torch.cat([cache, input[i]], dim=2) + cache = None + assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + next_cache = input[i][:, :, -cache_size:] + + # Conv forward for this input slice. + input[i] = self.memory_limit_conv( + input[i], + padding=padding, + prev_cache=cache, + ) + + # Update cache. + cache = next_cache + + return input[0] if squeeze_out else input + + def tflops(self, args, kwargs, output) -> float: + if torch.is_tensor(output): + output_numel = output.numel() + elif isinstance(output, list): + output_numel = sum(o.numel() for o in output) + else: + raise NotImplementedError + return (2 * math.prod(self.kernel_size) * self.in_channels * (output_numel / 1e6)) / 1e6 + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if self.inflation_mode != "none": + state_dict = modify_state_dict( + self, + state_dict, + prefix, + inflate_weight_fn=inflate_weight, + inflate_bias_fn=inflate_bias, + ) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + (strict and self.inflation_mode == "none"), + missing_keys, + unexpected_keys, + error_msgs, + ) + + +def init_causal_conv3d( + *args, + inflation_mode: _inflation_mode_t, + **kwargs, +): + """ + Initialize a Causal-3D convolution layer. + Parameters: + inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. + - none: No inflation will be conducted. + The loading logic of state dict will fall back to default. + - tail / replicate: Refer to the definition of `InflatedCausalConv3d`. + """ + return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) + + +def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x.to(input_dtype) + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x.to(input_dtype) + if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4: + return norm_layer(x).to(input_dtype) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + memory_occupy = x.numel() * x.element_size() / 1024**3 + if isinstance(norm_layer, nn.GroupNorm) and memory_occupy > get_norm_limit(): + num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) + logger.debug(f"large tensor {x.shape}, norm in {num_chunks} chunks") + assert norm_layer.num_groups % num_chunks == 0 + num_groups_per_chunk = norm_layer.num_groups // num_chunks + + x = list(x.chunk(num_chunks, dim=1)) + weights = norm_layer.weight.chunk(num_chunks, dim=0) + biases = norm_layer.bias.chunk(num_chunks, dim=0) + for i, (w, b) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + x[i] = x[i].to(input_dtype) + x = torch.cat(x, dim=1) + else: + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x.to(input_dtype) + raise NotImplementedError + + +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + """ + Remove duplicated first frame features in the up-sampling process. + """ + sp_rank = get_sequence_parallel_rank() + if times == 0 or sp_rank > 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + + +def extend_head(tensor: Tensor, times: int = 2, memory: Optional[Tensor] = None) -> Tensor: + """ + When memory is None: + - Duplicate first frame features in the down-sampling process. + When memory is not None: + - Concatenate memory features with the input features to keep temporal consistency. + """ + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + assert times >= 0, "Invalid input for function 'extend_head'!" + if times == 0: + return tensor + else: + tile_repeat = [1] * tensor.ndim + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) + + +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution weight matrix to a 3D one. + Parameters: + weight_2d: The weight matrix of 2D conv to be inflated. + weight_3d: The weight matrix of 3D conv to be initialized. + inflation_mode: the mode of inflation + """ + assert inflation_mode in ["tail", "replicate"] + assert weight_3d.shape[:2] == weight_2d.shape[:2] + with torch.no_grad(): + if inflation_mode == "replicate": + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) + else: + weight_3d.fill_(0.0) + weight_3d[:, :, -1].copy_(weight_2d) + return weight_3d + + +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution bias tensor to a 3D one + Parameters: + bias_2d: The bias tensor of 2D conv to be inflated. + bias_3d: The bias tensor of 3D conv to be initialized. + inflation_mode: Placeholder to align `inflate_weight`. + """ + assert bias_3d.shape == bias_2d.shape + with torch.no_grad(): + bias_3d.copy_(bias_2d) + return bias_3d + + +def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): + """ + the main function to inflated 2D parameters to 3D. + """ + weight_name = prefix + "weight" + bias_name = prefix + "bias" + if weight_name in state_dict: + weight_2d = state_dict[weight_name] + if weight_2d.dim() == 4: + # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) + weight_3d = inflate_weight_fn( + weight_2d=weight_2d, + weight_3d=layer.weight, + inflation_mode=layer.inflation_mode, + ) + state_dict[weight_name] = weight_3d + else: + return state_dict + # It's a 3d state dict, should not do inflation on both bias and weight. + if bias_name in state_dict: + bias_2d = state_dict[bias_name] + if bias_2d.dim() == 1: + # Assuming the 2D biases are 1D tensors (out_channels,) + bias_3d = inflate_bias_fn( + bias_2d=bias_2d, + bias_3d=layer.bias, + inflation_mode=layer.inflation_mode, + ) + state_dict[bias_name] = bias_3d + return state_dict diff --git a/models/video_vae_v3/modules/context_parallel_lib.py b/models/video_vae_v3/modules/context_parallel_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..55cfe481ee2ade7434166bfda0b83589b423137c --- /dev/null +++ b/models/video_vae_v3/modules/context_parallel_lib.py @@ -0,0 +1,164 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import List +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import Tensor + +from common.distributed import get_device +from common.distributed.advanced import ( + get_next_sequence_parallel_rank, + get_prev_sequence_parallel_rank, + get_sequence_parallel_group, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, +) +from common.distributed.ops import Gather +from common.logger import get_logger +from models.video_vae_v3.modules.types import MemoryState + +logger = get_logger(__name__) + + +def causal_conv_slice_inputs(x, split_size, memory_state): + sp_size = get_sequence_parallel_world_size() + sp_group = get_sequence_parallel_group() + sp_rank = get_sequence_parallel_rank() + if sp_group is None: + return x + + assert memory_state != MemoryState.UNSET + leave_out = 1 if memory_state != MemoryState.ACTIVE else 0 + + # Should have at least sp_size slices. + num_slices = (x.size(2) - leave_out) // split_size + assert num_slices >= sp_size, f"{num_slices} < {sp_size}" + + split_sizes = [split_size + leave_out] + [split_size] * (num_slices - 1) + split_sizes += [x.size(2) - sum(split_sizes)] + assert sum(split_sizes) == x.size(2) + + split_sizes = torch.tensor(split_sizes) + slices_per_rank = len(split_sizes) // sp_size + split_sizes = split_sizes.split( + [slices_per_rank] * (sp_size - 1) + [len(split_sizes) - slices_per_rank * (sp_size - 1)] + ) + split_sizes = list(map(lambda s: s.sum().item(), split_sizes)) + logger.debug(f"split_sizes: {split_sizes}") + return x.split(split_sizes, dim=2)[sp_rank] + + +def causal_conv_gather_outputs(x): + sp_group = get_sequence_parallel_group() + sp_size = get_sequence_parallel_world_size() + if sp_group is None: + return x + + # Communicate shapes. + unpad_lens = torch.empty((sp_size,), device=get_device(), dtype=torch.long) + local_unpad_len = torch.tensor([x.size(2)], device=get_device(), dtype=torch.long) + torch.distributed.all_gather_into_tensor(unpad_lens, local_unpad_len, group=sp_group) + + # Padding to max_len for gather. + max_len = unpad_lens.max() + x_pad = F.pad(x, (0, 0, 0, 0, 0, max_len - x.size(2))).contiguous() + + # Gather outputs. + x_pad = Gather.apply(sp_group, x_pad, 2, True) + + # Remove padding. + x_pad_lists = list(x_pad.chunk(sp_size, dim=2)) + for i, (x_pad, unpad_len) in enumerate(zip(x_pad_lists, unpad_lens)): + x_pad_lists[i] = x_pad[:, :, :unpad_len] + + return torch.cat(x_pad_lists, dim=2) + + +def get_output_len(conv_module, input_len, pad_len, dim=0): + dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + return output_len + + +def get_cache_size(conv_module, input_len, pad_len, dim=0): + dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + remain_len = ( + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + ) + overlap_len = dilated_kernerl_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len # >= 0 + logger.debug( + f"I:{input_len}, " + f"P:{pad_len}, " + f"K:{conv_module.kernel_size[dim]}, " + f"S:{conv_module.stride[dim]}, " + f"O:{output_len}, " + f"Cache:{cache_len}" + ) + assert output_len > 0 + return cache_len + + +def cache_send_recv(tensor: List[Tensor], cache_size, times, memory=None): + sp_group = get_sequence_parallel_group() + sp_rank = get_sequence_parallel_rank() + sp_size = get_sequence_parallel_world_size() + send_dst = get_next_sequence_parallel_rank() + recv_src = get_prev_sequence_parallel_rank() + recv_buffer = None + recv_req = None + + logger.debug( + f"[sp{sp_rank}] cur_tensors:{[(t.size(), t.dtype) for t in tensor]}, times: {times}" + ) + if sp_rank == 0 or sp_group is None: + if memory is not None: + recv_buffer = memory.to(tensor[0]) + elif times > 0: + tile_repeat = [1] * tensor[0].ndim + tile_repeat[2] = times + recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) + + if cache_size != 0 and sp_group is not None: + if sp_rank > 0: + shape = list(tensor[0].size()) + shape[2] = cache_size + recv_buffer = torch.empty( + *shape, device=tensor[0].device, dtype=tensor[0].dtype + ).contiguous() + recv_req = dist.irecv(recv_buffer, recv_src, group=sp_group) + if sp_rank < sp_size - 1: + if cache_size > tensor[-1].size(2) and len(tensor) == 1: + logger.debug(f"[sp{sp_rank}] force concat before send {tensor[-1].size()}") + if recv_req is not None: + recv_req.wait() + tensor[0] = torch.cat([recv_buffer, tensor[0]], dim=2) + recv_buffer = None + assert cache_size <= tensor[-1].size( + 2 + ), f"Not enough value to cache, got {tensor[-1].size()}, cache_size={cache_size}" + dist.isend( + tensor[-1][:, :, -cache_size:].detach().contiguous(), send_dst, group=sp_group + ) + if recv_req is not None: + recv_req.wait() + + logger.debug( + f"[sp{sp_rank}] recv_src:{recv_src}, " + f"recv_buffer:{recv_buffer.size() if recv_buffer is not None else None}" + ) + return recv_buffer diff --git a/models/video_vae_v3/modules/global_config.py b/models/video_vae_v3/modules/global_config.py new file mode 100644 index 0000000000000000000000000000000000000000..863117570a8aadde38b8eae8f1aa16480cd9f7ca --- /dev/null +++ b/models/video_vae_v3/modules/global_config.py @@ -0,0 +1,28 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import Optional + +_NORM_LIMIT = float("inf") + + +def get_norm_limit(): + return _NORM_LIMIT + + +def set_norm_limit(value: Optional[float] = None): + global _NORM_LIMIT + if value is None: + value = float("inf") + _NORM_LIMIT = value diff --git a/models/video_vae_v3/modules/inflated_layers.py b/models/video_vae_v3/modules/inflated_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8dfa4841a4ba3e4f758831396497b246614c7bb5 --- /dev/null +++ b/models/video_vae_v3/modules/inflated_layers.py @@ -0,0 +1,106 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from functools import partial +from typing import Literal, Optional +from torch import Tensor +from torch.nn import Conv3d + +from models.video_vae_v3.modules.inflated_lib import ( + MemoryState, + extend_head, + inflate_bias, + inflate_weight, + modify_state_dict, +) + +_inflation_mode_t = Literal["none", "tail", "replicate"] +_memory_device_t = Optional[Literal["cpu", "same"]] + + +class InflatedCausalConv3d(Conv3d): + def __init__( + self, + *args, + inflation_mode: _inflation_mode_t, + memory_device: _memory_device_t = "same", + **kwargs, + ): + self.inflation_mode = inflation_mode + self.memory = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.memory_device = memory_device + self.padding = (0, *self.padding[1:]) # Remove temporal pad to keep causal. + + def set_memory_device(self, memory_device: _memory_device_t): + self.memory_device = memory_device + + def forward(self, input: Tensor, memory_state: MemoryState = MemoryState.DISABLED) -> Tensor: + mem_size = self.stride[0] - self.kernel_size[0] + if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.memory) + else: + input = extend_head(input, times=self.temporal_padding * 2) + memory = ( + input[:, :, mem_size:].detach() + if (mem_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if ( + memory_state != MemoryState.DISABLED + and not self.training + and (self.memory_device is not None) + ): + self.memory = memory + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + return super().forward(input) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + if self.inflation_mode != "none": + state_dict = modify_state_dict( + self, + state_dict, + prefix, + inflate_weight_fn=partial(inflate_weight, position="tail"), + inflate_bias_fn=partial(inflate_bias, position="tail"), + ) + super()._load_from_state_dict( + state_dict, + prefix, + local_metadata, + (strict and self.inflation_mode == "none"), + missing_keys, + unexpected_keys, + error_msgs, + ) + + +def init_causal_conv3d( + *args, + inflation_mode: _inflation_mode_t, + **kwargs, +): + """ + Initialize a Causal-3D convolution layer. + Parameters: + inflation_mode: Listed as below. It's compatible with all the 3D-VAE checkpoints we have. + - none: No inflation will be conducted. + The loading logic of state dict will fall back to default. + - tail / replicate: Refer to the definition of `InflatedCausalConv3d`. + """ + return InflatedCausalConv3d(*args, inflation_mode=inflation_mode, **kwargs) diff --git a/models/video_vae_v3/modules/inflated_lib.py b/models/video_vae_v3/modules/inflated_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..cbdaf3138bb5994c4702185426f854a4660cc6a4 --- /dev/null +++ b/models/video_vae_v3/modules/inflated_lib.py @@ -0,0 +1,156 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from enum import Enum +from typing import Optional +import numpy as np +import torch +from diffusers.models.normalization import RMSNorm +from einops import rearrange +from torch import Tensor, nn + +from common.logger import get_logger + +logger = get_logger(__name__) + + +class MemoryState(Enum): + """ + State[Disabled]: No memory bank will be enabled. + State[Initializing]: The model is handling the first clip, + need to reset / initialize the memory bank. + State[Active]: There has been some data in the memory bank. + """ + + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + + +def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: + if isinstance(norm_layer, (nn.LayerNorm, RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x + if isinstance(norm_layer, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4: + return norm_layer(x) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x + raise NotImplementedError + + +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + """ + Remove duplicated first frame features in the up-sampling process. + """ + if times == 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + + +def extend_head( + tensor: Tensor, times: Optional[int] = 2, memory: Optional[Tensor] = None +) -> Tensor: + """ + When memory is None: + - Duplicate first frame features in the down-sampling process. + When memory is not None: + - Concatenate memory features with the input features to keep temporal consistency. + """ + if times == 0: + return tensor + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + else: + tile_repeat = np.ones(tensor.ndim).astype(int) + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], list(tile_repeat)), tensor), dim=2) + + +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution weight matrix to a 3D one. + Parameters: + weight_2d: The weight matrix of 2D conv to be inflated. + weight_3d: The weight matrix of 3D conv to be initialized. + inflation_mode: the mode of inflation + """ + assert inflation_mode in ["constant", "replicate"] + assert weight_3d.shape[:2] == weight_2d.shape[:2] + with torch.no_grad(): + if inflation_mode == "replicate": + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) + else: + weight_3d.fill_(0.0) + weight_3d[:, :, -1].copy_(weight_2d) + return weight_3d + + +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor, inflation_mode: str): + """ + Inflate a 2D convolution bias tensor to a 3D one + Parameters: + bias_2d: The bias tensor of 2D conv to be inflated. + bias_3d: The bias tensor of 3D conv to be initialized. + inflation_mode: Placeholder to align `inflate_weight`. + """ + assert bias_3d.shape == bias_2d.shape + with torch.no_grad(): + bias_3d.copy_(bias_2d) + return bias_3d + + +def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): + """ + the main function to inflated 2D parameters to 3D. + """ + weight_name = prefix + "weight" + bias_name = prefix + "bias" + if weight_name in state_dict: + weight_2d = state_dict[weight_name] + if weight_2d.dim() == 4: + # Assuming the 2D weights are 4D tensors (out_channels, in_channels, h, w) + weight_3d = inflate_weight_fn( + weight_2d=weight_2d, + weight_3d=layer.weight, + inflation_mode=layer.inflation_mode, + ) + state_dict[weight_name] = weight_3d + else: + return state_dict + # It's a 3d state dict, should not do inflation on both bias and weight. + if bias_name in state_dict: + bias_2d = state_dict[bias_name] + if bias_2d.dim() == 1: + # Assuming the 2D biases are 1D tensors (out_channels,) + bias_3d = inflate_bias_fn( + bias_2d=bias_2d, + bias_3d=layer.bias, + inflation_mode=layer.inflation_mode, + ) + state_dict[bias_name] = bias_3d + return state_dict diff --git a/models/video_vae_v3/modules/types.py b/models/video_vae_v3/modules/types.py new file mode 100644 index 0000000000000000000000000000000000000000..5a030d2d284f9535f2a84c1f9befcd3f82d8d9ff --- /dev/null +++ b/models/video_vae_v3/modules/types.py @@ -0,0 +1,76 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from enum import Enum +from typing import Dict, Literal, NamedTuple, Optional +import torch + +_receptive_field_t = Literal["half", "full"] +_inflation_mode_t = Literal["none", "tail", "replicate"] +_memory_device_t = Optional[Literal["cpu", "same"]] +_gradient_checkpointing_t = Optional[Literal["half", "full"]] +_selective_checkpointing_t = Optional[Literal["coarse", "fine"]] + +class DiagonalGaussianDistribution: + def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): + self.mean = mean + self.logvar = torch.clamp(logvar, -30.0, 20.0) + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + + def mode(self) -> torch.Tensor: + return self.mean + + def sample(self) -> torch.FloatTensor: + return self.mean + self.std * torch.randn_like(self.mean) + + def kl(self) -> torch.Tensor: + return 0.5 * torch.sum( + self.mean**2 + self.var - 1.0 - self.logvar, + dim=list(range(1, self.mean.ndim)), + ) + +class MemoryState(Enum): + """ + State[Disabled]: No memory bank will be enabled. + State[Initializing]: The model is handling the first clip, need to reset the memory bank. + State[Active]: There has been some data in the memory bank. + State[Unset]: Error state, indicating users didn't pass correct memory state in. + """ + + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + UNSET = 3 + + +class QuantizerOutput(NamedTuple): + latent: torch.Tensor + extra_loss: torch.Tensor + statistics: Dict[str, torch.Tensor] + + +class CausalAutoencoderOutput(NamedTuple): + sample: torch.Tensor + latent: torch.Tensor + posterior: Optional[DiagonalGaussianDistribution] + + +class CausalEncoderOutput(NamedTuple): + latent: torch.Tensor + posterior: Optional[DiagonalGaussianDistribution] + + +class CausalDecoderOutput(NamedTuple): + sample: torch.Tensor diff --git a/models/video_vae_v3/modules/video_vae.py b/models/video_vae_v3/modules/video_vae.py new file mode 100644 index 0000000000000000000000000000000000000000..1b169431c637ba273de7e6a2340c64206746ef28 --- /dev/null +++ b/models/video_vae_v3/modules/video_vae.py @@ -0,0 +1,955 @@ +# Copyright (c) 2023 HuggingFace Team +# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates. +# SPDX-License-Identifier: Apache License, Version 2.0 (the "License") +# +# This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025 +# +# Original file was released under Apache License, Version 2.0 (the "License"), with the full license text +# available at http://www.apache.org/licenses/LICENSE-2.0. +# +# This modified file is released under the same license. + +from contextlib import nullcontext +from typing import Optional, Tuple, Literal, Callable, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from einops import rearrange + +from common.distributed.advanced import get_sequence_parallel_world_size +from common.logger import get_logger +from models.video_vae_v3.modules.causal_inflation_lib import ( + InflatedCausalConv3d, + causal_norm_wrapper, + init_causal_conv3d, + remove_head, +) +from models.video_vae_v3.modules.context_parallel_lib import ( + causal_conv_gather_outputs, + causal_conv_slice_inputs, +) +from models.video_vae_v3.modules.global_config import set_norm_limit +from models.video_vae_v3.modules.types import ( + CausalAutoencoderOutput, + CausalDecoderOutput, + CausalEncoderOutput, + MemoryState, + _inflation_mode_t, + _memory_device_t, + _receptive_field_t, + _selective_checkpointing_t, +) + +logger = get_logger(__name__) # pylint: disable=invalid-name + +# Fake func, no checkpointing is required for inference +def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): + return module(*args, **kwargs) + +class ResnetBlock2D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv2d layer. + If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + """ + + def __init__( + self, *, in_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0 + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.nonlinearity = nn.SiLU() + + self.norm1 = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.norm2 = torch.nn.GroupNorm( + num_groups=32, num_channels=out_channels, eps=1e-6, affine=True + ) + + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + self.use_in_shortcut = self.in_channels != out_channels + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + hidden = input_tensor + + hidden = self.norm1(hidden) + hidden = self.nonlinearity(hidden) + hidden = self.conv1(hidden) + + hidden = self.norm2(hidden) + hidden = self.nonlinearity(hidden) + hidden = self.dropout(hidden) + hidden = self.conv2(hidden) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden + + return output_tensor + +class Upsample3D(nn.Module): + """A 3D upsampling layer.""" + + def __init__( + self, + channels: int, + inflation_mode: _inflation_mode_t = "tail", + temporal_up: bool = False, + spatial_up: bool = True, + slicing: bool = False, + ): + super().__init__() + self.channels = channels + self.conv = init_causal_conv3d( + self.channels, self.channels, kernel_size=3, padding=1, inflation_mode=inflation_mode + ) + + self.temporal_up = temporal_up + self.spatial_up = spatial_up + self.temporal_ratio = 2 if temporal_up else 1 + self.spatial_ratio = 2 if spatial_up else 1 + self.slicing = slicing + + upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio + self.upscale_conv = nn.Conv3d( + self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 + ) + identity = ( + torch.eye(self.channels).repeat(upscale_ratio, 1).reshape_as(self.upscale_conv.weight) + ) + + self.upscale_conv.weight.data.copy_(identity) + nn.init.zeros_(self.upscale_conv.bias) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state: MemoryState, + ) -> torch.FloatTensor: + return gradient_checkpointing( + self.custom_forward, + hidden_states, + memory_state, + enabled=self.training and self.gradient_checkpointing, + ) + + def custom_forward( + self, + hidden_states: torch.FloatTensor, + memory_state: MemoryState, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if self.slicing: + split_size = hidden_states.size(2) // 2 + hidden_states = list( + hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) + ) + else: + hidden_states = [hidden_states] + + for i in range(len(hidden_states)): + hidden_states[i] = self.upscale_conv(hidden_states[i]) + hidden_states[i] = rearrange( + hidden_states[i], + "b (x y z c) f h w -> b c (f z) (h x) (w y)", + x=self.spatial_ratio, + y=self.spatial_ratio, + z=self.temporal_ratio, + ) + + # [Overridden] For causal temporal conv + if self.temporal_up and memory_state != MemoryState.ACTIVE: + hidden_states[0] = remove_head(hidden_states[0]) + + if self.slicing: + hidden_states = self.conv(hidden_states, memory_state=memory_state) + return torch.cat(hidden_states, dim=2) + else: + return self.conv(hidden_states[0], memory_state=memory_state) + + +class Downsample3D(nn.Module): + """A 3D downsampling layer.""" + + def __init__( + self, + channels: int, + inflation_mode: _inflation_mode_t = "tail", + temporal_down: bool = False, + spatial_down: bool = True, + ): + super().__init__() + self.channels = channels + self.temporal_down = temporal_down + self.spatial_down = spatial_down + + self.temporal_ratio = 2 if temporal_down else 1 + self.spatial_ratio = 2 if spatial_down else 1 + + self.temporal_kernel = 3 if temporal_down else 1 + self.spatial_kernel = 3 if spatial_down else 1 + + self.conv = init_causal_conv3d( + self.channels, + self.channels, + kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + padding=((1 if self.temporal_down else 0), 0, 0), + inflation_mode=inflation_mode, + ) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state: MemoryState, + ) -> torch.FloatTensor: + return gradient_checkpointing( + self.custom_forward, + hidden_states, + memory_state, + enabled=self.training and self.gradient_checkpointing, + ) + + def custom_forward( + self, + hidden_states: torch.FloatTensor, + memory_state: MemoryState, + ) -> torch.FloatTensor: + + assert hidden_states.shape[1] == self.channels + + if self.spatial_down: + hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) + + hidden_states = self.conv(hidden_states, memory_state=memory_state) + return hidden_states + + +class ResnetBlock3D(ResnetBlock2D): + def __init__( + self, + *args, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + **kwargs, + ): + super().__init__(*args, **kwargs) + self.conv1 = init_causal_conv3d( + self.in_channels, + self.out_channels, + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.conv2 = init_causal_conv3d( + self.out_channels, + self.out_channels, + kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), + stride=1, + padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), + inflation_mode=inflation_mode, + ) + + if self.use_in_shortcut: + self.conv_shortcut = init_causal_conv3d( + self.in_channels, + self.out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=(self.conv_shortcut.bias is not None), + inflation_mode=inflation_mode, + ) + self.gradient_checkpointing = False + + def forward(self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET): + return gradient_checkpointing( + self.custom_forward, + input_tensor, + memory_state, + enabled=self.training and self.gradient_checkpointing, + ) + + def custom_forward( + self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET + ): + assert memory_state != MemoryState.UNSET + hidden_states = input_tensor + + hidden_states = causal_norm_wrapper(self.norm1, hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.conv1(hidden_states, memory_state=memory_state) + + hidden_states = causal_norm_wrapper(self.norm2, hidden_states) + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, memory_state=memory_state) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + add_downsample: bool = True, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_down: bool = True, + spatial_down: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.downsamplers = None + if add_downsample: + # Todo: Refactor this line before V5 Image VAE Training. + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + channels=out_channels, + inflation_mode=inflation_mode, + temporal_down=temporal_down, + spatial_down=spatial_down, + ) + ] + ) + + def forward( + self, hidden_states: torch.FloatTensor, memory_state: MemoryState + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, memory_state=memory_state) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + add_upsample: bool = True, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up: bool = True, + spatial_up: bool = True, + slicing: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=input_channels, + out_channels=out_channels, + dropout=dropout, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.upsamplers = None + # Todo: Refactor this line before V5 Image VAE Training. + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + Upsample3D( + channels=out_channels, + inflation_mode=inflation_mode, + temporal_up=temporal_up, + spatial_up=spatial_up, + slicing=slicing, + ) + ] + ) + + def forward( + self, hidden_states: torch.FloatTensor, memory_state: MemoryState + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, memory_state=memory_state) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + channels: int, + dropout: float = 0.0, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + self.resnets = nn.ModuleList( + [ + ResnetBlock3D( + in_channels=channels, + out_channels=channels, + dropout=dropout, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ), + ResnetBlock3D( + in_channels=channels, + out_channels=channels, + dropout=dropout, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ), + ] + ) + + def forward(self, hidden_states: torch.Tensor, memory_state: MemoryState): + for resnet in self.resnets: + hidden_states = resnet(hidden_states, memory_state) + return hidden_states + + +class Encoder3D(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes + its input into a latent representation. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + double_z: bool = True, + temporal_down_num: int = 2, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.temporal_down_num = temporal_down_num + + self.conv_in = init_causal_conv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 + # Note: take the last one + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + temporal_down=is_temporal_down_block, + spatial_down=True, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3D( + channels=block_out_channels[-1], + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], num_groups=32, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = init_causal_conv3d( + block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + assert len(selective_checkpointing) == len(self.down_blocks) + self.set_gradient_checkpointing(selective_checkpointing) + + def set_gradient_checkpointing(self, checkpointing_types): + gradient_checkpointing = [] + for down_block, sac_type in zip(self.down_blocks, checkpointing_types): + if sac_type == "coarse": + gradient_checkpointing.append(True) + elif sac_type == "fine": + for n, m in down_block.named_modules(): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = True + logger.debug(f"set gradient_checkpointing: {n}") + gradient_checkpointing.append(False) + else: + gradient_checkpointing.append(False) + self.gradient_checkpointing = gradient_checkpointing + logger.info(f"[Encoder3D] gradient_checkpointing: {checkpointing_types}") + + def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + sample = self.conv_in(sample, memory_state=memory_state) + # down + for down_block, sac in zip(self.down_blocks, self.gradient_checkpointing): + sample = gradient_checkpointing( + down_block, + sample, + memory_state=memory_state, + enabled=self.training and sac, + ) + + # middle + sample = self.mid_block(sample, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state=memory_state) + + return sample + + +class Decoder3D(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that + decodes its latent representation into an output sample. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up_num: int = 2, + slicing_up_num: int = 0, + selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_up_num = temporal_up_num + + self.conv_in = init_causal_conv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.up_blocks = nn.ModuleList([]) + + # mid + self.mid_block = UNetMidBlock3D( + channels=block_out_channels[-1], + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + is_temporal_up_block = i < self.temporal_up_num + is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num + # Note: Keep symmetric + + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + temporal_up=is_temporal_up_block, + slicing=is_slicing_up_block, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.up_blocks.append(up_block) + + # out + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=32, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = init_causal_conv3d( + block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + assert len(selective_checkpointing) == len(self.up_blocks) + self.set_gradient_checkpointing(selective_checkpointing) + + def set_gradient_checkpointing(self, checkpointing_types): + gradient_checkpointing = [] + for up_block, sac_type in zip(self.up_blocks, checkpointing_types): + if sac_type == "coarse": + gradient_checkpointing.append(True) + elif sac_type == "fine": + for n, m in up_block.named_modules(): + if hasattr(m, "gradient_checkpointing"): + m.gradient_checkpointing = True + logger.debug(f"set gradient_checkpointing: {n}") + gradient_checkpointing.append(False) + else: + gradient_checkpointing.append(False) + self.gradient_checkpointing = gradient_checkpointing + logger.info(f"[Decoder3D] gradient_checkpointing: {checkpointing_types}") + + def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample, memory_state=memory_state) + + # middle + sample = self.mid_block(sample, memory_state=memory_state) + + # up + for up_block, sac in zip(self.up_blocks, self.gradient_checkpointing): + sample = gradient_checkpointing( + up_block, + sample, + memory_state=memory_state, + enabled=self.training and sac, + ) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state=memory_state) + + return sample + + +class VideoAutoencoderKL(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + latent_channels: int = 4, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), + dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), + temporal_scale_num: int = 3, + slicing_up_num: int = 0, + inflation_mode: _inflation_mode_t = "tail", + time_receptive_field: _receptive_field_t = "half", + slicing_sample_min_size: int = None, + spatial_downsample_factor: int = 16, + temporal_downsample_factor: int = 8, + freeze_encoder: bool = False, + ): + super().__init__() + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + self.freeze_encoder = freeze_encoder + if slicing_sample_min_size is None: + slicing_sample_min_size = temporal_downsample_factor + self.slicing_sample_min_size = slicing_sample_min_size + self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) + + # pass init params to Encoder + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + double_z=True, + temporal_down_num=temporal_scale_num, + selective_checkpointing=enc_selective_checkpointing, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + # [Override] add temporal_up_num parameter + temporal_up_num=temporal_scale_num, + slicing_up_num=slicing_up_num, + selective_checkpointing=dec_selective_checkpointing, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + self.quant_conv = ( + init_causal_conv3d( + in_channels=2 * latent_channels, + out_channels=2 * latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + init_causal_conv3d( + in_channels=latent_channels, + out_channels=latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_post_quant_conv + else None + ) + + self.use_slicing = False + + def enable_slicing(self): + self.use_slicing = True + + def disable_slicing(self): + self.use_slicing = False + + def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: + if x.ndim == 4: + x = x.unsqueeze(2) + h = self.slicing_encode(x) + p = DiagonalGaussianDistribution(h) + z = p.sample() + return CausalEncoderOutput(z, p) + + def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: + if z.ndim == 4: + z = z.unsqueeze(2) + x = self.slicing_decode(z) + return CausalDecoderOutput(x) + + def _encode(self, x: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: + x = causal_conv_slice_inputs(x, self.slicing_sample_min_size, memory_state=memory_state) + h = self.encoder(x, memory_state=memory_state) + h = self.quant_conv(h, memory_state=memory_state) if self.quant_conv is not None else h + h = causal_conv_gather_outputs(h) + return h + + def _decode(self, z: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: + z = causal_conv_slice_inputs(z, self.slicing_latent_min_size, memory_state=memory_state) + z = ( + self.post_quant_conv(z, memory_state=memory_state) + if self.post_quant_conv is not None + else z + ) + x = self.decoder(z, memory_state=memory_state) + x = causal_conv_gather_outputs(x) + return x + + def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: + sp_size = get_sequence_parallel_world_size() + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) + encoded_slices = [ + self._encode( + torch.cat((x[:, :, :1], x_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for x_idx in range(1, len(x_slices)): + encoded_slices.append( + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + ) + return torch.cat(encoded_slices, dim=2) + else: + return self._encode(x, memory_state=MemoryState.DISABLED) + + def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: + sp_size = get_sequence_parallel_world_size() + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + decoded_slices = [ + self._decode( + torch.cat((z[:, :, :1], z_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for z_idx in range(1, len(z_slices)): + decoded_slices.append( + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + ) + return torch.cat(decoded_slices, dim=2) + else: + return self._decode(z, memory_state=MemoryState.DISABLED) + + def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: + with torch.no_grad() if self.freeze_encoder else nullcontext(): + z, p = self.encode(x) + x = self.decode(z).sample + return CausalAutoencoderOutput(x, z, p) + + def preprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + assert x.ndim == 4 or x.size(2) % self.temporal_downsample_factor == 1 + return x + + def postprocess(self, x: torch.Tensor): + # x should in [B, C, T, H, W], [B, C, H, W] + return x + + def set_causal_slicing( + self, + *, + split_size: Optional[int], + memory_device: _memory_device_t, + ): + assert ( + split_size is None or memory_device is not None + ), "if split_size is set, memory_device must not be None." + if split_size is not None: + self.enable_slicing() + self.slicing_sample_min_size = split_size + self.slicing_latent_min_size = split_size // self.temporal_downsample_factor + else: + self.disable_slicing() + for module in self.modules(): + if isinstance(module, InflatedCausalConv3d): + module.set_memory_device(memory_device) + + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): + set_norm_limit(norm_max_mem) + for m in self.modules(): + if isinstance(m, InflatedCausalConv3d): + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) + + +class VideoAutoencoderKLWrapper(VideoAutoencoderKL): + def __init__( + self, *args, spatial_downsample_factor: int, temporal_downsample_factor: int, **kwargs + ): + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + super().__init__(*args, **kwargs) + + def forward(self, x) -> CausalAutoencoderOutput: + z, _, p = self.encode(x) + x, _ = self.decode(z) + return CausalAutoencoderOutput(x, z, None, p) + + def encode(self, x) -> CausalEncoderOutput: + if x.ndim == 4: + x = x.unsqueeze(2) + p = super().encode(x).latent_dist + z = p.sample().squeeze(2) + return CausalEncoderOutput(z, None, p) + + def decode(self, z) -> CausalDecoderOutput: + if z.ndim == 4: + z = z.unsqueeze(2) + x = super().decode(z).sample.squeeze(2) + return CausalDecoderOutput(x, None) + + def preprocess(self, x): + # x should in [B, C, T, H, W], [B, C, H, W] + assert x.ndim == 4 or x.size(2) % 4 == 1 + return x + + def postprocess(self, x): + # x should in [B, C, T, H, W], [B, C, H, W] + return x + + def set_causal_slicing( + self, + *, + split_size: Optional[int], + memory_device: Optional[Literal["cpu", "same"]], + ): + assert ( + split_size is None or memory_device is not None + ), "if split_size is set, memory_device must not be None." + if split_size is not None: + self.enable_slicing() + else: + self.disable_slicing() + self.slicing_sample_min_size = split_size + if split_size is not None: + self.slicing_latent_min_size = split_size // self.temporal_downsample_factor + for module in self.modules(): + if isinstance(module, InflatedCausalConv3d): + module.set_memory_device(memory_device) \ No newline at end of file diff --git a/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml b/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58309522b791171f9d39f78ea1eaf57bab2a28fe --- /dev/null +++ b/models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml @@ -0,0 +1,33 @@ +__object__: + path: models.video_vae_v3.modules.attn_video_vae + name: VideoAutoencoderKLWrapper + args: as_params + +act_fn: silu +block_out_channels: + - 128 + - 256 + - 512 + - 512 +down_block_types: + - DownEncoderBlock3D + - DownEncoderBlock3D + - DownEncoderBlock3D + - DownEncoderBlock3D +in_channels: 3 +latent_channels: 16 +layers_per_block: 2 +norm_num_groups: 32 +out_channels: 3 +slicing_sample_min_size: 4 +temporal_scale_num: 2 +inflation_mode: pad +up_block_types: + - UpDecoderBlock3D + - UpDecoderBlock3D + - UpDecoderBlock3D + - UpDecoderBlock3D +spatial_downsample_factor: 8 +temporal_downsample_factor: 4 +use_quant_conv: False +use_post_quant_conv: False diff --git a/projects/inference_seedvr2_3b.py b/projects/inference_seedvr2_3b.py new file mode 100644 index 0000000000000000000000000000000000000000..298cb217451bcf939b5bb0134aca63348c2a5639 --- /dev/null +++ b/projects/inference_seedvr2_3b.py @@ -0,0 +1,322 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import os +import torch +import mediapy +from einops import rearrange +from omegaconf import OmegaConf +print(os.getcwd()) +import datetime +from tqdm import tqdm +import gc + + +from data.image.transforms.divisible_crop import DivisibleCrop +from data.image.transforms.na_resize import NaResize +from data.video.transforms.rearrange import Rearrange +if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): + from projects.video_diffusion_sr.color_fix import wavelet_reconstruction + use_colorfix=True +else: + use_colorfix = False + print('Note!!!!!! Color fix is not avaliable!') +from torchvision.transforms import Compose, Lambda, Normalize +from torchvision.io.video import read_video +import argparse + + +from common.distributed import ( + get_device, + init_torch, +) + +from common.distributed.advanced import ( + get_data_parallel_rank, + get_data_parallel_world_size, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + init_sequence_parallel, +) + +from projects.video_diffusion_sr.infer import VideoDiffusionInfer +from common.config import load_config +from common.distributed.ops import sync_data +from common.seed import set_seed +from common.partition import partition_by_groups, partition_by_size + + +def configure_sequence_parallel(sp_size): + if sp_size > 1: + init_sequence_parallel(sp_size) + +def configure_runner(sp_size): + config_path = os.path.join('./configs_3b', 'main.yaml') + config = load_config(config_path) + runner = VideoDiffusionInfer(config) + OmegaConf.set_readonly(runner.config, False) + + init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) + configure_sequence_parallel(sp_size) + runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth') + runner.configure_vae_model() + # Set memory limit. + if hasattr(runner.vae, "set_memory_limit"): + runner.vae.set_memory_limit(**runner.config.vae.memory_limit) + return runner + +def generation_step(runner, text_embeds_dict, cond_latents): + def _move_to_cuda(x): + return [i.to(get_device()) for i in x] + + noises = [torch.randn_like(latent) for latent in cond_latents] + aug_noises = [torch.randn_like(latent) for latent in cond_latents] + print(f"Generating with noise shape: {noises[0].size()}.") + noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) + noises, aug_noises, cond_latents = list( + map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) + ) + cond_noise_scale = 0.0 + + def _add_noise(x, aug_noise): + t = ( + torch.tensor([1000.0], device=get_device()) + * cond_noise_scale + ) + shape = torch.tensor(x.shape[1:], device=get_device())[None] + t = runner.timestep_transform(t, shape) + print( + f"Timestep shifting from" + f" {1000.0 * cond_noise_scale} to {t}." + ) + x = runner.schedule.forward(x, aug_noise, t) + return x + + conditions = [ + runner.get_condition( + noise, + task="sr", + latent_blur=_add_noise(latent_blur, aug_noise), + ) + for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) + ] + + with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): + video_tensors = runner.inference( + noises=noises, + conditions=conditions, + dit_offload=True, + **text_embeds_dict, + ) + + samples = [ + ( + rearrange(video[:, None], "c t h w -> t c h w") + if video.ndim == 3 + else rearrange(video, "c t h w -> t c h w") + ) + for video in video_tensors + ] + del video_tensors + + return samples + +def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, seed=666, res_h=1280, res_w=720, sp_size=1): + + def _build_pos_and_neg_prompt(): + # read positive prompt + positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ + hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ + skin pore detailing, hyper sharpness, perfect without deformations." + # read negative prompt + negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ + CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ + signature, jpeg artifacts, deformed, lowres, over-smooth" + return positive_text, negative_text + + def _build_test_prompts(video_path): + positive_text, negative_text = _build_pos_and_neg_prompt() + original_videos = [] + prompts = {} + video_list = os.listdir(video_path) + for f in video_list: + if f.endswith(".mp4"): + original_videos.append(f) + prompts[f] = positive_text + print(f"Total prompts to be generated: {len(original_videos)}") + return original_videos, prompts, negative_text + + def _extract_text_embeds(): + # Text encoder forward. + positive_prompts_embeds = [] + for texts_pos in tqdm(original_videos_local): + text_pos_embeds = torch.load('pos_emb.pt') + text_neg_embeds = torch.load('neg_emb.pt') + + positive_prompts_embeds.append( + {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} + ) + gc.collect() + torch.cuda.empty_cache() + return positive_prompts_embeds + + def cut_videos(videos, sp_size): + t = videos.size(1) + if t <= 4 * sp_size: + print(f"Cut input video size: {videos.size()}") + padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4 * sp_size) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 * sp_size - ((t - 1) % (4 * sp_size)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4 * sp_size) == 0 + return videos + + # classifier-free guidance + runner.config.diffusion.cfg.scale = cfg_scale + runner.config.diffusion.cfg.rescale = cfg_rescale + # sampling steps + runner.config.diffusion.timesteps.sampling.steps = sample_steps + runner.configure_diffusion() + + # set random seed + set_seed(seed, same_across_ranks=True) + os.makedirs(output_dir, exist_ok=True) + tgt_path = output_dir + + # get test prompts + original_videos, _, _ = _build_test_prompts(video_path) + + # divide the prompts into different groups + original_videos_group = partition_by_groups( + original_videos, + get_data_parallel_world_size() // get_sequence_parallel_world_size(), + ) + # store prompt mapping + original_videos_local = original_videos_group[ + get_data_parallel_rank() // get_sequence_parallel_world_size() + ] + original_videos_local = partition_by_size(original_videos_local, batch_size) + + # pre-extract the text embeddings + positive_prompts_embeds = _extract_text_embeds() + + video_transform = Compose( + [ + NaResize( + resolution=( + res_h * res_w + ) + ** 0.5, + mode="area", + # Upsample image, model only trained for high res. + downsample_only=False, + ), + Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), + DivisibleCrop((16, 16)), + Normalize(0.5, 0.5), + Rearrange("t c h w -> c t h w"), + ] + ) + + # generation loop + for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): + # read condition latents + cond_latents = [] + for video in videos: + video = ( + read_video( + os.path.join(video_path, video), output_format="TCHW" + )[0] + / 255.0 + ) + print(f"Read video size: {video.size()}") + cond_latents.append(video_transform(video.to(get_device()))) + + ori_lengths = [video.size(1) for video in cond_latents] + input_videos = cond_latents + cond_latents = [cut_videos(video, sp_size) for video in cond_latents] + + runner.dit.to("cpu") + print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") + runner.vae.to(get_device()) + cond_latents = runner.vae_encode(cond_latents) + runner.vae.to("cpu") + runner.dit.to(get_device()) + + for i, emb in enumerate(text_embeds["texts_pos"]): + text_embeds["texts_pos"][i] = emb.to(get_device()) + for i, emb in enumerate(text_embeds["texts_neg"]): + text_embeds["texts_neg"][i] = emb.to(get_device()) + + samples = generation_step(runner, text_embeds, cond_latents=cond_latents) + runner.dit.to("cpu") + del cond_latents + + # dump samples to the output directory + if get_sequence_parallel_rank() == 0: + for path, input, sample, ori_length in zip( + videos, input_videos, samples, ori_lengths + ): + if ori_length < sample.shape[0]: + sample = sample[:ori_length] + filename = os.path.join(tgt_path, os.path.basename(path)) + # color fix + input = ( + rearrange(input[:, None], "c t h w -> t c h w") + if input.ndim == 3 + else rearrange(input, "c t h w -> t c h w") + ) + if use_colorfix: + sample = wavelet_reconstruction( + sample.to("cpu"), input[: sample.size(0)].to("cpu") + ) + else: + sample = sample.to("cpu") + sample = ( + rearrange(sample[:, None], "t c h w -> t h w c") + if sample.ndim == 3 + else rearrange(sample, "t c h w -> t h w c") + ) + sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() + sample = sample.to(torch.uint8).numpy() + + if sample.shape[0] == 1: + mediapy.write_image(filename, sample.squeeze(0)) + else: + mediapy.write_video( + filename, sample, fps=24 + ) + gc.collect() + torch.cuda.empty_cache() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--video_path", type=str, default="./test_videos") + parser.add_argument("--output_dir", type=str, default="./results") + parser.add_argument("--seed", type=int, default=666) + parser.add_argument("--res_h", type=int, default=720) + parser.add_argument("--res_w", type=int, default=1280) + parser.add_argument("--sp_size", type=int, default=1) + args = parser.parse_args() + + runner = configure_runner(args.sp_size) + generation_loop(runner, **vars(args)) diff --git a/projects/inference_seedvr2_7b.py b/projects/inference_seedvr2_7b.py new file mode 100644 index 0000000000000000000000000000000000000000..c4b73c25ce91bc0691a34e87d157edde488272cd --- /dev/null +++ b/projects/inference_seedvr2_7b.py @@ -0,0 +1,321 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import os +import torch +import mediapy +from einops import rearrange +from omegaconf import OmegaConf +print(os.getcwd()) +import datetime +from tqdm import tqdm +from models.dit import na +import gc + +from data.image.transforms.divisible_crop import DivisibleCrop +from data.image.transforms.na_resize import NaResize +from data.video.transforms.rearrange import Rearrange +if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): + from projects.video_diffusion_sr.color_fix import wavelet_reconstruction + use_colorfix=True +else: + use_colorfix = False + print('Note!!!!!! Color fix is not avaliable!') +from torchvision.transforms import Compose, Lambda, Normalize +from torchvision.io.video import read_video + + +from common.distributed import ( + get_device, + init_torch, +) + +from common.distributed.advanced import ( + get_data_parallel_rank, + get_data_parallel_world_size, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + init_sequence_parallel, +) + +from projects.video_diffusion_sr.infer import VideoDiffusionInfer +from common.config import load_config +from common.distributed.ops import sync_data +from common.seed import set_seed +from common.partition import partition_by_groups, partition_by_size +import argparse + +def configure_sequence_parallel(sp_size): + if sp_size > 1: + init_sequence_parallel(sp_size) + +def configure_runner(sp_size): + config_path = os.path.join('./configs_7b', 'main.yaml') + config = load_config(config_path) + runner = VideoDiffusionInfer(config) + OmegaConf.set_readonly(runner.config, False) + + init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) + configure_sequence_parallel(sp_size) + runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_7b.pth') + runner.configure_vae_model() + # Set memory limit. + if hasattr(runner.vae, "set_memory_limit"): + runner.vae.set_memory_limit(**runner.config.vae.memory_limit) + return runner + +def generation_step(runner, text_embeds_dict, cond_latents): + def _move_to_cuda(x): + return [i.to(get_device()) for i in x] + + noises = [torch.randn_like(latent) for latent in cond_latents] + aug_noises = [torch.randn_like(latent) for latent in cond_latents] + print(f"Generating with noise shape: {noises[0].size()}.") + noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) + noises, aug_noises, cond_latents = list( + map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) + ) + cond_noise_scale = 0.0 + + def _add_noise(x, aug_noise): + t = ( + torch.tensor([1000.0], device=get_device()) + * cond_noise_scale + ) + shape = torch.tensor(x.shape[1:], device=get_device())[None] + t = runner.timestep_transform(t, shape) + print( + f"Timestep shifting from" + f" {1000.0 * cond_noise_scale} to {t}." + ) + x = runner.schedule.forward(x, aug_noise, t) + return x + + conditions = [ + runner.get_condition( + noise, + task="sr", + latent_blur=_add_noise(latent_blur, aug_noise), + ) + for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) + ] + + with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): + video_tensors = runner.inference( + noises=noises, + conditions=conditions, + dit_offload=True, + **text_embeds_dict, + ) + + samples = [ + ( + rearrange(video[:, None], "c t h w -> t c h w") + if video.ndim == 3 + else rearrange(video, "c t h w -> t c h w") + ) + for video in video_tensors + ] + del video_tensors + + return samples + +def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=1.0, cfg_rescale=0.0, sample_steps=1, seed=666, res_h=1280, res_w=720, sp_size=1): + + def _build_pos_and_neg_prompt(): + # read positive prompt + positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ + hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ + skin pore detailing, hyper sharpness, perfect without deformations." + # read negative prompt + negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ + CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ + signature, jpeg artifacts, deformed, lowres, over-smooth" + return positive_text, negative_text + + def _build_test_prompts(video_path): + positive_text, negative_text = _build_pos_and_neg_prompt() + original_videos = [] + prompts = {} + video_list = os.listdir(video_path) + for f in video_list: + if f.endswith(".mp4"): + original_videos.append(f) + prompts[f] = positive_text + print(f"Total prompts to be generated: {len(original_videos)}") + return original_videos, prompts, negative_text + + def _extract_text_embeds(): + # Text encoder forward. + positive_prompts_embeds = [] + for texts_pos in tqdm(original_videos_local): + text_pos_embeds = torch.load('pos_emb.pt') + text_neg_embeds = torch.load('neg_emb.pt') + + positive_prompts_embeds.append( + {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} + ) + gc.collect() + torch.cuda.empty_cache() + return positive_prompts_embeds + + def cut_videos(videos, sp_size): + t = videos.size(1) + if t <= 4 * sp_size: + print(f"Cut input video size: {videos.size()}") + padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4 * sp_size) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 * sp_size - ((t - 1) % (4 * sp_size)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4 * sp_size) == 0 + return videos + + # classifier-free guidance + runner.config.diffusion.cfg.scale = cfg_scale + runner.config.diffusion.cfg.rescale = cfg_rescale + # sampling steps + runner.config.diffusion.timesteps.sampling.steps = sample_steps + runner.configure_diffusion() + + # set random seed + set_seed(seed, same_across_ranks=True) + os.makedirs(output_dir, exist_ok=True) + tgt_path = output_dir + + # get test prompts + original_videos, _, _ = _build_test_prompts(video_path) + + # divide the prompts into different groups + original_videos_group = partition_by_groups( + original_videos, + get_data_parallel_world_size() // get_sequence_parallel_world_size(), + ) + # store prompt mapping + original_videos_local = original_videos_group[ + get_data_parallel_rank() // get_sequence_parallel_world_size() + ] + original_videos_local = partition_by_size(original_videos_local, batch_size) + + # pre-extract the text embeddings + positive_prompts_embeds = _extract_text_embeds() + + video_transform = Compose( + [ + NaResize( + resolution=( + res_h * res_w + ) + ** 0.5, + mode="area", + # Upsample image, model only trained for high res. + downsample_only=False, + ), + Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), + DivisibleCrop((16, 16)), + Normalize(0.5, 0.5), + Rearrange("t c h w -> c t h w"), + ] + ) + + # generation loop + for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): + # read condition latents + cond_latents = [] + for video in videos: + video = ( + read_video( + os.path.join(video_path, video), output_format="TCHW" + )[0] + / 255.0 + ) + print(f"Read video size: {video.size()}") + cond_latents.append(video_transform(video.to(get_device()))) + + ori_lengths = [video.size(1) for video in cond_latents] + input_videos = cond_latents + cond_latents = [cut_videos(video, sp_size) for video in cond_latents] + + runner.dit.to("cpu") + print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") + runner.vae.to(get_device()) + cond_latents = runner.vae_encode(cond_latents) + runner.vae.to("cpu") + runner.dit.to(get_device()) + + for i, emb in enumerate(text_embeds["texts_pos"]): + text_embeds["texts_pos"][i] = emb.to(get_device()) + for i, emb in enumerate(text_embeds["texts_neg"]): + text_embeds["texts_neg"][i] = emb.to(get_device()) + + samples = generation_step(runner, text_embeds, cond_latents=cond_latents) + runner.dit.to("cpu") + del cond_latents + + # dump samples to the output directory + if get_sequence_parallel_rank() == 0: + for path, input, sample, ori_length in zip( + videos, input_videos, samples, ori_lengths + ): + if ori_length < sample.shape[0]: + sample = sample[:ori_length] + filename = os.path.join(tgt_path, os.path.basename(path)) + # color fix + input = ( + rearrange(input[:, None], "c t h w -> t c h w") + if input.ndim == 3 + else rearrange(input, "c t h w -> t c h w") + ) + if use_colorfix: + sample = wavelet_reconstruction( + sample.to("cpu"), input[: sample.size(0)].to("cpu") + ) + else: + sample = sample.to("cpu") + sample = ( + rearrange(sample[:, None], "t c h w -> t h w c") + if sample.ndim == 3 + else rearrange(sample, "t c h w -> t h w c") + ) + sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() + sample = sample.to(torch.uint8).numpy() + + if sample.shape[0] == 1: + mediapy.write_image(filename, sample.squeeze(0)) + else: + mediapy.write_video( + filename, sample, fps=24 + ) + gc.collect() + torch.cuda.empty_cache() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--video_path", type=str, default="./test_videos") + parser.add_argument("--output_dir", type=str, default="./results") + parser.add_argument("--seed", type=int, default=666) + parser.add_argument("--res_h", type=int, default=720) + parser.add_argument("--res_w", type=int, default=1280) + parser.add_argument("--sp_size", type=int, default=1) + args = parser.parse_args() + + runner = configure_runner(args.sp_size) + generation_loop(runner, **vars(args)) diff --git a/projects/inference_seedvr_3b.py b/projects/inference_seedvr_3b.py new file mode 100644 index 0000000000000000000000000000000000000000..469a97d8dac0769d208be21943b5f7215b249380 --- /dev/null +++ b/projects/inference_seedvr_3b.py @@ -0,0 +1,323 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import os +import torch +import mediapy +from einops import rearrange +from omegaconf import OmegaConf +print(os.getcwd()) +import datetime +from tqdm import tqdm +import gc + +from data.image.transforms.divisible_crop import DivisibleCrop +from data.image.transforms.na_resize import NaResize +from data.video.transforms.rearrange import Rearrange +if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): + from projects.video_diffusion_sr.color_fix import wavelet_reconstruction + use_colorfix=True +else: + use_colorfix = False + print('Note!!!!!! Color fix is not avaliable!') +from torchvision.transforms import Compose, Lambda, Normalize +from torchvision.io.video import read_video +import argparse + +from common.distributed import ( + get_device, + init_torch, +) + +from common.distributed.advanced import ( + get_data_parallel_rank, + get_data_parallel_world_size, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + init_sequence_parallel, +) + +from projects.video_diffusion_sr.infer import VideoDiffusionInfer +from common.config import load_config +from common.distributed.ops import sync_data +from common.seed import set_seed +from common.partition import partition_by_groups, partition_by_size + + +def configure_sequence_parallel(sp_size): + if sp_size > 1: + init_sequence_parallel(sp_size) + +def configure_runner(sp_size): + config_path = os.path.join('./configs_3b', 'main.yaml') + config = load_config(config_path) + runner = VideoDiffusionInfer(config) + OmegaConf.set_readonly(runner.config, False) + + init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) + configure_sequence_parallel(sp_size) + runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr_ema_3b.pth') + runner.configure_vae_model() + # Set memory limit. + if hasattr(runner.vae, "set_memory_limit"): + runner.vae.set_memory_limit(**runner.config.vae.memory_limit) + return runner + +def generation_step(runner, text_embeds_dict, cond_latents): + def _move_to_cuda(x): + return [i.to(get_device()) for i in x] + + noises = [torch.randn_like(latent) for latent in cond_latents] + aug_noises = [torch.randn_like(latent) for latent in cond_latents] + print(f"Generating with noise shape: {noises[0].size()}.") + noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) + noises, aug_noises, cond_latents = list( + map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) + ) + cond_noise_scale = 0.1 + + def _add_noise(x, aug_noise): + t = ( + torch.tensor([1000.0], device=get_device()) + * cond_noise_scale + ) + shape = torch.tensor(x.shape[1:], device=get_device())[None] + t = runner.timestep_transform(t, shape) + print( + f"Timestep shifting from" + f" {1000.0 * cond_noise_scale} to {t}." + ) + x = runner.schedule.forward(x, aug_noise, t) + return x + + conditions = [ + runner.get_condition( + noise, + task="sr", + latent_blur=_add_noise(latent_blur, aug_noise), + ) + for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) + ] + + with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): + video_tensors = runner.inference( + noises=noises, + conditions=conditions, + dit_offload=True, + **text_embeds_dict, + ) + + samples = [ + ( + rearrange(video[:, None], "c t h w -> t c h w") + if video.ndim == 3 + else rearrange(video, "c t h w -> t c h w") + ) + for video in video_tensors + ] + del video_tensors + + return samples + +def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=6.5, cfg_rescale=0.0, sample_steps=50, seed=666, res_h=1280, res_w=720, sp_size=1): + + def _build_pos_and_neg_prompt(): + # read positive prompt + positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ + hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ + skin pore detailing, hyper sharpness, perfect without deformations." + # read negative prompt + negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ + CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ + signature, jpeg artifacts, deformed, lowres, over-smooth" + return positive_text, negative_text + + def _build_test_prompts(video_path): + positive_text, negative_text = _build_pos_and_neg_prompt() + original_videos = [] + prompts = {} + video_list = os.listdir(video_path) + for f in video_list: + if f.endswith(".mp4"): + original_videos.append(f) + prompts[f] = positive_text + print(f"Total prompts to be generated: {len(original_videos)}") + return original_videos, prompts, negative_text + + def _extract_text_embeds(): + # Text encoder forward. + positive_prompts_embeds = [] + for texts_pos in tqdm(original_videos_local): + text_pos_embeds = torch.load('pos_emb.pt') + text_neg_embeds = torch.load('neg_emb.pt') + + positive_prompts_embeds.append( + {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} + ) + gc.collect() + torch.cuda.empty_cache() + return positive_prompts_embeds + + def cut_videos(videos, sp_size): + t = videos.size(1) + if t <= 4 * sp_size: + print(f"Cut input video size: {videos.size()}") + padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4 * sp_size) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 * sp_size - ((t - 1) % (4 * sp_size)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4 * sp_size) == 0 + return videos + + # classifier-free guidance + runner.config.diffusion.cfg.scale = cfg_scale + runner.config.diffusion.cfg.rescale = cfg_rescale + # sampling steps + runner.config.diffusion.timesteps.sampling.steps = sample_steps + runner.configure_diffusion() + + # set random seed + set_seed(seed, same_across_ranks=True) + os.makedirs(output_dir, exist_ok=True) + tgt_path = output_dir + + # get test prompts + original_videos, _, _ = _build_test_prompts(video_path) + + # divide the prompts into different groups + original_videos_group = partition_by_groups( + original_videos, + get_data_parallel_world_size() // get_sequence_parallel_world_size(), + ) + # store prompt mapping + original_videos_local = original_videos_group[ + get_data_parallel_rank() // get_sequence_parallel_world_size() + ] + original_videos_local = partition_by_size(original_videos_local, batch_size) + + # pre-extract the text embeddings + positive_prompts_embeds = _extract_text_embeds() + + video_transform = Compose( + [ + NaResize( + resolution=( + res_h * res_w + ) + ** 0.5, + mode="area", + # Upsample image, model only trained for high res. + downsample_only=False, + ), + Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), + DivisibleCrop((16, 16)), + Normalize(0.5, 0.5), + Rearrange("t c h w -> c t h w"), + ] + ) + + # generation loop + for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): + # read condition latents + cond_latents = [] + for video in videos: + video = ( + read_video( + os.path.join(video_path, video), output_format="TCHW" + )[0] + / 255.0 + ) + print(f"Read video size: {video.size()}") + cond_latents.append(video_transform(video.to(get_device()))) + + ori_lengths = [video.size(1) for video in cond_latents] + input_videos = cond_latents + cond_latents = [cut_videos(video, sp_size) for video in cond_latents] + + runner.dit.to("cpu") + print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") + runner.vae.to(get_device()) + cond_latents = runner.vae_encode(cond_latents) + runner.vae.to("cpu") + runner.dit.to(get_device()) + + for i, emb in enumerate(text_embeds["texts_pos"]): + text_embeds["texts_pos"][i] = emb.to(get_device()) + for i, emb in enumerate(text_embeds["texts_neg"]): + text_embeds["texts_neg"][i] = emb.to(get_device()) + + samples = generation_step(runner, text_embeds, cond_latents=cond_latents) + runner.dit.to("cpu") + del cond_latents + + # dump samples to the output directory + if get_sequence_parallel_rank() == 0: + for path, input, sample, ori_length in zip( + videos, input_videos, samples, ori_lengths + ): + if ori_length < sample.shape[0]: + sample = sample[:ori_length] + filename = os.path.join(tgt_path, os.path.basename(path)) + # color fix + input = ( + rearrange(input[:, None], "c t h w -> t c h w") + if input.ndim == 3 + else rearrange(input, "c t h w -> t c h w") + ) + if use_colorfix: + sample = wavelet_reconstruction( + sample.to("cpu"), input[: sample.size(0)].to("cpu") + ) + else: + sample = sample.to("cpu") + sample = ( + rearrange(sample[:, None], "t c h w -> t h w c") + if sample.ndim == 3 + else rearrange(sample, "t c h w -> t h w c") + ) + sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() + sample = sample.to(torch.uint8).numpy() + + if sample.shape[0] == 1: + mediapy.write_image(filename, sample.squeeze(0)) + else: + mediapy.write_video( + filename, sample, fps=24 + ) + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--video_path", type=str, default="./test_videos") + parser.add_argument("--output_dir", type=str, default="./results") + parser.add_argument("--cfg_scale", type=float, default=6.5) + parser.add_argument("--sample_steps", type=int, default=50) + parser.add_argument("--seed", type=int, default=666) + parser.add_argument("--res_h", type=int, default=720) + parser.add_argument("--res_w", type=int, default=1280) + parser.add_argument("--sp_size", type=int, default=1) + args = parser.parse_args() + + runner = configure_runner(args.sp_size) + generation_loop(runner, **vars(args)) diff --git a/projects/inference_seedvr_7b.py b/projects/inference_seedvr_7b.py new file mode 100644 index 0000000000000000000000000000000000000000..1408c9ca6f1b40ff522611f59937c968e4f15b44 --- /dev/null +++ b/projects/inference_seedvr_7b.py @@ -0,0 +1,324 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import os +import torch +import mediapy +from einops import rearrange +from omegaconf import OmegaConf +print(os.getcwd()) +import datetime +from tqdm import tqdm +from models.dit import na +import gc + +from data.image.transforms.divisible_crop import DivisibleCrop +from data.image.transforms.na_resize import NaResize +from data.video.transforms.rearrange import Rearrange +if os.path.exists("./projects/video_diffusion_sr/color_fix.py"): + from projects.video_diffusion_sr.color_fix import wavelet_reconstruction + use_colorfix=True +else: + use_colorfix = False + print('Note!!!!!! Color fix is not avaliable!') +from torchvision.transforms import Compose, Lambda, Normalize +from torchvision.io.video import read_video +import argparse + + +from common.distributed import ( + get_device, + init_torch, +) + +from common.distributed.advanced import ( + get_data_parallel_rank, + get_data_parallel_world_size, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + init_sequence_parallel, +) + +from projects.video_diffusion_sr.infer import VideoDiffusionInfer +from common.config import load_config +from common.distributed.ops import sync_data +from common.seed import set_seed +from common.partition import partition_by_groups, partition_by_size + + +def configure_sequence_parallel(sp_size): + if sp_size > 1: + init_sequence_parallel(sp_size) + +def configure_runner(sp_size): + config_path = os.path.join('./configs_7b', 'main.yaml') + config = load_config(config_path) + runner = VideoDiffusionInfer(config) + OmegaConf.set_readonly(runner.config, False) + + init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600)) + configure_sequence_parallel(sp_size) + runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr_ema_7b.pth') + runner.configure_vae_model() + # Set memory limit. + if hasattr(runner.vae, "set_memory_limit"): + runner.vae.set_memory_limit(**runner.config.vae.memory_limit) + return runner + +def generation_step(runner, text_embeds_dict, cond_latents): + def _move_to_cuda(x): + return [i.to(get_device()) for i in x] + + noises = [torch.randn_like(latent) for latent in cond_latents] + aug_noises = [torch.randn_like(latent) for latent in cond_latents] + print(f"Generating with noise shape: {noises[0].size()}.") + noises, aug_noises, cond_latents = sync_data((noises, aug_noises, cond_latents), 0) + noises, aug_noises, cond_latents = list( + map(lambda x: _move_to_cuda(x), (noises, aug_noises, cond_latents)) + ) + cond_noise_scale = 0.1 + + def _add_noise(x, aug_noise): + t = ( + torch.tensor([1000.0], device=get_device()) + * cond_noise_scale + ) + shape = torch.tensor(x.shape[1:], device=get_device())[None] + t = runner.timestep_transform(t, shape) + print( + f"Timestep shifting from" + f" {1000.0 * cond_noise_scale} to {t}." + ) + x = runner.schedule.forward(x, aug_noise, t) + return x + + conditions = [ + runner.get_condition( + noise, + task="sr", + latent_blur=_add_noise(latent_blur, aug_noise), + ) + for noise, aug_noise, latent_blur in zip(noises, aug_noises, cond_latents) + ] + + with torch.no_grad(), torch.autocast("cuda", torch.bfloat16, enabled=True): + video_tensors = runner.inference( + noises=noises, + conditions=conditions, + dit_offload=True, + **text_embeds_dict, + ) + + samples = [ + ( + rearrange(video[:, None], "c t h w -> t c h w") + if video.ndim == 3 + else rearrange(video, "c t h w -> t c h w") + ) + for video in video_tensors + ] + del video_tensors + + return samples + +def generation_loop(runner, video_path='./test_videos', output_dir='./results', batch_size=1, cfg_scale=6.5, cfg_rescale=0.0, sample_steps=50, seed=666, res_h=1280, res_w=720, sp_size=1): + + def _build_pos_and_neg_prompt(): + # read positive prompt + positive_text = "Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera, \ + hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing, \ + skin pore detailing, hyper sharpness, perfect without deformations." + # read negative prompt + negative_text = "painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, \ + CG Style, 3D render, unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, \ + signature, jpeg artifacts, deformed, lowres, over-smooth" + return positive_text, negative_text + + def _build_test_prompts(video_path): + positive_text, negative_text = _build_pos_and_neg_prompt() + original_videos = [] + prompts = {} + video_list = os.listdir(video_path) + for f in video_list: + if f.endswith(".mp4"): + original_videos.append(f) + prompts[f] = positive_text + print(f"Total prompts to be generated: {len(original_videos)}") + return original_videos, prompts, negative_text + + def _extract_text_embeds(): + # Text encoder forward. + positive_prompts_embeds = [] + for texts_pos in tqdm(original_videos_local): + text_pos_embeds = torch.load('pos_emb.pt') + text_neg_embeds = torch.load('neg_emb.pt') + + positive_prompts_embeds.append( + {"texts_pos": [text_pos_embeds], "texts_neg": [text_neg_embeds]} + ) + gc.collect() + torch.cuda.empty_cache() + return positive_prompts_embeds + + def cut_videos(videos, sp_size): + t = videos.size(1) + if t <= 4 * sp_size: + print(f"Cut input video size: {videos.size()}") + padding = [videos[:, -1].unsqueeze(1)] * (4 * sp_size - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4 * sp_size) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 * sp_size - ((t - 1) % (4 * sp_size)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4 * sp_size) == 0 + return videos + + # classifier-free guidance + runner.config.diffusion.cfg.scale = cfg_scale + runner.config.diffusion.cfg.rescale = cfg_rescale + # sampling steps + runner.config.diffusion.timesteps.sampling.steps = sample_steps + runner.configure_diffusion() + + # set random seed + set_seed(seed, same_across_ranks=True) + os.makedirs(output_dir, exist_ok=True) + tgt_path = output_dir + + # get test prompts + original_videos, _, _ = _build_test_prompts(video_path) + + # divide the prompts into different groups + original_videos_group = partition_by_groups( + original_videos, + get_data_parallel_world_size() // get_sequence_parallel_world_size(), + ) + # store prompt mapping + original_videos_local = original_videos_group[ + get_data_parallel_rank() // get_sequence_parallel_world_size() + ] + original_videos_local = partition_by_size(original_videos_local, batch_size) + + # pre-extract the text embeddings + positive_prompts_embeds = _extract_text_embeds() + + video_transform = Compose( + [ + NaResize( + resolution=( + res_h * res_w + ) + ** 0.5, + mode="area", + # Upsample image, model only trained for high res. + downsample_only=False, + ), + Lambda(lambda x: torch.clamp(x, 0.0, 1.0)), + DivisibleCrop((16, 16)), + Normalize(0.5, 0.5), + Rearrange("t c h w -> c t h w"), + ] + ) + + # generation loop + for videos, text_embeds in tqdm(zip(original_videos_local, positive_prompts_embeds)): + # read condition latents + cond_latents = [] + for video in videos: + video = ( + read_video( + os.path.join(video_path, video), output_format="TCHW" + )[0] + / 255.0 + ) + print(f"Read video size: {video.size()}") + cond_latents.append(video_transform(video.to(get_device()))) + + ori_lengths = [video.size(1) for video in cond_latents] + input_videos = cond_latents + cond_latents = [cut_videos(video, sp_size) for video in cond_latents] + + runner.dit.to("cpu") + print(f"Encoding videos: {list(map(lambda x: x.size(), cond_latents))}") + runner.vae.to(get_device()) + cond_latents = runner.vae_encode(cond_latents) + runner.vae.to("cpu") + runner.dit.to(get_device()) + + for i, emb in enumerate(text_embeds["texts_pos"]): + text_embeds["texts_pos"][i] = emb.to(get_device()) + for i, emb in enumerate(text_embeds["texts_neg"]): + text_embeds["texts_neg"][i] = emb.to(get_device()) + + samples = generation_step(runner, text_embeds, cond_latents=cond_latents) + runner.dit.to("cpu") + del cond_latents + + # dump samples to the output directory + if get_sequence_parallel_rank() == 0: + for path, input, sample, ori_length in zip( + videos, input_videos, samples, ori_lengths + ): + if ori_length < sample.shape[0]: + sample = sample[:ori_length] + filename = os.path.join(tgt_path, os.path.basename(path)) + # color fix + input = ( + rearrange(input[:, None], "c t h w -> t c h w") + if input.ndim == 3 + else rearrange(input, "c t h w -> t c h w") + ) + if use_colorfix: + sample = wavelet_reconstruction( + sample.to("cpu"), input[: sample.size(0)].to("cpu") + ) + else: + sample = sample.to("cpu") + sample = ( + rearrange(sample[:, None], "t c h w -> t h w c") + if sample.ndim == 3 + else rearrange(sample, "t c h w -> t h w c") + ) + sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).round() + sample = sample.to(torch.uint8).numpy() + + if sample.shape[0] == 1: + mediapy.write_image(filename, sample.squeeze(0)) + else: + mediapy.write_video( + filename, sample, fps=24 + ) + gc.collect() + torch.cuda.empty_cache() + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--video_path", type=str, default="./test_videos") + parser.add_argument("--output_dir", type=str, default="./results") + parser.add_argument("--cfg_scale", type=float, default=6.5) + parser.add_argument("--sample_steps", type=int, default=50) + parser.add_argument("--seed", type=int, default=666) + parser.add_argument("--res_h", type=int, default=720) + parser.add_argument("--res_w", type=int, default=1280) + parser.add_argument("--sp_size", type=int, default=1) + args = parser.parse_args() + + runner = configure_runner(args.sp_size) + generation_loop(runner, **vars(args)) diff --git a/projects/video_diffusion_sr/color_fix.py b/projects/video_diffusion_sr/color_fix.py new file mode 100644 index 0000000000000000000000000000000000000000..efe804519873717eee01468439c416325eb8e192 --- /dev/null +++ b/projects/video_diffusion_sr/color_fix.py @@ -0,0 +1,113 @@ +import torch +from PIL import Image +from torch import Tensor +from torch.nn import functional as F + +from torchvision.transforms import ToTensor, ToPILImage + +def adain_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply adaptive instance normalization + result_tensor = adaptive_instance_normalization(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def wavelet_color_fix(target: Image, source: Image): + # Convert images to tensors + to_tensor = ToTensor() + target_tensor = to_tensor(target).unsqueeze(0) + source_tensor = to_tensor(source).unsqueeze(0) + + # Apply wavelet reconstruction + result_tensor = wavelet_reconstruction(target_tensor, source_tensor) + + # Convert tensor back to image + to_image = ToPILImage() + result_image = to_image(result_tensor.squeeze(0).clamp_(0.0, 1.0)) + + return result_image + +def calc_mean_std(feat: Tensor, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + +def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): + """Adaptive instance normalization. + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + +def wavelet_blur(image: Tensor, radius: int): + """ + Apply wavelet blur to the input tensor. + """ + # input shape: (1, 3, H, W) + # convolution kernel + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + # add channel dimensions to the kernel to make it a 4D tensor + kernel = kernel[None, None] + # repeat the kernel across all input channels + kernel = kernel.repeat(3, 1, 1, 1) + image = F.pad(image, (radius, radius, radius, radius), mode='replicate') + # apply convolution + output = F.conv2d(image, kernel, groups=3, dilation=radius) + return output + +def wavelet_decomposition(image: Tensor, levels=5): + """ + Apply wavelet decomposition to the input tensor. + This function only returns the low frequency & the high frequency. + """ + high_freq = torch.zeros_like(image) + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq += (image - low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): + """ + Apply wavelet decomposition, so that the content will have the same color as the style. + """ + # calculate the wavelet decomposition of the content feature + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq + # calculate the wavelet decomposition of the style feature + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq + # reconstruct the content feature with the style's high frequency + return content_high_freq + style_low_freq \ No newline at end of file diff --git a/projects/video_diffusion_sr/infer.py b/projects/video_diffusion_sr/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..54bb5fba186f884dd52aed61672b6c675046e42f --- /dev/null +++ b/projects/video_diffusion_sr/infer.py @@ -0,0 +1,342 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +from typing import List, Optional, Tuple, Union +import torch +from einops import rearrange +from omegaconf import DictConfig, ListConfig +from torch import Tensor + +from common.config import create_object +from common.decorators import log_on_entry, log_runtime +from common.diffusion import ( + classifier_free_guidance_dispatcher, + create_sampler_from_config, + create_sampling_timesteps_from_config, + create_schedule_from_config, +) +from common.distributed import ( + get_device, + get_global_rank, +) + +from common.distributed.meta_init_utils import ( + meta_non_persistent_buffer_init_fn, +) +# from common.fs import download + +from models.dit_v2 import na + +class VideoDiffusionInfer(): + def __init__(self, config: DictConfig): + self.config = config + self.device = "cuda" + + def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor: + t, h, w, c = latent.shape + cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + if task == "t2v" or t == 1: + # t2i or t2v generation. + if task == "sr": + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 + return cond + if task == "i2v": + # i2v generation. + cond[:1, ..., :-1] = latent[:1] + cond[:1, ..., -1:] = 1.0 + return cond + if task == "v2v": + # v2v frame extension. + cond[:2, ..., :-1] = latent[:2] + cond[:2, ..., -1:] = 1.0 + return cond + if task == "sr": + # sr generation. + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 + return cond + raise NotImplementedError + + @log_on_entry + @log_runtime + def configure_dit_model(self, device="cuda", checkpoint=None): + # Load dit checkpoint. + # For fast init & resume, + # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP. + # otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True. + + # Create dit model. + with torch.device(self.device): + self.dit = create_object(self.config.dit.model) + self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint) + + if checkpoint: + state = torch.load(checkpoint, map_location=self.device, mmap=True) + loading_info = self.dit.load_state_dict(state, strict=True, assign=True) + print(f"Loading pretrained ckpt from {checkpoint}") + print(f"Loading info: {loading_info}") + self.dit = meta_non_persistent_buffer_init_fn(self.dit) + + # Print model size. + num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad) + print(f"DiT trainable parameters: {num_params:,}") + + @log_on_entry + @log_runtime + def configure_vae_model(self): + # Create vae model. + dtype = getattr(torch, self.config.vae.dtype) + self.vae = create_object(self.config.vae.model) + self.vae.requires_grad_(False).eval() + self.vae.to(device=get_device(), dtype=dtype) + + # Load vae checkpoint. + state = torch.load( + self.config.vae.checkpoint, map_location=get_device(), mmap=True + ) + self.vae.load_state_dict(state) + + # Set causal slicing. + if hasattr(self.vae, "set_causal_slicing") and hasattr(self.config.vae, "slicing"): + self.vae.set_causal_slicing(**self.config.vae.slicing) + + # ------------------------------ Diffusion ------------------------------ # + + def configure_diffusion(self): + self.schedule = create_schedule_from_config( + config=self.config.diffusion.schedule, + device=get_device(), + ) + self.sampling_timesteps = create_sampling_timesteps_from_config( + config=self.config.diffusion.timesteps.sampling, + schedule=self.schedule, + device=get_device(), + ) + self.sampler = create_sampler_from_config( + config=self.config.diffusion.sampler, + schedule=self.schedule, + timesteps=self.sampling_timesteps, + ) + + # -------------------------------- Helper ------------------------------- # + + @torch.no_grad() + def vae_encode(self, samples: List[Tensor]) -> List[Tensor]: + use_sample = self.config.vae.get("use_sample", True) + latents = [] + if len(samples) > 0: + device = get_device() + dtype = getattr(torch, self.config.vae.dtype) + scale = self.config.vae.scaling_factor + shift = self.config.vae.get("shifting_factor", 0.0) + + if isinstance(scale, ListConfig): + scale = torch.tensor(scale, device=device, dtype=dtype) + if isinstance(shift, ListConfig): + shift = torch.tensor(shift, device=device, dtype=dtype) + + # Group samples of the same shape to batches if enabled. + if self.config.vae.grouping: + batches, indices = na.pack(samples) + else: + batches = [sample.unsqueeze(0) for sample in samples] + + # Vae process by each group. + for sample in batches: + sample = sample.to(device, dtype) + if hasattr(self.vae, "preprocess"): + sample = self.vae.preprocess(sample) + if use_sample: + latent = self.vae.encode(sample).latent + else: + # Deterministic vae encode, only used for i2v inference (optionally) + latent = self.vae.encode(sample).posterior.mode().squeeze(2) + latent = latent.unsqueeze(2) if latent.ndim == 4 else latent + latent = rearrange(latent, "b c ... -> b ... c") + latent = (latent - shift) * scale + latents.append(latent) + + # Ungroup back to individual latent with the original order. + if self.config.vae.grouping: + latents = na.unpack(latents, indices) + else: + latents = [latent.squeeze(0) for latent in latents] + + return latents + + @torch.no_grad() + def vae_decode(self, latents: List[Tensor]) -> List[Tensor]: + samples = [] + if len(latents) > 0: + device = get_device() + dtype = getattr(torch, self.config.vae.dtype) + scale = self.config.vae.scaling_factor + shift = self.config.vae.get("shifting_factor", 0.0) + + if isinstance(scale, ListConfig): + scale = torch.tensor(scale, device=device, dtype=dtype) + if isinstance(shift, ListConfig): + shift = torch.tensor(shift, device=device, dtype=dtype) + + # Group latents of the same shape to batches if enabled. + if self.config.vae.grouping: + latents, indices = na.pack(latents) + else: + latents = [latent.unsqueeze(0) for latent in latents] + + # Vae process by each group. + for latent in latents: + latent = latent.to(device, dtype) + latent = latent / scale + shift + latent = rearrange(latent, "b ... c -> b c ...") + latent = latent.squeeze(2) + sample = self.vae.decode(latent).sample + if hasattr(self.vae, "postprocess"): + sample = self.vae.postprocess(sample) + samples.append(sample) + + # Ungroup back to individual sample with the original order. + if self.config.vae.grouping: + samples = na.unpack(samples, indices) + else: + samples = [sample.squeeze(0) for sample in samples] + + return samples + + def timestep_transform(self, timesteps: Tensor, latents_shapes: Tensor): + # Skip if not needed. + if not self.config.diffusion.timesteps.get("transform", False): + return timesteps + + # Compute resolution. + vt = self.config.vae.model.get("temporal_downsample_factor", 4) + vs = self.config.vae.model.get("spatial_downsample_factor", 8) + frames = (latents_shapes[:, 0] - 1) * vt + 1 + heights = latents_shapes[:, 1] * vs + widths = latents_shapes[:, 2] * vs + + # Compute shift factor. + def get_lin_function(x1, y1, x2, y2): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2) + vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0) + shift = torch.where( + frames > 1, + vid_shift_fn(heights * widths * frames), + img_shift_fn(heights * widths), + ) + + # Shift timesteps. + timesteps = timesteps / self.schedule.T + timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) + timesteps = timesteps * self.schedule.T + return timesteps + + @torch.no_grad() + def inference( + self, + noises: List[Tensor], + conditions: List[Tensor], + texts_pos: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + texts_neg: Union[List[str], List[Tensor], List[Tuple[Tensor]]], + cfg_scale: Optional[float] = None, + dit_offload: bool = False, + ) -> List[Tensor]: + assert len(noises) == len(conditions) == len(texts_pos) == len(texts_neg) + batch_size = len(noises) + + # Return if empty. + if batch_size == 0: + return [] + + # Set cfg scale + if cfg_scale is None: + cfg_scale = self.config.diffusion.cfg.scale + + # Text embeddings. + assert type(texts_pos[0]) is type(texts_neg[0]) + if isinstance(texts_pos[0], str): + text_pos_embeds, text_pos_shapes = self.text_encode(texts_pos) + text_neg_embeds, text_neg_shapes = self.text_encode(texts_neg) + elif isinstance(texts_pos[0], tuple): + text_pos_embeds, text_pos_shapes = [], [] + text_neg_embeds, text_neg_shapes = [], [] + for pos in zip(*texts_pos): + emb, shape = na.flatten(pos) + text_pos_embeds.append(emb) + text_pos_shapes.append(shape) + for neg in zip(*texts_neg): + emb, shape = na.flatten(neg) + text_neg_embeds.append(emb) + text_neg_shapes.append(shape) + else: + text_pos_embeds, text_pos_shapes = na.flatten(texts_pos) + text_neg_embeds, text_neg_shapes = na.flatten(texts_neg) + + # Flatten. + latents, latents_shapes = na.flatten(noises) + latents_cond, _ = na.flatten(conditions) + + # Enter eval mode. + was_training = self.dit.training + self.dit.eval() + + # Sampling. + latents = self.sampler.sample( + x=latents, + f=lambda args: classifier_free_guidance_dispatcher( + pos=lambda: self.dit( + vid=torch.cat([args.x_t, latents_cond], dim=-1), + txt=text_pos_embeds, + vid_shape=latents_shapes, + txt_shape=text_pos_shapes, + timestep=args.t.repeat(batch_size), + ).vid_sample, + neg=lambda: self.dit( + vid=torch.cat([args.x_t, latents_cond], dim=-1), + txt=text_neg_embeds, + vid_shape=latents_shapes, + txt_shape=text_neg_shapes, + timestep=args.t.repeat(batch_size), + ).vid_sample, + scale=( + cfg_scale + if (args.i + 1) / len(self.sampler.timesteps) + <= self.config.diffusion.cfg.get("partial", 1) + else 1.0 + ), + rescale=self.config.diffusion.cfg.rescale, + ), + ) + + # Exit eval mode. + self.dit.train(was_training) + + # Unflatten. + latents = na.unflatten(latents, latents_shapes) + + if dit_offload: + self.dit.to("cpu") + + # Vae decode. + self.vae.to(get_device()) + samples = self.vae_decode(latents) + + if dit_offload: + self.dit.to(get_device()) + return samples \ No newline at end of file diff --git a/projects/video_diffusion_sr/utils.py b/projects/video_diffusion_sr/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae48d2662d7ed8630579cb52b97fc0e256335a5a --- /dev/null +++ b/projects/video_diffusion_sr/utils.py @@ -0,0 +1,368 @@ +# // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates +# // +# // Licensed under the Apache License, Version 2.0 (the "License"); +# // you may not use this file except in compliance with the License. +# // You may obtain a copy of the License at +# // +# // http://www.apache.org/licenses/LICENSE-2.0 +# // +# // Unless required by applicable law or agreed to in writing, software +# // distributed under the License is distributed on an "AS IS" BASIS, +# // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# // See the License for the specific language governing permissions and +# // limitations under the License. + +import os +import random +import threading +from abc import ABC +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from functools import partial +from itertools import chain +from typing import Any, Dict, List, Optional, Tuple, Union +import pyarrow as pa +import pyarrow.parquet as pq +from omegaconf import DictConfig + +from common.distributed import get_global_rank, get_world_size +from common.fs import copy, exists, listdir, mkdir, remove +from common.partition import partition_by_groups +from common.persistence.utils import get_local_path +from data.common.parquet_sampler import ( + IdentityParquetSampler, + ParquetSampler, + create_parquet_sampler, +) +from data.common.utils import filter_parquets, get_parquet_metadata + + +# Function to save a Parquet file and copy it to a target path +def save_and_copy( + pa_table, + local_path: str, + target_path: str, + row_group_size: int, + executor: ThreadPoolExecutor, + do_async: bool = False, + futures: List[Tuple[threading.Thread, str]] = [], +): + # Function to handle completion of the future + def _make_on_complete(local_path): + def _on_complete(future): + target_path = future.result() + remove(local_path) + # del future + print(f"Target path saved: {target_path}") + + return _on_complete + + # Function to write Parquet table and copy it + def _fn(pa_table, local_path, target_path, row_group_size): + pq.write_table( + pa_table, + local_path, + row_group_size=row_group_size, + ) + mkdir(os.path.dirname(target_path)) + copy(local_path, target_path) + return target_path + + # Submit the task to the executor + future = executor.submit(_fn, pa_table, local_path, target_path, row_group_size) + future.add_done_callback(_make_on_complete(local_path)) + futures.append(future) + + # If not asynchronous, wait for all futures to complete + if not do_async: + for future in as_completed(futures): + try: + future.result() + except Exception as exc: + print(f"Generated an exception: {exc}") + executor.shutdown(wait=True) + + +@dataclass +class FileListOutput: + existing_files: List[str] + source_files: List[Any] + target_files: List[str] + + +@dataclass +class PersistedParquet: + path: str + + # Method to save the Parquet file + def save( + self, + row_group_size: int, + executor: ThreadPoolExecutor, + pa_table: Optional[pa.Table] = None, + data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, + is_last_file=False, + futures: List[threading.Thread] = [], + ): + assert (pa_table is None) != (data_dict is None) + local_path = get_local_path(self.path) + if not pa_table: + schema_dict = self.generate_schema_from_dict(data_dict) + pa_table = pa.Table.from_pydict(data_dict, schema=schema_dict) + save_and_copy( + pa_table, + local_path=local_path, + target_path=self.path, + row_group_size=row_group_size, + executor=executor, + do_async=not is_last_file, + futures=futures, + ) + + # Method to generate schema from a dictionary + def generate_schema_from_dict( + self, + data_dict: Dict[str, List[Union[str, bytes]]], + ): + schema_dict = {} + for key, value in data_dict.items(): + if isinstance(value[0], str): + schema_dict[key] = pa.string() + elif isinstance(value[0], bytes): + schema_dict[key] = pa.binary() + else: + raise ValueError(f"Unsupported data type for key '{key}': {type(value)}") + return pa.schema(schema_dict) + + +# Base class for managing Parquet files +class ParquetManager(ABC): + """ + Base class for the DumpingManager and RepackingManager. + """ + + def __init__( + self, + task: Optional[DictConfig] = None, + target_dir: str = ".", + ): + self.task = task + self.target_dir = target_dir.rstrip("/") + self.executor = ThreadPoolExecutor(max_workers=4) + self.futures = [] + + # Method to get list of Parquet files from source path + def get_parquet_files( + self, + source_path: str, + parquet_sampler: ParquetSampler = IdentityParquetSampler(), + path_mode: str = "dir", + ): + + # Helper function to flatten nested lists + def _flatten(paths): + if isinstance(paths, list): + if any(isinstance(i, list) for i in paths): + return list(chain(*paths)) + else: + return paths + else: + return [paths] + + file_paths = _flatten(source_path) + if path_mode == "dir": + file_paths = map(listdir, file_paths) + if isinstance(parquet_sampler.size, float): + file_paths = map(filter_parquets, file_paths) + file_paths = map(parquet_sampler, file_paths) + file_paths = list(chain(*file_paths)) + else: + file_paths = chain(*file_paths) + file_paths = parquet_sampler(filter_parquets(file_paths)) + + return file_paths + + # Method to save a Parquet file + def save_parquet( + self, + *, + file_name: str, + row_group_size: int, + pa_table: Optional[pa.Table] = None, + data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, + override: bool = True, + is_last_file: bool = False, + ): + + persist = self._get_parquet(file_name) + if override or not exists(persist.path): + persist.save( + pa_table=pa_table, + data_dict=data_dict, + executor=self.executor, + row_group_size=row_group_size, + is_last_file=is_last_file, + futures=self.futures, + ) + + # Method to get a PersistedParquet object + def _get_parquet(self, file_name: str) -> PersistedParquet: + return PersistedParquet(file_name) + + +# Class to manage dumping of Parquet files +class DumpingManager(ParquetManager): + """ + Dumping manager handles parquet saving and resuming. + """ + + def __init__( + self, + task: DictConfig, + target_dir: str, + ): + super().__init__(task=task, target_dir=target_dir) + + # Method to generate saving path + def generate_saving_path(self, file_path: str, rsplit: int): + part_list = file_path.rsplit("/", rsplit) + result_folder = "/".join( + [self.target_dir] + [f"epoch_{self.task.epoch}"] + part_list[-rsplit:-1] + ) + result_file = "/".join([result_folder, part_list[-1]]) + return result_folder, result_file + + # Method to configure task paths + def configure_task_path(self, source_path: str, rsplit: int, path_mode: str = "dir"): + + file_paths = self.get_parquet_files( + source_path=source_path, + path_mode=path_mode, + ) + + # Shuffle file paths + random.Random(0).shuffle(file_paths) + + # Partition the file paths based on task configuration + full_source_files = partition_by_groups(file_paths, self.task.total_count)[self.task.index] + full_source_files = partition_by_groups(full_source_files, get_world_size())[ + get_global_rank() + ] + + if not full_source_files: + return FileListOutput([], [], []) + + generate_saving_path = partial(self.generate_saving_path, rsplit=rsplit) + full_paths = map(generate_saving_path, full_source_files) + full_target_folders, full_target_files = map(list, zip(*full_paths)) + full_target_folders = set(full_target_folders) + + existing_file_paths = map( + lambda folder: listdir(folder) if exists(folder) else [], full_target_folders + ) + existing_file_paths = chain(*existing_file_paths) + self.existing_files = list( + filter( + lambda path: path.endswith(".parquet") and path in full_target_files, + existing_file_paths, + ) + ) + + filtered_pairs = list( + filter( + lambda pair: pair[1] not in self.existing_files, + zip(full_source_files, full_target_files), + ) + ) + if filtered_pairs: + filtered_source_files, filtered_target_files = map(list, zip(*filtered_pairs)) + else: + filtered_source_files, filtered_target_files = [], [] + + # Skip existing file paths if specified + skip_exists = self.task.skip_exists + self.source_files = filtered_source_files if skip_exists else full_source_files + self.target_files = filtered_target_files if skip_exists else full_target_files + + return FileListOutput(self.existing_files, self.source_files, self.target_files) + + +class RepackingManager(ParquetManager): + """ + Repacking manager handles parquet spliting and saving. + """ + + def __init__( + self, + task: DictConfig, + target_dir: str, + repackaging: DictConfig, + ): + super().__init__(task=task, target_dir=target_dir) + self.repackaging = repackaging + + # Configure the task paths for repacking + def configure_task_path( + self, + source_path: str, + parquet_sampler: Optional[DictConfig] = None, + path_mode: str = "dir", + ): + + parquet_sampler = create_parquet_sampler(config=parquet_sampler) + file_paths = self.get_parquet_files( + source_path=source_path, + parquet_sampler=parquet_sampler, + path_mode=path_mode, + ) + + random.Random(0).shuffle(file_paths) + target_dir = self.target_dir + size = abs(parquet_sampler.size) + + if self.task: + # Partition the file paths based on task configuration + file_paths = partition_by_groups(file_paths, self.task.total_count)[self.task.index] + target_dir = os.path.join(target_dir, f"{self.task.total_count}_{self.task.index}") + + if size > 1: + size = len( + partition_by_groups(range(size), self.task.total_count)[self.task.index] + ) + + # Get metadata for each Parquet file + metadatas = get_parquet_metadata(file_paths, self.repackaging.num_processes) + + # Create a list of (file_path, row) tuples for each row in the files + target_items = [ + (file_path, row) + for file_path, metadata in zip(file_paths, metadatas) + for row in range(metadata.num_rows) + ] + + # Shuffle the target items + random.Random(0).shuffle(target_items) + + if size > 1: + target_items = target_items[:size] + + # Partition the items into groups for each target file + items_per_file = partition_by_groups(target_items, self.repackaging.num_files) + + # Generate target file paths + target_files = [ + os.path.join(target_dir, f"{str(i).zfill(5)}.parquet") + for i in range(self.repackaging.num_files) + ] + + existing_file_paths = listdir(target_dir) if exists(target_dir) else [] + self.existing_files = list( + filter( + lambda path: path.endswith(".parquet"), + existing_file_paths, + ) + ) + self.source_files = items_per_file + self.target_files = target_files + + return FileListOutput(self.existing_files, self.source_files, self.target_files)