File size: 6,792 Bytes
a3a2e41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import sys
import logging
import torch
from tqdm import tqdm
from omegaconf import OmegaConf
from ovi.utils.io_utils import save_video
from ovi.utils.processing_utils import format_prompt_for_filename, validate_and_process_user_prompt
from ovi.utils.utils import get_arguments
from ovi.distributed_comms.util import get_world_size, get_local_rank, get_global_rank
from ovi.distributed_comms.parallel_states import initialize_sequence_parallel_state, get_sequence_parallel_state, nccl_info
from ovi.ovi_fusion_engine import OviFusionEngine
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def main(config, args):
world_size = get_world_size()
global_rank = get_global_rank()
local_rank = get_local_rank()
device = local_rank
torch.cuda.set_device(local_rank)
sp_size = config.get("sp_size", 1)
assert sp_size <= world_size and world_size % sp_size == 0, "sp_size must be less than or equal to world_size and world_size must be divisible by sp_size."
_init_logging(global_rank)
if world_size > 1:
torch.distributed.init_process_group(
backend="nccl",
init_method="env://",
rank=global_rank,
world_size=world_size)
else:
assert sp_size == 1, f"When world_size is 1, sp_size must also be 1, but got {sp_size}."
## TODO: assert not sharding t5 etc...
initialize_sequence_parallel_state(sp_size)
logging.info(f"Using SP: {get_sequence_parallel_state()}, SP_SIZE: {sp_size}")
args.local_rank = local_rank
args.device = device
target_dtype = torch.bfloat16
# validate inputs before loading model to not waste time if input is not valid
text_prompt = config.get("text_prompt")
image_path = config.get("image_path", None)
assert config.get("mode") in ["t2v", "i2v", "t2i2v"], f"Invalid mode {config.get('mode')}, must be one of ['t2v', 'i2v', 't2i2v']"
text_prompts, image_paths = validate_and_process_user_prompt(text_prompt, image_path, mode=config.get("mode"))
if config.get("mode") != "i2v":
logging.info(f"mode: {config.get('mode')}, setting all image_paths to None")
image_paths = [None] * len(text_prompts)
else:
assert all(p is not None and os.path.isfile(p) for p in image_paths), f"In i2v mode, all image paths must be provided.{image_paths}"
logging.info("Loading OVI Fusion Engine...")
ovi_engine = OviFusionEngine(config=config, device=device, target_dtype=target_dtype)
logging.info("OVI Fusion Engine loaded!")
output_dir = config.get("output_dir", "./outputs")
os.makedirs(output_dir, exist_ok=True)
# Load CSV data
all_eval_data = list(zip(text_prompts, image_paths))
# Get SP configuration
use_sp = get_sequence_parallel_state()
if use_sp:
sp_size = nccl_info.sp_size
sp_rank = nccl_info.rank_within_group
sp_group_id = global_rank // sp_size
num_sp_groups = world_size // sp_size
else:
# No SP: treat each GPU as its own group
sp_size = 1
sp_rank = 0
sp_group_id = global_rank
num_sp_groups = world_size
# Data distribution - by SP groups
total_files = len(all_eval_data)
require_sample_padding = False
if total_files == 0:
logging.error(f"ERROR: No evaluation files found")
this_rank_eval_data = []
else:
# Pad to match number of SP groups
remainder = total_files % num_sp_groups
if require_sample_padding and remainder != 0:
pad_count = num_sp_groups - remainder
all_eval_data += [all_eval_data[0]] * pad_count
# Distribute across SP groups
this_rank_eval_data = all_eval_data[sp_group_id :: num_sp_groups]
for _, (text_prompt, image_path) in tqdm(enumerate(this_rank_eval_data)):
video_frame_height_width = config.get("video_frame_height_width", None)
seed = config.get("seed", 100)
solver_name = config.get("solver_name", "unipc")
sample_steps = config.get("sample_steps", 50)
shift = config.get("shift", 5.0)
video_guidance_scale = config.get("video_guidance_scale", 4.0)
audio_guidance_scale = config.get("audio_guidance_scale", 3.0)
slg_layer = config.get("slg_layer", 11)
video_negative_prompt = config.get("video_negative_prompt", "")
audio_negative_prompt = config.get("audio_negative_prompt", "")
for idx in range(config.get("each_example_n_times", 1)):
generated_video, generated_audio, generated_image = ovi_engine.generate(text_prompt=text_prompt,
image_path=image_path,
video_frame_height_width=video_frame_height_width,
seed=seed+idx,
solver_name=solver_name,
sample_steps=sample_steps,
shift=shift,
video_guidance_scale=video_guidance_scale,
audio_guidance_scale=audio_guidance_scale,
slg_layer=slg_layer,
video_negative_prompt=video_negative_prompt,
audio_negative_prompt=audio_negative_prompt)
if sp_rank == 0:
formatted_prompt = format_prompt_for_filename(text_prompt)
output_path = os.path.join(output_dir, f"{formatted_prompt}_{'x'.join(map(str, video_frame_height_width))}_{seed+idx}_{global_rank}.mp4")
save_video(output_path, generated_video, generated_audio, fps=24, sample_rate=16000)
if generated_image is not None:
generated_image.save(output_path.replace('.mp4', '.png'))
if __name__ == "__main__":
args = get_arguments()
config = OmegaConf.load(args.config_file)
main(config=config,args=args) |