Carlexxx commited on
Commit
0344c73
·
1 Parent(s): 05f2657

feat: Implement self-contained specialist managers

Browse files
common/README.md ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🛠️ helpers/ - Ferramentas de IA de Terceiros Adaptadas para ADUC-SDR
2
+
3
+ Esta pasta contém implementações adaptadas de modelos e utilitários de IA de terceiros, que servem como "especialistas" ou "ferramentas" de baixo nível para a arquitetura ADUC-SDR.
4
+
5
+ **IMPORTANTE:** O conteúdo desta pasta é de autoria de seus respectivos idealizadores e desenvolvedores originais. Esta pasta **NÃO FAZ PARTE** do projeto principal ADUC-SDR em termos de sua arquitetura inovadora. Ela serve como um repositório para as **dependências diretas e modificadas** que os `DeformesXDEngines` (os estágios do "foguete" ADUC-SDR) invocam para realizar tarefas específicas (geração de imagem, vídeo, áudio).
6
+
7
+ As modificações realizadas nos arquivos aqui presentes visam principalmente:
8
+ 1. **Adaptação de Interfaces:** Padronizar as interfaces para que se encaixem no fluxo de orquestração do ADUC-SDR.
9
+ 2. **Gerenciamento de Recursos:** Integrar lógicas de carregamento/descarregamento de modelos (GPU management) e configurações via arquivos YAML.
10
+ 3. **Otimização de Fluxo:** Ajustar as pipelines para aceitar formatos de entrada mais eficientes (ex: tensores pré-codificados em vez de caminhos de mídia, pulando etapas de codificação/decodificação redundantes).
11
+
12
+ ---
13
+
14
+ ## 📄 Licenciamento
15
+
16
+ O conteúdo original dos projetos listados abaixo é licenciado sob a **Licença Apache 2.0**, ou outra licença especificada pelos autores originais. Todas as modificações e o uso desses arquivos dentro da estrutura `helpers/` do projeto ADUC-SDR estão em conformidade com os termos da **Licença Apache 2.0**.
17
+
18
+ As licenças originais dos projetos podem ser encontradas nas suas respectivas fontes ou nos subdiretórios `incl_licenses/` dentro de cada módulo adaptado.
19
+
20
+ ---
21
+
22
+ ## 🛠️ API dos Helpers e Guia de Uso
23
+
24
+ Esta seção detalha como cada helper (agente especialista) deve ser utilizado dentro do ecossistema ADUC-SDR. Todos os agentes são instanciados como **singletons** no `hardware_manager.py` para garantir o gerenciamento centralizado de recursos de GPU.
25
+
26
+ ### **gemini_helpers.py (GeminiAgent)**
27
+
28
+ * **Propósito:** Atua como o "Oráculo de Síntese Adaptativo", responsável por todas as tarefas de processamento de linguagem natural, como criação de storyboards, geração de prompts, e tomada de decisões narrativas.
29
+ * **Singleton Instance:** `gemini_agent_singleton`
30
+ * **Construtor:** `GeminiAgent()`
31
+ * Lê `configs/gemini_config.yaml` para obter o nome do modelo, parâmetros de inferência e caminhos de templates de prompt. A chave da API é lida da variável de ambiente `GEMINI_API_KEY`.
32
+ * **Métodos Públicos:**
33
+ * `generate_storyboard(prompt: str, num_keyframes: int, ref_image_paths: list[str])`
34
+ * **Inputs:**
35
+ * `prompt`: A ideia geral do filme (string).
36
+ * `num_keyframes`: O número de cenas a serem geradas (int).
37
+ * `ref_image_paths`: Lista de caminhos para as imagens de referência (list[str]).
38
+ * **Output:** `tuple[list[str], str]` (Uma tupla contendo a lista de strings do storyboard e um relatório textual da operação).
39
+ * `select_keyframes_from_pool(storyboard: list, base_image_paths: list[str], pool_image_paths: list[str])`
40
+ * **Inputs:**
41
+ * `storyboard`: A lista de strings do storyboard gerado.
42
+ * `base_image_paths`: Imagens de referência base (list[str]).
43
+ * `pool_image_paths`: O "banco de imagens" de onde selecionar (list[str]).
44
+ * **Output:** `tuple[list[str], str]` (Uma tupla contendo a lista de caminhos de imagens selecionadas e um relatório textual).
45
+ * `get_anticipatory_keyframe_prompt(...)`
46
+ * **Inputs:** Contexto narrativo e visual para gerar um prompt de imagem.
47
+ * **Output:** `tuple[str, str]` (Uma tupla contendo o prompt gerado para o modelo de imagem e um relatório textual).
48
+ * `get_initial_motion_prompt(...)`
49
+ * **Inputs:** Contexto narrativo e visual para a primeira transição de vídeo.
50
+ * **Output:** `tuple[str, str]` (Uma tupla contendo o prompt de movimento gerado e um relatório textual).
51
+ * `get_transition_decision(...)`
52
+ * **Inputs:** Contexto narrativo e visual para uma transição de vídeo intermediária.
53
+ * **Output:** `tuple[dict, str]` (Uma tupla contendo um dicionário `{"transition_type": "...", "motion_prompt": "..."}` e um relatório textual).
54
+ * `generate_audio_prompts(...)`
55
+ * **Inputs:** Contexto narrativo global.
56
+ * **Output:** `tuple[dict, str]` (Uma tupla contendo um dicionário `{"music_prompt": "...", "sfx_prompt": "..."}` e um relatório textual).
57
+
58
+ ### **flux_kontext_helpers.py (FluxPoolManager)**
59
+
60
+ * **Propósito:** Especialista em geração de imagens de alta qualidade (keyframes) usando a pipeline FluxKontext. Gerencia um pool de workers para otimizar o uso de múltiplas GPUs.
61
+ * **Singleton Instance:** `flux_kontext_singleton`
62
+ * **Construtor:** `FluxPoolManager(device_ids: list[str], flux_config_file: str)`
63
+ * Lê `configs/flux_config.yaml`.
64
+ * **Método Público:**
65
+ * `generate_image(prompt: str, reference_images: list[Image.Image], width: int, height: int, seed: int = 42, callback: callable = None)`
66
+ * **Inputs:**
67
+ * `prompt`: Prompt textual para guiar a geração (string).
68
+ * `reference_images`: Lista de objetos `PIL.Image` como referência visual.
69
+ * `width`, `height`: Dimensões da imagem de saída (int).
70
+ * `seed`: Semente para reprodutibilidade (int).
71
+ * `callback`: Função de callback opcional para monitorar o progresso.
72
+ * **Output:** `PIL.Image.Image` (O objeto da imagem gerada).
73
+
74
+ ### **dreamo_helpers.py (DreamOAgent)**
75
+
76
+ * **Propósito:** Especialista em geração de imagens de alta qualidade (keyframes) usando a pipeline DreamO, com capacidades avançadas de edição e estilo a partir de referências.
77
+ * **Singleton Instance:** `dreamo_agent_singleton`
78
+ * **Construtor:** `DreamOAgent(device_id: str = None)`
79
+ * Lê `configs/dreamo_config.yaml`.
80
+ * **Método Público:**
81
+ * `generate_image(prompt: str, reference_images: list[Image.Image], width: int, height: int)`
82
+ * **Inputs:**
83
+ * `prompt`: Prompt textual para guiar a geração (string).
84
+ * `reference_images`: Lista de objetos `PIL.Image` como referência visual. A lógica interna atribui a primeira imagem como `style` e as demais como `ip`.
85
+ * `width`, `height`: Dimensões da imagem de saída (int).
86
+ * **Output:** `PIL.Image.Image` (O objeto da imagem gerada).
87
+
88
+ ### **ltx_manager_helpers.py (LtxPoolManager)**
89
+
90
+ * **Propósito:** Especialista na geração de fragmentos de vídeo no espaço latente usando a pipeline LTX-Video. Gerencia um pool de workers para otimizar o uso de múltiplas GPUs.
91
+ * **Singleton Instance:** `ltx_manager_singleton`
92
+ * **Construtor:** `LtxPoolManager(device_ids: list[str], ltx_model_config_file: str, ltx_global_config_file: str)`
93
+ * Lê o `ltx_global_config_file` e o `ltx_model_config_file` para configurar a pipeline.
94
+ * **Método Público:**
95
+ * `generate_latent_fragment(**kwargs)`
96
+ * **Inputs:** Dicionário de keyword arguments (`kwargs`) contendo todos os parâmetros da pipeline LTX, incluindo:
97
+ * `height`, `width`: Dimensões do vídeo (int).
98
+ * `video_total_frames`: Número total de frames a serem gerados (int).
99
+ * `video_fps`: Frames por segundo (int).
100
+ * `motion_prompt`: Prompt de movimento (string).
101
+ * `conditioning_items_data`: Lista de objetos `LatentConditioningItem` contendo os tensores latentes de condição.
102
+ * `guidance_scale`, `stg_scale`, `num_inference_steps`, etc.
103
+ * **Output:** `tuple[torch.Tensor, tuple]` (Uma tupla contendo o tensor latente gerado e os valores de padding utilizados).
104
+
105
+ ### **mmaudio_helper.py (MMAudioAgent)**
106
+
107
+ * **Propósito:** Especialista em geração de áudio para um determinado fragmento de vídeo.
108
+ * **Singleton Instance:** `mmaudio_agent_singleton`
109
+ * **Construtor:** `MMAudioAgent(workspace_dir: str, device_id: str = None, mmaudio_config_file: str)`
110
+ * Lê `configs/mmaudio_config.yaml`.
111
+ * **Método Público:**
112
+ * `generate_audio_for_video(video_path: str, prompt: str, negative_prompt: str, duration_seconds: float)`
113
+ * **Inputs:**
114
+ * `video_path`: Caminho para o arquivo de vídeo silencioso (string).
115
+ * `prompt`: Prompt textual para guiar a geração de áudio (string).
116
+ * `negative_prompt`: Prompt negativo para áudio (string).
117
+ * `duration_seconds`: Duração exata do vídeo (float).
118
+ * **Output:** `str` (O caminho para o novo arquivo de vídeo com a faixa de áudio integrada).
119
+
120
+
121
+ ### https://huggingface.co/spaces/ByteDance-Seed/SeedVR2-3B/tree/main
122
+
123
+ ---
124
+
125
+ ## 🔗 Projetos Originais e Atribuições
126
+ (A seção de atribuições e licenças permanece a mesma que definimos anteriormente)
127
+
128
+ ### DreamO
129
+ * **Repositório Original:** [https://github.com/bytedance/DreamO](https://github.com/bytedance/DreamO)
130
+ ...
131
+
132
+ ### LTX-Video
133
+ * **Repositório Original:** [https://github.com/Lightricks/LTX-Video](https://github.com/Lightricks/LTX-Video)
134
+ ...
135
+
136
+ ### MMAudio
137
+ * **Repositório Original:** [https://github.com/hkchengrex/MMAudio](https://github.com/hkchengrex/MMAudio)
138
+ ...
common/__init__.py ADDED
File without changes
common/cache.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ from typing import Callable
16
+
17
+
18
+ class Cache:
19
+ """Caching reusable args for faster inference"""
20
+
21
+ def __init__(self, disable=False, prefix="", cache=None):
22
+ self.cache = cache if cache is not None else {}
23
+ self.disable = disable
24
+ self.prefix = prefix
25
+
26
+ def __call__(self, key: str, fn: Callable):
27
+ if self.disable:
28
+ return fn()
29
+
30
+ key = self.prefix + key
31
+ try:
32
+ result = self.cache[key]
33
+ except KeyError:
34
+ result = fn()
35
+ self.cache[key] = result
36
+ return result
37
+
38
+ def namespace(self, namespace: str):
39
+ return Cache(
40
+ disable=self.disable,
41
+ prefix=self.prefix + namespace + ".",
42
+ cache=self.cache,
43
+ )
44
+
45
+ def get(self, key: str):
46
+ key = self.prefix + key
47
+ return self.cache[key]
common/config.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Configuration utility functions
17
+ """
18
+
19
+ import importlib
20
+ from typing import Any, Callable, List, Union
21
+ from omegaconf import DictConfig, ListConfig, OmegaConf
22
+
23
+ OmegaConf.register_new_resolver("eval", eval)
24
+
25
+
26
+ def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
27
+ """
28
+ Load a configuration. Will resolve inheritance.
29
+ """
30
+ config = OmegaConf.load(path)
31
+ if argv is not None:
32
+ config_argv = OmegaConf.from_dotlist(argv)
33
+ config = OmegaConf.merge(config, config_argv)
34
+ config = resolve_recursive(config, resolve_inheritance)
35
+ return config
36
+
37
+
38
+ def resolve_recursive(
39
+ config: Any,
40
+ resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
41
+ ) -> Any:
42
+ config = resolver(config)
43
+ if isinstance(config, DictConfig):
44
+ for k in config.keys():
45
+ v = config.get(k)
46
+ if isinstance(v, (DictConfig, ListConfig)):
47
+ config[k] = resolve_recursive(v, resolver)
48
+ if isinstance(config, ListConfig):
49
+ for i in range(len(config)):
50
+ v = config.get(i)
51
+ if isinstance(v, (DictConfig, ListConfig)):
52
+ config[i] = resolve_recursive(v, resolver)
53
+ return config
54
+
55
+
56
+ def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
57
+ """
58
+ Recursively resolve inheritance if the config contains:
59
+ __inherit__: path/to/parent.yaml or a ListConfig of such paths.
60
+ """
61
+ if isinstance(config, DictConfig):
62
+ inherit = config.pop("__inherit__", None)
63
+
64
+ if inherit:
65
+ inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
66
+
67
+ parent_config = None
68
+ for parent_path in inherit_list:
69
+ assert isinstance(parent_path, str)
70
+ parent_config = (
71
+ load_config(parent_path)
72
+ if parent_config is None
73
+ else OmegaConf.merge(parent_config, load_config(parent_path))
74
+ )
75
+
76
+ if len(config.keys()) > 0:
77
+ config = OmegaConf.merge(parent_config, config)
78
+ else:
79
+ config = parent_config
80
+ return config
81
+
82
+
83
+ def import_item(path: str, name: str) -> Any:
84
+ """
85
+ Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
86
+ """
87
+ return getattr(importlib.import_module(path), name)
88
+
89
+
90
+ def create_object(config: DictConfig) -> Any:
91
+ """
92
+ Create an object from config.
93
+ The config is expected to contains the following:
94
+ __object__:
95
+ path: path.to.module
96
+ name: MyClass
97
+ args: as_config | as_params (default to as_config)
98
+ """
99
+ item = import_item(
100
+ path=config.__object__.path,
101
+ name=config.__object__.name,
102
+ )
103
+ args = config.__object__.get("args", "as_config")
104
+ if args == "as_config":
105
+ return item(config)
106
+ if args == "as_params":
107
+ config = OmegaConf.to_object(config)
108
+ config.pop("__object__")
109
+ return item(**config)
110
+ raise NotImplementedError(f"Unknown args type: {args}")
common/decorators.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Decorators.
17
+ """
18
+
19
+ import functools
20
+ import threading
21
+ import time
22
+ from typing import Callable
23
+ import torch
24
+
25
+ from common.distributed import barrier_if_distributed, get_global_rank, get_local_rank
26
+ from common.logger import get_logger
27
+
28
+ logger = get_logger(__name__)
29
+
30
+
31
+ def log_on_entry(func: Callable) -> Callable:
32
+ """
33
+ Functions with this decorator will log the function name at entry.
34
+ When using multiple decorators, this must be applied innermost to properly capture the name.
35
+ """
36
+
37
+ def log_on_entry_wrapper(*args, **kwargs):
38
+ logger.info(f"Entering {func.__name__}")
39
+ return func(*args, **kwargs)
40
+
41
+ return log_on_entry_wrapper
42
+
43
+
44
+ def barrier_on_entry(func: Callable) -> Callable:
45
+ """
46
+ Functions with this decorator will start executing when all ranks are ready to enter.
47
+ """
48
+
49
+ def barrier_on_entry_wrapper(*args, **kwargs):
50
+ barrier_if_distributed()
51
+ return func(*args, **kwargs)
52
+
53
+ return barrier_on_entry_wrapper
54
+
55
+
56
+ def _conditional_execute_wrapper_factory(execute: bool, func: Callable) -> Callable:
57
+ """
58
+ Helper function for local_rank_zero_only and global_rank_zero_only.
59
+ """
60
+
61
+ def conditional_execute_wrapper(*args, **kwargs):
62
+ # Only execute if needed.
63
+ result = func(*args, **kwargs) if execute else None
64
+ # All GPUs must wait.
65
+ barrier_if_distributed()
66
+ # Return results.
67
+ return result
68
+
69
+ return conditional_execute_wrapper
70
+
71
+
72
+ def _asserted_wrapper_factory(condition: bool, func: Callable, err_msg: str = "") -> Callable:
73
+ """
74
+ Helper function for some functions with special constraints,
75
+ especially functions called by other global_rank_zero_only / local_rank_zero_only ones,
76
+ in case they are wrongly invoked in other scenarios.
77
+ """
78
+
79
+ def asserted_execute_wrapper(*args, **kwargs):
80
+ assert condition, err_msg
81
+ result = func(*args, **kwargs)
82
+ return result
83
+
84
+ return asserted_execute_wrapper
85
+
86
+
87
+ def local_rank_zero_only(func: Callable) -> Callable:
88
+ """
89
+ Functions with this decorator will only execute on local rank zero.
90
+ """
91
+ return _conditional_execute_wrapper_factory(get_local_rank() == 0, func)
92
+
93
+
94
+ def global_rank_zero_only(func: Callable) -> Callable:
95
+ """
96
+ Functions with this decorator will only execute on global rank zero.
97
+ """
98
+ return _conditional_execute_wrapper_factory(get_global_rank() == 0, func)
99
+
100
+
101
+ def assert_only_global_rank_zero(func: Callable) -> Callable:
102
+ """
103
+ Functions with this decorator are only accessible to processes with global rank zero.
104
+ """
105
+ return _asserted_wrapper_factory(
106
+ get_global_rank() == 0, func, err_msg="Not accessible to processes with global_rank != 0"
107
+ )
108
+
109
+
110
+ def assert_only_local_rank_zero(func: Callable) -> Callable:
111
+ """
112
+ Functions with this decorator are only accessible to processes with local rank zero.
113
+ """
114
+ return _asserted_wrapper_factory(
115
+ get_local_rank() == 0, func, err_msg="Not accessible to processes with local_rank != 0"
116
+ )
117
+
118
+
119
+ def new_thread(func: Callable) -> Callable:
120
+ """
121
+ Functions with this decorator will run in a new thread.
122
+ The function will return the thread, which can be joined to wait for completion.
123
+ """
124
+
125
+ def new_thread_wrapper(*args, **kwargs):
126
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs)
127
+ thread.start()
128
+ return thread
129
+
130
+ return new_thread_wrapper
131
+
132
+
133
+ def log_runtime(func: Callable) -> Callable:
134
+ """
135
+ Functions with this decorator will logging the runtime.
136
+ """
137
+
138
+ @functools.wraps(func)
139
+ def wrapped(*args, **kwargs):
140
+ torch.distributed.barrier()
141
+ start = time.perf_counter()
142
+ result = func(*args, **kwargs)
143
+ torch.distributed.barrier()
144
+ logger.info(f"Completed {func.__name__} in {time.perf_counter() - start:.3f} seconds.")
145
+ return result
146
+
147
+ return wrapped
common/diffusion/__init__.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Diffusion package.
17
+ """
18
+
19
+ from .config import (
20
+ create_sampler_from_config,
21
+ create_sampling_timesteps_from_config,
22
+ create_schedule_from_config,
23
+ )
24
+ from .samplers.base import Sampler
25
+ from .samplers.euler import EulerSampler
26
+ from .schedules.base import Schedule
27
+ from .schedules.lerp import LinearInterpolationSchedule
28
+ from .timesteps.base import SamplingTimesteps, Timesteps
29
+ from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps
30
+ from .types import PredictionType, SamplingDirection
31
+ from .utils import classifier_free_guidance, classifier_free_guidance_dispatcher, expand_dims
32
+
33
+ __all__ = [
34
+ # Configs
35
+ "create_sampler_from_config",
36
+ "create_sampling_timesteps_from_config",
37
+ "create_schedule_from_config",
38
+ # Schedules
39
+ "Schedule",
40
+ "DiscreteVariancePreservingSchedule",
41
+ "LinearInterpolationSchedule",
42
+ # Samplers
43
+ "Sampler",
44
+ "EulerSampler",
45
+ # Timesteps
46
+ "Timesteps",
47
+ "SamplingTimesteps",
48
+ # Types
49
+ "PredictionType",
50
+ "SamplingDirection",
51
+ "UniformTrailingSamplingTimesteps",
52
+ # Utils
53
+ "classifier_free_guidance",
54
+ "classifier_free_guidance_dispatcher",
55
+ "expand_dims",
56
+ ]
common/diffusion/config.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Utility functions for creating schedules and samplers from config.
17
+ """
18
+
19
+ import torch
20
+ from omegaconf import DictConfig
21
+
22
+ from .samplers.base import Sampler
23
+ from .samplers.euler import EulerSampler
24
+ from .schedules.base import Schedule
25
+ from .schedules.lerp import LinearInterpolationSchedule
26
+ from .timesteps.base import SamplingTimesteps
27
+ from .timesteps.sampling.trailing import UniformTrailingSamplingTimesteps
28
+
29
+
30
+ def create_schedule_from_config(
31
+ config: DictConfig,
32
+ device: torch.device,
33
+ dtype: torch.dtype = torch.float32,
34
+ ) -> Schedule:
35
+ """
36
+ Create a schedule from configuration.
37
+ """
38
+ if config.type == "lerp":
39
+ return LinearInterpolationSchedule(T=config.get("T", 1.0))
40
+
41
+ raise NotImplementedError
42
+
43
+
44
+ def create_sampler_from_config(
45
+ config: DictConfig,
46
+ schedule: Schedule,
47
+ timesteps: SamplingTimesteps,
48
+ ) -> Sampler:
49
+ """
50
+ Create a sampler from configuration.
51
+ """
52
+ if config.type == "euler":
53
+ return EulerSampler(
54
+ schedule=schedule,
55
+ timesteps=timesteps,
56
+ prediction_type=config.prediction_type,
57
+ )
58
+ raise NotImplementedError
59
+
60
+
61
+ def create_sampling_timesteps_from_config(
62
+ config: DictConfig,
63
+ schedule: Schedule,
64
+ device: torch.device,
65
+ dtype: torch.dtype = torch.float32,
66
+ ) -> SamplingTimesteps:
67
+ if config.type == "uniform_trailing":
68
+ return UniformTrailingSamplingTimesteps(
69
+ T=schedule.T,
70
+ steps=config.steps,
71
+ shift=config.get("shift", 1.0),
72
+ device=device,
73
+ )
74
+ raise NotImplementedError
common/diffusion/samplers/base.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Sampler base class.
17
+ """
18
+
19
+ from abc import ABC, abstractmethod
20
+ from dataclasses import dataclass
21
+ from typing import Callable
22
+ import torch
23
+ from tqdm import tqdm
24
+
25
+ from ..schedules.base import Schedule
26
+ from ..timesteps.base import SamplingTimesteps
27
+ from ..types import PredictionType, SamplingDirection
28
+ from ..utils import assert_schedule_timesteps_compatible
29
+
30
+
31
+ @dataclass
32
+ class SamplerModelArgs:
33
+ x_t: torch.Tensor
34
+ t: torch.Tensor
35
+ i: int
36
+
37
+
38
+ class Sampler(ABC):
39
+ """
40
+ Samplers are ODE/SDE solvers.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ schedule: Schedule,
46
+ timesteps: SamplingTimesteps,
47
+ prediction_type: PredictionType,
48
+ return_endpoint: bool = True,
49
+ ):
50
+ assert_schedule_timesteps_compatible(
51
+ schedule=schedule,
52
+ timesteps=timesteps,
53
+ )
54
+ self.schedule = schedule
55
+ self.timesteps = timesteps
56
+ self.prediction_type = prediction_type
57
+ self.return_endpoint = return_endpoint
58
+
59
+ @abstractmethod
60
+ def sample(
61
+ self,
62
+ x: torch.Tensor,
63
+ f: Callable[[SamplerModelArgs], torch.Tensor],
64
+ ) -> torch.Tensor:
65
+ """
66
+ Generate a new sample given the the intial sample x and score function f.
67
+ """
68
+
69
+ def get_next_timestep(
70
+ self,
71
+ t: torch.Tensor,
72
+ ) -> torch.Tensor:
73
+ """
74
+ Get the next sample timestep.
75
+ Support multiple different timesteps t in a batch.
76
+ If no more steps, return out of bound value -1 or T+1.
77
+ """
78
+ T = self.timesteps.T
79
+ steps = len(self.timesteps)
80
+ curr_idx = self.timesteps.index(t)
81
+ next_idx = curr_idx + 1
82
+ bound = -1 if self.timesteps.direction == SamplingDirection.backward else T + 1
83
+
84
+ s = self.timesteps[next_idx.clamp_max(steps - 1)]
85
+ s = s.where(next_idx < steps, bound)
86
+ return s
87
+
88
+ def get_endpoint(
89
+ self,
90
+ pred: torch.Tensor,
91
+ x_t: torch.Tensor,
92
+ t: torch.Tensor,
93
+ ) -> torch.Tensor:
94
+ """
95
+ Get to the endpoint of the probability flow.
96
+ """
97
+ x_0, x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
98
+ return x_0 if self.timesteps.direction == SamplingDirection.backward else x_T
99
+
100
+ def get_progress_bar(self):
101
+ """
102
+ Get progress bar for sampling.
103
+ """
104
+ return tqdm(
105
+ iterable=range(len(self.timesteps) - (0 if self.return_endpoint else 1)),
106
+ dynamic_ncols=True,
107
+ desc=self.__class__.__name__,
108
+ )
common/diffusion/samplers/euler.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+
16
+ """
17
+ Euler ODE solver.
18
+ """
19
+
20
+ from typing import Callable
21
+ import torch
22
+ from einops import rearrange
23
+ from torch.nn import functional as F
24
+
25
+ from models.dit_v2 import na
26
+
27
+ from ..types import PredictionType
28
+ from ..utils import expand_dims
29
+ from .base import Sampler, SamplerModelArgs
30
+
31
+
32
+ class EulerSampler(Sampler):
33
+ """
34
+ The Euler method is the simplest ODE solver.
35
+ <https://en.wikipedia.org/wiki/Euler_method>
36
+ """
37
+
38
+ def sample(
39
+ self,
40
+ x: torch.Tensor,
41
+ f: Callable[[SamplerModelArgs], torch.Tensor],
42
+ ) -> torch.Tensor:
43
+ timesteps = self.timesteps.timesteps
44
+ progress = self.get_progress_bar()
45
+ i = 0
46
+ for t, s in zip(timesteps[:-1], timesteps[1:]):
47
+ pred = f(SamplerModelArgs(x, t, i))
48
+ x = self.step_to(pred, x, t, s)
49
+ i += 1
50
+ progress.update()
51
+
52
+ if self.return_endpoint:
53
+ t = timesteps[-1]
54
+ pred = f(SamplerModelArgs(x, t, i))
55
+ x = self.get_endpoint(pred, x, t)
56
+ progress.update()
57
+ return x
58
+
59
+ def step(
60
+ self,
61
+ pred: torch.Tensor,
62
+ x_t: torch.Tensor,
63
+ t: torch.Tensor,
64
+ ) -> torch.Tensor:
65
+ """
66
+ Step to the next timestep.
67
+ """
68
+ return self.step_to(pred, x_t, t, self.get_next_timestep(t))
69
+
70
+ def step_to(
71
+ self,
72
+ pred: torch.Tensor,
73
+ x_t: torch.Tensor,
74
+ t: torch.Tensor,
75
+ s: torch.Tensor,
76
+ ) -> torch.Tensor:
77
+ """
78
+ Steps from x_t at timestep t to x_s at timestep s. Returns x_s.
79
+ """
80
+ t = expand_dims(t, x_t.ndim)
81
+ s = expand_dims(s, x_t.ndim)
82
+ T = self.schedule.T
83
+ # Step from x_t to x_s.
84
+ pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t)
85
+ pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T))
86
+ # Clamp x_s to x_0 and x_T if s is out of bound.
87
+ pred_x_s = pred_x_s.where(s >= 0, pred_x_0)
88
+ pred_x_s = pred_x_s.where(s <= T, pred_x_T)
89
+ return pred_x_s
common/diffusion/schedules/base.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Schedule base class.
17
+ """
18
+
19
+ from abc import ABC, abstractmethod, abstractproperty
20
+ from typing import Tuple, Union
21
+ import torch
22
+
23
+ from ..types import PredictionType
24
+ from ..utils import expand_dims
25
+
26
+
27
+ class Schedule(ABC):
28
+ """
29
+ Diffusion schedules are uniquely defined by T, A, B:
30
+
31
+ x_t = A(t) * x_0 + B(t) * x_T, where t in [0, T]
32
+
33
+ Schedules can be continuous or discrete.
34
+ """
35
+
36
+ @abstractproperty
37
+ def T(self) -> Union[int, float]:
38
+ """
39
+ Maximum timestep inclusive.
40
+ Schedule is continuous if float, discrete if int.
41
+ """
42
+
43
+ @abstractmethod
44
+ def A(self, t: torch.Tensor) -> torch.Tensor:
45
+ """
46
+ Interpolation coefficient A.
47
+ Returns tensor with the same shape as t.
48
+ """
49
+
50
+ @abstractmethod
51
+ def B(self, t: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Interpolation coefficient B.
54
+ Returns tensor with the same shape as t.
55
+ """
56
+
57
+ # ----------------------------------------------------
58
+
59
+ def snr(self, t: torch.Tensor) -> torch.Tensor:
60
+ """
61
+ Signal to noise ratio.
62
+ Returns tensor with the same shape as t.
63
+ """
64
+ return (self.A(t) ** 2) / (self.B(t) ** 2)
65
+
66
+ def isnr(self, snr: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Inverse signal to noise ratio.
69
+ Returns tensor with the same shape as snr.
70
+ Subclass may implement.
71
+ """
72
+ raise NotImplementedError
73
+
74
+ # ----------------------------------------------------
75
+
76
+ def is_continuous(self) -> bool:
77
+ """
78
+ Whether the schedule is continuous.
79
+ """
80
+ return isinstance(self.T, float)
81
+
82
+ def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
83
+ """
84
+ Diffusion forward function.
85
+ """
86
+ t = expand_dims(t, x_0.ndim)
87
+ return self.A(t) * x_0 + self.B(t) * x_T
88
+
89
+ def convert_from_pred(
90
+ self, pred: torch.Tensor, pred_type: PredictionType, x_t: torch.Tensor, t: torch.Tensor
91
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ Convert from prediction. Return predicted x_0 and x_T.
94
+ """
95
+ t = expand_dims(t, x_t.ndim)
96
+ A_t = self.A(t)
97
+ B_t = self.B(t)
98
+
99
+ if pred_type == PredictionType.x_T:
100
+ pred_x_T = pred
101
+ pred_x_0 = (x_t - B_t * pred_x_T) / A_t
102
+ elif pred_type == PredictionType.x_0:
103
+ pred_x_0 = pred
104
+ pred_x_T = (x_t - A_t * pred_x_0) / B_t
105
+ elif pred_type == PredictionType.v_cos:
106
+ pred_x_0 = A_t * x_t - B_t * pred
107
+ pred_x_T = A_t * pred + B_t * x_t
108
+ elif pred_type == PredictionType.v_lerp:
109
+ pred_x_0 = (x_t - B_t * pred) / (A_t + B_t)
110
+ pred_x_T = (x_t + A_t * pred) / (A_t + B_t)
111
+ else:
112
+ raise NotImplementedError
113
+
114
+ return pred_x_0, pred_x_T
115
+
116
+ def convert_to_pred(
117
+ self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: PredictionType
118
+ ) -> torch.FloatTensor:
119
+ """
120
+ Convert to prediction target given x_0 and x_T.
121
+ """
122
+ if pred_type == PredictionType.x_T:
123
+ return x_T
124
+ if pred_type == PredictionType.x_0:
125
+ return x_0
126
+ if pred_type == PredictionType.v_cos:
127
+ t = expand_dims(t, x_0.ndim)
128
+ return self.A(t) * x_T - self.B(t) * x_0
129
+ if pred_type == PredictionType.v_lerp:
130
+ return x_T - x_0
131
+ raise NotImplementedError
common/diffusion/schedules/lerp.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Linear interpolation schedule (lerp).
17
+ """
18
+
19
+ from typing import Union
20
+ import torch
21
+
22
+ from .base import Schedule
23
+
24
+
25
+ class LinearInterpolationSchedule(Schedule):
26
+ """
27
+ Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow.
28
+ It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3.
29
+ <https://arxiv.org/abs/2209.03003>
30
+ <https://arxiv.org/abs/2210.02747>
31
+
32
+ x_t = (1 - t) * x_0 + t * x_T
33
+
34
+ Can be either continuous or discrete.
35
+ """
36
+
37
+ def __init__(self, T: Union[int, float] = 1.0):
38
+ self._T = T
39
+
40
+ @property
41
+ def T(self) -> Union[int, float]:
42
+ return self._T
43
+
44
+ def A(self, t: torch.Tensor) -> torch.Tensor:
45
+ return 1 - (t / self.T)
46
+
47
+ def B(self, t: torch.Tensor) -> torch.Tensor:
48
+ return t / self.T
49
+
50
+ # ----------------------------------------------------
51
+
52
+ def isnr(self, snr: torch.Tensor) -> torch.Tensor:
53
+ t = self.T / (1 + snr**0.5)
54
+ t = t if self.is_continuous() else t.round().int()
55
+ return t
common/diffusion/timesteps/base.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Sequence, Union
3
+ import torch
4
+
5
+ from ..types import SamplingDirection
6
+
7
+
8
+ class Timesteps(ABC):
9
+ """
10
+ Timesteps base class.
11
+ """
12
+
13
+ def __init__(self, T: Union[int, float]):
14
+ assert T > 0
15
+ self._T = T
16
+
17
+ @property
18
+ def T(self) -> Union[int, float]:
19
+ """
20
+ Maximum timestep inclusive.
21
+ int if discrete, float if continuous.
22
+ """
23
+ return self._T
24
+
25
+ def is_continuous(self) -> bool:
26
+ """
27
+ Whether the schedule is continuous.
28
+ """
29
+ return isinstance(self.T, float)
30
+
31
+
32
+ class SamplingTimesteps(Timesteps):
33
+ """
34
+ Sampling timesteps.
35
+ It defines the discretization of sampling steps.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ T: Union[int, float],
41
+ timesteps: torch.Tensor,
42
+ direction: SamplingDirection,
43
+ ):
44
+ assert timesteps.ndim == 1
45
+ super().__init__(T)
46
+ self.timesteps = timesteps
47
+ self.direction = direction
48
+
49
+ def __len__(self) -> int:
50
+ """
51
+ Number of sampling steps.
52
+ """
53
+ return len(self.timesteps)
54
+
55
+ def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor:
56
+ """
57
+ The timestep at the sampling step.
58
+ Returns a scalar tensor if idx is int,
59
+ or tensor of the same size if idx is a tensor.
60
+ """
61
+ return self.timesteps[idx]
62
+
63
+ def index(self, t: torch.Tensor) -> torch.Tensor:
64
+ """
65
+ Find index by t.
66
+ Return index of the same shape as t.
67
+ Index is -1 if t not found in timesteps.
68
+ """
69
+ i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True)
70
+ idx = torch.full_like(t, fill_value=-1, dtype=torch.int)
71
+ idx.view(-1)[i] = j.int()
72
+ return idx
common/diffusion/timesteps/sampling/trailing.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+
17
+ from ...types import SamplingDirection
18
+ from ..base import SamplingTimesteps
19
+
20
+
21
+ class UniformTrailingSamplingTimesteps(SamplingTimesteps):
22
+ """
23
+ Uniform trailing sampling timesteps.
24
+ Defined in (https://arxiv.org/abs/2305.08891)
25
+
26
+ Shift is proposed in SD3 for RF schedule.
27
+ Defined in (https://arxiv.org/pdf/2403.03206) eq.23
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ T: int,
33
+ steps: int,
34
+ shift: float = 1.0,
35
+ device: torch.device = "cpu",
36
+ ):
37
+ # Create trailing timesteps.
38
+ timesteps = torch.arange(1.0, 0.0, -1.0 / steps, device=device)
39
+
40
+ # Shift timesteps.
41
+ timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
42
+
43
+ # Scale to T range.
44
+ if isinstance(T, float):
45
+ timesteps = timesteps * T
46
+ else:
47
+ timesteps = timesteps.mul(T + 1).sub(1).round().int()
48
+
49
+ super().__init__(T=T, timesteps=timesteps, direction=SamplingDirection.backward)
common/diffusion/types.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Type definitions.
17
+ """
18
+
19
+ from enum import Enum
20
+
21
+
22
+ class PredictionType(str, Enum):
23
+ """
24
+ x_0:
25
+ Predict data sample.
26
+ x_T:
27
+ Predict noise sample.
28
+ Proposed by DDPM (https://arxiv.org/abs/2006.11239)
29
+ Proved problematic by zsnr paper (https://arxiv.org/abs/2305.08891)
30
+ v_cos:
31
+ Predict velocity dx/dt based on the cosine schedule (A_t * x_T - B_t * x_0).
32
+ Proposed by progressive distillation (https://arxiv.org/abs/2202.00512)
33
+ v_lerp:
34
+ Predict velocity dx/dt based on the lerp schedule (x_T - x_0).
35
+ Proposed by rectified flow (https://arxiv.org/abs/2209.03003)
36
+ """
37
+
38
+ x_0 = "x_0"
39
+ x_T = "x_T"
40
+ v_cos = "v_cos"
41
+ v_lerp = "v_lerp"
42
+
43
+
44
+ class SamplingDirection(str, Enum):
45
+ """
46
+ backward: Sample from x_T to x_0 for data generation.
47
+ forward: Sample from x_0 to x_T for noise inversion.
48
+ """
49
+
50
+ backward = "backward"
51
+ forward = "forward"
52
+
53
+ @staticmethod
54
+ def reverse(direction):
55
+ if direction == SamplingDirection.backward:
56
+ return SamplingDirection.forward
57
+ if direction == SamplingDirection.forward:
58
+ return SamplingDirection.backward
59
+ raise NotImplementedError
common/diffusion/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Utility functions.
17
+ """
18
+
19
+ from typing import Callable
20
+ import torch
21
+
22
+
23
+ def expand_dims(tensor: torch.Tensor, ndim: int):
24
+ """
25
+ Expand tensor to target ndim. New dims are added to the right.
26
+ For example, if the tensor shape was (8,), target ndim is 4, return (8, 1, 1, 1).
27
+ """
28
+ shape = tensor.shape + (1,) * (ndim - tensor.ndim)
29
+ return tensor.reshape(shape)
30
+
31
+
32
+ def assert_schedule_timesteps_compatible(schedule, timesteps):
33
+ """
34
+ Check if schedule and timesteps are compatible.
35
+ """
36
+ if schedule.T != timesteps.T:
37
+ raise ValueError("Schedule and timesteps must have the same T.")
38
+ if schedule.is_continuous() != timesteps.is_continuous():
39
+ raise ValueError("Schedule and timesteps must have the same continuity.")
40
+
41
+
42
+ def classifier_free_guidance(
43
+ pos: torch.Tensor,
44
+ neg: torch.Tensor,
45
+ scale: float,
46
+ rescale: float = 0.0,
47
+ ):
48
+ """
49
+ Apply classifier-free guidance.
50
+ """
51
+ # Classifier-free guidance (https://arxiv.org/abs/2207.12598)
52
+ cfg = neg + scale * (pos - neg)
53
+
54
+ # Classifier-free guidance rescale (https://arxiv.org/pdf/2305.08891.pdf)
55
+ if rescale != 0.0:
56
+ pos_std = pos.std(dim=list(range(1, pos.ndim)), keepdim=True)
57
+ cfg_std = cfg.std(dim=list(range(1, cfg.ndim)), keepdim=True)
58
+ factor = pos_std / cfg_std
59
+ factor = rescale * factor + (1 - rescale)
60
+ cfg *= factor
61
+
62
+ return cfg
63
+
64
+
65
+ def classifier_free_guidance_dispatcher(
66
+ pos: Callable,
67
+ neg: Callable,
68
+ scale: float,
69
+ rescale: float = 0.0,
70
+ ):
71
+ """
72
+ Optionally execute models depending on classifer-free guidance scale.
73
+ """
74
+ # If scale is 1, no need to execute neg model.
75
+ if scale == 1.0:
76
+ return pos()
77
+
78
+ # Otherwise, execute both pos nad neg models and apply cfg.
79
+ return classifier_free_guidance(
80
+ pos=pos(),
81
+ neg=neg(),
82
+ scale=scale,
83
+ rescale=rescale,
84
+ )
common/distributed/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Distributed package.
17
+ """
18
+
19
+ from .basic import (
20
+ barrier_if_distributed,
21
+ convert_to_ddp,
22
+ get_device,
23
+ get_global_rank,
24
+ get_local_rank,
25
+ get_world_size,
26
+ init_torch,
27
+ )
28
+
29
+ __all__ = [
30
+ "barrier_if_distributed",
31
+ "convert_to_ddp",
32
+ "get_device",
33
+ "get_global_rank",
34
+ "get_local_rank",
35
+ "get_world_size",
36
+ "init_torch",
37
+ ]
common/distributed/advanced.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Advanced distributed functions for sequence parallel.
17
+ """
18
+
19
+ from typing import Optional, List
20
+ import torch
21
+ import torch.distributed as dist
22
+ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
23
+ from torch.distributed.fsdp import ShardingStrategy
24
+
25
+ from .basic import get_global_rank, get_world_size
26
+
27
+
28
+ _DATA_PARALLEL_GROUP = None
29
+ _SEQUENCE_PARALLEL_GROUP = None
30
+ _SEQUENCE_PARALLEL_CPU_GROUP = None
31
+ _MODEL_SHARD_CPU_INTER_GROUP = None
32
+ _MODEL_SHARD_CPU_INTRA_GROUP = None
33
+ _MODEL_SHARD_INTER_GROUP = None
34
+ _MODEL_SHARD_INTRA_GROUP = None
35
+ _SEQUENCE_PARALLEL_GLOBAL_RANKS = None
36
+
37
+
38
+ def get_data_parallel_group() -> Optional[dist.ProcessGroup]:
39
+ """
40
+ Get data parallel process group.
41
+ """
42
+ return _DATA_PARALLEL_GROUP
43
+
44
+
45
+ def get_sequence_parallel_group() -> Optional[dist.ProcessGroup]:
46
+ """
47
+ Get sequence parallel process group.
48
+ """
49
+ return _SEQUENCE_PARALLEL_GROUP
50
+
51
+
52
+ def get_sequence_parallel_cpu_group() -> Optional[dist.ProcessGroup]:
53
+ """
54
+ Get sequence parallel CPU process group.
55
+ """
56
+ return _SEQUENCE_PARALLEL_CPU_GROUP
57
+
58
+
59
+ def get_data_parallel_rank() -> int:
60
+ """
61
+ Get data parallel rank.
62
+ """
63
+ group = get_data_parallel_group()
64
+ return dist.get_rank(group) if group else get_global_rank()
65
+
66
+
67
+ def get_data_parallel_world_size() -> int:
68
+ """
69
+ Get data parallel world size.
70
+ """
71
+ group = get_data_parallel_group()
72
+ return dist.get_world_size(group) if group else get_world_size()
73
+
74
+
75
+ def get_sequence_parallel_rank() -> int:
76
+ """
77
+ Get sequence parallel rank.
78
+ """
79
+ group = get_sequence_parallel_group()
80
+ return dist.get_rank(group) if group else 0
81
+
82
+
83
+ def get_sequence_parallel_world_size() -> int:
84
+ """
85
+ Get sequence parallel world size.
86
+ """
87
+ group = get_sequence_parallel_group()
88
+ return dist.get_world_size(group) if group else 1
89
+
90
+
91
+ def get_model_shard_cpu_intra_group() -> Optional[dist.ProcessGroup]:
92
+ """
93
+ Get the CPU intra process group of model sharding.
94
+ """
95
+ return _MODEL_SHARD_CPU_INTRA_GROUP
96
+
97
+
98
+ def get_model_shard_cpu_inter_group() -> Optional[dist.ProcessGroup]:
99
+ """
100
+ Get the CPU inter process group of model sharding.
101
+ """
102
+ return _MODEL_SHARD_CPU_INTER_GROUP
103
+
104
+
105
+ def get_model_shard_intra_group() -> Optional[dist.ProcessGroup]:
106
+ """
107
+ Get the GPU intra process group of model sharding.
108
+ """
109
+ return _MODEL_SHARD_INTRA_GROUP
110
+
111
+
112
+ def get_model_shard_inter_group() -> Optional[dist.ProcessGroup]:
113
+ """
114
+ Get the GPU inter process group of model sharding.
115
+ """
116
+ return _MODEL_SHARD_INTER_GROUP
117
+
118
+
119
+ def init_sequence_parallel(sequence_parallel_size: int):
120
+ """
121
+ Initialize sequence parallel.
122
+ """
123
+ global _DATA_PARALLEL_GROUP
124
+ global _SEQUENCE_PARALLEL_GROUP
125
+ global _SEQUENCE_PARALLEL_CPU_GROUP
126
+ global _SEQUENCE_PARALLEL_GLOBAL_RANKS
127
+ assert dist.is_initialized()
128
+ world_size = dist.get_world_size()
129
+ rank = dist.get_rank()
130
+ data_parallel_size = world_size // sequence_parallel_size
131
+ for i in range(data_parallel_size):
132
+ start_rank = i * sequence_parallel_size
133
+ end_rank = (i + 1) * sequence_parallel_size
134
+ ranks = range(start_rank, end_rank)
135
+ group = dist.new_group(ranks)
136
+ cpu_group = dist.new_group(ranks, backend="gloo")
137
+ if rank in ranks:
138
+ _SEQUENCE_PARALLEL_GROUP = group
139
+ _SEQUENCE_PARALLEL_CPU_GROUP = cpu_group
140
+ _SEQUENCE_PARALLEL_GLOBAL_RANKS = list(ranks)
141
+
142
+
143
+ def init_model_shard_group(
144
+ *,
145
+ sharding_strategy: ShardingStrategy,
146
+ device_mesh: Optional[DeviceMesh] = None,
147
+ ):
148
+ """
149
+ Initialize process group of model sharding.
150
+ """
151
+ global _MODEL_SHARD_INTER_GROUP
152
+ global _MODEL_SHARD_INTRA_GROUP
153
+ global _MODEL_SHARD_CPU_INTER_GROUP
154
+ global _MODEL_SHARD_CPU_INTRA_GROUP
155
+ assert dist.is_initialized()
156
+ world_size = dist.get_world_size()
157
+ if device_mesh is not None:
158
+ num_shards_per_group = device_mesh.shape[1]
159
+ elif sharding_strategy == ShardingStrategy.NO_SHARD:
160
+ num_shards_per_group = 1
161
+ elif sharding_strategy in [
162
+ ShardingStrategy.HYBRID_SHARD,
163
+ ShardingStrategy._HYBRID_SHARD_ZERO2,
164
+ ]:
165
+ num_shards_per_group = torch.cuda.device_count()
166
+ else:
167
+ num_shards_per_group = world_size
168
+ num_groups = world_size // num_shards_per_group
169
+ device_mesh = (num_groups, num_shards_per_group)
170
+
171
+ gpu_mesh_2d = init_device_mesh("cuda", device_mesh, mesh_dim_names=("inter", "intra"))
172
+ cpu_mesh_2d = init_device_mesh("cpu", device_mesh, mesh_dim_names=("inter", "intra"))
173
+
174
+ _MODEL_SHARD_INTER_GROUP = gpu_mesh_2d.get_group("inter")
175
+ _MODEL_SHARD_INTRA_GROUP = gpu_mesh_2d.get_group("intra")
176
+ _MODEL_SHARD_CPU_INTER_GROUP = cpu_mesh_2d.get_group("inter")
177
+ _MODEL_SHARD_CPU_INTRA_GROUP = cpu_mesh_2d.get_group("intra")
178
+
179
+ def get_sequence_parallel_global_ranks() -> List[int]:
180
+ """
181
+ Get all global ranks of the sequence parallel process group
182
+ that the caller rank belongs to.
183
+ """
184
+ if _SEQUENCE_PARALLEL_GLOBAL_RANKS is None:
185
+ return [dist.get_rank()]
186
+ return _SEQUENCE_PARALLEL_GLOBAL_RANKS
187
+
188
+
189
+ def get_next_sequence_parallel_rank() -> int:
190
+ """
191
+ Get the next global rank of the sequence parallel process group
192
+ that the caller rank belongs to.
193
+ """
194
+ sp_global_ranks = get_sequence_parallel_global_ranks()
195
+ sp_rank = get_sequence_parallel_rank()
196
+ sp_size = get_sequence_parallel_world_size()
197
+ return sp_global_ranks[(sp_rank + 1) % sp_size]
198
+
199
+
200
+ def get_prev_sequence_parallel_rank() -> int:
201
+ """
202
+ Get the previous global rank of the sequence parallel process group
203
+ that the caller rank belongs to.
204
+ """
205
+ sp_global_ranks = get_sequence_parallel_global_ranks()
206
+ sp_rank = get_sequence_parallel_rank()
207
+ sp_size = get_sequence_parallel_world_size()
208
+ return sp_global_ranks[(sp_rank + sp_size - 1) % sp_size]
common/distributed/basic.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Distributed basic functions.
17
+ """
18
+
19
+ import os
20
+ from datetime import timedelta
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch.nn.parallel import DistributedDataParallel
24
+
25
+
26
+ def get_global_rank() -> int:
27
+ """
28
+ Get the global rank, the global index of the GPU.
29
+ """
30
+ return int(os.environ.get("RANK", "0"))
31
+
32
+
33
+ def get_local_rank() -> int:
34
+ """
35
+ Get the local rank, the local index of the GPU.
36
+ """
37
+ return int(os.environ.get("LOCAL_RANK", "0"))
38
+
39
+
40
+ def get_world_size() -> int:
41
+ """
42
+ Get the world size, the total amount of GPUs.
43
+ """
44
+ return int(os.environ.get("WORLD_SIZE", "1"))
45
+
46
+
47
+ def get_device() -> torch.device:
48
+ """
49
+ Get current rank device.
50
+ """
51
+ return torch.device("cuda", get_local_rank())
52
+
53
+
54
+ def barrier_if_distributed(*args, **kwargs):
55
+ """
56
+ Synchronizes all processes if under distributed context.
57
+ """
58
+ if dist.is_initialized():
59
+ return dist.barrier(*args, **kwargs)
60
+
61
+
62
+ def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)):
63
+ """
64
+ Common PyTorch initialization configuration.
65
+ """
66
+ torch.backends.cuda.matmul.allow_tf32 = True
67
+ torch.backends.cudnn.allow_tf32 = True
68
+ torch.backends.cudnn.benchmark = cudnn_benchmark
69
+ torch.cuda.set_device(get_local_rank())
70
+ dist.init_process_group(
71
+ backend="nccl",
72
+ rank=get_global_rank(),
73
+ world_size=get_world_size(),
74
+ timeout=timeout,
75
+ )
76
+
77
+
78
+ def convert_to_ddp(module: torch.nn.Module, **kwargs) -> DistributedDataParallel:
79
+ return DistributedDataParallel(
80
+ module=module,
81
+ device_ids=[get_local_rank()],
82
+ output_device=get_local_rank(),
83
+ **kwargs,
84
+ )
common/distributed/meta_init_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import torch
16
+ from rotary_embedding_torch import RotaryEmbedding
17
+ from torch import nn
18
+ from torch.distributed.fsdp._common_utils import _is_fsdp_flattened
19
+
20
+ __all__ = ["meta_non_persistent_buffer_init_fn"]
21
+
22
+
23
+ def meta_non_persistent_buffer_init_fn(module: nn.Module) -> nn.Module:
24
+ """
25
+ Used for materializing `non-persistent tensor buffers` while model resuming.
26
+
27
+ Since non-persistent tensor buffers are not saved in state_dict,
28
+ when initializing model with meta device, user should materialize those buffers manually.
29
+
30
+ Currently, only `rope.dummy` is this special case.
31
+ """
32
+ with torch.no_grad():
33
+ for submodule in module.modules():
34
+ if not isinstance(submodule, RotaryEmbedding):
35
+ continue
36
+ for buffer_name, buffer in submodule.named_buffers(recurse=False):
37
+ if buffer.is_meta and "dummy" in buffer_name:
38
+ materialized_buffer = torch.zeros_like(buffer, device="cpu")
39
+ setattr(submodule, buffer_name, materialized_buffer)
40
+ assert not any(b.is_meta for n, b in module.named_buffers())
41
+ return module
common/distributed/ops.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Distributed ops for supporting sequence parallel.
17
+ """
18
+
19
+ from collections import defaultdict
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+ import torch
22
+ import torch.distributed as dist
23
+ from torch import Tensor
24
+
25
+ from common.cache import Cache
26
+ from common.distributed.advanced import (
27
+ get_sequence_parallel_group,
28
+ get_sequence_parallel_rank,
29
+ get_sequence_parallel_world_size,
30
+ )
31
+
32
+ from .basic import get_device
33
+
34
+ _SEQ_DATA_BUF = defaultdict(lambda: [None, None, None])
35
+ _SEQ_DATA_META_SHAPES = defaultdict()
36
+ _SEQ_DATA_META_DTYPES = defaultdict()
37
+ _SEQ_DATA_ASYNC_COMMS = defaultdict(list)
38
+ _SYNC_BUFFER = defaultdict(dict)
39
+
40
+
41
+ def single_all_to_all(
42
+ local_input: Tensor,
43
+ scatter_dim: int,
44
+ gather_dim: int,
45
+ group: dist.ProcessGroup,
46
+ async_op: bool = False,
47
+ ):
48
+ """
49
+ A function to do all-to-all on a tensor
50
+ """
51
+ seq_world_size = dist.get_world_size(group)
52
+ prev_scatter_dim = scatter_dim
53
+ if scatter_dim != 0:
54
+ local_input = local_input.transpose(0, scatter_dim)
55
+ if gather_dim == 0:
56
+ gather_dim = scatter_dim
57
+ scatter_dim = 0
58
+
59
+ inp_shape = list(local_input.shape)
60
+ inp_shape[scatter_dim] = inp_shape[scatter_dim] // seq_world_size
61
+ input_t = local_input.reshape(
62
+ [seq_world_size, inp_shape[scatter_dim]] + inp_shape[scatter_dim + 1 :]
63
+ ).contiguous()
64
+ output = torch.empty_like(input_t)
65
+ comm = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)
66
+ if async_op:
67
+ # let user's code transpose & reshape
68
+ return output, comm, prev_scatter_dim
69
+
70
+ # first dim is seq_world_size, so we can split it directly
71
+ output = torch.cat(output.split(1), dim=gather_dim + 1).squeeze(0)
72
+ if prev_scatter_dim:
73
+ output = output.transpose(0, prev_scatter_dim).contiguous()
74
+ return output
75
+
76
+
77
+ def _all_to_all(
78
+ local_input: Tensor,
79
+ scatter_dim: int,
80
+ gather_dim: int,
81
+ group: dist.ProcessGroup,
82
+ ):
83
+ seq_world_size = dist.get_world_size(group)
84
+ input_list = [
85
+ t.contiguous() for t in torch.tensor_split(local_input, seq_world_size, scatter_dim)
86
+ ]
87
+ output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)]
88
+ dist.all_to_all(output_list, input_list, group=group)
89
+ return torch.cat(output_list, dim=gather_dim).contiguous()
90
+
91
+
92
+ class SeqAllToAll(torch.autograd.Function):
93
+ @staticmethod
94
+ def forward(
95
+ ctx: Any,
96
+ group: dist.ProcessGroup,
97
+ local_input: Tensor,
98
+ scatter_dim: int,
99
+ gather_dim: int,
100
+ async_op: bool,
101
+ ) -> Tensor:
102
+ ctx.group = group
103
+ ctx.scatter_dim = scatter_dim
104
+ ctx.gather_dim = gather_dim
105
+ ctx.async_op = async_op
106
+ if async_op:
107
+ output, comm, prev_scatter_dim = single_all_to_all(
108
+ local_input, scatter_dim, gather_dim, group, async_op=async_op
109
+ )
110
+ ctx.prev_scatter_dim = prev_scatter_dim
111
+ return output, comm
112
+
113
+ return _all_to_all(local_input, scatter_dim, gather_dim, group)
114
+
115
+ @staticmethod
116
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:
117
+ if ctx.async_op:
118
+ input_t = torch.cat(grad_output[0].split(1), dim=ctx.gather_dim + 1).squeeze(0)
119
+ if ctx.prev_scatter_dim:
120
+ input_t = input_t.transpose(0, ctx.prev_scatter_dim)
121
+ else:
122
+ input_t = grad_output[0]
123
+ return (
124
+ None,
125
+ _all_to_all(input_t, ctx.gather_dim, ctx.scatter_dim, ctx.group),
126
+ None,
127
+ None,
128
+ None,
129
+ )
130
+
131
+
132
+ class Slice(torch.autograd.Function):
133
+ @staticmethod
134
+ def forward(ctx: Any, group: dist.ProcessGroup, local_input: Tensor, dim: int) -> Tensor:
135
+ ctx.group = group
136
+ ctx.rank = dist.get_rank(group)
137
+ seq_world_size = dist.get_world_size(group)
138
+ ctx.seq_world_size = seq_world_size
139
+ ctx.dim = dim
140
+ dim_size = local_input.shape[dim]
141
+ return local_input.split(dim_size // seq_world_size, dim=dim)[ctx.rank].contiguous()
142
+
143
+ @staticmethod
144
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor, None]:
145
+ dim_size = list(grad_output.size())
146
+ split_size = dim_size[0]
147
+ dim_size[0] = dim_size[0] * ctx.seq_world_size
148
+ output = torch.empty(dim_size, dtype=grad_output.dtype, device=torch.cuda.current_device())
149
+ dist._all_gather_base(output, grad_output, group=ctx.group)
150
+ return (None, torch.cat(output.split(split_size), dim=ctx.dim), None)
151
+
152
+
153
+ class Gather(torch.autograd.Function):
154
+ @staticmethod
155
+ def forward(
156
+ ctx: Any,
157
+ group: dist.ProcessGroup,
158
+ local_input: Tensor,
159
+ dim: int,
160
+ grad_scale: Optional[bool] = False,
161
+ ) -> Tensor:
162
+ ctx.group = group
163
+ ctx.rank = dist.get_rank(group)
164
+ ctx.dim = dim
165
+ ctx.grad_scale = grad_scale
166
+ seq_world_size = dist.get_world_size(group)
167
+ ctx.seq_world_size = seq_world_size
168
+ dim_size = list(local_input.size())
169
+ split_size = dim_size[0]
170
+ ctx.part_size = dim_size[dim]
171
+ dim_size[0] = dim_size[0] * seq_world_size
172
+ output = torch.empty(dim_size, dtype=local_input.dtype, device=torch.cuda.current_device())
173
+ dist._all_gather_base(output, local_input.contiguous(), group=ctx.group)
174
+ return torch.cat(output.split(split_size), dim=dim)
175
+
176
+ @staticmethod
177
+ def backward(ctx: Any, grad_output: Tensor) -> Tuple[None, Tensor]:
178
+ if ctx.grad_scale:
179
+ grad_output = grad_output * ctx.seq_world_size
180
+ return (
181
+ None,
182
+ grad_output.split(ctx.part_size, dim=ctx.dim)[ctx.rank].contiguous(),
183
+ None,
184
+ None,
185
+ )
186
+
187
+
188
+ def gather_seq_scatter_heads_qkv(
189
+ qkv_tensor: Tensor,
190
+ *,
191
+ seq_dim: int,
192
+ qkv_shape: Optional[Tensor] = None,
193
+ cache: Cache = Cache(disable=True),
194
+ restore_shape: bool = True,
195
+ ):
196
+ """
197
+ A func to sync splited qkv tensor
198
+ qkv_tensor: the tensor we want to do alltoall with. The last dim must
199
+ be the projection_idx, which we will split into 3 part. After
200
+ spliting, the gather idx will be projecttion_idx + 1
201
+ seq_dim: gather_dim for all2all comm
202
+ restore_shape: if True, output will has the same shape length as input
203
+ """
204
+ group = get_sequence_parallel_group()
205
+ if not group:
206
+ return qkv_tensor
207
+ world = get_sequence_parallel_world_size()
208
+ orig_shape = qkv_tensor.shape
209
+ scatter_dim = qkv_tensor.dim()
210
+ bef_all2all_shape = list(orig_shape)
211
+ qkv_proj_dim = bef_all2all_shape[-1]
212
+ bef_all2all_shape = bef_all2all_shape[:-1] + [3, qkv_proj_dim // 3]
213
+ qkv_tensor = qkv_tensor.view(bef_all2all_shape)
214
+ qkv_tensor = SeqAllToAll.apply(group, qkv_tensor, scatter_dim, seq_dim, False)
215
+ if restore_shape:
216
+ out_shape = list(orig_shape)
217
+ out_shape[seq_dim] *= world
218
+ out_shape[-1] = qkv_proj_dim // world
219
+ qkv_tensor = qkv_tensor.view(out_shape)
220
+
221
+ # remove padding
222
+ if qkv_shape is not None:
223
+ unpad_dim_size = cache(
224
+ "unpad_dim_size", lambda: torch.sum(torch.prod(qkv_shape, dim=-1)).item()
225
+ )
226
+ if unpad_dim_size % world != 0:
227
+ padding_size = qkv_tensor.size(seq_dim) - unpad_dim_size
228
+ qkv_tensor = _unpad_tensor(qkv_tensor, seq_dim, padding_size)
229
+ return qkv_tensor
230
+
231
+
232
+ def slice_inputs(x: Tensor, dim: int, padding: bool = True):
233
+ """
234
+ A func to slice the input sequence in sequence parallel
235
+ """
236
+ group = get_sequence_parallel_group()
237
+ if group is None:
238
+ return x
239
+ sp_rank = get_sequence_parallel_rank()
240
+ sp_world = get_sequence_parallel_world_size()
241
+ dim_size = x.shape[dim]
242
+ unit = (dim_size + sp_world - 1) // sp_world
243
+ if padding and dim_size % sp_world:
244
+ padding_size = sp_world - (dim_size % sp_world)
245
+ x = _pad_tensor(x, dim, padding_size)
246
+ slc = [slice(None)] * len(x.shape)
247
+ slc[dim] = slice(unit * sp_rank, unit * (sp_rank + 1))
248
+ return x[slc]
249
+
250
+
251
+ def remove_seqeunce_parallel_padding(x: Tensor, dim: int, unpad_dim_size: int):
252
+ """
253
+ A func to remove the padding part of the tensor based on its original shape
254
+ """
255
+ group = get_sequence_parallel_group()
256
+ if group is None:
257
+ return x
258
+ sp_world = get_sequence_parallel_world_size()
259
+ if unpad_dim_size % sp_world == 0:
260
+ return x
261
+ padding_size = sp_world - (unpad_dim_size % sp_world)
262
+ assert (padding_size + unpad_dim_size) % sp_world == 0
263
+ return _unpad_tensor(x, dim=dim, padding_size=padding_size)
264
+
265
+
266
+ def gather_heads_scatter_seq(x: Tensor, head_dim: int, seq_dim: int) -> Tensor:
267
+ """
268
+ A func to sync attention result with alltoall in sequence parallel
269
+ """
270
+ group = get_sequence_parallel_group()
271
+ if not group:
272
+ return x
273
+ dim_size = x.size(seq_dim)
274
+ sp_world = get_sequence_parallel_world_size()
275
+ if dim_size % sp_world != 0:
276
+ padding_size = sp_world - (dim_size % sp_world)
277
+ x = _pad_tensor(x, seq_dim, padding_size)
278
+ return SeqAllToAll.apply(group, x, seq_dim, head_dim, False)
279
+
280
+
281
+ def gather_seq_scatter_heads(x: Tensor, seq_dim: int, head_dim: int) -> Tensor:
282
+ """
283
+ A func to sync embedding input with alltoall in sequence parallel
284
+ """
285
+ group = get_sequence_parallel_group()
286
+ if not group:
287
+ return x
288
+ return SeqAllToAll.apply(group, x, head_dim, seq_dim, False)
289
+
290
+
291
+ def scatter_heads(x: Tensor, dim: int) -> Tensor:
292
+ """
293
+ A func to split heads before attention in sequence parallel
294
+ """
295
+ group = get_sequence_parallel_group()
296
+ if not group:
297
+ return x
298
+ return Slice.apply(group, x, dim)
299
+
300
+
301
+ def gather_heads(x: Tensor, dim: int, grad_scale: Optional[bool] = False) -> Tensor:
302
+ """
303
+ A func to gather heads for the attention result in sequence parallel
304
+ """
305
+ group = get_sequence_parallel_group()
306
+ if not group:
307
+ return x
308
+ return Gather.apply(group, x, dim, grad_scale)
309
+
310
+
311
+ def gather_outputs(
312
+ x: Tensor,
313
+ *,
314
+ gather_dim: int,
315
+ padding_dim: Optional[int] = None,
316
+ unpad_shape: Optional[Tensor] = None,
317
+ cache: Cache = Cache(disable=True),
318
+ scale_grad=True,
319
+ ):
320
+ """
321
+ A func to gather the outputs for the model result in sequence parallel
322
+ """
323
+ group = get_sequence_parallel_group()
324
+ if not group:
325
+ return x
326
+ x = Gather.apply(group, x, gather_dim, scale_grad)
327
+ if padding_dim is not None:
328
+ unpad_dim_size = cache(
329
+ "unpad_dim_size", lambda: torch.sum(torch.prod(unpad_shape, dim=1)).item()
330
+ )
331
+ x = remove_seqeunce_parallel_padding(x, padding_dim, unpad_dim_size)
332
+ return x
333
+
334
+
335
+ def _pad_tensor(x: Tensor, dim: int, padding_size: int):
336
+ shape = list(x.shape)
337
+ shape[dim] = padding_size
338
+ pad = torch.zeros(shape, dtype=x.dtype, device=x.device)
339
+ return torch.cat([x, pad], dim=dim)
340
+
341
+
342
+ def _unpad_tensor(x: Tensor, dim: int, padding_size):
343
+ slc = [slice(None)] * len(x.shape)
344
+ slc[dim] = slice(0, -padding_size)
345
+ return x[slc]
346
+
347
+
348
+ def _broadcast_data(data, shape, dtype, src, group, async_op):
349
+ comms = []
350
+ if isinstance(data, (list, tuple)):
351
+ for i, sub_shape in enumerate(shape):
352
+ comms += _broadcast_data(data[i], sub_shape, dtype[i], src, group, async_op)
353
+ elif isinstance(data, dict):
354
+ for key, sub_data in data.items():
355
+ comms += _broadcast_data(sub_data, shape[key], dtype[key], src, group, async_op)
356
+ elif isinstance(data, Tensor):
357
+ comms.append(dist.broadcast(data, src=src, group=group, async_op=async_op))
358
+ return comms
359
+
360
+
361
+ def _traverse(data: Any, op: Callable) -> Union[None, List, Dict, Any]:
362
+ if isinstance(data, (list, tuple)):
363
+ return [_traverse(sub_data, op) for sub_data in data]
364
+ elif isinstance(data, dict):
365
+ return {key: _traverse(sub_data, op) for key, sub_data in data.items()}
366
+ elif isinstance(data, Tensor):
367
+ return op(data)
368
+ else:
369
+ return None
370
+
371
+
372
+ def _get_shapes(data):
373
+ return _traverse(data, op=lambda x: x.shape)
374
+
375
+
376
+ def _get_dtypes(data):
377
+ return _traverse(data, op=lambda x: x.dtype)
378
+
379
+
380
+ def _construct_broadcast_buffer(shapes, dtypes, device):
381
+ if isinstance(shapes, torch.Size):
382
+ return torch.empty(shapes, dtype=dtypes, device=device)
383
+
384
+ if isinstance(shapes, (list, tuple)):
385
+ buffer = []
386
+ for i, sub_shape in enumerate(shapes):
387
+ buffer.append(_construct_broadcast_buffer(sub_shape, dtypes[i], device))
388
+ elif isinstance(shapes, dict):
389
+ buffer = {}
390
+ for key, sub_shape in shapes.items():
391
+ buffer[key] = _construct_broadcast_buffer(sub_shape, dtypes[key], device)
392
+ else:
393
+ return None
394
+ return buffer
395
+
396
+
397
+ class SPDistForward:
398
+ """A forward tool to sync different result across sp group
399
+
400
+ Args:
401
+ module: a function or module to process users input
402
+ sp_step: current training step to judge which rank to broadcast its result to all
403
+ name: a distinct str to save meta and async comm
404
+ comm_shape: if different ranks have different shape, mark this arg to True
405
+ device: the device for current rank, can be empty
406
+ """
407
+
408
+ def __init__(
409
+ self,
410
+ name: str,
411
+ comm_shape: bool,
412
+ device: torch.device = None,
413
+ ):
414
+ self.name = name
415
+ self.comm_shape = comm_shape
416
+ if device:
417
+ self.device = device
418
+ else:
419
+ self.device = get_device()
420
+
421
+ def __call__(self, inputs) -> Any:
422
+ group = get_sequence_parallel_group()
423
+ if not group:
424
+ yield inputs
425
+ else:
426
+ device = self.device
427
+ sp_world = get_sequence_parallel_world_size()
428
+ sp_rank = get_sequence_parallel_rank()
429
+ for local_step in range(sp_world):
430
+ src_rank = dist.get_global_rank(group, local_step)
431
+ is_src = sp_rank == local_step
432
+ local_shapes = []
433
+ local_dtypes = []
434
+ if local_step == 0:
435
+ local_result = inputs
436
+ _SEQ_DATA_BUF[self.name][-1] = local_result
437
+ local_shapes = _get_shapes(local_result)
438
+ local_dtypes = _get_dtypes(local_result)
439
+ if self.comm_shape:
440
+ group_shapes_lists = [None] * sp_world
441
+ dist.all_gather_object(group_shapes_lists, local_shapes, group=group)
442
+ _SEQ_DATA_META_SHAPES[self.name] = group_shapes_lists
443
+ else:
444
+ _SEQ_DATA_META_SHAPES[self.name] = [local_shapes] * sp_world
445
+ _SEQ_DATA_META_DTYPES[self.name] = local_dtypes
446
+ shapes = _SEQ_DATA_META_SHAPES[self.name][local_step]
447
+ dtypes = _SEQ_DATA_META_DTYPES[self.name]
448
+ buf_id = local_step % 2
449
+ if local_step == 0:
450
+ sync_data = (
451
+ local_result
452
+ if is_src
453
+ else _construct_broadcast_buffer(shapes, dtypes, device)
454
+ )
455
+ _broadcast_data(sync_data, shapes, dtypes, src_rank, group, False)
456
+ _SEQ_DATA_BUF[self.name][buf_id] = sync_data
457
+
458
+ # wait for async comm ops
459
+ if _SEQ_DATA_ASYNC_COMMS[self.name]:
460
+ for comm in _SEQ_DATA_ASYNC_COMMS[self.name]:
461
+ comm.wait()
462
+ # before return the sync result, do async broadcast for next batch
463
+ if local_step < sp_world - 1:
464
+ next_buf_id = 1 - buf_id
465
+ shapes = _SEQ_DATA_META_SHAPES[self.name][local_step + 1]
466
+ src_rank = dist.get_global_rank(group, local_step + 1)
467
+ is_src = sp_rank == local_step + 1
468
+ next_sync_data = (
469
+ _SEQ_DATA_BUF[self.name][-1]
470
+ if is_src
471
+ else _construct_broadcast_buffer(shapes, dtypes, device)
472
+ )
473
+ _SEQ_DATA_ASYNC_COMMS[self.name] = _broadcast_data(
474
+ next_sync_data, shapes, dtypes, src_rank, group, True
475
+ )
476
+ _SEQ_DATA_BUF[self.name][next_buf_id] = next_sync_data
477
+ yield _SEQ_DATA_BUF[self.name][buf_id]
478
+
479
+
480
+ sync_inputs = SPDistForward(name="bef_fwd", comm_shape=True)
481
+
482
+
483
+ def sync_data(data, sp_idx, name="tmp"):
484
+ group = get_sequence_parallel_group()
485
+ if group is None:
486
+ return data
487
+ # if sp_idx in _SYNC_BUFFER[name]:
488
+ # return _SYNC_BUFFER[name][sp_idx]
489
+ sp_rank = get_sequence_parallel_rank()
490
+ src_rank = dist.get_global_rank(group, sp_idx)
491
+ objects = [data] if sp_rank == sp_idx else [None]
492
+ dist.broadcast_object_list(objects, src=src_rank, group=group)
493
+ # _SYNC_BUFFER[name] = {sp_idx: objects[0]}
494
+ return objects[0]
common/logger.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Logging utility functions.
17
+ """
18
+
19
+ import logging
20
+ import sys
21
+ from typing import Optional
22
+
23
+ from common.distributed import get_global_rank, get_local_rank, get_world_size
24
+
25
+ _default_handler = logging.StreamHandler(sys.stdout)
26
+ _default_handler.setFormatter(
27
+ logging.Formatter(
28
+ "%(asctime)s "
29
+ + (f"[Rank:{get_global_rank()}]" if get_world_size() > 1 else "")
30
+ + (f"[LocalRank:{get_local_rank()}]" if get_world_size() > 1 else "")
31
+ + "[%(threadName).12s][%(name)s][%(levelname).5s] "
32
+ + "%(message)s"
33
+ )
34
+ )
35
+
36
+
37
+ def get_logger(name: Optional[str] = None) -> logging.Logger:
38
+ """
39
+ Get a logger.
40
+ """
41
+ logger = logging.getLogger(name)
42
+ logger.addHandler(_default_handler)
43
+ logger.setLevel(logging.INFO)
44
+ return logger
common/partition.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ """
16
+ Partition utility functions.
17
+ """
18
+
19
+ from typing import Any, List
20
+
21
+
22
+ def partition_by_size(data: List[Any], size: int) -> List[List[Any]]:
23
+ """
24
+ Partition a list by size.
25
+ When indivisible, the last group contains fewer items than the target size.
26
+
27
+ Examples:
28
+ - data: [1,2,3,4,5]
29
+ - size: 2
30
+ - return: [[1,2], [3,4], [5]]
31
+ """
32
+ assert size > 0
33
+ return [data[i : (i + size)] for i in range(0, len(data), size)]
34
+
35
+
36
+ def partition_by_groups(data: List[Any], groups: int) -> List[List[Any]]:
37
+ """
38
+ Partition a list by groups.
39
+ When indivisible, some groups may have more items than others.
40
+
41
+ Examples:
42
+ - data: [1,2,3,4,5]
43
+ - groups: 2
44
+ - return: [[1,3,5], [2,4]]
45
+ """
46
+ assert groups > 0
47
+ return [data[i::groups] for i in range(groups)]
48
+
49
+
50
+ def shift_list(data: List[Any], n: int) -> List[Any]:
51
+ """
52
+ Rotate a list by n elements.
53
+
54
+ Examples:
55
+ - data: [1,2,3,4,5]
56
+ - n: 3
57
+ - return: [4,5,1,2,3]
58
+ """
59
+ return data[(n % len(data)) :] + data[: (n % len(data))]
common/seed.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # // Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # //
3
+ # // Licensed under the Apache License, Version 2.0 (the "License");
4
+ # // you may not use this file except in compliance with the License.
5
+ # // You may obtain a copy of the License at
6
+ # //
7
+ # // http://www.apache.org/licenses/LICENSE-2.0
8
+ # //
9
+ # // Unless required by applicable law or agreed to in writing, software
10
+ # // distributed under the License is distributed on an "AS IS" BASIS,
11
+ # // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # // See the License for the specific language governing permissions and
13
+ # // limitations under the License.
14
+
15
+ import random
16
+ from typing import Optional
17
+ import numpy as np
18
+ import torch
19
+
20
+ from common.distributed import get_global_rank
21
+
22
+
23
+ def set_seed(seed: Optional[int], same_across_ranks: bool = False):
24
+ """Function that sets the seed for pseudo-random number generators."""
25
+ if seed is not None:
26
+ seed += get_global_rank() if not same_across_ranks else 0
27
+ random.seed(seed)
28
+ np.random.seed(seed)
29
+ torch.manual_seed(seed)
30
+
common/utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import re
17
+
18
+ import cv2
19
+ import numpy as np
20
+ import torch
21
+ from torchvision.utils import make_grid
22
+
23
+
24
+ # from basicsr
25
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
26
+ """Numpy array to tensor.
27
+
28
+ Args:
29
+ imgs (list[ndarray] | ndarray): Input images.
30
+ bgr2rgb (bool): Whether to change bgr to rgb.
31
+ float32 (bool): Whether to change to float32.
32
+
33
+ Returns:
34
+ list[tensor] | tensor: Tensor images. If returned results only have
35
+ one element, just return tensor.
36
+ """
37
+
38
+ def _totensor(img, bgr2rgb, float32):
39
+ if img.shape[2] == 3 and bgr2rgb:
40
+ if img.dtype == 'float64':
41
+ img = img.astype('float32')
42
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
43
+ img = torch.from_numpy(img.transpose(2, 0, 1))
44
+ if float32:
45
+ img = img.float()
46
+ return img
47
+
48
+ if isinstance(imgs, list):
49
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
50
+ return _totensor(imgs, bgr2rgb, float32)
51
+
52
+
53
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
54
+ """Convert torch Tensors into image numpy arrays.
55
+
56
+ After clamping to [min, max], values will be normalized to [0, 1].
57
+
58
+ Args:
59
+ tensor (Tensor or list[Tensor]): Accept shapes:
60
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
61
+ 2) 3D Tensor of shape (3/1 x H x W);
62
+ 3) 2D Tensor of shape (H x W).
63
+ Tensor channel should be in RGB order.
64
+ rgb2bgr (bool): Whether to change rgb to bgr.
65
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
66
+ to uint8 type with range [0, 255]; otherwise, float type with
67
+ range [0, 1]. Default: ``np.uint8``.
68
+ min_max (tuple[int]): min and max values for clamp.
69
+
70
+ Returns:
71
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
72
+ shape (H x W). The channel order is BGR.
73
+ """
74
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
75
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
76
+
77
+ if torch.is_tensor(tensor):
78
+ tensor = [tensor]
79
+ result = []
80
+ for _tensor in tensor:
81
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
82
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
83
+
84
+ n_dim = _tensor.dim()
85
+ if n_dim == 4:
86
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
87
+ img_np = img_np.transpose(1, 2, 0)
88
+ if rgb2bgr:
89
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
90
+ elif n_dim == 3:
91
+ img_np = _tensor.numpy()
92
+ img_np = img_np.transpose(1, 2, 0)
93
+ if img_np.shape[2] == 1: # gray image
94
+ img_np = np.squeeze(img_np, axis=2)
95
+ else:
96
+ if rgb2bgr:
97
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
98
+ elif n_dim == 2:
99
+ img_np = _tensor.numpy()
100
+ else:
101
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
102
+ if out_type == np.uint8:
103
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
104
+ img_np = (img_np * 255.0).round()
105
+ img_np = img_np.astype(out_type)
106
+ result.append(img_np)
107
+ if len(result) == 1:
108
+ result = result[0]
109
+ return result
110
+
111
+
112
+ def resize_numpy_image_area(image, area=512 * 512):
113
+ h, w = image.shape[:2]
114
+ k = math.sqrt(area / (h * w))
115
+ h = int(h * k) - (int(h * k) % 16)
116
+ w = int(w * k) - (int(w * k) % 16)
117
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
118
+ return image
119
+
120
+ def resize_numpy_image_long(image, long_edge=768):
121
+ h, w = image.shape[:2]
122
+ if max(h, w) <= long_edge:
123
+ return image
124
+ k = long_edge / max(h, w)
125
+ h = int(h * k)
126
+ w = int(w * k)
127
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
128
+ return image
129
+
130
+
131
+ # reference: https://github.com/huggingface/diffusers/pull/9295/files
132
+ def convert_flux_lora_to_diffusers(old_state_dict):
133
+ new_state_dict = {}
134
+ orig_keys = list(old_state_dict.keys())
135
+
136
+ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
137
+ down_weight = sds_sd.pop(sds_key)
138
+ up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight"))
139
+
140
+ # calculate dims if not provided
141
+ num_splits = len(ait_keys)
142
+ if dims is None:
143
+ dims = [up_weight.shape[0] // num_splits] * num_splits
144
+ else:
145
+ assert sum(dims) == up_weight.shape[0]
146
+
147
+ # make ai-toolkit weight
148
+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
149
+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
150
+
151
+ # down_weight is copied to each split
152
+ ait_sd.update({k: down_weight for k in ait_down_keys})
153
+
154
+ # up_weight is split to each split
155
+ ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
156
+
157
+ for old_key in orig_keys:
158
+ # Handle double_blocks
159
+ if 'double_blocks' in old_key:
160
+ block_num = re.search(r"double_blocks_(\d+)", old_key).group(1)
161
+ new_key = f"transformer.transformer_blocks.{block_num}"
162
+
163
+ if "proj_lora1" in old_key:
164
+ new_key += ".attn.to_out.0"
165
+ elif "proj_lora2" in old_key:
166
+ new_key += ".attn.to_add_out"
167
+ elif "qkv_lora2" in old_key and "up" not in old_key:
168
+ handle_qkv(
169
+ old_state_dict,
170
+ new_state_dict,
171
+ old_key,
172
+ [
173
+ f"transformer.transformer_blocks.{block_num}.attn.add_q_proj",
174
+ f"transformer.transformer_blocks.{block_num}.attn.add_k_proj",
175
+ f"transformer.transformer_blocks.{block_num}.attn.add_v_proj",
176
+ ],
177
+ )
178
+ # continue
179
+ elif "qkv_lora1" in old_key and "up" not in old_key:
180
+ handle_qkv(
181
+ old_state_dict,
182
+ new_state_dict,
183
+ old_key,
184
+ [
185
+ f"transformer.transformer_blocks.{block_num}.attn.to_q",
186
+ f"transformer.transformer_blocks.{block_num}.attn.to_k",
187
+ f"transformer.transformer_blocks.{block_num}.attn.to_v",
188
+ ],
189
+ )
190
+ # continue
191
+
192
+ if "down" in old_key:
193
+ new_key += ".lora_A.weight"
194
+ elif "up" in old_key:
195
+ new_key += ".lora_B.weight"
196
+
197
+ # Handle single_blocks
198
+ elif 'single_blocks' in old_key:
199
+ block_num = re.search(r"single_blocks_(\d+)", old_key).group(1)
200
+ new_key = f"transformer.single_transformer_blocks.{block_num}"
201
+
202
+ if "proj_lora" in old_key:
203
+ new_key += ".proj_out"
204
+ elif "qkv_lora" in old_key and "up" not in old_key:
205
+ handle_qkv(
206
+ old_state_dict,
207
+ new_state_dict,
208
+ old_key,
209
+ [
210
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_q",
211
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_k",
212
+ f"transformer.single_transformer_blocks.{block_num}.attn.to_v",
213
+ ],
214
+ )
215
+
216
+ if "down" in old_key:
217
+ new_key += ".lora_A.weight"
218
+ elif "up" in old_key:
219
+ new_key += ".lora_B.weight"
220
+
221
+ else:
222
+ # Handle other potential key patterns here
223
+ new_key = old_key
224
+
225
+ # Since we already handle qkv above.
226
+ if "qkv" not in old_key and 'embedding' not in old_key:
227
+ new_state_dict[new_key] = old_state_dict.pop(old_key)
228
+
229
+ # if len(old_state_dict) > 0:
230
+ # raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.")
231
+
232
+ return new_state_dict