Commit
·
e7be476
1
Parent(s):
0bcab0b
log current step
Browse files
cosmos_transfer1/diffusion/diffusion/modules/res_sampler.py
CHANGED
|
@@ -30,6 +30,7 @@ import torch
|
|
| 30 |
|
| 31 |
from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
|
| 32 |
from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
|
|
|
|
| 33 |
from cosmos_transfer1.utils.ddp_config import make_freezable
|
| 34 |
|
| 35 |
COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"]
|
|
@@ -250,6 +251,7 @@ def differential_equation_solver(
|
|
| 250 |
def step_fn(
|
| 251 |
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
|
| 252 |
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
|
|
|
| 253 |
input_x_B_StateShape, x0_preds = state
|
| 254 |
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
|
| 255 |
|
|
|
|
| 30 |
|
| 31 |
from cosmos_transfer1.diffusion.diffusion.functional.multi_step import get_multi_step_fn, is_multi_step_fn_supported
|
| 32 |
from cosmos_transfer1.diffusion.diffusion.functional.runge_kutta import get_runge_kutta_fn, is_runge_kutta_fn_supported
|
| 33 |
+
from cosmos_transfer1.utils import log
|
| 34 |
from cosmos_transfer1.utils.ddp_config import make_freezable
|
| 35 |
|
| 36 |
COMMON_SOLVER_OPTIONS = Literal["2ab", "2mid", "1euler"]
|
|
|
|
| 251 |
def step_fn(
|
| 252 |
i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
|
| 253 |
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
| 254 |
+
log.info(f"Step [{i_th}/{num_step}]")
|
| 255 |
input_x_B_StateShape, x0_preds = state
|
| 256 |
sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
|
| 257 |
|