Carlexxx commited on
Commit
c3bf719
·
1 Parent(s): a45010b

feat: Implement self-contained specialist managers

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. common/__init__.py +0 -0
  2. common/cache.py +47 -0
  3. common/config.py +110 -0
  4. common/decorators.py +147 -0
  5. common/diffusion/__init__.py +56 -0
  6. common/diffusion/config.py +74 -0
  7. common/diffusion/samplers/base.py +108 -0
  8. common/diffusion/samplers/euler.py +89 -0
  9. common/diffusion/schedules/base.py +131 -0
  10. common/diffusion/schedules/lerp.py +55 -0
  11. common/diffusion/timesteps/base.py +72 -0
  12. common/diffusion/timesteps/sampling/trailing.py +49 -0
  13. common/diffusion/types.py +59 -0
  14. common/diffusion/utils.py +84 -0
  15. common/distributed/__init__.py +37 -0
  16. common/distributed/advanced.py +208 -0
  17. common/distributed/basic.py +84 -0
  18. common/distributed/meta_init_utils.py +41 -0
  19. common/distributed/ops.py +494 -0
  20. common/logger.py +44 -0
  21. common/partition.py +59 -0
  22. common/seed.py +30 -0
  23. configs_3b/main.yaml +88 -0
  24. configs_7b/main.yaml +85 -0
  25. data/image/transforms/area_resize.py +135 -0
  26. data/image/transforms/divisible_crop.py +40 -0
  27. data/image/transforms/na_resize.py +50 -0
  28. data/image/transforms/side_resize.py +54 -0
  29. data/video/transforms/rearrange.py +24 -0
  30. models/dit/attention.py +46 -0
  31. models/dit/blocks/__init__.py +25 -0
  32. models/dit/blocks/mmdit_window_block.py +233 -0
  33. models/dit/embedding.py +62 -0
  34. models/dit/mlp.py +62 -0
  35. models/dit/mm.py +67 -0
  36. models/dit/modulation.py +97 -0
  37. models/dit/na.py +241 -0
  38. models/dit/nablocks/__init__.py +25 -0
  39. models/dit/nablocks/mmsr_block.py +248 -0
  40. models/dit/nadit.py +350 -0
  41. models/dit/normalization.py +63 -0
  42. models/dit/patch.py +112 -0
  43. models/dit/rope.py +101 -0
  44. models/dit/window.py +83 -0
  45. models/dit_v2/attention.py +46 -0
  46. models/dit_v2/embedding.py +62 -0
  47. models/dit_v2/mlp.py +62 -0
  48. models/dit_v2/mm.py +74 -0
  49. models/dit_v2/modulation.py +102 -0
  50. models/dit_v2/na.py +241 -0
common/__init__.py ADDED
File without changes
common/cache.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Callable
16
+
17
+
18
+ class Cache:
19
+ """Caching reusable args for faster inference"""
20
+
21
+ def __init__(self, disable=False, prefix="", cache=None):
22
+ self.cache = cache if cache is not None else {}
23
+ self.disable = disable
24
+ self.prefix = prefix
25
+
26
+ def __call__(self, key: str, fn: Callable):
27
+ if self.disable:
28
+ return fn()
29
+
30
+ key = self.prefix + key
31
+ try:
32
+ result = self.cache[key]
33
+ except KeyError:
34
+ result = fn()
35
+ self.cache[key] = result
36
+ return result
37
+
38
+ def namespace(self, namespace: str):
39
+ return Cache(
40
+ disable=self.disable,
41
+ prefix=self.prefix + namespace + ".",
42
+ cache=self.cache,
43
+ )
44
+
45
+ def get(self, key: str):
46
+ key = self.prefix + key
47
+ return self.cache[key]
common/config.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Configuration utility functions
17
+ """
18
+
19
+ import importlib
20
+ from typing import Any, Callable, List, Union
21
+ from omegaconf import DictConfig, ListConfig, OmegaConf
22
+
23
+ OmegaConf.register_new_resolver("eval", eval)
24
+
25
+
26
+ def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
27
+ """
28
+ Load a configuration. Will resolve inheritance.
29
+ """
30
+ config = OmegaConf.load(path)
31
+ if argv is not None:
32
+ config_argv = OmegaConf.from_dotlist(argv)
33
+ config = OmegaConf.merge(config, config_argv)
34
+ config = resolve_recursive(config, resolve_inheritance)
35
+ return config
36
+
37
+
38
+ def resolve_recursive(
39
+ config: Any,
40
+ resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
41
+ ) -> Any:
42
+ config = resolver(config)
43
+ if isinstance(config, DictConfig):
44
+ for k in config.keys():
45
+ v = config.get(k)
46
+ if isinstance(v, (DictConfig, ListConfig)):
47
+ config[k] = resolve_recursive(v, resolver)
48
+ if isinstance(config, ListConfig):
49
+ for i in range(len(config)):
50
+ v = config.get(i)
51
+ if isinstance(v, (DictConfig, ListConfig)):
52
+ config[i] = resolve_recursive(v, resolver)
53
+ return config
54
+
55
+
56
+ def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
57
+ """
58
+ Recursively resolve inheritance if the config contains:
59
+ __inherit__: path/to/parent.yaml or a ListConfig of such paths.
60
+ """
61
+ if isinstance(config, DictConfig):
62
+ inherit = config.pop("__inherit__", None)
63
+
64
+ if inherit:
65
+ inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
66
+
67
+ parent_config = None
68
+ for parent_path in inherit_list:
69
+ assert isinstance(parent_path, str)
70
+ parent_config = (
71
+ load_config(parent_path)
72
+ if parent_config is None
73
+ else OmegaConf.merge(parent_config, load_config(parent_path))
74
+ )
75
+
76
+ if len(config.keys()) > 0:
77
+ config = OmegaConf.merge(parent_config, config)
78
+ else:
79
+ config = parent_config
80
+ return config
81
+
82
+
83
+ def import_item(path: str, name: str) -> Any:
84
+ """
85
+ Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
86
+ """
87
+ return getattr(importlib.import_module(path), name)
88
+
89
+
90
+ def create_object(config: DictConfig) -> Any:
91
+ """
92
+ Create an object from config.
93
+ The config is expected to contains the following:
94
+ __object__:
95
+ path: path.to.module
96
+ name: MyClass
97
+ args: as_config | as_params (default to as_config)
98
+ """
99
+ item = import_item(
100
+ path=config.__object__.path,
101
+ name=config.__object__.name,
102
+ )
103
+ args = config.__object__.get("args", "as_config")
104
+ if args == "as_config":
105
+ return item(config)
106
+ if args == "as_params":
107
+ config = OmegaConf.to_object(config)
108
+ config.pop("__object__")
109
+ return item(**config)
110
+ raise NotImplementedError(f"Unknown args type: {args}")
common/decorators.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Decorators.
17
+ """
18
+
19
+ import functools
20
+ import threading
21
+ import time
22
+ from typing import Callable
23
+ import torch
24
+
25
+ from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank
26
+ from common.logger import get_logger
27
+
28
+ logger = get_logger(__name__)
29
+
30
+
31
+ def log_on_entry(func: Callable) -> Callable:
32
+ """
33
+ Functions with this decorator will log the function name at entry.
34
+ When using multiple decorators, this must be applied innermost to properly capture the name.
35
+ """
36
+
37
+ def log_on_entry_wrapper(*args, **kwargs):
38
+ logger.info(f"Entering {func.__name__}")
39
+ return func(*args, **kwargs)
40
+
41
+ return log_on_entry_wrapper
42
+
43
+
44
+ def barrier_on_entry(func: Callable) -> Callable:
45
+ """
46
+ Functions with this decorator will start executing when all ranks are ready to enter.
47
+ """
48
+
49
+ def barrier_on_entry_wrapper(*args, **kwargs):
50
+ barrier_if_distributed()
51
+ return func(*args, **kwargs)
52
+
53
+ return barrier_on_entry_wrapper
54
+
55
+
56
+ def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable:
57
+ """
58
+ Helper function for local_rank_zero_only and global_rank_zero_only.
59
+ """
60
+
61
+ def conditional_execute_wrapper(*args, **kwargs):
62
+ # Only execute if needed.
63
+ result = func(*args, **kwargs) if execute else None
64
+ # All GPUs must wait.
65
+ barrier_if_distributed()
66
+ # Return results.
67
+ return result
68
+
69
+ return conditional_execute_wrapper
70
+
71
+
72
+ def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable:
73
+ """
74
+ Helper function for some functions with special constraints,
75
+ especially functions called by other global_rank_zero_only / local_rank_zero_only ones,
76
+ in case they are wrongly invoked in other scenarios.
77
+ """
78
+
79
+ def asserted_execute_wrapper(*args, **kwargs):
80
+ assert condition, err_msg
81
+ result = func(*args, **kwargs)
82
+ return result
83
+
84
+ return asserted_execute_wrapper
85
+
86
+
87
+ def local_rank_zero_only(func: Callable) -> Callable:
88
+ """
89
+ Functions with this decorator will only execute on local rank zero.
90
+ """
91
+ return _conditional_execute_wrapper_factory(get_local_rank() == 0, func)
92
+
93
+
94
+ def global_rank_zero_only(func: Callable) -> Callable:
95
+ """
96
+ Functions with this decorator will only execute on global rank zero.
97
+ """
98
+ return _conditional_execute_wrapper_factory(get_global_rank() == 0, func)
99
+
100
+
101
+ def assert_only_global_rank_zero(func: Callable) -> Callable:
102
+ """
103
+ Functions with this decorator are only accessible to processes with global rank zero.
104
+ """
105
+ return _asserted_wrapper_factory(
106
+ get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0"
107
+ )
108
+
109
+
110
+ def assert_only_local_rank_zero(func: Callable) -> Callable:
111
+ """
112
+ Functions with this decorator are only accessible to processes with local rank zero.
113
+ """
114
+ return _asserted_wrapper_factory(
115
+ get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0"
116
+ )
117
+
118
+
119
+ def new_thread(func: Callable) -> Callable:
120
+ """
121
+ Functions with this decorator will run in a new thread.
122
+ The function will return the thread, which can be joined to wait for completion.
123
+ """
124
+
125
+ def new_thread_wrapper(*args, **kwargs):
126
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs)
127
+ thread.start()
128
+ return thread
129
+
130
+ return new_thread_wrapper
131
+
132
+
133
+ def log_runtime(func: Callable) -> Callable:
134
+ """
135
+ Functions with this decorator will logging the runtime.
136
+ """
137
+
138
+ @functools.wraps(func)
139
+ def wrapped(*args, **kwargs):
140
+ torch.distributed.barrier()
141
+ start = time.perf_counter()
142
+ result = func(*args, **kwargs)
143
+ torch.distributed.barrier()
144
+ logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.")
145
+ return result
146
+
147
+ return wrapped
common/diffusion/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Diffusion package.
17
+ """
18
+
19
+ from .config import (
20
+ create_sampler_from_config,
21
+ create_sampling_timesteps_from_config,
22
+ create_schedule_from_config,
23
+ )
24
+ from .samplers.base import Sampler
25
+ from .samplers.euler import EulerSampler
26
+ from .schedules.base import Schedule
27
+ from .schedules.lerp import LinearInterpolationSchedule
28
+ from .timesteps.base import SamplingTimesteps, Timesteps
29
+ from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps
30
+ from .types import PredictionType, SamplingDirection
31
+ from .utils import classifier_free_guidance, classifier_free_guidance_dispatcher, expand_dims
32
+
33
+ __all__ = [
34
+ # Configs
35
+ "create_sampler_from_config",
36
+ "create_sampling_timesteps_from_config",
37
+ "create_schedule_from_config",
38
+ # Schedules
39
+ "Schedule",
40
+ "DiscreteVariancePreservingSchedule",
41
+ "LinearInterpolationSchedule",
42
+ # Samplers
43
+ "Sampler",
44
+ "EulerSampler",
45
+ # Timesteps
46
+ "Timesteps",
47
+ "SamplingTimesteps",
48
+ # Types
49
+ "PredictionType",
50
+ "SamplingDirection",
51
+ "UniformTrailingSamplingTimesteps",
52
+ # Utils
53
+ "classifier_free_guidance",
54
+ "classifier_free_guidance_dispatcher",
55
+ "expand_dims",
56
+ ]
common/diffusion/config.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Utility functions for creating schedules and samplers from config.
17
+ """
18
+
19
+ import torch
20
+ from omegaconf import DictConfig
21
+
22
+ from .samplers.base import Sampler
23
+ from .samplers.euler import EulerSampler
24
+ from .schedules.base import Schedule
25
+ from .schedules.lerp import LinearInterpolationSchedule
26
+ from .timesteps.base import SamplingTimesteps
27
+ from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps
28
+
29
+
30
+ def create_schedule_from_config(
31
+ config: DictConfig,
32
+ device: torch.device,
33
+ dtype: torch.dtype = torch.float32,
34
+ ) -> Schedule:
35
+ """
36
+ Create a schedule from configuration.
37
+ """
38
+ if config.type == "lerp":
39
+ return LinearInterpolationSchedule(T=config.get("T", 1.0))
40
+
41
+ raise NotImplementedError
42
+
43
+
44
+ def create_sampler_from_config(
45
+ config: DictConfig,
46
+ schedule: Schedule,
47
+ timesteps: SamplingTimesteps,
48
+ ) -> Sampler:
49
+ """
50
+ Create a sampler from configuration.
51
+ """
52
+ if config.type == "euler":
53
+ return EulerSampler(
54
+ schedule=schedule,
55
+ timesteps=timesteps,
56
+ prediction_type=config.prediction_type,
57
+ )
58
+ raise NotImplementedError
59
+
60
+
61
+ def create_sampling_timesteps_from_config(
62
+ config: DictConfig,
63
+ schedule: Schedule,
64
+ device: torch.device,
65
+ dtype: torch.dtype = torch.float32,
66
+ ) -> SamplingTimesteps:
67
+ if config.type == "uniform_trailing":
68
+ return UniformTrailingSamplingTimesteps(
69
+ T=schedule.T,
70
+ steps=config.steps,
71
+ shift=config.get("shift", 1.0),
72
+ device=device,
73
+ )
74
+ raise NotImplementedError
common/diffusion/samplers/base.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Sampler base class.
17
+ """
18
+
19
+ from abc import ABC, abstractmethod
20
+ from dataclasses import dataclass
21
+ from typing import Callable
22
+ import torch
23
+ from tqdm import tqdm
24
+
25
+ from ..schedules.base import Schedule
26
+ from ..timesteps.base import SamplingTimesteps
27
+ from ..types import PredictionType, SamplingDirection
28
+ from ..utils import assert_schedule_timesteps_compatible
29
+
30
+
31
+ @dataclass
32
+ class SamplerModelArgs:
33
+ x_t: torch.Tensor
34
+ t: torch.Tensor
35
+ i: int
36
+
37
+
38
+ class Sampler(ABC):
39
+ """
40
+ Samplers are ODE/SDE solvers.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ schedule: Schedule,
46
+ timesteps: SamplingTimesteps,
47
+ prediction_type: PredictionType,
48
+ return_endpoint: bool = True,
49
+ ):
50
+ assert_schedule_timesteps_compatible(
51
+ schedule=schedule,
52
+ timesteps=timesteps,
53
+ )
54
+ self.schedule = schedule
55
+ self.timesteps = timesteps
56
+ self.prediction_type = prediction_type
57
+ self.return_endpoint = return_endpoint
58
+
59
+ @abstractmethod
60
+ def sample(
61
+ self,
62
+ x: torch.Tensor,
63
+ f: Callable[[SamplerModelArgs], torch.Tensor],
64
+ ) -> torch.Tensor:
65
+ """
66
+ Generate a new sample given the the intial sample x and score function f.
67
+ """
68
+
69
+ def get_next_timestep(
70
+ self,
71
+ t: torch.Tensor,
72
+ ) -> torch.Tensor:
73
+ """
74
+ Get the next sample timestep.
75
+ Support multiple different timesteps t in a batch.
76
+ If no more steps, return out of bound value -1 or T+1.
77
+ """
78
+ T = self.timesteps.T
79
+ steps = len(self.timesteps)
80
+ curr_idx = self.timesteps.index(t)
81
+ next_idx = curr_idx + 1
82
+ bound = -1 if self.timesteps.direction == SamplingDirection.backward else T + 1
83
+
84
+ s = self.timesteps[next_idx.clamp_max(steps - 1)]
85
+ s = s.where(next_idx < steps, bound)
86
+ return s
87
+
88
+ def get_endpoint(
89
+ self,
90
+ pred: torch.Tensor,
91
+ x_t: torch.Tensor,
92
+ t: torch.Tensor,
93
+ ) -> torch.Tensor:
94
+ """
95
+ Get to the endpoint of the probability flow.
96
+ """
97
+ x_0, x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
98
+ return x_0 if self.timesteps.direction == SamplingDirection.backward else x_T
99
+
100
+ def get_progress_bar(self):
101
+ """
102
+ Get progress bar for sampling.
103
+ """
104
+ return tqdm(
105
+ iterable=range(len(self.timesteps) - (0 if self.return_endpoint else 1)),
106
+ dynamic_ncols=True,
107
+ desc=self.__class__.__name__,
108
+ )
common/diffusion/samplers/euler.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+
16
+ """
17
+ Euler ODE solver.
18
+ """
19
+
20
+ from typing import Callable
21
+ import torch
22
+ from einops import rearrange
23
+ from torch.nn import functional as F
24
+
25
+ from models.dit_v2 import na
26
+
27
+ from ..types import PredictionType
28
+ from ..utils import expand_dims
29
+ from .base import Sampler, SamplerModelArgs
30
+
31
+
32
+ class EulerSampler(Sampler):
33
+ """
34
+ The Euler method is the simplest ODE solver.
35
+ <https://en.wikipedia.org/wiki/Euler_method>
36
+ """
37
+
38
+ def sample(
39
+ self,
40
+ x: torch.Tensor,
41
+ f: Callable[[SamplerModelArgs], torch.Tensor],
42
+ ) -> torch.Tensor:
43
+ timesteps = self.timesteps.timesteps
44
+ progress = self.get_progress_bar()
45
+ i = 0
46
+ for t, s in zip(timesteps[:-1], timesteps[1:]):
47
+ pred = f(SamplerModelArgs(x, t, i))
48
+ x = self.step_to(pred, x, t, s)
49
+ i += 1
50
+ progress.update()
51
+
52
+ if self.return_endpoint:
53
+ t = timesteps[-1]
54
+ pred = f(SamplerModelArgs(x, t, i))
55
+ x = self.get_endpoint(pred, x, t)
56
+ progress.update()
57
+ return x
58
+
59
+ def step(
60
+ self,
61
+ pred: torch.Tensor,
62
+ x_t: torch.Tensor,
63
+ t: torch.Tensor,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Step to the next timestep.
67
+ """
68
+ return self.step_to(pred, x_t, t, self.get_next_timestep(t))
69
+
70
+ def step_to(
71
+ self,
72
+ pred: torch.Tensor,
73
+ x_t: torch.Tensor,
74
+ t: torch.Tensor,
75
+ s: torch.Tensor,
76
+ ) -> torch.Tensor:
77
+ """
78
+ Steps from x_t at timestep t to x_s at timestep s. Returns x_s.
79
+ """
80
+ t = expand_dims(t, x_t.ndim)
81
+ s = expand_dims(s, x_t.ndim)
82
+ T = self.schedule.T
83
+ # Step from x_t to x_s.
84
+ pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
85
+ pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T))
86
+ # Clamp x_s to x_0 and x_T if s is out of bound.
87
+ pred_x_s = pred_x_s.where(s >= 0, pred_x_0)
88
+ pred_x_s = pred_x_s.where(s <= T, pred_x_T)
89
+ return pred_x_s
common/diffusion/schedules/base.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Schedule base class.
17
+ """
18
+
19
+ from abc import ABC, abstractmethod, abstractproperty
20
+ from typing import Tuple, Union
21
+ import torch
22
+
23
+ from ..types import PredictionType
24
+ from ..utils import expand_dims
25
+
26
+
27
+ class Schedule(ABC):
28
+ """
29
+ Diffusion schedules are uniquely defined by T, A, B:
30
+
31
+ x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T]
32
+
33
+ Schedules can be continuous or discrete.
34
+ """
35
+
36
+ @abstractproperty
37
+ def T(self) -> Union[int, float]:
38
+ """
39
+ Maximum timestep inclusive.
40
+ Schedule is continuous if float, discrete if int.
41
+ """
42
+
43
+ @abstractmethod
44
+ def A(self, t: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Interpolation coefficient A.
47
+ Returns tensor with the same shape as t.
48
+ """
49
+
50
+ @abstractmethod
51
+ def B(self, t: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Interpolation coefficient B.
54
+ Returns tensor with the same shape as t.
55
+ """
56
+
57
+ # ----------------------------------------------------
58
+
59
+ def snr(self, t: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ Signal to noise ratio.
62
+ Returns tensor with the same shape as t.
63
+ """
64
+ return (self.A(t) ** 2) / (self.B(t) ** 2)
65
+
66
+ def isnr(self, snr: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Inverse signal to noise ratio.
69
+ Returns tensor with the same shape as snr.
70
+ Subclass may implement.
71
+ """
72
+ raise NotImplementedError
73
+
74
+ # ----------------------------------------------------
75
+
76
+ def is_continuous(self) -> bool:
77
+ """
78
+ Whether the schedule is continuous.
79
+ """
80
+ return isinstance(self.T, float)
81
+
82
+ def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Diffusion forward function.
85
+ """
86
+ t = expand_dims(t, x_0.ndim)
87
+ return self.A(t) * x_0 + self.B(t) * x_T
88
+
89
+ def convert_from_pred(
90
+ self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor
91
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ Convert from prediction. Return predicted x_0 and x_T.
94
+ """
95
+ t = expand_dims(t, x_t.ndim)
96
+ A_t = self.A(t)
97
+ B_t = self.B(t)
98
+
99
+ if pred_type == PredictionType.x_T:
100
+ pred_x_T = pred
101
+ pred_x_0 = (x_t - B_t * pred_x_T) / A_t
102
+ elif pred_type == PredictionType.x_0:
103
+ pred_x_0 = pred
104
+ pred_x_T = (x_t - A_t * pred_x_0) / B_t
105
+ elif pred_type == PredictionType.v_cos:
106
+ pred_x_0 = A_t * x_t - B_t * pred
107
+ pred_x_T = A_t * pred + B_t * x_t
108
+ elif pred_type == PredictionType.v_lerp:
109
+ pred_x_0 = (x_t - B_t * pred) / (A_t + B_t)
110
+ pred_x_T = (x_t + A_t * pred) / (A_t + B_t)
111
+ else:
112
+ raise NotImplementedError
113
+
114
+ return pred_x_0, pred_x_T
115
+
116
+ def convert_to_pred(
117
+ self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType
118
+ ) -> torch.FloatTensor:
119
+ """
120
+ Convert to prediction target given x_0 and x_T.
121
+ """
122
+ if pred_type == PredictionType.x_T:
123
+ return x_T
124
+ if pred_type == PredictionType.x_0:
125
+ return x_0
126
+ if pred_type == PredictionType.v_cos:
127
+ t = expand_dims(t, x_0.ndim)
128
+ return self.A(t) * x_T - self.B(t) * x_0
129
+ if pred_type == PredictionType.v_lerp:
130
+ return x_T - x_0
131
+ raise NotImplementedError
common/diffusion/schedules/lerp.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Linear interpolation schedule (lerp).
17
+ """
18
+
19
+ from typing import Union
20
+ import torch
21
+
22
+ from .base import Schedule
23
+
24
+
25
+ class LinearInterpolationSchedule(Schedule):
26
+ """
27
+ Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow.
28
+ It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3.
29
+ <https://arxiv.org/abs/2209.03003>
30
+ <https://arxiv.org/abs/2210.02747>
31
+
32
+ x_t = (1 - t) * x_0 + t * x_T
33
+
34
+ Can be either continuous or discrete.
35
+ """
36
+
37
+ def __init__(self, T: Union[int, float] = 1.0):
38
+ self._T = T
39
+
40
+ @property
41
+ def T(self) -> Union[int, float]:
42
+ return self._T
43
+
44
+ def A(self, t: torch.Tensor) -> torch.Tensor:
45
+ return 1 - (t / self.T)
46
+
47
+ def B(self, t: torch.Tensor) -> torch.Tensor:
48
+ return t / self.T
49
+
50
+ # ----------------------------------------------------
51
+
52
+ def isnr(self, snr: torch.Tensor) -> torch.Tensor:
53
+ t = self.T / (1 + snr**0.5)
54
+ t = t if self.is_continuous() else t.round().int()
55
+ return t
common/diffusion/timesteps/base.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Sequence, Union
3
+ import torch
4
+
5
+ from ..types import SamplingDirection
6
+
7
+
8
+ class Timesteps(ABC):
9
+ """
10
+ Timesteps base class.
11
+ """
12
+
13
+ def __init__(self, T: Union[int, float]):
14
+ assert T > 0
15
+ self._T = T
16
+
17
+ @property
18
+ def T(self) -> Union[int, float]:
19
+ """
20
+ Maximum timestep inclusive.
21
+ int if discrete, float if continuous.
22
+ """
23
+ return self._T
24
+
25
+ def is_continuous(self) -> bool:
26
+ """
27
+ Whether the schedule is continuous.
28
+ """
29
+ return isinstance(self.T, float)
30
+
31
+
32
+ class SamplingTimesteps(Timesteps):
33
+ """
34
+ Sampling timesteps.
35
+ It defines the discretization of sampling steps.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ T: Union[int, float],
41
+ timesteps: torch.Tensor,
42
+ direction: SamplingDirection,
43
+ ):
44
+ assert timesteps.ndim == 1
45
+ super().__init__(T)
46
+ self.timesteps = timesteps
47
+ self.direction = direction
48
+
49
+ def __len__(self) -> int:
50
+ """
51
+ Number of sampling steps.
52
+ """
53
+ return len(self.timesteps)
54
+
55
+ def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor:
56
+ """
57
+ The timestep at the sampling step.
58
+ Returns a scalar tensor if idx is int,
59
+ or tensor of the same size if idx is a tensor.
60
+ """
61
+ return self.timesteps[idx]
62
+
63
+ def index(self, t: torch.Tensor) -> torch.Tensor:
64
+ """
65
+ Find index by t.
66
+ Return index of the same shape as t.
67
+ Index is -1 if t not found in timesteps.
68
+ """
69
+ i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True)
70
+ idx = torch.full_like(t, fill_value=-1, dtype=torch.int)
71
+ idx.view(-1)[i] = j.int()
72
+ return idx
common/diffusion/timesteps/sampling/trailing.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+
17
+ from ...types import SamplingDirection
18
+ from ..base import SamplingTimesteps
19
+
20
+
21
+ class UniformTrailingSamplingTimesteps(SamplingTimesteps):
22
+ """
23
+ Uniform trailing sampling timesteps.
24
+ Defined in (https://arxiv.org/abs/2305.08891)
25
+
26
+ Shift is proposed in SD3 for RF schedule.
27
+ Defined in (https://arxiv.org/pdf/2403.03206) eq.23
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ T: int,
33
+ steps: int,
34
+ shift: float = 1.0,
35
+ device: torch.device = "cpu",
36
+ ):
37
+ # Create trailing timesteps.
38
+ timesteps = torch.arange(1.0, 0.0, -1.0 / steps, device=device)
39
+
40
+ # Shift timesteps.
41
+ timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
42
+
43
+ # Scale to T range.
44
+ if isinstance(T, float):
45
+ timesteps = timesteps * T
46
+ else:
47
+ timesteps = timesteps.mul(T + 1).sub(1).round().int()
48
+
49
+ super().__init__(T=T, timesteps=timesteps, direction=SamplingDirection.backward)
common/diffusion/types.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Type definitions.
17
+ """
18
+
19
+ from enum import Enum
20
+
21
+
22
+ class PredictionType(str, Enum):
23
+ """
24
+ x_0:
25
+ Predict data sample.
26
+ x_T:
27
+ Predict noise sample.
28
+ Proposed by DDPM (https://arxiv.org/abs/2006.11239)
29
+ Proved problematic by zsnr paper (https://arxiv.org/abs/2305.08891)
30
+ v_cos:
31
+ Predict velocity dx/dt based on the cosine schedule (A_t * x_T - B_t * x_0).
32
+ Proposed by progressive distillation (https://arxiv.org/abs/2202.00512)
33
+ v_lerp:
34
+ Predict velocity dx/dt based on the lerp schedule (x_T - x_0).
35
+ Proposed by rectified flow (https://arxiv.org/abs/2209.03003)
36
+ """
37
+
38
+ x_0 = "x_0"
39
+ x_T = "x_T"
40
+ v_cos = "v_cos"
41
+ v_lerp = "v_lerp"
42
+
43
+
44
+ class SamplingDirection(str, Enum):
45
+ """
46
+ backward: Sample from x_T to x_0 for data generation.
47
+ forward: Sample from x_0 to x_T for noise inversion.
48
+ """
49
+
50
+ backward = "backward"
51
+ forward = "forward"
52
+
53
+ @staticmethod
54
+ def reverse(direction):
55
+ if direction == SamplingDirection.backward:
56
+ return SamplingDirection.forward
57
+ if direction == SamplingDirection.forward:
58
+ return SamplingDirection.backward
59
+ raise NotImplementedError
common/diffusion/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Utility functions.
17
+ """
18
+
19
+ from typing import Callable
20
+ import torch
21
+
22
+
23
+ def expand_dims(tensor: torch.Tensor, ndim: int):
24
+ """
25
+ Expand tensor to target ndim. New dims are added to the right.
26
+ For example, if the tensor shape was (8,), target ndim is 4, return (8, 1, 1, 1).
27
+ """
28
+ shape = tensor.shape + (1,) * (ndim - tensor.ndim)
29
+ return tensor.reshape(shape)
30
+
31
+
32
+ def assert_schedule_timesteps_compatible(schedule, timesteps):
33
+ """
34
+ Check if schedule and timesteps are compatible.
35
+ """
36
+ if schedule.T != timesteps.T:
37
+ raise ValueError("Schedule and timesteps must have the same T.")
38
+ if schedule.is_continuous() != timesteps.is_continuous():
39
+ raise ValueError("Schedule and timesteps must have the same continuity.")
40
+
41
+
42
+ def classifier_free_guidance(
43
+ pos: torch.Tensor,
44
+ neg: torch.Tensor,
45
+ scale: float,
46
+ rescale: float = 0.0,
47
+ ):
48
+ """
49
+ Apply classifier-free guidance.
50
+ """
51
+ # Classifier-free guidance (https://arxiv.org/abs/2207.12598)
52
+ cfg = neg + scale * (pos - neg)
53
+
54
+ # Classifier-free guidance rescale (https://arxiv.org/pdf/2305.08891.pdf)
55
+ if rescale != 0.0:
56
+ pos_std = pos.std(dim=list(range(1, pos.ndim)), keepdim=True)
57
+ cfg_std = cfg.std(dim=list(range(1, cfg.ndim)), keepdim=True)
58
+ factor = pos_std / cfg_std
59
+ factor = rescale * factor + (1 - rescale)
60
+ cfg *= factor
61
+
62
+ return cfg
63
+
64
+
65
+ def classifier_free_guidance_dispatcher(
66
+ pos: Callable,
67
+ neg: Callable,
68
+ scale: float,
69
+ rescale: float = 0.0,
70
+ ):
71
+ """
72
+ Optionally execute models depending on classifer-free guidance scale.
73
+ """
74
+ # If scale is 1, no need to execute neg model.
75
+ if scale == 1.0:
76
+ return pos()
77
+
78
+ # Otherwise, execute both pos nad neg models and apply cfg.
79
+ return classifier_free_guidance(
80
+ pos=pos(),
81
+ neg=neg(),
82
+ scale=scale,
83
+ rescale=rescale,
84
+ )
common/distributed/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Distributed package.
17
+ """
18
+
19
+ from .basic import (
20
+ barrier_if_distributed,
21
+ convert_to_ddp,
22
+ get_device,
23
+ get_global_rank,
24
+ get_local_rank,
25
+ get_world_size,
26
+ init_torch,
27
+ )
28
+
29
+ __all__ = [
30
+ "barrier_if_distributed",
31
+ "convert_to_ddp",
32
+ "get_device",
33
+ "get_global_rank",
34
+ "get_local_rank",
35
+ "get_world_size",
36
+ "init_torch",
37
+ ]
common/distributed/advanced.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Advanced distributed functions for sequence parallel.
17
+ """
18
+
19
+ from typing import Optional, List
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
23
+ from torch.distributed.fsdp import ShardingStrategy
24
+
25
+ from .basic import get_global_rank, get_world_size
26
+
27
+
28
+ _DATA_PARALLEL_GROUP = None
29
+ _SEQUENCE_PARALLEL_GROUP = None
30
+ _SEQUENCE_PARALLEL_CPU_GROUP = None
31
+ _MODEL_SHARD_CPU_INTER_GROUP = None
32
+ _MODEL_SHARD_CPU_INTRA_GROUP = None
33
+ _MODEL_SHARD_INTER_GROUP = None
34
+ _MODEL_SHARD_INTRA_GROUP = None
35
+ _SEQUENCE_PARALLEL_GLOBAL_RANKS = None
36
+
37
+
38
+ def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
39
+ """
40
+ Get data parallel process group.
41
+ """
42
+ return _DATA_PARALLEL_GROUP
43
+
44
+
45
+ def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
46
+ """
47
+ Get sequence parallel process group.
48
+ """
49
+ return _SEQUENCE_PARALLEL_GROUP
50
+
51
+
52
+ def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
53
+ """
54
+ Get sequence parallel CPU process group.
55
+ """
56
+ return _SEQUENCE_PARALLEL_CPU_GROUP
57
+
58
+
59
+ def get_data_parallel_rank() -> int:
60
+ """
61
+ Get data parallel rank.
62
+ """
63
+ group = get_data_parallel_group()
64
+ return dist.get_rank(group) if group else get_global_rank()
65
+
66
+
67
+ def get_data_parallel_world_size() -> int:
68
+ """
69
+ Get data parallel world size.
70
+ """
71
+ group = get_data_parallel_group()
72
+ return dist.get_world_size(group) if group else get_world_size()
73
+
74
+
75
+ def get_sequence_parallel_rank() -> int:
76
+ """
77
+ Get sequence parallel rank.
78
+ """
79
+ group = get_sequence_parallel_group()
80
+ return dist.get_rank(group) if group else 0
81
+
82
+
83
+ def get_sequence_parallel_world_size() -> int:
84
+ """
85
+ Get sequence parallel world size.
86
+ """
87
+ group = get_sequence_parallel_group()
88
+ return dist.get_world_size(group) if group else 1
89
+
90
+
91
+ def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]:
92
+ """
93
+ Get the CPU intra process group of model sharding.
94
+ """
95
+ return _MODEL_SHARD_CPU_INTRA_GROUP
96
+
97
+
98
+ def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]:
99
+ """
100
+ Get the CPU inter process group of model sharding.
101
+ """
102
+ return _MODEL_SHARD_CPU_INTER_GROUP
103
+
104
+
105
+ def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]:
106
+ """
107
+ Get the GPU intra process group of model sharding.
108
+ """
109
+ return _MODEL_SHARD_INTRA_GROUP
110
+
111
+
112
+ def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]:
113
+ """
114
+ Get the GPU inter process group of model sharding.
115
+ """
116
+ return _MODEL_SHARD_INTER_GROUP
117
+
118
+
119
+ def init_sequence_parallel(sequence_parallel_size: int):
120
+ """
121
+ Initialize sequence parallel.
122
+ """
123
+ global _DATA_PARALLEL_GROUP
124
+ global _SEQUENCE_PARALLEL_GROUP
125
+ global _SEQUENCE_PARALLEL_CPU_GROUP
126
+ global _SEQUENCE_PARALLEL_GLOBAL_RANKS
127
+ assert dist.is_initialized()
128
+ world_size = dist.get_world_size()
129
+ rank = dist.get_rank()
130
+ data_parallel_size = world_size // sequence_parallel_size
131
+ for i in range(data_parallel_size):
132
+ start_rank = i * sequence_parallel_size
133
+ end_rank = (i + 1) * sequence_parallel_size
134
+ ranks = range(start_rank, end_rank)
135
+ group = dist.new_group(ranks)
136
+ cpu_group = dist.new_group(ranks, backend="gloo")
137
+ if rank in ranks:
138
+ _SEQUENCE_PARALLEL_GROUP = group
139
+ _SEQUENCE_PARALLEL_CPU_GROUP = cpu_group
140
+ _SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks)
141
+
142
+
143
+ def init_model_shard_group(
144
+ *,
145
+ sharding_strategy: ShardingStrategy,
146
+ device_mesh: Optional[DeviceMesh] = None,
147
+ ):
148
+ """
149
+ Initialize process group of model sharding.
150
+ """
151
+ global _MODEL_SHARD_INTER_GROUP
152
+ global _MODEL_SHARD_INTRA_GROUP
153
+ global _MODEL_SHARD_CPU_INTER_GROUP
154
+ global _MODEL_SHARD_CPU_INTRA_GROUP
155
+ assert dist.is_initialized()
156
+ world_size = dist.get_world_size()
157
+ if device_mesh is not None:
158
+ num_shards_per_group = device_mesh.shape[1]
159
+ elif sharding_strategy == ShardingStrategy.NO_SHARD:
160
+ num_shards_per_group = 1
161
+ elif sharding_strategy in [
162
+ ShardingStrategy.HYBRID_SHARD,
163
+ ShardingStrategy._HYBRID_SHARD_ZERO2,
164
+ ]:
165
+ num_shards_per_group = torch.cuda.device_count()
166
+ else:
167
+ num_shards_per_group = world_size
168
+ num_groups = world_size // num_shards_per_group
169
+ device_mesh = (num_groups, num_shards_per_group)
170
+
171
+ gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra"))
172
+ cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra"))
173
+
174
+ _MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter")
175
+ _MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra")
176
+ _MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter")
177
+ _MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra")
178
+
179
+ def get_sequence_parallel_global_ranks() -> List[int]:
180
+ """
181
+ Get all global ranks of the sequence parallel process group
182
+ that the caller rank belongs to.
183
+ """
184
+ if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None:
185
+ return [dist.get_rank()]
186
+ return _SEQUENCE_PARALLEL_GLOBAL_RANKS
187
+
188
+
189
+ def get_next_sequence_parallel_rank() -> int:
190
+ """
191
+ Get the next global rank of the sequence parallel process group
192
+ that the caller rank belongs to.
193
+ """
194
+ sp_global_ranks = get_sequence_parallel_global_ranks()
195
+ sp_rank = get_sequence_parallel_rank()
196
+ sp_size = get_sequence_parallel_world_size()
197
+ return sp_global_ranks[(sp_rank + 1) % sp_size]
198
+
199
+
200
+ def get_prev_sequence_parallel_rank() -> int:
201
+ """
202
+ Get the previous global rank of the sequence parallel process group
203
+ that the caller rank belongs to.
204
+ """
205
+ sp_global_ranks = get_sequence_parallel_global_ranks()
206
+ sp_rank = get_sequence_parallel_rank()
207
+ sp_size = get_sequence_parallel_world_size()
208
+ return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size]
common/distributed/basic.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Distributed basic functions.
17
+ """
18
+
19
+ import os
20
+ from datetime import timedelta
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.nn.parallel import DistributedDataParallel
24
+
25
+
26
+ def get_global_rank() -> int:
27
+ """
28
+ Get the global rank, the global index of the GPU.
29
+ """
30
+ return int(os.environ.get("RANK", "0"))
31
+
32
+
33
+ def get_local_rank() -> int:
34
+ """
35
+ Get the local rank, the local index of the GPU.
36
+ """
37
+ return int(os.environ.get("LOCAL_RANK", "0"))
38
+
39
+
40
+ def get_world_size() -> int:
41
+ """
42
+ Get the world size, the total amount of GPUs.
43
+ """
44
+ return int(os.environ.get("WORLD_SIZE", "1"))
45
+
46
+
47
+ def get_device() -> torch.device:
48
+ """
49
+ Get current rank device.
50
+ """
51
+ return torch.device("cuda", get_local_rank())
52
+
53
+
54
+ def barrier_if_distributed(*args, **kwargs):
55
+ """
56
+ Synchronizes all processes if under distributed context.
57
+ """
58
+ if dist.is_initialized():
59
+ return dist.barrier(*args, **kwargs)
60
+
61
+
62
+ def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)):
63
+ """
64
+ Common PyTorch initialization configuration.
65
+ """
66
+ torch.backends.cuda.matmul.allow_tf32 = True
67
+ torch.backends.cudnn.allow_tf32 = True
68
+ torch.backends.cudnn.benchmark = cudnn_benchmark
69
+ torch.cuda.set_device(get_local_rank())
70
+ dist.init_process_group(
71
+ backend="nccl",
72
+ rank=get_global_rank(),
73
+ world_size=get_world_size(),
74
+ timeout=timeout,
75
+ )
76
+
77
+
78
+ def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
79
+ return DistributedDataParallel(
80
+ module=module,
81
+ device_ids=[get_local_rank()],
82
+ output_device=get_local_rank(),
83
+ **kwargs,
84
+ )
common/distributed/meta_init_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ from rotary_embedding_torch import RotaryEmbedding
17
+ from torch import nn
18
+ from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
19
+
20
+ __all__ = ["meta_non_persistent_buffer_init_fn"]
21
+
22
+
23
+ def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
24
+ """
25
+ Used for materializing `non-persistent tensor buffers` while model resuming.
26
+
27
+ Since non-persistent tensor buffers are not saved in state_dict,
28
+ when initializing model with meta device, user should materialize those buffers manually.
29
+
30
+ Currently, only `rope.dummy` is this special case.
31
+ """
32
+ with torch.no_grad():
33
+ for submodule in module.modules():
34
+ if not isinstance(submodule, RotaryEmbedding):
35
+ continue
36
+ for buffer_name, buffer in submodule.named_buffers(recurse=False):
37
+ if buffer.is_meta and "dummy" in buffer_name:
38
+ materialized_buffer = torch.zeros_like(buffer, device="cpu")
39
+ setattr(submodule, buffer_name, materialized_buffer)
40
+ assert not any(b.is_meta for n, b in module.named_buffers())
41
+ return module
common/distributed/ops.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Distributed ops for supporting sequence parallel.
17
+ """
18
+
19
+ from collections import defaultdict
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch import Tensor
24
+
25
+ from common.cache import Cache
26
+ from common.distributed.advanced import (
27
+ get_sequence_parallel_group,
28
+ get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size,
30
+ )
31
+
32
+ from .basic import get_device
33
+
34
+ _SEQ_DATA_BUF = defaultdict(lambda: [None, None, None])
35
+ _SEQ_DATA_META_SHAPES = defaultdict()
36
+ _SEQ_DATA_META_DTYPES = defaultdict()
37
+ _SEQ_DATA_ASYNC_COMMS = defaultdict(list)
38
+ _SYNC_BUFFER = defaultdict(dict)
39
+
40
+
41
+ def single_all_to_all(
42
+ local_input: Tensor,
43
+ scatter_dim: int,
44
+ gather_dim: int,
45
+ group: dist.ProcessGroup,
46
+ async_op: bool = False,
47
+ ):
48
+ """
49
+ A function to do all-to-all on a tensor
50
+ """
51
+ seq_world_size = dist.get_world_size(group)
52
+ prev_scatter_dim = scatter_dim
53
+ if scatter_dim != 0:
54
+ local_input = local_input.transpose(0, scatter_dim)
55
+ if gather_dim == 0:
56
+ gather_dim = scatter_dim
57
+ scatter_dim = 0
58
+
59
+ inp_shape = list(local_input.shape)
60
+ inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
61
+ input_t = local_input.reshape(
62
+ [seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]
63
+ ).contiguous()
64
+ output = torch.empty_like(input_t)
65
+ comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
66
+ if async_op:
67
+ # let user's code transpose & reshape
68
+ return output, comm, prev_scatter_dim
69
+
70
+ # first dim is seq_world_size, so we can split it directly
71
+ output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0)
72
+ if prev_scatter_dim:
73
+ output = output.transpose(0, prev_scatter_dim).contiguous()
74
+ return output
75
+
76
+
77
+ def _all_to_all(
78
+ local_input: Tensor,
79
+ scatter_dim: int,
80
+ gather_dim: int,
81
+ group: dist.ProcessGroup,
82
+ ):
83
+ seq_world_size = dist.get_world_size(group)
84
+ input_list = [
85
+ t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)
86
+ ]
87
+ output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
88
+ dist.all_to_all(output_list, input_list, group=group)
89
+ return torch.cat(output_list, dim=gather_dim).contiguous()
90
+
91
+
92
+ class SeqAllToAll(torch.autograd.Function):
93
+ @staticmethod
94
+ def forward(
95
+ ctx: Any,
96
+ group: dist.ProcessGroup,
97
+ local_input: Tensor,
98
+ scatter_dim: int,
99
+ gather_dim: int,
100
+ async_op: bool,
101
+ ) -> Tensor:
102
+ ctx.group = group
103
+ ctx.scatter_dim = scatter_dim
104
+ ctx.gather_dim = gather_dim
105
+ ctx.async_op = async_op
106
+ if async_op:
107
+ output, comm, prev_scatter_dim = single_all_to_all(
108
+ local_input, scatter_dim, gather_dim, group, async_op=async_op
109
+ )
110
+ ctx.prev_scatter_dim = prev_scatter_dim
111
+ return output, comm
112
+
113
+ return _all_to_all(local_input, scatter_dim, gather_dim, group)
114
+
115
+ @staticmethod
116
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
117
+ if ctx.async_op:
118
+ input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0)
119
+ if ctx.prev_scatter_dim:
120
+ input_t = input_t.transpose(0, ctx.prev_scatter_dim)
121
+ else:
122
+ input_t = grad_output[0]
123
+ return (
124
+ None,
125
+ _all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group),
126
+ None,
127
+ None,
128
+ None,
129
+ )
130
+
131
+
132
+ class Slice(torch.autograd.Function):
133
+ @staticmethod
134
+ def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor:
135
+ ctx.group = group
136
+ ctx.rank = dist.get_rank(group)
137
+ seq_world_size = dist.get_world_size(group)
138
+ ctx.seq_world_size = seq_world_size
139
+ ctx.dim = dim
140
+ dim_size = local_input.shape[dim]
141
+ return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
142
+
143
+ @staticmethod
144
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
145
+ dim_size = list(grad_output.size())
146
+ split_size = dim_size[0]
147
+ dim_size[0] = dim_size[0] * ctx.seq_world_size
148
+ output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
149
+ dist._all_gather_base(output, grad_output, group=ctx.group)
150
+ return (None, torch.cat(output.split(split_size), dim=ctx.dim), None)
151
+
152
+
153
+ class Gather(torch.autograd.Function):
154
+ @staticmethod
155
+ def forward(
156
+ ctx: Any,
157
+ group: dist.ProcessGroup,
158
+ local_input: Tensor,
159
+ dim: int,
160
+ grad_scale: Optional[bool] = False,
161
+ ) -> Tensor:
162
+ ctx.group = group
163
+ ctx.rank = dist.get_rank(group)
164
+ ctx.dim = dim
165
+ ctx.grad_scale = grad_scale
166
+ seq_world_size = dist.get_world_size(group)
167
+ ctx.seq_world_size = seq_world_size
168
+ dim_size = list(local_input.size())
169
+ split_size = dim_size[0]
170
+ ctx.part_size = dim_size[dim]
171
+ dim_size[0] = dim_size[0] * seq_world_size
172
+ output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
173
+ dist._all_gather_base(output, local_input.contiguous(), group=ctx.group)
174
+ return torch.cat(output.split(split_size), dim=dim)
175
+
176
+ @staticmethod
177
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
178
+ if ctx.grad_scale:
179
+ grad_output = grad_output * ctx.seq_world_size
180
+ return (
181
+ None,
182
+ grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(),
183
+ None,
184
+ None,
185
+ )
186
+
187
+
188
+ def gather_seq_scatter_heads_qkv(
189
+ qkv_tensor: Tensor,
190
+ *,
191
+ seq_dim: int,
192
+ qkv_shape: Optional[Tensor] = None,
193
+ cache: Cache = Cache(disable=True),
194
+ restore_shape: bool = True,
195
+ ):
196
+ """
197
+ A func to sync splited qkv tensor
198
+ qkv_tensor: the tensor we want to do alltoall with. The last dim must
199
+ be the projection_idx, which we will split into 3 part. After
200
+ spliting, the gather idx will be projecttion_idx + 1
201
+ seq_dim: gather_dim for all2all comm
202
+ restore_shape: if True, output will has the same shape length as input
203
+ """
204
+ group = get_sequence_parallel_group()
205
+ if not group:
206
+ return qkv_tensor
207
+ world = get_sequence_parallel_world_size()
208
+ orig_shape = qkv_tensor.shape
209
+ scatter_dim = qkv_tensor.dim()
210
+ bef_all2all_shape = list(orig_shape)
211
+ qkv_proj_dim = bef_all2all_shape[-1]
212
+ bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
213
+ qkv_tensor = qkv_tensor.view(bef_all2all_shape)
214
+ qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False)
215
+ if restore_shape:
216
+ out_shape = list(orig_shape)
217
+ out_shape[seq_dim] *= world
218
+ out_shape[-1] = qkv_proj_dim // world
219
+ qkv_tensor = qkv_tensor.view(out_shape)
220
+
221
+ # remove padding
222
+ if qkv_shape is not None:
223
+ unpad_dim_size = cache(
224
+ "unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item()
225
+ )
226
+ if unpad_dim_size % world != 0:
227
+ padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size
228
+ qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size)
229
+ return qkv_tensor
230
+
231
+
232
+ def slice_inputs(x: Tensor, dim: int, padding: bool = True):
233
+ """
234
+ A func to slice the input sequence in sequence parallel
235
+ """
236
+ group = get_sequence_parallel_group()
237
+ if group is None:
238
+ return x
239
+ sp_rank = get_sequence_parallel_rank()
240
+ sp_world = get_sequence_parallel_world_size()
241
+ dim_size = x.shape[dim]
242
+ unit = (dim_size + sp_world - 1) // sp_world
243
+ if padding and dim_size % sp_world:
244
+ padding_size = sp_world - (dim_size % sp_world)
245
+ x = _pad_tensor(x, dim, padding_size)
246
+ slc = [slice(None)] * len(x.shape)
247
+ slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1))
248
+ return x[slc]
249
+
250
+
251
+ def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int):
252
+ """
253
+ A func to remove the padding part of the tensor based on its original shape
254
+ """
255
+ group = get_sequence_parallel_group()
256
+ if group is None:
257
+ return x
258
+ sp_world = get_sequence_parallel_world_size()
259
+ if unpad_dim_size % sp_world == 0:
260
+ return x
261
+ padding_size = sp_world - (unpad_dim_size % sp_world)
262
+ assert (padding_size + unpad_dim_size) % sp_world == 0
263
+ return _unpad_tensor(x, dim=dim, padding_size=padding_size)
264
+
265
+
266
+ def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
267
+ """
268
+ A func to sync attention result with alltoall in sequence parallel
269
+ """
270
+ group = get_sequence_parallel_group()
271
+ if not group:
272
+ return x
273
+ dim_size = x.size(seq_dim)
274
+ sp_world = get_sequence_parallel_world_size()
275
+ if dim_size % sp_world != 0:
276
+ padding_size = sp_world - (dim_size % sp_world)
277
+ x = _pad_tensor(x, seq_dim, padding_size)
278
+ return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
279
+
280
+
281
+ def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor:
282
+ """
283
+ A func to sync embedding input with alltoall in sequence parallel
284
+ """
285
+ group = get_sequence_parallel_group()
286
+ if not group:
287
+ return x
288
+ return SeqAllToAll.apply(group, x, head_dim, seq_dim, False)
289
+
290
+
291
+ def scatter_heads(x: Tensor, dim: int) -> Tensor:
292
+ """
293
+ A func to split heads before attention in sequence parallel
294
+ """
295
+ group = get_sequence_parallel_group()
296
+ if not group:
297
+ return x
298
+ return Slice.apply(group, x, dim)
299
+
300
+
301
+ def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor:
302
+ """
303
+ A func to gather heads for the attention result in sequence parallel
304
+ """
305
+ group = get_sequence_parallel_group()
306
+ if not group:
307
+ return x
308
+ return Gather.apply(group, x, dim, grad_scale)
309
+
310
+
311
+ def gather_outputs(
312
+ x: Tensor,
313
+ *,
314
+ gather_dim: int,
315
+ padding_dim: Optional[int] = None,
316
+ unpad_shape: Optional[Tensor] = None,
317
+ cache: Cache = Cache(disable=True),
318
+ scale_grad=True,
319
+ ):
320
+ """
321
+ A func to gather the outputs for the model result in sequence parallel
322
+ """
323
+ group = get_sequence_parallel_group()
324
+ if not group:
325
+ return x
326
+ x = Gather.apply(group, x, gather_dim, scale_grad)
327
+ if padding_dim is not None:
328
+ unpad_dim_size = cache(
329
+ "unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item()
330
+ )
331
+ x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size)
332
+ return x
333
+
334
+
335
+ def _pad_tensor(x: Tensor, dim: int, padding_size: int):
336
+ shape = list(x.shape)
337
+ shape[dim] = padding_size
338
+ pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
339
+ return torch.cat([x, pad], dim=dim)
340
+
341
+
342
+ def _unpad_tensor(x: Tensor, dim: int, padding_size):
343
+ slc = [slice(None)] * len(x.shape)
344
+ slc[dim] = slice(0, -padding_size)
345
+ return x[slc]
346
+
347
+
348
+ def _broadcast_data(data, shape, dtype, src, group, async_op):
349
+ comms = []
350
+ if isinstance(data, (list, tuple)):
351
+ for i, sub_shape in enumerate(shape):
352
+ comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op)
353
+ elif isinstance(data, dict):
354
+ for key, sub_data in data.items():
355
+ comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op)
356
+ elif isinstance(data, Tensor):
357
+ comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op))
358
+ return comms
359
+
360
+
361
+ def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]:
362
+ if isinstance(data, (list, tuple)):
363
+ return [_traverse(sub_data, op) for sub_data in data]
364
+ elif isinstance(data, dict):
365
+ return {key: _traverse(sub_data, op) for key, sub_data in data.items()}
366
+ elif isinstance(data, Tensor):
367
+ return op(data)
368
+ else:
369
+ return None
370
+
371
+
372
+ def _get_shapes(data):
373
+ return _traverse(data, op=lambda x: x.shape)
374
+
375
+
376
+ def _get_dtypes(data):
377
+ return _traverse(data, op=lambda x: x.dtype)
378
+
379
+
380
+ def _construct_broadcast_buffer(shapes, dtypes, device):
381
+ if isinstance(shapes, torch.Size):
382
+ return torch.empty(shapes, dtype=dtypes, device=device)
383
+
384
+ if isinstance(shapes, (list, tuple)):
385
+ buffer = []
386
+ for i, sub_shape in enumerate(shapes):
387
+ buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device))
388
+ elif isinstance(shapes, dict):
389
+ buffer = {}
390
+ for key, sub_shape in shapes.items():
391
+ buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device)
392
+ else:
393
+ return None
394
+ return buffer
395
+
396
+
397
+ class SPDistForward:
398
+ """A forward tool to sync different result across sp group
399
+
400
+ Args:
401
+ module: a function or module to process users input
402
+ sp_step: current training step to judge which rank to broadcast its result to all
403
+ name: a distinct str to save meta and async comm
404
+ comm_shape: if different ranks have different shape, mark this arg to True
405
+ device: the device for current rank, can be empty
406
+ """
407
+
408
+ def __init__(
409
+ self,
410
+ name: str,
411
+ comm_shape: bool,
412
+ device: torch.device = None,
413
+ ):
414
+ self.name = name
415
+ self.comm_shape = comm_shape
416
+ if device:
417
+ self.device = device
418
+ else:
419
+ self.device = get_device()
420
+
421
+ def __call__(self, inputs) -> Any:
422
+ group = get_sequence_parallel_group()
423
+ if not group:
424
+ yield inputs
425
+ else:
426
+ device = self.device
427
+ sp_world = get_sequence_parallel_world_size()
428
+ sp_rank = get_sequence_parallel_rank()
429
+ for local_step in range(sp_world):
430
+ src_rank = dist.get_global_rank(group, local_step)
431
+ is_src = sp_rank == local_step
432
+ local_shapes = []
433
+ local_dtypes = []
434
+ if local_step == 0:
435
+ local_result = inputs
436
+ _SEQ_DATA_BUF[self.name][-1] = local_result
437
+ local_shapes = _get_shapes(local_result)
438
+ local_dtypes = _get_dtypes(local_result)
439
+ if self.comm_shape:
440
+ group_shapes_lists = [None] * sp_world
441
+ dist.all_gather_object(group_shapes_lists, local_shapes, group=group)
442
+ _SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists
443
+ else:
444
+ _SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world
445
+ _SEQ_DATA_META_DTYPES[self.name] = local_dtypes
446
+ shapes = _SEQ_DATA_META_SHAPES[self.name][local_step]
447
+ dtypes = _SEQ_DATA_META_DTYPES[self.name]
448
+ buf_id = local_step % 2
449
+ if local_step == 0:
450
+ sync_data = (
451
+ local_result
452
+ if is_src
453
+ else _construct_broadcast_buffer(shapes, dtypes, device)
454
+ )
455
+ _broadcast_data(sync_data, shapes, dtypes, src_rank, group, False)
456
+ _SEQ_DATA_BUF[self.name][buf_id] = sync_data
457
+
458
+ # wait for async comm ops
459
+ if _SEQ_DATA_ASYNC_COMMS[self.name]:
460
+ for comm in _SEQ_DATA_ASYNC_COMMS[self.name]:
461
+ comm.wait()
462
+ # before return the sync result, do async broadcast for next batch
463
+ if local_step < sp_world - 1:
464
+ next_buf_id = 1 - buf_id
465
+ shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1]
466
+ src_rank = dist.get_global_rank(group, local_step + 1)
467
+ is_src = sp_rank == local_step + 1
468
+ next_sync_data = (
469
+ _SEQ_DATA_BUF[self.name][-1]
470
+ if is_src
471
+ else _construct_broadcast_buffer(shapes, dtypes, device)
472
+ )
473
+ _SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data(
474
+ next_sync_data, shapes, dtypes, src_rank, group, True
475
+ )
476
+ _SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data
477
+ yield _SEQ_DATA_BUF[self.name][buf_id]
478
+
479
+
480
+ sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True)
481
+
482
+
483
+ def sync_data(data, sp_idx, name="tmp"):
484
+ group = get_sequence_parallel_group()
485
+ if group is None:
486
+ return data
487
+ # if sp_idx in _SYNC_BUFFER[name]:
488
+ # return _SYNC_BUFFER[name][sp_idx]
489
+ sp_rank = get_sequence_parallel_rank()
490
+ src_rank = dist.get_global_rank(group, sp_idx)
491
+ objects = [data] if sp_rank == sp_idx else [None]
492
+ dist.broadcast_object_list(objects, src=src_rank, group=group)
493
+ # _SYNC_BUFFER[name] = {sp_idx: objects[0]}
494
+ return objects[0]
common/logger.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Logging utility functions.
17
+ """
18
+
19
+ import logging
20
+ import sys
21
+ from typing import Optional
22
+
23
+ from common.distributed import get_global_rank, get_local_rank, get_world_size
24
+
25
+ _default_handler = logging.StreamHandler(sys.stdout)
26
+ _default_handler.setFormatter(
27
+ logging.Formatter(
28
+ "%(asctime)s "
29
+ + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
30
+ + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
31
+ + "[%(threadName).12s][%(name)s][%(levelname).5s] "
32
+ + "%(message)s"
33
+ )
34
+ )
35
+
36
+
37
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
38
+ """
39
+ Get a logger.
40
+ """
41
+ logger = logging.getLogger(name)
42
+ logger.addHandler(_default_handler)
43
+ logger.setLevel(logging.INFO)
44
+ return logger
common/partition.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Partition utility functions.
17
+ """
18
+
19
+ from typing import Any, List
20
+
21
+
22
+ def partition_by_size(data: List[Any], size: int) -> List[List[Any]]:
23
+ """
24
+ Partition a list by size.
25
+ When indivisible, the last group contains fewer items than the target size.
26
+
27
+ Examples:
28
+ - data: [1,2,3,4,5]
29
+ - size: 2
30
+ - return: [[1,2], [3,4], [5]]
31
+ """
32
+ assert size > 0
33
+ return [data[i : (i + size)] for i in range(0, len(data), size)]
34
+
35
+
36
+ def partition_by_groups(data: List[Any], groups: int) -> List[List[Any]]:
37
+ """
38
+ Partition a list by groups.
39
+ When indivisible, some groups may have more items than others.
40
+
41
+ Examples:
42
+ - data: [1,2,3,4,5]
43
+ - groups: 2
44
+ - return: [[1,3,5], [2,4]]
45
+ """
46
+ assert groups > 0
47
+ return [data[i::groups] for i in range(groups)]
48
+
49
+
50
+ def shift_list(data: List[Any], n: int) -> List[Any]:
51
+ """
52
+ Rotate a list by n elements.
53
+
54
+ Examples:
55
+ - data: [1,2,3,4,5]
56
+ - n: 3
57
+ - return: [4,5,1,2,3]
58
+ """
59
+ return data[(n % len(data)) :] + data[: (n % len(data))]
common/seed.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import random
16
+ from typing import Optional
17
+ import numpy as np
18
+ import torch
19
+
20
+ from common.distributed import get_global_rank
21
+
22
+
23
+ def set_seed(seed: Optional[int], same_across_ranks: bool = False):
24
+ """Function that sets the seed for pseudo-random number generators."""
25
+ if seed is not None:
26
+ seed += get_global_rank() if not same_across_ranks else 0
27
+ random.seed(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+
configs_3b/main.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: projects.video_diffusion_sr.train
3
+ name: VideoDiffusionTrainer
4
+
5
+ dit:
6
+ model:
7
+ __object__:
8
+ path: models.dit_v2.nadit
9
+ name: NaDiT
10
+ args: as_params
11
+ vid_in_channels: 33
12
+ vid_out_channels: 16
13
+ vid_dim: 2560
14
+ vid_out_norm: fusedrms
15
+ txt_in_dim: 5120
16
+ txt_in_norm: fusedln
17
+ txt_dim: ${.vid_dim}
18
+ emb_dim: ${eval:'6 * ${.vid_dim}'}
19
+ heads: 20
20
+ head_dim: 128 # llm-like
21
+ expand_ratio: 4
22
+ norm: fusedrms
23
+ norm_eps: 1.0e-05
24
+ ada: single
25
+ qk_bias: False
26
+ qk_norm: fusedrms
27
+ patch_size: [ 1,2,2 ]
28
+ num_layers: 32 # llm-like
29
+ mm_layers: 10
30
+ mlp_type: swiglu
31
+ msa_type: None
32
+ block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full
33
+ window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full
34
+ window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full
35
+ rope_type: mmrope3d
36
+ rope_dim: 128
37
+ compile: False
38
+ gradient_checkpoint: True
39
+ fsdp:
40
+ sharding_strategy: _HYBRID_SHARD_ZERO2
41
+
42
+ ema:
43
+ decay: 0.9998
44
+
45
+ vae:
46
+ model:
47
+ __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml
48
+ freeze_encoder: False
49
+ # gradient_checkpoint: True
50
+ slicing:
51
+ split_size: 4
52
+ memory_device: same
53
+ memory_limit:
54
+ conv_max_mem: 0.5
55
+ norm_max_mem: 0.5
56
+ checkpoint: ./ckpts/ema_vae.pth
57
+ scaling_factor: 0.9152
58
+ compile: False
59
+ grouping: False
60
+ dtype: bfloat16
61
+
62
+ diffusion:
63
+ schedule:
64
+ type: lerp
65
+ T: 1000.0
66
+ sampler:
67
+ type: euler
68
+ prediction_type: v_lerp
69
+ timesteps:
70
+ training:
71
+ type: logitnormal
72
+ loc: 0.0
73
+ scale: 1.0
74
+ sampling:
75
+ type: uniform_trailing
76
+ steps: 50
77
+ transform: True
78
+ loss:
79
+ type: v_lerp
80
+ cfg:
81
+ scale: 7.5
82
+ rescale: 0
83
+
84
+ condition:
85
+ i2v: 0.0
86
+ v2v: 0.0
87
+ sr: 1.0
88
+ noise_scale: 0.25
configs_7b/main.yaml ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: projects.video_diffusion_sr.train
3
+ name: VideoDiffusionTrainer
4
+
5
+ dit:
6
+ model:
7
+ __object__:
8
+ path: models.dit.nadit
9
+ name: NaDiT
10
+ args: as_params
11
+ vid_in_channels: 33
12
+ vid_out_channels: 16
13
+ vid_dim: 3072
14
+ txt_in_dim: 5120
15
+ txt_dim: ${.vid_dim}
16
+ emb_dim: ${eval:'6 * ${.vid_dim}'}
17
+ heads: 24
18
+ head_dim: 128 # llm-like
19
+ expand_ratio: 4
20
+ norm: fusedrms
21
+ norm_eps: 1e-5
22
+ ada: single
23
+ qk_bias: False
24
+ qk_rope: True
25
+ qk_norm: fusedrms
26
+ patch_size: [ 1,2,2 ]
27
+ num_layers: 36 # llm-like
28
+ shared_mlp: False
29
+ shared_qkv: False
30
+ mlp_type: normal
31
+ block_type: ${eval:'${.num_layers} * ["mmdit_sr"]'} # space-full
32
+ window: ${eval:'${.num_layers} * [(4,3,3)]'} # space-full
33
+ window_method: ${eval:'${.num_layers} // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]'} # space-full
34
+ compile: False
35
+ gradient_checkpoint: True
36
+ fsdp:
37
+ sharding_strategy: _HYBRID_SHARD_ZERO2
38
+
39
+ ema:
40
+ decay: 0.9998
41
+
42
+ vae:
43
+ model:
44
+ __inherit__: models/video_vae_v3/s8_c16_t4_inflation_sd3.yaml
45
+ freeze_encoder: False
46
+ # gradient_checkpoint: True
47
+ slicing:
48
+ split_size: 4
49
+ memory_device: same
50
+ memory_limit:
51
+ conv_max_mem: 0.5
52
+ norm_max_mem: 0.5
53
+ checkpoint: ./ckpts/ema_vae.pth
54
+ scaling_factor: 0.9152
55
+ compile: False
56
+ grouping: False
57
+ dtype: bfloat16
58
+
59
+ diffusion:
60
+ schedule:
61
+ type: lerp
62
+ T: 1000.0
63
+ sampler:
64
+ type: euler
65
+ prediction_type: v_lerp
66
+ timesteps:
67
+ training:
68
+ type: logitnormal
69
+ loc: 0.0
70
+ scale: 1.0
71
+ sampling:
72
+ type: uniform_trailing
73
+ steps: 50
74
+ transform: True
75
+ loss:
76
+ type: v_lerp
77
+ cfg:
78
+ scale: 7.5
79
+ rescale: 0
80
+
81
+ condition:
82
+ i2v: 0.0
83
+ v2v: 0.0
84
+ sr: 1.0
85
+ noise_scale: 0.25
data/image/transforms/area_resize.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import math
16
+ import random
17
+ from typing import Union
18
+ import torch
19
+ from PIL import Image
20
+ from torchvision.transforms import functional as TVF
21
+ from torchvision.transforms.functional import InterpolationMode
22
+
23
+
24
+ class AreaResize:
25
+ def __init__(
26
+ self,
27
+ max_area: float,
28
+ downsample_only: bool = False,
29
+ interpolation: InterpolationMode = InterpolationMode.BICUBIC,
30
+ ):
31
+ self.max_area = max_area
32
+ self.downsample_only = downsample_only
33
+ self.interpolation = interpolation
34
+
35
+ def __call__(self, image: Union[torch.Tensor, Image.Image]):
36
+
37
+ if isinstance(image, torch.Tensor):
38
+ height, width = image.shape[-2:]
39
+ elif isinstance(image, Image.Image):
40
+ width, height = image.size
41
+ else:
42
+ raise NotImplementedError
43
+
44
+ scale = math.sqrt(self.max_area / (height * width))
45
+
46
+ # keep original height and width for small pictures.
47
+ scale = 1 if scale >= 1 and self.downsample_only else scale
48
+
49
+ resized_height, resized_width = round(height * scale), round(width * scale)
50
+
51
+ return TVF.resize(
52
+ image,
53
+ size=(resized_height, resized_width),
54
+ interpolation=self.interpolation,
55
+ )
56
+
57
+
58
+ class AreaRandomCrop:
59
+ def __init__(
60
+ self,
61
+ max_area: float,
62
+ ):
63
+ self.max_area = max_area
64
+
65
+ def get_params(self, input_size, output_size):
66
+ """Get parameters for ``crop`` for a random crop.
67
+
68
+ Args:
69
+ img (PIL Image): Image to be cropped.
70
+ output_size (tuple): Expected output size of the crop.
71
+
72
+ Returns:
73
+ tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
74
+ """
75
+ # w, h = _get_image_size(img)
76
+ h, w = input_size
77
+ th, tw = output_size
78
+ if w <= tw and h <= th:
79
+ return 0, 0, h, w
80
+
81
+ i = random.randint(0, h - th)
82
+ j = random.randint(0, w - tw)
83
+ return i, j, th, tw
84
+
85
+ def __call__(self, image: Union[torch.Tensor, Image.Image]):
86
+ if isinstance(image, torch.Tensor):
87
+ height, width = image.shape[-2:]
88
+ elif isinstance(image, Image.Image):
89
+ width, height = image.size
90
+ else:
91
+ raise NotImplementedError
92
+
93
+ resized_height = math.sqrt(self.max_area / (width / height))
94
+ resized_width = (width / height) * resized_height
95
+
96
+ # print('>>>>>>>>>>>>>>>>>>>>>')
97
+ # print((height, width))
98
+ # print( (resized_height, resized_width))
99
+
100
+ resized_height, resized_width = round(resized_height), round(resized_width)
101
+ i, j, h, w = self.get_params((height, width), (resized_height, resized_width))
102
+ image = TVF.crop(image, i, j, h, w)
103
+ return image
104
+
105
+ class ScaleResize:
106
+ def __init__(
107
+ self,
108
+ scale: float,
109
+ ):
110
+ self.scale = scale
111
+
112
+ def __call__(self, image: Union[torch.Tensor, Image.Image]):
113
+ if isinstance(image, torch.Tensor):
114
+ height, width = image.shape[-2:]
115
+ interpolation_mode = InterpolationMode.BILINEAR
116
+ antialias = True if image.ndim == 4 else "warn"
117
+ elif isinstance(image, Image.Image):
118
+ width, height = image.size
119
+ interpolation_mode = InterpolationMode.LANCZOS
120
+ antialias = "warn"
121
+ else:
122
+ raise NotImplementedError
123
+
124
+ scale = self.scale
125
+
126
+ # keep original height and width for small pictures
127
+
128
+ resized_height, resized_width = round(height * scale), round(width * scale)
129
+ image = TVF.resize(
130
+ image,
131
+ size=(resized_height, resized_width),
132
+ interpolation=interpolation_mode,
133
+ antialias=antialias,
134
+ )
135
+ return image
data/image/transforms/divisible_crop.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Union
16
+ import torch
17
+ from PIL import Image
18
+ from torchvision.transforms import functional as TVF
19
+
20
+
21
+ class DivisibleCrop:
22
+ def __init__(self, factor):
23
+ if not isinstance(factor, tuple):
24
+ factor = (factor, factor)
25
+
26
+ self.height_factor, self.width_factor = factor[0], factor[1]
27
+
28
+ def __call__(self, image: Union[torch.Tensor, Image.Image]):
29
+ if isinstance(image, torch.Tensor):
30
+ height, width = image.shape[-2:]
31
+ elif isinstance(image, Image.Image):
32
+ width, height = image.size
33
+ else:
34
+ raise NotImplementedError
35
+
36
+ cropped_height = height - (height % self.height_factor)
37
+ cropped_width = width - (width % self.width_factor)
38
+
39
+ image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width))
40
+ return image
data/image/transforms/na_resize.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Literal
16
+ from torchvision.transforms import CenterCrop, Compose, InterpolationMode, Resize
17
+
18
+ from .area_resize import AreaResize
19
+ from .side_resize import SideResize
20
+
21
+
22
+ def NaResize(
23
+ resolution: int,
24
+ mode: Literal["area", "side"],
25
+ downsample_only: bool,
26
+ interpolation: InterpolationMode = InterpolationMode.BICUBIC,
27
+ ):
28
+ if mode == "area":
29
+ return AreaResize(
30
+ max_area=resolution**2,
31
+ downsample_only=downsample_only,
32
+ interpolation=interpolation,
33
+ )
34
+ if mode == "side":
35
+ return SideResize(
36
+ size=resolution,
37
+ downsample_only=downsample_only,
38
+ interpolation=interpolation,
39
+ )
40
+ if mode == "square":
41
+ return Compose(
42
+ [
43
+ Resize(
44
+ size=resolution,
45
+ interpolation=interpolation,
46
+ ),
47
+ CenterCrop(resolution),
48
+ ]
49
+ )
50
+ raise ValueError(f"Unknown resize mode: {mode}")
data/image/transforms/side_resize.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Union
16
+ import torch
17
+ from PIL import Image
18
+ from torchvision.transforms import InterpolationMode
19
+ from torchvision.transforms import functional as TVF
20
+
21
+
22
+ class SideResize:
23
+ def __init__(
24
+ self,
25
+ size: int,
26
+ downsample_only: bool = False,
27
+ interpolation: InterpolationMode = InterpolationMode.BICUBIC,
28
+ ):
29
+ self.size = size
30
+ self.downsample_only = downsample_only
31
+ self.interpolation = interpolation
32
+
33
+ def __call__(self, image: Union[torch.Tensor, Image.Image]):
34
+ """
35
+ Args:
36
+ image (PIL Image or Tensor): Image to be scaled.
37
+
38
+ Returns:
39
+ PIL Image or Tensor: Rescaled image.
40
+ """
41
+ if isinstance(image, torch.Tensor):
42
+ height, width = image.shape[-2:]
43
+ elif isinstance(image, Image.Image):
44
+ width, height = image.size
45
+ else:
46
+ raise NotImplementedError
47
+
48
+ if self.downsample_only and min(width, height) < self.size:
49
+ # keep original height and width for small pictures.
50
+ size = min(width, height)
51
+ else:
52
+ size = self.size
53
+
54
+ return TVF.resize(image, size, self.interpolation)
data/video/transforms/rearrange.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from einops import rearrange
16
+
17
+
18
+ class Rearrange:
19
+ def __init__(self, pattern: str, **kwargs):
20
+ self.pattern = pattern
21
+ self.kwargs = kwargs
22
+
23
+ def __call__(self, x):
24
+ return rearrange(x, self.pattern, **self.kwargs)
models/dit/attention.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ from flash_attn import flash_attn_varlen_func
19
+
20
+ from torch import nn
21
+
22
+ class TorchAttention(nn.Module):
23
+ def tflops(self, args, kwargs, output) -> float:
24
+ assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs"
25
+ q = kwargs.get("query") or args[0]
26
+ k = kwargs.get("key") or args[1]
27
+ b, h, sq, d = q.shape
28
+ b, h, sk, d = k.shape
29
+ return b * h * (4 * d * (sq / 1e6) * (sk / 1e6))
30
+
31
+ def forward(self, *args, **kwargs):
32
+ return F.scaled_dot_product_attention(*args, **kwargs)
33
+
34
+
35
+ class FlashAttentionVarlen(nn.Module):
36
+ def tflops(self, args, kwargs, output) -> float:
37
+ cu_seqlens_q = kwargs["cu_seqlens_q"]
38
+ cu_seqlens_k = kwargs["cu_seqlens_k"]
39
+ _, h, d = output.shape
40
+ seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6
41
+ seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6
42
+ return h * (4 * d * (seqlens_q * seqlens_k).sum())
43
+
44
+ def forward(self, *args, **kwargs):
45
+ kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
46
+ return flash_attn_varlen_func(*args, **kwargs)
models/dit/blocks/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from .mmdit_window_block import MMWindowTransformerBlock
16
+
17
+ dit_blocks = {
18
+ "mmdit_window": MMWindowTransformerBlock,
19
+ }
20
+
21
+
22
+ def get_block(block_type: str):
23
+ if block_type in dit_blocks:
24
+ return dit_blocks[block_type]
25
+ raise NotImplementedError(f"{block_type} is not supported")
models/dit/blocks/mmdit_window_block.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Tuple, Union
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+ from torch.nn.modules.utils import _triple
21
+
22
+ from common.distributed.ops import (
23
+ gather_heads,
24
+ gather_heads_scatter_seq,
25
+ gather_seq_scatter_heads_qkv,
26
+ scatter_heads,
27
+ )
28
+
29
+ from ..attention import TorchAttention
30
+ from ..mlp import get_mlp
31
+ from ..mm import MMArg, MMModule
32
+ from ..modulation import ada_layer_type
33
+ from ..normalization import norm_layer_type
34
+ from ..rope import RotaryEmbedding3d
35
+
36
+
37
+ class MMWindowAttention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ vid_dim: int,
41
+ txt_dim: int,
42
+ heads: int,
43
+ head_dim: int,
44
+ qk_bias: bool,
45
+ qk_rope: bool,
46
+ qk_norm: norm_layer_type,
47
+ qk_norm_eps: float,
48
+ window: Union[int, Tuple[int, int, int]],
49
+ window_method: str,
50
+ shared_qkv: bool,
51
+ ):
52
+ super().__init__()
53
+ dim = MMArg(vid_dim, txt_dim)
54
+ inner_dim = heads * head_dim
55
+ qkv_dim = inner_dim * 3
56
+
57
+ self.window = _triple(window)
58
+ self.window_method = window_method
59
+ assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
60
+
61
+ self.head_dim = head_dim
62
+ self.proj_qkv = MMModule(nn.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_qkv)
63
+ self.proj_out = MMModule(nn.Linear, inner_dim, dim, shared_weights=shared_qkv)
64
+ self.norm_q = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True)
65
+ self.norm_k = MMModule(qk_norm, dim=head_dim, eps=qk_norm_eps, elementwise_affine=True)
66
+ self.rope = RotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None
67
+ self.attn = TorchAttention()
68
+
69
+ def forward(
70
+ self,
71
+ vid: torch.FloatTensor, # b T H W c
72
+ txt: torch.FloatTensor, # b L c
73
+ txt_mask: torch.BoolTensor, # b L
74
+ ) -> Tuple[
75
+ torch.FloatTensor,
76
+ torch.FloatTensor,
77
+ ]:
78
+ # Project q, k, v.
79
+ vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
80
+ vid_qkv = gather_seq_scatter_heads_qkv(vid_qkv, seq_dim=2)
81
+ _, T, H, W, _ = vid_qkv.shape
82
+ _, L, _ = txt.shape
83
+
84
+ if self.window_method == "win":
85
+ nt, nh, nw = self.window
86
+ tt, hh, ww = T // nt, H // nh, W // nw
87
+ elif self.window_method == "win_by_size":
88
+ tt, hh, ww = self.window
89
+ tt, hh, ww = (
90
+ tt if tt > 0 else T,
91
+ hh if hh > 0 else H,
92
+ ww if ww > 0 else W,
93
+ )
94
+ nt, nh, nw = T // tt, H // hh, W // ww
95
+ else:
96
+ raise NotImplementedError
97
+
98
+ vid_qkv = rearrange(vid_qkv, "b T H W (o h d) -> o b h (T H W) d", o=3, d=self.head_dim)
99
+ txt_qkv = rearrange(txt_qkv, "b L (o h d) -> o b h L d", o=3, d=self.head_dim)
100
+ txt_qkv = scatter_heads(txt_qkv, dim=2)
101
+
102
+ vid_q, vid_k, vid_v = vid_qkv.unbind()
103
+ txt_q, txt_k, txt_v = txt_qkv.unbind()
104
+
105
+ vid_q, txt_q = self.norm_q(vid_q, txt_q)
106
+ vid_k, txt_k = self.norm_k(vid_k, txt_k)
107
+
108
+ if self.rope:
109
+ vid_q, vid_k = self.rope(vid_q, vid_k, (T, H, W))
110
+
111
+ def vid_window(v):
112
+ return rearrange(
113
+ v,
114
+ "b h (nt tt nh hh nw ww) d -> b h (nt nh nw) (tt hh ww) d",
115
+ hh=hh,
116
+ ww=ww,
117
+ tt=tt,
118
+ nh=nh,
119
+ nw=nw,
120
+ nt=nt,
121
+ )
122
+
123
+ def txt_window(t):
124
+ return rearrange(t, "b h L d -> b h 1 L d").expand(-1, -1, nt * nh * nw, -1, -1)
125
+
126
+ # Process video attention.
127
+ vid_msk = F.pad(txt_mask, (tt * hh * ww, 0), value=True)
128
+ vid_msk = rearrange(vid_msk, "b l -> b 1 1 1 l").expand(-1, 1, 1, tt * hh * ww, -1)
129
+ vid_out = self.attn(
130
+ vid_window(vid_q),
131
+ torch.cat([vid_window(vid_k), txt_window(txt_k)], dim=-2),
132
+ torch.cat([vid_window(vid_v), txt_window(txt_v)], dim=-2),
133
+ vid_msk,
134
+ )
135
+ vid_out = rearrange(
136
+ vid_out,
137
+ "b h (nt nh nw) (tt hh ww) d -> b (nt tt) (nh hh) (nw ww) (h d)",
138
+ hh=hh,
139
+ ww=ww,
140
+ tt=tt,
141
+ nh=nh,
142
+ nw=nw,
143
+ )
144
+ vid_out = gather_heads_scatter_seq(vid_out, head_dim=4, seq_dim=2)
145
+
146
+ # Process text attention.
147
+ txt_msk = F.pad(txt_mask, (T * H * W, 0), value=True)
148
+ txt_msk = rearrange(txt_msk, "b l -> b 1 1 l").expand(-1, 1, L, -1)
149
+ txt_out = self.attn(
150
+ txt_q,
151
+ torch.cat([vid_k, txt_k], dim=-2),
152
+ torch.cat([vid_v, txt_v], dim=-2),
153
+ txt_msk,
154
+ )
155
+ txt_out = rearrange(txt_out, "b h L d -> b L (h d)")
156
+ txt_out = gather_heads(txt_out, dim=2)
157
+
158
+ # Project output.
159
+ vid_out, txt_out = self.proj_out(vid_out, txt_out)
160
+ return vid_out, txt_out
161
+
162
+
163
+ class MMWindowTransformerBlock(nn.Module):
164
+ def __init__(
165
+ self,
166
+ *,
167
+ vid_dim: int,
168
+ txt_dim: int,
169
+ emb_dim: int,
170
+ heads: int,
171
+ head_dim: int,
172
+ expand_ratio: int,
173
+ norm: norm_layer_type,
174
+ norm_eps: float,
175
+ ada: ada_layer_type,
176
+ qk_bias: bool,
177
+ qk_rope: bool,
178
+ qk_norm: norm_layer_type,
179
+ window: Union[int, Tuple[int, int, int]],
180
+ window_method: str,
181
+ shared_qkv: bool,
182
+ shared_mlp: bool,
183
+ mlp_type: str,
184
+ **kwargs,
185
+ ):
186
+ super().__init__()
187
+ dim = MMArg(vid_dim, txt_dim)
188
+ self.attn_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False)
189
+ self.attn = MMWindowAttention(
190
+ vid_dim=vid_dim,
191
+ txt_dim=txt_dim,
192
+ heads=heads,
193
+ head_dim=head_dim,
194
+ qk_bias=qk_bias,
195
+ qk_rope=qk_rope,
196
+ qk_norm=qk_norm,
197
+ qk_norm_eps=norm_eps,
198
+ window=window,
199
+ window_method=window_method,
200
+ shared_qkv=shared_qkv,
201
+ )
202
+ self.mlp_norm = MMModule(norm, dim=dim, eps=norm_eps, elementwise_affine=False)
203
+ self.mlp = MMModule(
204
+ get_mlp(mlp_type),
205
+ dim=dim,
206
+ expand_ratio=expand_ratio,
207
+ shared_weights=shared_mlp,
208
+ )
209
+ self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"])
210
+
211
+ def forward(
212
+ self,
213
+ vid: torch.FloatTensor,
214
+ txt: torch.FloatTensor,
215
+ txt_mask: torch.BoolTensor,
216
+ emb: torch.FloatTensor,
217
+ ) -> Tuple[
218
+ torch.FloatTensor,
219
+ torch.FloatTensor,
220
+ ]:
221
+ vid_attn, txt_attn = self.attn_norm(vid, txt)
222
+ vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="in")
223
+ vid_attn, txt_attn = self.attn(vid_attn, txt_attn, txt_mask=txt_mask)
224
+ vid_attn, txt_attn = self.ada(vid_attn, txt_attn, emb=emb, layer="attn", mode="out")
225
+ vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
226
+
227
+ vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
228
+ vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="in")
229
+ vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
230
+ vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, emb=emb, layer="mlp", mode="out")
231
+ vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
232
+
233
+ return vid_mlp, txt_mlp
models/dit/embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+ import torch
17
+ from diffusers.models.embeddings import get_timestep_embedding
18
+ from torch import nn
19
+
20
+
21
+ def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]):
22
+ return emb1 if emb2 is None else emb1 + emb2
23
+
24
+
25
+ class TimeEmbedding(nn.Module):
26
+ def __init__(
27
+ self,
28
+ sinusoidal_dim: int,
29
+ hidden_dim: int,
30
+ output_dim: int,
31
+ ):
32
+ super().__init__()
33
+ self.sinusoidal_dim = sinusoidal_dim
34
+ self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
35
+ self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
36
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
37
+ self.act = nn.SiLU()
38
+
39
+ def forward(
40
+ self,
41
+ timestep: Union[int, float, torch.IntTensor, torch.FloatTensor],
42
+ device: torch.device,
43
+ dtype: torch.dtype,
44
+ ) -> torch.FloatTensor:
45
+ if not torch.is_tensor(timestep):
46
+ timestep = torch.tensor([timestep], device=device, dtype=dtype)
47
+ if timestep.ndim == 0:
48
+ timestep = timestep[None]
49
+
50
+ emb = get_timestep_embedding(
51
+ timesteps=timestep,
52
+ embedding_dim=self.sinusoidal_dim,
53
+ flip_sin_to_cos=False,
54
+ downscale_freq_shift=0,
55
+ )
56
+ emb = emb.to(dtype)
57
+ emb = self.proj_in(emb)
58
+ emb = self.act(emb)
59
+ emb = self.proj_hid(emb)
60
+ emb = self.act(emb)
61
+ emb = self.proj_out(emb)
62
+ return emb
models/dit/mlp.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Optional
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+
21
+ def get_mlp(mlp_type: Optional[str] = "normal"):
22
+ if mlp_type == "normal":
23
+ return MLP
24
+ elif mlp_type == "swiglu":
25
+ return SwiGLUMLP
26
+
27
+
28
+ class MLP(nn.Module):
29
+ def __init__(
30
+ self,
31
+ dim: int,
32
+ expand_ratio: int,
33
+ ):
34
+ super().__init__()
35
+ self.proj_in = nn.Linear(dim, dim * expand_ratio)
36
+ self.act = nn.GELU("tanh")
37
+ self.proj_out = nn.Linear(dim * expand_ratio, dim)
38
+
39
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
40
+ x = self.proj_in(x)
41
+ x = self.act(x)
42
+ x = self.proj_out(x)
43
+ return x
44
+
45
+
46
+ class SwiGLUMLP(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ expand_ratio: int,
51
+ multiple_of: int = 256,
52
+ ):
53
+ super().__init__()
54
+ hidden_dim = int(2 * dim * expand_ratio / 3)
55
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
56
+ self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False)
57
+ self.proj_out = nn.Linear(hidden_dim, dim, bias=False)
58
+ self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
59
+
60
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
61
+ x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
62
+ return x
models/dit/mm.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Callable, Dict, List, Tuple
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ @dataclass
22
+ class MMArg:
23
+ vid: Any
24
+ txt: Any
25
+
26
+
27
+ def get_args(key: str, args: List[Any]) -> List[Any]:
28
+ return [getattr(v, key) if isinstance(v, MMArg) else v for v in args]
29
+
30
+
31
+ def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
32
+ return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
33
+
34
+
35
+ class MMModule(nn.Module):
36
+ def __init__(
37
+ self,
38
+ module: Callable[..., nn.Module],
39
+ *args,
40
+ shared_weights: bool = False,
41
+ **kwargs,
42
+ ):
43
+ super().__init__()
44
+ self.shared_weights = shared_weights
45
+ if self.shared_weights:
46
+ assert get_args("vid", args) == get_args("txt", args)
47
+ assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs)
48
+ self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
49
+ else:
50
+ self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
51
+ self.txt = module(*get_args("txt", args), **get_kwargs("txt", kwargs))
52
+
53
+ def forward(
54
+ self,
55
+ vid: torch.FloatTensor,
56
+ txt: torch.FloatTensor,
57
+ *args,
58
+ **kwargs,
59
+ ) -> Tuple[
60
+ torch.FloatTensor,
61
+ torch.FloatTensor,
62
+ ]:
63
+ vid_module = self.vid if not self.shared_weights else self.all
64
+ txt_module = self.txt if not self.shared_weights else self.all
65
+ vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
66
+ txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
67
+ return vid, txt
models/dit/modulation.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Callable, List, Optional
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import nn
19
+
20
+ from common.cache import Cache
21
+ from common.distributed.ops import slice_inputs
22
+
23
+ # (dim: int, emb_dim: int)
24
+ ada_layer_type = Callable[[int, int], nn.Module]
25
+
26
+
27
+ def get_ada_layer(ada_layer: str) -> ada_layer_type:
28
+ if ada_layer == "single":
29
+ return AdaSingle
30
+ raise NotImplementedError(f"{ada_layer} is not supported")
31
+
32
+
33
+ def expand_dims(x: torch.Tensor, dim: int, ndim: int):
34
+ """
35
+ Expand tensor "x" to "ndim" by adding empty dims at "dim".
36
+ Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d).
37
+ """
38
+ shape = x.shape
39
+ shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:]
40
+ return x.reshape(shape)
41
+
42
+
43
+ class AdaSingle(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ emb_dim: int,
48
+ layers: List[str],
49
+ ):
50
+ assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
51
+ super().__init__()
52
+ self.dim = dim
53
+ self.emb_dim = emb_dim
54
+ self.layers = layers
55
+ for l in layers:
56
+ self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5))
57
+ self.register_parameter(f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1))
58
+ self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5))
59
+
60
+ def forward(
61
+ self,
62
+ hid: torch.FloatTensor, # b ... c
63
+ emb: torch.FloatTensor, # b d
64
+ layer: str,
65
+ mode: str,
66
+ cache: Cache = Cache(disable=True),
67
+ branch_tag: str = "",
68
+ hid_len: Optional[torch.LongTensor] = None, # b
69
+ ) -> torch.FloatTensor:
70
+ idx = self.layers.index(layer)
71
+ emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
72
+ emb = expand_dims(emb, 1, hid.ndim + 1)
73
+
74
+ if hid_len is not None:
75
+ emb = cache(
76
+ f"emb_repeat_{idx}_{branch_tag}",
77
+ lambda: slice_inputs(
78
+ torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
79
+ dim=0,
80
+ ),
81
+ )
82
+
83
+ shiftA, scaleA, gateA = emb.unbind(-1)
84
+ shiftB, scaleB, gateB = (
85
+ getattr(self, f"{layer}_shift"),
86
+ getattr(self, f"{layer}_scale"),
87
+ getattr(self, f"{layer}_gate"),
88
+ )
89
+
90
+ if mode == "in":
91
+ return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
92
+ if mode == "out":
93
+ return hid.mul_(gateA + gateB)
94
+ raise NotImplementedError
95
+
96
+ def extra_repr(self) -> str:
97
+ return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}"
models/dit/na.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from itertools import chain
16
+ from typing import Callable, Dict, List, Tuple
17
+ import einops
18
+ import torch
19
+
20
+
21
+ def flatten(
22
+ hid: List[torch.FloatTensor], # List of (*** c)
23
+ ) -> Tuple[
24
+ torch.FloatTensor, # (L c)
25
+ torch.LongTensor, # (b n)
26
+ ]:
27
+ assert len(hid) > 0
28
+ shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid])
29
+ hid = torch.cat([x.flatten(0, -2) for x in hid])
30
+ return hid, shape
31
+
32
+
33
+ def unflatten(
34
+ hid: torch.FloatTensor, # (L c) or (L ... c)
35
+ hid_shape: torch.LongTensor, # (b n)
36
+ ) -> List[torch.Tensor]: # List of (*** c) or (*** ... c)
37
+ hid_len = hid_shape.prod(-1)
38
+ hid = hid.split(hid_len.tolist())
39
+ hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)]
40
+ return hid
41
+
42
+
43
+ def concat(
44
+ vid: torch.FloatTensor, # (VL ... c)
45
+ txt: torch.FloatTensor, # (TL ... c)
46
+ vid_len: torch.LongTensor, # (b)
47
+ txt_len: torch.LongTensor, # (b)
48
+ ) -> torch.FloatTensor: # (L ... c)
49
+ vid = torch.split(vid, vid_len.tolist())
50
+ txt = torch.split(txt, txt_len.tolist())
51
+ return torch.cat(list(chain(*zip(vid, txt))))
52
+
53
+
54
+ def concat_idx(
55
+ vid_len: torch.LongTensor, # (b)
56
+ txt_len: torch.LongTensor, # (b)
57
+ ) -> Tuple[
58
+ Callable,
59
+ Callable,
60
+ ]:
61
+ device = vid_len.device
62
+ vid_idx = torch.arange(vid_len.sum(), device=device)
63
+ txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
64
+ tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len)
65
+ src_idx = torch.argsort(tgt_idx)
66
+ return (
67
+ lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx),
68
+ lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]),
69
+ )
70
+
71
+
72
+ def unconcat(
73
+ all: torch.FloatTensor, # (L ... c)
74
+ vid_len: torch.LongTensor, # (b)
75
+ txt_len: torch.LongTensor, # (b)
76
+ ) -> Tuple[
77
+ torch.FloatTensor, # (VL ... c)
78
+ torch.FloatTensor, # (TL ... c)
79
+ ]:
80
+ interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist())))
81
+ all = all.split(interleave_len)
82
+ vid = torch.cat(all[0::2])
83
+ txt = torch.cat(all[1::2])
84
+ return vid, txt
85
+
86
+
87
+ def repeat_concat(
88
+ vid: torch.FloatTensor, # (VL ... c)
89
+ txt: torch.FloatTensor, # (TL ... c)
90
+ vid_len: torch.LongTensor, # (n*b)
91
+ txt_len: torch.LongTensor, # (b)
92
+ txt_repeat: List, # (n)
93
+ ) -> torch.FloatTensor: # (L ... c)
94
+ vid = torch.split(vid, vid_len.tolist())
95
+ txt = torch.split(txt, txt_len.tolist())
96
+ txt = [[x] * n for x, n in zip(txt, txt_repeat)]
97
+ txt = list(chain(*txt))
98
+ return torch.cat(list(chain(*zip(vid, txt))))
99
+
100
+
101
+ def repeat_concat_idx(
102
+ vid_len: torch.LongTensor, # (n*b)
103
+ txt_len: torch.LongTensor, # (b)
104
+ txt_repeat: torch.LongTensor, # (n)
105
+ ) -> Tuple[
106
+ Callable,
107
+ Callable,
108
+ ]:
109
+ device = vid_len.device
110
+ vid_idx = torch.arange(vid_len.sum(), device=device)
111
+ txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
112
+ txt_repeat_list = txt_repeat.tolist()
113
+ tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat)
114
+ src_idx = torch.argsort(tgt_idx)
115
+ txt_idx_len = len(tgt_idx) - len(vid_idx)
116
+ repeat_txt_len = (txt_len * txt_repeat).tolist()
117
+
118
+ def unconcat_coalesce(all):
119
+ """
120
+ Un-concat vid & txt, and coalesce the repeated txt.
121
+ e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8]
122
+ txt [9 10]
123
+ repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10]
124
+ 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10]
125
+ split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10]
126
+ 2. reshape & mean for each sample to coalesce the repeated txt.
127
+ """
128
+ vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len])
129
+ txt_out_coalesced = []
130
+ for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list):
131
+ txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1)
132
+ txt_out_coalesced.append(txt)
133
+ return vid_out, torch.cat(txt_out_coalesced)
134
+
135
+ # Note: Backward of torch.index_select is non-deterministic when existing repeated index,
136
+ # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here.
137
+ return (
138
+ lambda vid, txt: torch.cat([vid, txt])[tgt_idx],
139
+ lambda all: unconcat_coalesce(all),
140
+ )
141
+
142
+
143
+ def rearrange(
144
+ hid: torch.FloatTensor, # (L c)
145
+ hid_shape: torch.LongTensor, # (b n)
146
+ pattern: str,
147
+ **kwargs: Dict[str, int],
148
+ ) -> Tuple[
149
+ torch.FloatTensor,
150
+ torch.LongTensor,
151
+ ]:
152
+ return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)])
153
+
154
+
155
+ def rearrange_idx(
156
+ hid_shape: torch.LongTensor, # (b n)
157
+ pattern: str,
158
+ **kwargs: Dict[str, int],
159
+ ) -> Tuple[Callable, Callable, torch.LongTensor]:
160
+ hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
161
+ tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs)
162
+ tgt_idx = tgt_idx.squeeze(-1)
163
+ src_idx = torch.argsort(tgt_idx)
164
+ return (
165
+ lambda hid: torch.index_select(hid, 0, tgt_idx),
166
+ lambda hid: torch.index_select(hid, 0, src_idx),
167
+ tgt_shape,
168
+ )
169
+
170
+
171
+ def repeat(
172
+ hid: torch.FloatTensor, # (L c)
173
+ hid_shape: torch.LongTensor, # (b n)
174
+ pattern: str,
175
+ **kwargs: Dict[str, torch.LongTensor], # (b)
176
+ ) -> Tuple[
177
+ torch.FloatTensor,
178
+ torch.LongTensor,
179
+ ]:
180
+ hid = unflatten(hid, hid_shape)
181
+ kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
182
+ return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
183
+
184
+
185
+ def pack(
186
+ samples: List[torch.Tensor], # List of (h w c).
187
+ ) -> Tuple[
188
+ List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)]
189
+ List[List[int]], # reversal indices.
190
+ ]:
191
+ batches = {}
192
+ indices = {}
193
+ for i, sample in enumerate(samples):
194
+ shape = sample.shape
195
+ batches[shape] = batches.get(shape, [])
196
+ indices[shape] = indices.get(shape, [])
197
+ batches[shape].append(sample)
198
+ indices[shape].append(i)
199
+
200
+ batches = list(map(torch.stack, batches.values()))
201
+ indices = list(indices.values())
202
+ return batches, indices
203
+
204
+
205
+ def unpack(
206
+ batches: List[torch.Tensor],
207
+ indices: List[List[int]],
208
+ ) -> List[torch.Tensor]:
209
+ samples = [None] * (max(chain(*indices)) + 1)
210
+ for batch, index in zip(batches, indices):
211
+ for sample, i in zip(batch.unbind(), index):
212
+ samples[i] = sample
213
+ return samples
214
+
215
+
216
+ def window(
217
+ hid: torch.FloatTensor, # (L c)
218
+ hid_shape: torch.LongTensor, # (b n)
219
+ window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
220
+ ):
221
+ hid = unflatten(hid, hid_shape)
222
+ hid = list(map(window_fn, hid))
223
+ hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device)
224
+ hid, hid_shape = flatten(list(chain(*hid)))
225
+ return hid, hid_shape, hid_windows
226
+
227
+
228
+ def window_idx(
229
+ hid_shape: torch.LongTensor, # (b n)
230
+ window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
231
+ ):
232
+ hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
233
+ tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn)
234
+ tgt_idx = tgt_idx.squeeze(-1)
235
+ src_idx = torch.argsort(tgt_idx)
236
+ return (
237
+ lambda hid: torch.index_select(hid, 0, tgt_idx),
238
+ lambda hid: torch.index_select(hid, 0, src_idx),
239
+ tgt_shape,
240
+ tgt_windows,
241
+ )
models/dit/nablocks/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from .mmsr_block import NaMMSRTransformerBlock
16
+
17
+ nadit_blocks = {
18
+ "mmdit_sr": NaMMSRTransformerBlock,
19
+ }
20
+
21
+
22
+ def get_nablock(block_type: str):
23
+ if block_type in nadit_blocks:
24
+ return nadit_blocks[block_type]
25
+ raise NotImplementedError(f"{block_type} is not supported")
models/dit/nablocks/mmsr_block.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Tuple, Union
16
+ import torch
17
+ from einops import rearrange
18
+ from torch.nn import functional as F
19
+
20
+ # from ..cache import Cache
21
+ from common.cache import Cache
22
+ from common.distributed.ops import gather_heads_scatter_seq, gather_seq_scatter_heads_qkv
23
+
24
+ from .. import na
25
+ from ..attention import FlashAttentionVarlen
26
+ from ..blocks.mmdit_window_block import MMWindowAttention, MMWindowTransformerBlock
27
+ from ..mm import MMArg
28
+ from ..modulation import ada_layer_type
29
+ from ..normalization import norm_layer_type
30
+ from ..rope import NaRotaryEmbedding3d
31
+ from ..window import get_window_op
32
+
33
+
34
+ class NaSwinAttention(MMWindowAttention):
35
+ def __init__(
36
+ self,
37
+ vid_dim: int,
38
+ txt_dim: int,
39
+ heads: int,
40
+ head_dim: int,
41
+ qk_bias: bool,
42
+ qk_rope: bool,
43
+ qk_norm: norm_layer_type,
44
+ qk_norm_eps: float,
45
+ window: Union[int, Tuple[int, int, int]],
46
+ window_method: str,
47
+ shared_qkv: bool,
48
+ **kwargs,
49
+ ):
50
+ super().__init__(
51
+ vid_dim=vid_dim,
52
+ txt_dim=txt_dim,
53
+ heads=heads,
54
+ head_dim=head_dim,
55
+ qk_bias=qk_bias,
56
+ qk_rope=qk_rope,
57
+ qk_norm=qk_norm,
58
+ qk_norm_eps=qk_norm_eps,
59
+ window=window,
60
+ window_method=window_method,
61
+ shared_qkv=shared_qkv,
62
+ )
63
+ self.rope = NaRotaryEmbedding3d(dim=head_dim // 2) if qk_rope else None
64
+ self.attn = FlashAttentionVarlen()
65
+ self.window_op = get_window_op(window_method)
66
+
67
+ def forward(
68
+ self,
69
+ vid: torch.FloatTensor, # l c
70
+ txt: torch.FloatTensor, # l c
71
+ vid_shape: torch.LongTensor, # b 3
72
+ txt_shape: torch.LongTensor, # b 1
73
+ cache: Cache,
74
+ ) -> Tuple[
75
+ torch.FloatTensor,
76
+ torch.FloatTensor,
77
+ ]:
78
+
79
+ vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
80
+ vid_qkv = gather_seq_scatter_heads_qkv(
81
+ vid_qkv,
82
+ seq_dim=0,
83
+ qkv_shape=vid_shape,
84
+ cache=cache.namespace("vid"),
85
+ )
86
+ txt_qkv = gather_seq_scatter_heads_qkv(
87
+ txt_qkv,
88
+ seq_dim=0,
89
+ qkv_shape=txt_shape,
90
+ cache=cache.namespace("txt"),
91
+ )
92
+
93
+ # re-org the input seq for window attn
94
+ cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
95
+
96
+ def make_window(x: torch.Tensor):
97
+ t, h, w, _ = x.shape
98
+ window_slices = self.window_op((t, h, w), self.window)
99
+ return [x[st, sh, sw] for (st, sh, sw) in window_slices]
100
+
101
+ window_partition, window_reverse, window_shape, window_count = cache_win(
102
+ "win_transform",
103
+ lambda: na.window_idx(vid_shape, make_window),
104
+ )
105
+ vid_qkv_win = window_partition(vid_qkv)
106
+
107
+ vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim)
108
+ txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim)
109
+
110
+ vid_q, vid_k, vid_v = vid_qkv_win.unbind(1)
111
+ txt_q, txt_k, txt_v = txt_qkv.unbind(1)
112
+
113
+ vid_q, txt_q = self.norm_q(vid_q, txt_q)
114
+ vid_k, txt_k = self.norm_k(vid_k, txt_k)
115
+
116
+ txt_len = cache("txt_len", lambda: txt_shape.prod(-1))
117
+
118
+ vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
119
+ txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
120
+ all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
121
+ concat_win, unconcat_win = cache_win(
122
+ "mm_pnp", lambda: na.repeat_concat_idx(vid_len_win, txt_len, window_count)
123
+ )
124
+
125
+ # window rope
126
+ if self.rope:
127
+ vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
128
+
129
+ out = self.attn(
130
+ q=concat_win(vid_q, txt_q).bfloat16(),
131
+ k=concat_win(vid_k, txt_k).bfloat16(),
132
+ v=concat_win(vid_v, txt_v).bfloat16(),
133
+ cu_seqlens_q=cache_win(
134
+ "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
135
+ ),
136
+ cu_seqlens_k=cache_win(
137
+ "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
138
+ ),
139
+ max_seqlen_q=cache_win("vid_max_seqlen_q", lambda: all_len_win.max().item()),
140
+ max_seqlen_k=cache_win("vid_max_seqlen_k", lambda: all_len_win.max().item()),
141
+ ).type_as(vid_q)
142
+
143
+ # text pooling
144
+ vid_out, txt_out = unconcat_win(out)
145
+
146
+ vid_out = rearrange(vid_out, "l h d -> l (h d)")
147
+ txt_out = rearrange(txt_out, "l h d -> l (h d)")
148
+ vid_out = window_reverse(vid_out)
149
+
150
+ vid_out = gather_heads_scatter_seq(vid_out, head_dim=1, seq_dim=0)
151
+ txt_out = gather_heads_scatter_seq(txt_out, head_dim=1, seq_dim=0)
152
+
153
+ vid_out, txt_out = self.proj_out(vid_out, txt_out)
154
+
155
+ return vid_out, txt_out
156
+
157
+
158
+ class NaMMSRTransformerBlock(MMWindowTransformerBlock):
159
+ def __init__(
160
+ self,
161
+ *,
162
+ vid_dim: int,
163
+ txt_dim: int,
164
+ emb_dim: int,
165
+ heads: int,
166
+ head_dim: int,
167
+ expand_ratio: int,
168
+ norm: norm_layer_type,
169
+ norm_eps: float,
170
+ ada: ada_layer_type,
171
+ qk_bias: bool,
172
+ qk_rope: bool,
173
+ qk_norm: norm_layer_type,
174
+ shared_qkv: bool,
175
+ shared_mlp: bool,
176
+ mlp_type: str,
177
+ **kwargs,
178
+ ):
179
+ super().__init__(
180
+ vid_dim=vid_dim,
181
+ txt_dim=txt_dim,
182
+ emb_dim=emb_dim,
183
+ heads=heads,
184
+ head_dim=head_dim,
185
+ expand_ratio=expand_ratio,
186
+ norm=norm,
187
+ norm_eps=norm_eps,
188
+ ada=ada,
189
+ qk_bias=qk_bias,
190
+ qk_rope=qk_rope,
191
+ qk_norm=qk_norm,
192
+ shared_qkv=shared_qkv,
193
+ shared_mlp=shared_mlp,
194
+ mlp_type=mlp_type,
195
+ **kwargs,
196
+ )
197
+
198
+ self.attn = NaSwinAttention(
199
+ vid_dim=vid_dim,
200
+ txt_dim=txt_dim,
201
+ heads=heads,
202
+ head_dim=head_dim,
203
+ qk_bias=qk_bias,
204
+ qk_rope=qk_rope,
205
+ qk_norm=qk_norm,
206
+ qk_norm_eps=norm_eps,
207
+ shared_qkv=shared_qkv,
208
+ **kwargs,
209
+ )
210
+
211
+ def forward(
212
+ self,
213
+ vid: torch.FloatTensor, # l c
214
+ txt: torch.FloatTensor, # l c
215
+ vid_shape: torch.LongTensor, # b 3
216
+ txt_shape: torch.LongTensor, # b 1
217
+ emb: torch.FloatTensor,
218
+ cache: Cache,
219
+ ) -> Tuple[
220
+ torch.FloatTensor,
221
+ torch.FloatTensor,
222
+ torch.LongTensor,
223
+ torch.LongTensor,
224
+ ]:
225
+ hid_len = MMArg(
226
+ cache("vid_len", lambda: vid_shape.prod(-1)),
227
+ cache("txt_len", lambda: txt_shape.prod(-1)),
228
+ )
229
+ ada_kwargs = {
230
+ "emb": emb,
231
+ "hid_len": hid_len,
232
+ "cache": cache,
233
+ "branch_tag": MMArg("vid", "txt"),
234
+ }
235
+
236
+ vid_attn, txt_attn = self.attn_norm(vid, txt)
237
+ vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs)
238
+ vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache)
239
+ vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs)
240
+ vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt)
241
+
242
+ vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn)
243
+ vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs)
244
+ vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp)
245
+ vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs)
246
+ vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn)
247
+
248
+ return vid_mlp, txt_mlp, vid_shape, txt_shape
models/dit/nadit.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union, Callable
17
+ import torch
18
+ from torch import nn
19
+
20
+ from common.cache import Cache
21
+ from common.distributed.ops import slice_inputs
22
+
23
+ from . import na
24
+ from .embedding import TimeEmbedding
25
+ from .modulation import get_ada_layer
26
+ from .nablocks import get_nablock
27
+ from .normalization import get_norm_layer
28
+ from .patch import NaPatchIn, NaPatchOut
29
+
30
+ # Fake func, no checkpointing is required for inference
31
+ def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs):
32
+ return module(*args, **kwargs)
33
+
34
+ @dataclass
35
+ class NaDiTOutput:
36
+ vid_sample: torch.Tensor
37
+
38
+
39
+ class NaDiT(nn.Module):
40
+ """
41
+ Native Resolution Diffusion Transformer (NaDiT)
42
+ """
43
+
44
+ gradient_checkpointing = False
45
+
46
+ def __init__(
47
+ self,
48
+ vid_in_channels: int,
49
+ vid_out_channels: int,
50
+ vid_dim: int,
51
+ txt_in_dim: Optional[int],
52
+ txt_dim: Optional[int],
53
+ emb_dim: int,
54
+ heads: int,
55
+ head_dim: int,
56
+ expand_ratio: int,
57
+ norm: Optional[str],
58
+ norm_eps: float,
59
+ ada: str,
60
+ qk_bias: bool,
61
+ qk_rope: bool,
62
+ qk_norm: Optional[str],
63
+ patch_size: Union[int, Tuple[int, int, int]],
64
+ num_layers: int,
65
+ block_type: Union[str, Tuple[str]],
66
+ shared_qkv: bool = False,
67
+ shared_mlp: bool = False,
68
+ mlp_type: str = "normal",
69
+ window: Optional[Tuple] = None,
70
+ window_method: Optional[Tuple[str]] = None,
71
+ temporal_window_size: int = None,
72
+ temporal_shifted: bool = False,
73
+ **kwargs,
74
+ ):
75
+ ada = get_ada_layer(ada)
76
+ norm = get_norm_layer(norm)
77
+ qk_norm = get_norm_layer(qk_norm)
78
+ if isinstance(block_type, str):
79
+ block_type = [block_type] * num_layers
80
+ elif len(block_type) != num_layers:
81
+ raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
82
+ super().__init__()
83
+ self.vid_in = NaPatchIn(
84
+ in_channels=vid_in_channels,
85
+ patch_size=patch_size,
86
+ dim=vid_dim,
87
+ )
88
+ self.txt_in = (
89
+ nn.Linear(txt_in_dim, txt_dim)
90
+ if txt_in_dim and txt_in_dim != txt_dim
91
+ else nn.Identity()
92
+ )
93
+ self.emb_in = TimeEmbedding(
94
+ sinusoidal_dim=256,
95
+ hidden_dim=max(vid_dim, txt_dim),
96
+ output_dim=emb_dim,
97
+ )
98
+
99
+ if window is None or isinstance(window[0], int):
100
+ window = [window] * num_layers
101
+ if window_method is None or isinstance(window_method, str):
102
+ window_method = [window_method] * num_layers
103
+ if temporal_window_size is None or isinstance(temporal_window_size, int):
104
+ temporal_window_size = [temporal_window_size] * num_layers
105
+ if temporal_shifted is None or isinstance(temporal_shifted, bool):
106
+ temporal_shifted = [temporal_shifted] * num_layers
107
+
108
+ self.blocks = nn.ModuleList(
109
+ [
110
+ get_nablock(block_type[i])(
111
+ vid_dim=vid_dim,
112
+ txt_dim=txt_dim,
113
+ emb_dim=emb_dim,
114
+ heads=heads,
115
+ head_dim=head_dim,
116
+ expand_ratio=expand_ratio,
117
+ norm=norm,
118
+ norm_eps=norm_eps,
119
+ ada=ada,
120
+ qk_bias=qk_bias,
121
+ qk_rope=qk_rope,
122
+ qk_norm=qk_norm,
123
+ shared_qkv=shared_qkv,
124
+ shared_mlp=shared_mlp,
125
+ mlp_type=mlp_type,
126
+ window=window[i],
127
+ window_method=window_method[i],
128
+ temporal_window_size=temporal_window_size[i],
129
+ temporal_shifted=temporal_shifted[i],
130
+ **kwargs,
131
+ )
132
+ for i in range(num_layers)
133
+ ]
134
+ )
135
+ self.vid_out = NaPatchOut(
136
+ out_channels=vid_out_channels,
137
+ patch_size=patch_size,
138
+ dim=vid_dim,
139
+ )
140
+
141
+ self.need_txt_repeat = block_type[0] in [
142
+ "mmdit_stwin",
143
+ "mmdit_stwin_spatial",
144
+ "mmdit_stwin_3d_spatial",
145
+ ]
146
+
147
+ def set_gradient_checkpointing(self, enable: bool):
148
+ self.gradient_checkpointing = enable
149
+
150
+ def forward(
151
+ self,
152
+ vid: torch.FloatTensor, # l c
153
+ txt: torch.FloatTensor, # l c
154
+ vid_shape: torch.LongTensor, # b 3
155
+ txt_shape: torch.LongTensor, # b 1
156
+ timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
157
+ disable_cache: bool = True, # for test
158
+ ):
159
+ # Text input.
160
+ if txt_shape.size(-1) == 1 and self.need_txt_repeat:
161
+ txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
162
+ # slice vid after patching in when using sequence parallelism
163
+ txt = slice_inputs(txt, dim=0)
164
+ txt = self.txt_in(txt)
165
+
166
+ # Video input.
167
+ # Sequence parallel slicing is done inside patching class.
168
+ vid, vid_shape = self.vid_in(vid, vid_shape)
169
+
170
+ # Embedding input.
171
+ emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
172
+
173
+ # Body
174
+ cache = Cache(disable=disable_cache)
175
+ for i, block in enumerate(self.blocks):
176
+ vid, txt, vid_shape, txt_shape = gradient_checkpointing(
177
+ enabled=(self.gradient_checkpointing and self.training),
178
+ module=block,
179
+ vid=vid,
180
+ txt=txt,
181
+ vid_shape=vid_shape,
182
+ txt_shape=txt_shape,
183
+ emb=emb,
184
+ cache=cache,
185
+ )
186
+
187
+ vid, vid_shape = self.vid_out(vid, vid_shape, cache)
188
+ return NaDiTOutput(vid_sample=vid)
189
+
190
+
191
+ class NaDiTUpscaler(nn.Module):
192
+ """
193
+ Native Resolution Diffusion Transformer (NaDiT)
194
+ """
195
+
196
+ gradient_checkpointing = False
197
+
198
+ def __init__(
199
+ self,
200
+ vid_in_channels: int,
201
+ vid_out_channels: int,
202
+ vid_dim: int,
203
+ txt_in_dim: Optional[int],
204
+ txt_dim: Optional[int],
205
+ emb_dim: int,
206
+ heads: int,
207
+ head_dim: int,
208
+ expand_ratio: int,
209
+ norm: Optional[str],
210
+ norm_eps: float,
211
+ ada: str,
212
+ qk_bias: bool,
213
+ qk_rope: bool,
214
+ qk_norm: Optional[str],
215
+ patch_size: Union[int, Tuple[int, int, int]],
216
+ num_layers: int,
217
+ block_type: Union[str, Tuple[str]],
218
+ shared_qkv: bool = False,
219
+ shared_mlp: bool = False,
220
+ mlp_type: str = "normal",
221
+ window: Optional[Tuple] = None,
222
+ window_method: Optional[Tuple[str]] = None,
223
+ temporal_window_size: int = None,
224
+ temporal_shifted: bool = False,
225
+ **kwargs,
226
+ ):
227
+ ada = get_ada_layer(ada)
228
+ norm = get_norm_layer(norm)
229
+ qk_norm = get_norm_layer(qk_norm)
230
+ if isinstance(block_type, str):
231
+ block_type = [block_type] * num_layers
232
+ elif len(block_type) != num_layers:
233
+ raise ValueError("The ``block_type`` list should equal to ``num_layers``.")
234
+ super().__init__()
235
+ self.vid_in = NaPatchIn(
236
+ in_channels=vid_in_channels,
237
+ patch_size=patch_size,
238
+ dim=vid_dim,
239
+ )
240
+ self.txt_in = (
241
+ nn.Linear(txt_in_dim, txt_dim)
242
+ if txt_in_dim and txt_in_dim != txt_dim
243
+ else nn.Identity()
244
+ )
245
+ self.emb_in = TimeEmbedding(
246
+ sinusoidal_dim=256,
247
+ hidden_dim=max(vid_dim, txt_dim),
248
+ output_dim=emb_dim,
249
+ )
250
+
251
+ self.emb_scale = TimeEmbedding(
252
+ sinusoidal_dim=256,
253
+ hidden_dim=max(vid_dim, txt_dim),
254
+ output_dim=emb_dim,
255
+ )
256
+
257
+ if window is None or isinstance(window[0], int):
258
+ window = [window] * num_layers
259
+ if window_method is None or isinstance(window_method, str):
260
+ window_method = [window_method] * num_layers
261
+ if temporal_window_size is None or isinstance(temporal_window_size, int):
262
+ temporal_window_size = [temporal_window_size] * num_layers
263
+ if temporal_shifted is None or isinstance(temporal_shifted, bool):
264
+ temporal_shifted = [temporal_shifted] * num_layers
265
+
266
+ self.blocks = nn.ModuleList(
267
+ [
268
+ get_nablock(block_type[i])(
269
+ vid_dim=vid_dim,
270
+ txt_dim=txt_dim,
271
+ emb_dim=emb_dim,
272
+ heads=heads,
273
+ head_dim=head_dim,
274
+ expand_ratio=expand_ratio,
275
+ norm=norm,
276
+ norm_eps=norm_eps,
277
+ ada=ada,
278
+ qk_bias=qk_bias,
279
+ qk_rope=qk_rope,
280
+ qk_norm=qk_norm,
281
+ shared_qkv=shared_qkv,
282
+ shared_mlp=shared_mlp,
283
+ mlp_type=mlp_type,
284
+ window=window[i],
285
+ window_method=window_method[i],
286
+ temporal_window_size=temporal_window_size[i],
287
+ temporal_shifted=temporal_shifted[i],
288
+ **kwargs,
289
+ )
290
+ for i in range(num_layers)
291
+ ]
292
+ )
293
+ self.vid_out = NaPatchOut(
294
+ out_channels=vid_out_channels,
295
+ patch_size=patch_size,
296
+ dim=vid_dim,
297
+ )
298
+
299
+ self.need_txt_repeat = block_type[0] in [
300
+ "mmdit_stwin",
301
+ "mmdit_stwin_spatial",
302
+ "mmdit_stwin_3d_spatial",
303
+ ]
304
+
305
+ def set_gradient_checkpointing(self, enable: bool):
306
+ self.gradient_checkpointing = enable
307
+
308
+ def forward(
309
+ self,
310
+ vid: torch.FloatTensor, # l c
311
+ txt: torch.FloatTensor, # l c
312
+ vid_shape: torch.LongTensor, # b 3
313
+ txt_shape: torch.LongTensor, # b 1
314
+ timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
315
+ downscale: Union[int, float, torch.IntTensor, torch.FloatTensor], # b
316
+ disable_cache: bool = False, # for test
317
+ ):
318
+
319
+ # Text input.
320
+ if txt_shape.size(-1) == 1 and self.need_txt_repeat:
321
+ txt, txt_shape = na.repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
322
+ # slice vid after patching in when using sequence parallelism
323
+ txt = slice_inputs(txt, dim=0)
324
+ txt = self.txt_in(txt)
325
+
326
+ # Video input.
327
+ # Sequence parallel slicing is done inside patching class.
328
+ vid, vid_shape = self.vid_in(vid, vid_shape)
329
+
330
+ # Embedding input.
331
+ emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype)
332
+ emb_scale = self.emb_scale(downscale, device=vid.device, dtype=vid.dtype)
333
+ emb = emb + emb_scale
334
+
335
+ # Body
336
+ cache = Cache(disable=disable_cache)
337
+ for i, block in enumerate(self.blocks):
338
+ vid, txt, vid_shape, txt_shape = gradient_checkpointing(
339
+ enabled=(self.gradient_checkpointing and self.training),
340
+ module=block,
341
+ vid=vid,
342
+ txt=txt,
343
+ vid_shape=vid_shape,
344
+ txt_shape=txt_shape,
345
+ emb=emb,
346
+ cache=cache,
347
+ )
348
+
349
+ vid, vid_shape = self.vid_out(vid, vid_shape, cache)
350
+ return NaDiTOutput(vid_sample=vid)
models/dit/normalization.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Callable, Optional
16
+ from diffusers.models.normalization import RMSNorm
17
+ from torch import nn
18
+
19
+ # (dim: int, eps: float, elementwise_affine: bool)
20
+ norm_layer_type = Callable[[int, float, bool], nn.Module]
21
+
22
+
23
+ def get_norm_layer(norm_type: Optional[str]) -> norm_layer_type:
24
+
25
+ def _norm_layer(dim: int, eps: float, elementwise_affine: bool):
26
+ if norm_type is None:
27
+ return nn.Identity()
28
+
29
+ if norm_type == "layer":
30
+ return nn.LayerNorm(
31
+ normalized_shape=dim,
32
+ eps=eps,
33
+ elementwise_affine=elementwise_affine,
34
+ )
35
+
36
+ if norm_type == "rms":
37
+ return RMSNorm(
38
+ dim=dim,
39
+ eps=eps,
40
+ elementwise_affine=elementwise_affine,
41
+ )
42
+
43
+ if norm_type == "fusedln":
44
+ from apex.normalization import FusedLayerNorm
45
+
46
+ return FusedLayerNorm(
47
+ normalized_shape=dim,
48
+ elementwise_affine=elementwise_affine,
49
+ eps=eps,
50
+ )
51
+
52
+ if norm_type == "fusedrms":
53
+ from apex.normalization import FusedRMSNorm
54
+
55
+ return FusedRMSNorm(
56
+ normalized_shape=dim,
57
+ elementwise_affine=elementwise_affine,
58
+ eps=eps,
59
+ )
60
+
61
+ raise NotImplementedError(f"{norm_type} is not supported")
62
+
63
+ return _norm_layer
models/dit/patch.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Tuple, Union
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import nn
19
+ from torch.nn.modules.utils import _triple
20
+
21
+ from common.cache import Cache
22
+ from common.distributed.ops import gather_outputs, slice_inputs
23
+
24
+ from . import na
25
+
26
+
27
+ class PatchIn(nn.Module):
28
+ def __init__(
29
+ self,
30
+ in_channels: int,
31
+ patch_size: Union[int, Tuple[int, int, int]],
32
+ dim: int,
33
+ ):
34
+ super().__init__()
35
+ t, h, w = _triple(patch_size)
36
+ self.patch_size = t, h, w
37
+ self.proj = nn.Linear(in_channels * t * h * w, dim)
38
+
39
+ def forward(
40
+ self,
41
+ vid: torch.Tensor,
42
+ ) -> torch.Tensor:
43
+ t, h, w = self.patch_size
44
+ vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w)
45
+ vid = self.proj(vid)
46
+ return vid
47
+
48
+
49
+ class PatchOut(nn.Module):
50
+ def __init__(
51
+ self,
52
+ out_channels: int,
53
+ patch_size: Union[int, Tuple[int, int, int]],
54
+ dim: int,
55
+ ):
56
+ super().__init__()
57
+ t, h, w = _triple(patch_size)
58
+ self.patch_size = t, h, w
59
+ self.proj = nn.Linear(dim, out_channels * t * h * w)
60
+
61
+ def forward(
62
+ self,
63
+ vid: torch.Tensor,
64
+ ) -> torch.Tensor:
65
+ t, h, w = self.patch_size
66
+ vid = self.proj(vid)
67
+ vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w)
68
+ return vid
69
+
70
+
71
+ class NaPatchIn(PatchIn):
72
+ def forward(
73
+ self,
74
+ vid: torch.Tensor, # l c
75
+ vid_shape: torch.LongTensor,
76
+ ) -> torch.Tensor:
77
+ t, h, w = self.patch_size
78
+ if not (t == h == w == 1):
79
+ vid, vid_shape = na.rearrange(
80
+ vid, vid_shape, "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w
81
+ )
82
+ # slice vid after patching in when using sequence parallelism
83
+ vid = slice_inputs(vid, dim=0)
84
+ vid = self.proj(vid)
85
+ return vid, vid_shape
86
+
87
+
88
+ class NaPatchOut(PatchOut):
89
+ def forward(
90
+ self,
91
+ vid: torch.FloatTensor, # l c
92
+ vid_shape: torch.LongTensor,
93
+ cache: Cache = Cache(disable=True),
94
+ ) -> Tuple[
95
+ torch.FloatTensor,
96
+ torch.LongTensor,
97
+ ]:
98
+ t, h, w = self.patch_size
99
+ vid = self.proj(vid)
100
+ # gather vid before patching out when enabling sequence parallelism
101
+ vid = gather_outputs(
102
+ vid,
103
+ gather_dim=0,
104
+ padding_dim=0,
105
+ unpad_shape=vid_shape,
106
+ cache=cache.namespace("vid"),
107
+ )
108
+ if not (t == h == w == 1):
109
+ vid, vid_shape = na.rearrange(
110
+ vid, vid_shape, "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w
111
+ )
112
+ return vid, vid_shape
models/dit/rope.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from functools import lru_cache
16
+ from typing import Tuple
17
+ import torch
18
+ from einops import rearrange
19
+ from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
20
+ from torch import nn
21
+
22
+ from common.cache import Cache
23
+
24
+
25
+ class RotaryEmbeddingBase(nn.Module):
26
+ def __init__(self, dim: int, rope_dim: int):
27
+ super().__init__()
28
+ self.rope = RotaryEmbedding(
29
+ dim=dim // rope_dim,
30
+ freqs_for="pixel",
31
+ max_freq=256,
32
+ )
33
+ # 1. Set model.requires_grad_(True) after model creation will make
34
+ # the `requires_grad=False` for rope freqs no longer hold.
35
+ # 2. Even if we don't set requires_grad_(True) explicitly,
36
+ # FSDP is not memory efficient when handling fsdp_wrap
37
+ # with mixed requires_grad=True/False.
38
+ # With above consideration, it is easier just remove the freqs
39
+ # out of nn.Parameters when `learned_freq=False`
40
+ freqs = self.rope.freqs
41
+ del self.rope.freqs
42
+ self.rope.register_buffer("freqs", freqs.data)
43
+
44
+ @lru_cache(maxsize=128)
45
+ def get_axial_freqs(self, *dims):
46
+ return self.rope.get_axial_freqs(*dims)
47
+
48
+
49
+ class RotaryEmbedding3d(RotaryEmbeddingBase):
50
+ def __init__(self, dim: int):
51
+ super().__init__(dim, rope_dim=3)
52
+
53
+ def forward(
54
+ self,
55
+ q: torch.FloatTensor, # b h l d
56
+ k: torch.FloatTensor, # b h l d
57
+ size: Tuple[int, int, int],
58
+ ) -> Tuple[
59
+ torch.FloatTensor,
60
+ torch.FloatTensor,
61
+ ]:
62
+ T, H, W = size
63
+ freqs = self.get_axial_freqs(T, H, W)
64
+ q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W)
65
+ k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W)
66
+ q = apply_rotary_emb(freqs, q)
67
+ k = apply_rotary_emb(freqs, k)
68
+ q = rearrange(q, "b h T H W d -> b h (T H W) d")
69
+ k = rearrange(k, "b h T H W d -> b h (T H W) d")
70
+ return q, k
71
+
72
+
73
+ class NaRotaryEmbedding3d(RotaryEmbedding3d):
74
+ def forward(
75
+ self,
76
+ q: torch.FloatTensor, # L h d
77
+ k: torch.FloatTensor, # L h d
78
+ shape: torch.LongTensor,
79
+ cache: Cache,
80
+ ) -> Tuple[
81
+ torch.FloatTensor,
82
+ torch.FloatTensor,
83
+ ]:
84
+ freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape))
85
+ q = rearrange(q, "L h d -> h L d")
86
+ k = rearrange(k, "L h d -> h L d")
87
+ q = apply_rotary_emb(freqs, q.float()).to(q.dtype)
88
+ k = apply_rotary_emb(freqs, k.float()).to(k.dtype)
89
+ q = rearrange(q, "h L d -> L h d")
90
+ k = rearrange(k, "h L d -> L h d")
91
+ return q, k
92
+
93
+ def get_freqs(
94
+ self,
95
+ shape: torch.LongTensor,
96
+ ) -> torch.Tensor:
97
+ freq_list = []
98
+ for f, h, w in shape.tolist():
99
+ freqs = self.get_axial_freqs(f, h, w)
100
+ freq_list.append(freqs.view(-1, freqs.size(-1)))
101
+ return torch.cat(freq_list, dim=0)
models/dit/window.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from math import ceil
16
+ from typing import Tuple
17
+ import math
18
+
19
+ def get_window_op(name: str):
20
+ if name == "720pwin_by_size_bysize":
21
+ return make_720Pwindows_bysize
22
+ if name == "720pswin_by_size_bysize":
23
+ return make_shifted_720Pwindows_bysize
24
+ raise ValueError(f"Unknown windowing method: {name}")
25
+
26
+
27
+ # -------------------------------- Windowing -------------------------------- #
28
+ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
29
+ t, h, w = size
30
+ resized_nt, resized_nh, resized_nw = num_windows
31
+ #cal windows under 720p
32
+ scale = math.sqrt((45 * 80) / (h * w))
33
+ resized_h, resized_w = round(h * scale), round(w * scale)
34
+ wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
35
+ wt = ceil(min(t, 30) / resized_nt) # window size.
36
+ nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
37
+ return [
38
+ (
39
+ slice(it * wt, min((it + 1) * wt, t)),
40
+ slice(ih * wh, min((ih + 1) * wh, h)),
41
+ slice(iw * ww, min((iw + 1) * ww, w)),
42
+ )
43
+ for iw in range(nw)
44
+ if min((iw + 1) * ww, w) > iw * ww
45
+ for ih in range(nh)
46
+ if min((ih + 1) * wh, h) > ih * wh
47
+ for it in range(nt)
48
+ if min((it + 1) * wt, t) > it * wt
49
+ ]
50
+
51
+ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
52
+ t, h, w = size
53
+ resized_nt, resized_nh, resized_nw = num_windows
54
+ #cal windows under 720p
55
+ scale = math.sqrt((45 * 80) / (h * w))
56
+ resized_h, resized_w = round(h * scale), round(w * scale)
57
+ wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
58
+ wt = ceil(min(t, 30) / resized_nt) # window size.
59
+
60
+ st, sh, sw = ( # shift size.
61
+ 0.5 if wt < t else 0,
62
+ 0.5 if wh < h else 0,
63
+ 0.5 if ww < w else 0,
64
+ )
65
+ nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size.
66
+ nt, nh, nw = ( # number of window.
67
+ nt + 1 if st > 0 else 1,
68
+ nh + 1 if sh > 0 else 1,
69
+ nw + 1 if sw > 0 else 1,
70
+ )
71
+ return [
72
+ (
73
+ slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)),
74
+ slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)),
75
+ slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)),
76
+ )
77
+ for iw in range(nw)
78
+ if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0)
79
+ for ih in range(nh)
80
+ if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0)
81
+ for it in range(nt)
82
+ if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0)
83
+ ]
models/dit_v2/attention.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+ from flash_attn import flash_attn_varlen_func
19
+
20
+ from torch import nn
21
+
22
+ class TorchAttention(nn.Module):
23
+ def tflops(self, args, kwargs, output) -> float:
24
+ assert len(args) == 0 or len(args) > 2, "query, key should both provided by args / kwargs"
25
+ q = kwargs.get("query") or args[0]
26
+ k = kwargs.get("key") or args[1]
27
+ b, h, sq, d = q.shape
28
+ b, h, sk, d = k.shape
29
+ return b * h * (4 * d * (sq / 1e6) * (sk / 1e6))
30
+
31
+ def forward(self, *args, **kwargs):
32
+ return F.scaled_dot_product_attention(*args, **kwargs)
33
+
34
+
35
+ class FlashAttentionVarlen(nn.Module):
36
+ def tflops(self, args, kwargs, output) -> float:
37
+ cu_seqlens_q = kwargs["cu_seqlens_q"]
38
+ cu_seqlens_k = kwargs["cu_seqlens_k"]
39
+ _, h, d = output.shape
40
+ seqlens_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]) / 1e6
41
+ seqlens_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]) / 1e6
42
+ return h * (4 * d * (seqlens_q * seqlens_k).sum())
43
+
44
+ def forward(self, *args, **kwargs):
45
+ kwargs["deterministic"] = torch.are_deterministic_algorithms_enabled()
46
+ return flash_attn_varlen_func(*args, **kwargs)
models/dit_v2/embedding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+ import torch
17
+ from diffusers.models.embeddings import get_timestep_embedding
18
+ from torch import nn
19
+
20
+
21
+ def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]):
22
+ return emb1 if emb2 is None else emb1 + emb2
23
+
24
+
25
+ class TimeEmbedding(nn.Module):
26
+ def __init__(
27
+ self,
28
+ sinusoidal_dim: int,
29
+ hidden_dim: int,
30
+ output_dim: int,
31
+ ):
32
+ super().__init__()
33
+ self.sinusoidal_dim = sinusoidal_dim
34
+ self.proj_in = nn.Linear(sinusoidal_dim, hidden_dim)
35
+ self.proj_hid = nn.Linear(hidden_dim, hidden_dim)
36
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
37
+ self.act = nn.SiLU()
38
+
39
+ def forward(
40
+ self,
41
+ timestep: Union[int, float, torch.IntTensor, torch.FloatTensor],
42
+ device: torch.device,
43
+ dtype: torch.dtype,
44
+ ) -> torch.FloatTensor:
45
+ if not torch.is_tensor(timestep):
46
+ timestep = torch.tensor([timestep], device=device, dtype=dtype)
47
+ if timestep.ndim == 0:
48
+ timestep = timestep[None]
49
+
50
+ emb = get_timestep_embedding(
51
+ timesteps=timestep,
52
+ embedding_dim=self.sinusoidal_dim,
53
+ flip_sin_to_cos=False,
54
+ downscale_freq_shift=0,
55
+ )
56
+ emb = emb.to(dtype)
57
+ emb = self.proj_in(emb)
58
+ emb = self.act(emb)
59
+ emb = self.proj_hid(emb)
60
+ emb = self.act(emb)
61
+ emb = self.proj_out(emb)
62
+ return emb
models/dit_v2/mlp.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Optional
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+
20
+
21
+ def get_mlp(mlp_type: Optional[str] = "normal"):
22
+ if mlp_type == "normal":
23
+ return MLP
24
+ elif mlp_type == "swiglu":
25
+ return SwiGLUMLP
26
+
27
+
28
+ class MLP(nn.Module):
29
+ def __init__(
30
+ self,
31
+ dim: int,
32
+ expand_ratio: int,
33
+ ):
34
+ super().__init__()
35
+ self.proj_in = nn.Linear(dim, dim * expand_ratio)
36
+ self.act = nn.GELU("tanh")
37
+ self.proj_out = nn.Linear(dim * expand_ratio, dim)
38
+
39
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
40
+ x = self.proj_in(x)
41
+ x = self.act(x)
42
+ x = self.proj_out(x)
43
+ return x
44
+
45
+
46
+ class SwiGLUMLP(nn.Module):
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ expand_ratio: int,
51
+ multiple_of: int = 256,
52
+ ):
53
+ super().__init__()
54
+ hidden_dim = int(2 * dim * expand_ratio / 3)
55
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
56
+ self.proj_in_gate = nn.Linear(dim, hidden_dim, bias=False)
57
+ self.proj_out = nn.Linear(hidden_dim, dim, bias=False)
58
+ self.proj_in = nn.Linear(dim, hidden_dim, bias=False)
59
+
60
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
61
+ x = self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
62
+ return x
models/dit_v2/mm.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Any, Callable, Dict, List, Tuple
17
+ import torch
18
+ from torch import nn
19
+
20
+
21
+ @dataclass
22
+ class MMArg:
23
+ vid: Any
24
+ txt: Any
25
+
26
+
27
+ def get_args(key: str, args: List[Any]) -> List[Any]:
28
+ return [getattr(v, key) if isinstance(v, MMArg) else v for v in args]
29
+
30
+
31
+ def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
32
+ return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()}
33
+
34
+
35
+ class MMModule(nn.Module):
36
+ def __init__(
37
+ self,
38
+ module: Callable[..., nn.Module],
39
+ *args,
40
+ shared_weights: bool = False,
41
+ vid_only: bool = False,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+ self.shared_weights = shared_weights
46
+ self.vid_only = vid_only
47
+ if self.shared_weights:
48
+ assert get_args("vid", args) == get_args("txt", args)
49
+ assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs)
50
+ self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
51
+ else:
52
+ self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
53
+ self.txt = (
54
+ module(*get_args("txt", args), **get_kwargs("txt", kwargs))
55
+ if not vid_only
56
+ else None
57
+ )
58
+
59
+ def forward(
60
+ self,
61
+ vid: torch.FloatTensor,
62
+ txt: torch.FloatTensor,
63
+ *args,
64
+ **kwargs,
65
+ ) -> Tuple[
66
+ torch.FloatTensor,
67
+ torch.FloatTensor,
68
+ ]:
69
+ vid_module = self.vid if not self.shared_weights else self.all
70
+ vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs))
71
+ if not self.vid_only:
72
+ txt_module = self.txt if not self.shared_weights else self.all
73
+ txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs))
74
+ return vid, txt
models/dit_v2/modulation.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Callable, List, Optional
16
+ import torch
17
+ from einops import rearrange
18
+ from torch import nn
19
+
20
+ from common.cache import Cache
21
+ from common.distributed.ops import slice_inputs
22
+
23
+ # (dim: int, emb_dim: int)
24
+ ada_layer_type = Callable[[int, int], nn.Module]
25
+
26
+
27
+ def get_ada_layer(ada_layer: str) -> ada_layer_type:
28
+ if ada_layer == "single":
29
+ return AdaSingle
30
+ raise NotImplementedError(f"{ada_layer} is not supported")
31
+
32
+
33
+ def expand_dims(x: torch.Tensor, dim: int, ndim: int):
34
+ """
35
+ Expand tensor "x" to "ndim" by adding empty dims at "dim".
36
+ Example: x is (b d), target ndim is 5, add dim at 1, return (b 1 1 1 d).
37
+ """
38
+ shape = x.shape
39
+ shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:]
40
+ return x.reshape(shape)
41
+
42
+
43
+ class AdaSingle(nn.Module):
44
+ def __init__(
45
+ self,
46
+ dim: int,
47
+ emb_dim: int,
48
+ layers: List[str],
49
+ modes: List[str] = ["in", "out"],
50
+ ):
51
+ assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
52
+ super().__init__()
53
+ self.dim = dim
54
+ self.emb_dim = emb_dim
55
+ self.layers = layers
56
+ for l in layers:
57
+ if "in" in modes:
58
+ self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim) / dim**0.5))
59
+ self.register_parameter(
60
+ f"{l}_scale", nn.Parameter(torch.randn(dim) / dim**0.5 + 1)
61
+ )
62
+ if "out" in modes:
63
+ self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim) / dim**0.5))
64
+
65
+ def forward(
66
+ self,
67
+ hid: torch.FloatTensor, # b ... c
68
+ emb: torch.FloatTensor, # b d
69
+ layer: str,
70
+ mode: str,
71
+ cache: Cache = Cache(disable=True),
72
+ branch_tag: str = "",
73
+ hid_len: Optional[torch.LongTensor] = None, # b
74
+ ) -> torch.FloatTensor:
75
+ idx = self.layers.index(layer)
76
+ emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :]
77
+ emb = expand_dims(emb, 1, hid.ndim + 1)
78
+
79
+ if hid_len is not None:
80
+ emb = cache(
81
+ f"emb_repeat_{idx}_{branch_tag}",
82
+ lambda: slice_inputs(
83
+ torch.cat([e.repeat(l, *([1] * e.ndim)) for e, l in zip(emb, hid_len)]),
84
+ dim=0,
85
+ ),
86
+ )
87
+
88
+ shiftA, scaleA, gateA = emb.unbind(-1)
89
+ shiftB, scaleB, gateB = (
90
+ getattr(self, f"{layer}_shift", None),
91
+ getattr(self, f"{layer}_scale", None),
92
+ getattr(self, f"{layer}_gate", None),
93
+ )
94
+
95
+ if mode == "in":
96
+ return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB)
97
+ if mode == "out":
98
+ return hid.mul_(gateA + gateB)
99
+ raise NotImplementedError
100
+
101
+ def extra_repr(self) -> str:
102
+ return f"dim={self.dim}, emb_dim={self.emb_dim}, layers={self.layers}"
models/dit_v2/na.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from itertools import chain
16
+ from typing import Callable, Dict, List, Tuple
17
+ import einops
18
+ import torch
19
+
20
+
21
+ def flatten(
22
+ hid: List[torch.FloatTensor], # List of (*** c)
23
+ ) -> Tuple[
24
+ torch.FloatTensor, # (L c)
25
+ torch.LongTensor, # (b n)
26
+ ]:
27
+ assert len(hid) > 0
28
+ shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid])
29
+ hid = torch.cat([x.flatten(0, -2) for x in hid])
30
+ return hid, shape
31
+
32
+
33
+ def unflatten(
34
+ hid: torch.FloatTensor, # (L c) or (L ... c)
35
+ hid_shape: torch.LongTensor, # (b n)
36
+ ) -> List[torch.Tensor]: # List of (*** c) or (*** ... c)
37
+ hid_len = hid_shape.prod(-1)
38
+ hid = hid.split(hid_len.tolist())
39
+ hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)]
40
+ return hid
41
+
42
+
43
+ def concat(
44
+ vid: torch.FloatTensor, # (VL ... c)
45
+ txt: torch.FloatTensor, # (TL ... c)
46
+ vid_len: torch.LongTensor, # (b)
47
+ txt_len: torch.LongTensor, # (b)
48
+ ) -> torch.FloatTensor: # (L ... c)
49
+ vid = torch.split(vid, vid_len.tolist())
50
+ txt = torch.split(txt, txt_len.tolist())
51
+ return torch.cat(list(chain(*zip(vid, txt))))
52
+
53
+
54
+ def concat_idx(
55
+ vid_len: torch.LongTensor, # (b)
56
+ txt_len: torch.LongTensor, # (b)
57
+ ) -> Tuple[
58
+ Callable,
59
+ Callable,
60
+ ]:
61
+ device = vid_len.device
62
+ vid_idx = torch.arange(vid_len.sum(), device=device)
63
+ txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
64
+ tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len)
65
+ src_idx = torch.argsort(tgt_idx)
66
+ return (
67
+ lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx),
68
+ lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]),
69
+ )
70
+
71
+
72
+ def unconcat(
73
+ all: torch.FloatTensor, # (L ... c)
74
+ vid_len: torch.LongTensor, # (b)
75
+ txt_len: torch.LongTensor, # (b)
76
+ ) -> Tuple[
77
+ torch.FloatTensor, # (VL ... c)
78
+ torch.FloatTensor, # (TL ... c)
79
+ ]:
80
+ interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist())))
81
+ all = all.split(interleave_len)
82
+ vid = torch.cat(all[0::2])
83
+ txt = torch.cat(all[1::2])
84
+ return vid, txt
85
+
86
+
87
+ def repeat_concat(
88
+ vid: torch.FloatTensor, # (VL ... c)
89
+ txt: torch.FloatTensor, # (TL ... c)
90
+ vid_len: torch.LongTensor, # (n*b)
91
+ txt_len: torch.LongTensor, # (b)
92
+ txt_repeat: List, # (n)
93
+ ) -> torch.FloatTensor: # (L ... c)
94
+ vid = torch.split(vid, vid_len.tolist())
95
+ txt = torch.split(txt, txt_len.tolist())
96
+ txt = [[x] * n for x, n in zip(txt, txt_repeat)]
97
+ txt = list(chain(*txt))
98
+ return torch.cat(list(chain(*zip(vid, txt))))
99
+
100
+
101
+ def repeat_concat_idx(
102
+ vid_len: torch.LongTensor, # (n*b)
103
+ txt_len: torch.LongTensor, # (b)
104
+ txt_repeat: torch.LongTensor, # (n)
105
+ ) -> Tuple[
106
+ Callable,
107
+ Callable,
108
+ ]:
109
+ device = vid_len.device
110
+ vid_idx = torch.arange(vid_len.sum(), device=device)
111
+ txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
112
+ txt_repeat_list = txt_repeat.tolist()
113
+ tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat)
114
+ src_idx = torch.argsort(tgt_idx)
115
+ txt_idx_len = len(tgt_idx) - len(vid_idx)
116
+ repeat_txt_len = (txt_len * txt_repeat).tolist()
117
+
118
+ def unconcat_coalesce(all):
119
+ """
120
+ Un-concat vid & txt, and coalesce the repeated txt.
121
+ e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8]
122
+ txt [9 10]
123
+ repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10]
124
+ 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10]
125
+ split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10]
126
+ 2. reshape & mean for each sample to coalesce the repeated txt.
127
+ """
128
+ vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len])
129
+ txt_out_coalesced = []
130
+ for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list):
131
+ txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1)
132
+ txt_out_coalesced.append(txt)
133
+ return vid_out, torch.cat(txt_out_coalesced)
134
+
135
+ # Note: Backward of torch.index_select is non-deterministic when existing repeated index,
136
+ # the difference may cumulative like torch.repeat_interleave, so we use vanilla index here.
137
+ return (
138
+ lambda vid, txt: torch.cat([vid, txt])[tgt_idx],
139
+ lambda all: unconcat_coalesce(all),
140
+ )
141
+
142
+
143
+ def rearrange(
144
+ hid: torch.FloatTensor, # (L c)
145
+ hid_shape: torch.LongTensor, # (b n)
146
+ pattern: str,
147
+ **kwargs: Dict[str, int],
148
+ ) -> Tuple[
149
+ torch.FloatTensor,
150
+ torch.LongTensor,
151
+ ]:
152
+ return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)])
153
+
154
+
155
+ def rearrange_idx(
156
+ hid_shape: torch.LongTensor, # (b n)
157
+ pattern: str,
158
+ **kwargs: Dict[str, int],
159
+ ) -> Tuple[Callable, Callable, torch.LongTensor]:
160
+ hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
161
+ tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs)
162
+ tgt_idx = tgt_idx.squeeze(-1)
163
+ src_idx = torch.argsort(tgt_idx)
164
+ return (
165
+ lambda hid: torch.index_select(hid, 0, tgt_idx),
166
+ lambda hid: torch.index_select(hid, 0, src_idx),
167
+ tgt_shape,
168
+ )
169
+
170
+
171
+ def repeat(
172
+ hid: torch.FloatTensor, # (L c)
173
+ hid_shape: torch.LongTensor, # (b n)
174
+ pattern: str,
175
+ **kwargs: Dict[str, torch.LongTensor], # (b)
176
+ ) -> Tuple[
177
+ torch.FloatTensor,
178
+ torch.LongTensor,
179
+ ]:
180
+ hid = unflatten(hid, hid_shape)
181
+ kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))]
182
+ return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)])
183
+
184
+
185
+ def pack(
186
+ samples: List[torch.Tensor], # List of (h w c).
187
+ ) -> Tuple[
188
+ List[torch.Tensor], # groups [(b1 h1 w1 c1), (b2 h2 w2 c2)]
189
+ List[List[int]], # reversal indices.
190
+ ]:
191
+ batches = {}
192
+ indices = {}
193
+ for i, sample in enumerate(samples):
194
+ shape = sample.shape
195
+ batches[shape] = batches.get(shape, [])
196
+ indices[shape] = indices.get(shape, [])
197
+ batches[shape].append(sample)
198
+ indices[shape].append(i)
199
+
200
+ batches = list(map(torch.stack, batches.values()))
201
+ indices = list(indices.values())
202
+ return batches, indices
203
+
204
+
205
+ def unpack(
206
+ batches: List[torch.Tensor],
207
+ indices: List[List[int]],
208
+ ) -> List[torch.Tensor]:
209
+ samples = [None] * (max(chain(*indices)) + 1)
210
+ for batch, index in zip(batches, indices):
211
+ for sample, i in zip(batch.unbind(), index):
212
+ samples[i] = sample
213
+ return samples
214
+
215
+
216
+ def window(
217
+ hid: torch.FloatTensor, # (L c)
218
+ hid_shape: torch.LongTensor, # (b n)
219
+ window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
220
+ ):
221
+ hid = unflatten(hid, hid_shape)
222
+ hid = list(map(window_fn, hid))
223
+ hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device)
224
+ hid, hid_shape = flatten(list(chain(*hid)))
225
+ return hid, hid_shape, hid_windows
226
+
227
+
228
+ def window_idx(
229
+ hid_shape: torch.LongTensor, # (b n)
230
+ window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
231
+ ):
232
+ hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
233
+ tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn)
234
+ tgt_idx = tgt_idx.squeeze(-1)
235
+ src_idx = torch.argsort(tgt_idx)
236
+ return (
237
+ lambda hid: torch.index_select(hid, 0, tgt_idx),
238
+ lambda hid: torch.index_select(hid, 0, src_idx),
239
+ tgt_shape,
240
+ tgt_windows,
241
+ )