|
|
import os
|
|
|
import sys
|
|
|
import uuid
|
|
|
import cv2
|
|
|
import glob
|
|
|
import torch
|
|
|
import logging
|
|
|
from textwrap import indent
|
|
|
import torch.nn as nn
|
|
|
from diffusers import FluxPipeline
|
|
|
from tqdm import tqdm
|
|
|
from ovi.distributed_comms.parallel_states import get_sequence_parallel_state, nccl_info
|
|
|
from ovi.utils.model_loading_utils import init_fusion_score_model_ovi, init_text_model, init_mmaudio_vae, init_wan_vae_2_2, load_fusion_checkpoint
|
|
|
from ovi.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
|
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
|
from ovi.utils.fm_solvers import (FlowDPMSolverMultistepScheduler,
|
|
|
get_sampling_sigmas, retrieve_timesteps)
|
|
|
import traceback
|
|
|
from omegaconf import OmegaConf
|
|
|
from ovi.utils.processing_utils import clean_text, preprocess_image_tensor, snap_hw_to_multiple_of_32, scale_hw_to_area_divisible
|
|
|
|
|
|
DEFAULT_CONFIG = OmegaConf.load('ovi/configs/inference/inference_fusion.yaml')
|
|
|
|
|
|
class OviFusionEngine:
|
|
|
def __init__(self, config=DEFAULT_CONFIG, device=0, target_dtype=torch.bfloat16):
|
|
|
|
|
|
self.device = device
|
|
|
self.target_dtype = target_dtype
|
|
|
meta_init = True
|
|
|
self.cpu_offload = config.get("cpu_offload", False) or config.get("mode") == "t2i2v"
|
|
|
if self.cpu_offload:
|
|
|
logging.info("CPU offloading is enabled. Initializing all models aside from VAEs on CPU")
|
|
|
|
|
|
model, video_config, audio_config = init_fusion_score_model_ovi(rank=device, meta_init=meta_init)
|
|
|
|
|
|
if not meta_init:
|
|
|
model = model.to(dtype=target_dtype).to(device=device if not self.cpu_offload else "cpu").eval()
|
|
|
|
|
|
|
|
|
vae_model_video = init_wan_vae_2_2(config.ckpt_dir, rank=device)
|
|
|
vae_model_video.model.requires_grad_(False).eval()
|
|
|
vae_model_video.model = vae_model_video.model.bfloat16()
|
|
|
self.vae_model_video = vae_model_video
|
|
|
|
|
|
vae_model_audio = init_mmaudio_vae(config.ckpt_dir, rank=device)
|
|
|
vae_model_audio.requires_grad_(False).eval()
|
|
|
self.vae_model_audio = vae_model_audio.bfloat16()
|
|
|
|
|
|
|
|
|
self.text_model = init_text_model(config.ckpt_dir, rank=device)
|
|
|
if config.get("shard_text_model", False):
|
|
|
raise NotImplementedError("Sharding text model is not implemented yet.")
|
|
|
if self.cpu_offload:
|
|
|
self.offload_to_cpu(self.text_model.model)
|
|
|
|
|
|
|
|
|
checkpoint_path = os.path.join(config.ckpt_dir, "Ovi", "model.safetensors")
|
|
|
|
|
|
if not os.path.exists(checkpoint_path):
|
|
|
raise RuntimeError(f"No fusion checkpoint found in {config.ckpt_dir}")
|
|
|
|
|
|
|
|
|
load_fusion_checkpoint(model, checkpoint_path=checkpoint_path, from_meta=meta_init)
|
|
|
|
|
|
if meta_init:
|
|
|
model = model.to(dtype=target_dtype).to(device=device if not self.cpu_offload else "cpu").eval()
|
|
|
model.set_rope_params()
|
|
|
self.model = model
|
|
|
|
|
|
|
|
|
self.image_model = None
|
|
|
if config.get("mode") == "t2i2v":
|
|
|
logging.info(f"Loading Flux Krea for first frame generation...")
|
|
|
self.image_model = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-Krea-dev", torch_dtype=torch.bfloat16)
|
|
|
self.image_model.enable_model_cpu_offload(gpu_id=self.device)
|
|
|
|
|
|
|
|
|
self.audio_latent_channel = audio_config.get("in_dim")
|
|
|
self.video_latent_channel = video_config.get("in_dim")
|
|
|
self.audio_latent_length = 157
|
|
|
self.video_latent_length = 31
|
|
|
|
|
|
logging.info(f"OVI Fusion Engine initialized, cpu_offload={self.cpu_offload}. GPU VRAM allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB, reserved: {torch.cuda.memory_reserved(device)/1e9:.2f} GB")
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
def generate(self,
|
|
|
text_prompt,
|
|
|
image_path=None,
|
|
|
video_frame_height_width=None,
|
|
|
seed=100,
|
|
|
solver_name="unipc",
|
|
|
sample_steps=50,
|
|
|
shift=5.0,
|
|
|
video_guidance_scale=5.0,
|
|
|
audio_guidance_scale=4.0,
|
|
|
slg_layer=9,
|
|
|
video_negative_prompt="",
|
|
|
audio_negative_prompt=""
|
|
|
):
|
|
|
|
|
|
params = {
|
|
|
"Text Prompt": text_prompt,
|
|
|
"Image Path": image_path if image_path else "None (T2V mode)",
|
|
|
"Frame Height Width": video_frame_height_width,
|
|
|
"Seed": seed,
|
|
|
"Solver": solver_name,
|
|
|
"Sample Steps": sample_steps,
|
|
|
"Shift": shift,
|
|
|
"Video Guidance Scale": video_guidance_scale,
|
|
|
"Audio Guidance Scale": audio_guidance_scale,
|
|
|
"SLG Layer": slg_layer,
|
|
|
"Video Negative Prompt": video_negative_prompt,
|
|
|
"Audio Negative Prompt": audio_negative_prompt,
|
|
|
}
|
|
|
|
|
|
pretty = "\n".join(f"{k:>24}: {v}" for k, v in params.items())
|
|
|
logging.info("\n========== Generation Parameters ==========\n"
|
|
|
f"{pretty}\n"
|
|
|
"==========================================")
|
|
|
try:
|
|
|
scheduler_video, timesteps_video = self.get_scheduler_time_steps(
|
|
|
sampling_steps=sample_steps,
|
|
|
device=self.device,
|
|
|
solver_name=solver_name,
|
|
|
shift=shift
|
|
|
)
|
|
|
scheduler_audio, timesteps_audio = self.get_scheduler_time_steps(
|
|
|
sampling_steps=sample_steps,
|
|
|
device=self.device,
|
|
|
solver_name=solver_name,
|
|
|
shift=shift
|
|
|
)
|
|
|
|
|
|
is_t2v = image_path is None
|
|
|
is_i2v = not is_t2v
|
|
|
|
|
|
first_frame = None
|
|
|
image = None
|
|
|
if is_i2v and not self.image_model:
|
|
|
|
|
|
first_frame = preprocess_image_tensor(image_path, self.device, self.target_dtype)
|
|
|
else:
|
|
|
assert video_frame_height_width is not None, f"If mode=t2v or t2i2v, video_frame_height_width must be provided."
|
|
|
video_h, video_w = video_frame_height_width
|
|
|
video_h, video_w = snap_hw_to_multiple_of_32(video_h, video_w, area = 720 * 720)
|
|
|
video_latent_h, video_latent_w = video_h // 16, video_w // 16
|
|
|
if self.image_model is not None:
|
|
|
|
|
|
image_h, image_w = scale_hw_to_area_divisible(video_h, video_w, area = 1024 * 1024)
|
|
|
image = self.image_model(
|
|
|
clean_text(text_prompt),
|
|
|
height=image_h,
|
|
|
width=image_w,
|
|
|
guidance_scale=4.5,
|
|
|
generator=torch.Generator().manual_seed(seed)
|
|
|
).images[0]
|
|
|
first_frame = preprocess_image_tensor(image, self.device, self.target_dtype)
|
|
|
is_i2v = True
|
|
|
else:
|
|
|
print(f"Pure T2V mode: calculated video latent size: {video_latent_h} x {video_latent_w}")
|
|
|
|
|
|
|
|
|
if self.cpu_offload:
|
|
|
self.text_model.model = self.text_model.model.to(self.device)
|
|
|
text_embeddings = self.text_model([text_prompt, video_negative_prompt, audio_negative_prompt], self.text_model.device)
|
|
|
text_embeddings = [emb.to(self.target_dtype).to(self.device) for emb in text_embeddings]
|
|
|
|
|
|
if self.cpu_offload:
|
|
|
self.offload_to_cpu(self.text_model.model)
|
|
|
|
|
|
|
|
|
text_embeddings_audio_pos = text_embeddings[0]
|
|
|
text_embeddings_video_pos = text_embeddings[0]
|
|
|
|
|
|
text_embeddings_video_neg = text_embeddings[1]
|
|
|
text_embeddings_audio_neg = text_embeddings[2]
|
|
|
|
|
|
if is_i2v:
|
|
|
with torch.no_grad():
|
|
|
latents_images = self.vae_model_video.wrapped_encode(first_frame[:, :, None]).to(self.target_dtype).squeeze(0)
|
|
|
latents_images = latents_images.to(self.target_dtype)
|
|
|
video_latent_h, video_latent_w = latents_images.shape[2], latents_images.shape[3]
|
|
|
|
|
|
video_noise = torch.randn((self.video_latent_channel, self.video_latent_length, video_latent_h, video_latent_w), device=self.device, dtype=self.target_dtype, generator=torch.Generator(device=self.device).manual_seed(seed))
|
|
|
audio_noise = torch.randn((self.audio_latent_length, self.audio_latent_channel), device=self.device, dtype=self.target_dtype, generator=torch.Generator(device=self.device).manual_seed(seed))
|
|
|
|
|
|
|
|
|
max_seq_len_audio = audio_noise.shape[0]
|
|
|
_patch_size_h, _patch_size_w = self.model.video_model.patch_size[1], self.model.video_model.patch_size[2]
|
|
|
max_seq_len_video = video_noise.shape[1] * video_noise.shape[2] * video_noise.shape[3] // (_patch_size_h*_patch_size_w)
|
|
|
|
|
|
|
|
|
if self.cpu_offload:
|
|
|
self.model = self.model.to(self.device)
|
|
|
with torch.amp.autocast('cuda', enabled=self.target_dtype != torch.float32, dtype=self.target_dtype):
|
|
|
for i, (t_v, t_a) in tqdm(enumerate(zip(timesteps_video, timesteps_audio))):
|
|
|
timestep_input = torch.full((1,), t_v, device=self.device)
|
|
|
|
|
|
if is_i2v:
|
|
|
video_noise[:, :1] = latents_images
|
|
|
|
|
|
|
|
|
pos_forward_args = {
|
|
|
'audio_context': [text_embeddings_audio_pos],
|
|
|
'vid_context': [text_embeddings_video_pos],
|
|
|
'vid_seq_len': max_seq_len_video,
|
|
|
'audio_seq_len': max_seq_len_audio,
|
|
|
'first_frame_is_clean': is_i2v
|
|
|
}
|
|
|
|
|
|
pred_vid_pos, pred_audio_pos = self.model(
|
|
|
vid=[video_noise],
|
|
|
audio=[audio_noise],
|
|
|
t=timestep_input,
|
|
|
**pos_forward_args
|
|
|
)
|
|
|
|
|
|
|
|
|
neg_forward_args = {
|
|
|
'audio_context': [text_embeddings_audio_neg],
|
|
|
'vid_context': [text_embeddings_video_neg],
|
|
|
'vid_seq_len': max_seq_len_video,
|
|
|
'audio_seq_len': max_seq_len_audio,
|
|
|
'first_frame_is_clean': is_i2v,
|
|
|
'slg_layer': slg_layer
|
|
|
}
|
|
|
|
|
|
pred_vid_neg, pred_audio_neg = self.model(
|
|
|
vid=[video_noise],
|
|
|
audio=[audio_noise],
|
|
|
t=timestep_input,
|
|
|
**neg_forward_args
|
|
|
)
|
|
|
|
|
|
|
|
|
pred_video_guided = pred_vid_neg[0] + video_guidance_scale * (pred_vid_pos[0] - pred_vid_neg[0])
|
|
|
pred_audio_guided = pred_audio_neg[0] + audio_guidance_scale * (pred_audio_pos[0] - pred_audio_neg[0])
|
|
|
|
|
|
|
|
|
video_noise = scheduler_video.step(
|
|
|
pred_video_guided.unsqueeze(0), t_v, video_noise.unsqueeze(0), return_dict=False
|
|
|
)[0].squeeze(0)
|
|
|
|
|
|
audio_noise = scheduler_audio.step(
|
|
|
pred_audio_guided.unsqueeze(0), t_a, audio_noise.unsqueeze(0), return_dict=False
|
|
|
)[0].squeeze(0)
|
|
|
|
|
|
if self.cpu_offload:
|
|
|
self.offload_to_cpu(self.model)
|
|
|
|
|
|
if is_i2v:
|
|
|
video_noise[:, :1] = latents_images
|
|
|
|
|
|
|
|
|
audio_latents_for_vae = audio_noise.unsqueeze(0).transpose(1, 2)
|
|
|
generated_audio = self.vae_model_audio.wrapped_decode(audio_latents_for_vae)
|
|
|
generated_audio = generated_audio.squeeze().cpu().float().numpy()
|
|
|
|
|
|
|
|
|
video_latents_for_vae = video_noise.unsqueeze(0)
|
|
|
generated_video = self.vae_model_video.wrapped_decode(video_latents_for_vae)
|
|
|
generated_video = generated_video.squeeze(0).cpu().float().numpy()
|
|
|
|
|
|
return generated_video, generated_audio, image
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
logging.error(traceback.format_exc())
|
|
|
return None
|
|
|
|
|
|
def offload_to_cpu(self, model):
|
|
|
model = model.cpu()
|
|
|
torch.cuda.synchronize()
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
return model
|
|
|
|
|
|
def get_scheduler_time_steps(self, sampling_steps, solver_name='unipc', device=0, shift=5.0):
|
|
|
torch.manual_seed(4)
|
|
|
|
|
|
if solver_name == 'unipc':
|
|
|
sample_scheduler = FlowUniPCMultistepScheduler(
|
|
|
num_train_timesteps=1000,
|
|
|
shift=1,
|
|
|
use_dynamic_shifting=False)
|
|
|
sample_scheduler.set_timesteps(
|
|
|
sampling_steps, device=device, shift=shift)
|
|
|
timesteps = sample_scheduler.timesteps
|
|
|
|
|
|
elif solver_name == 'dpm++':
|
|
|
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
|
|
num_train_timesteps=1000,
|
|
|
shift=1,
|
|
|
use_dynamic_shifting=False)
|
|
|
sampling_sigmas = get_sampling_sigmas(sampling_steps, shift=shift)
|
|
|
timesteps, _ = retrieve_timesteps(
|
|
|
sample_scheduler,
|
|
|
device=device,
|
|
|
sigmas=sampling_sigmas)
|
|
|
|
|
|
elif solver_name == 'euler':
|
|
|
sample_scheduler = FlowMatchEulerDiscreteScheduler(
|
|
|
shift=shift
|
|
|
)
|
|
|
timesteps, sampling_steps = retrieve_timesteps(
|
|
|
sample_scheduler,
|
|
|
sampling_steps,
|
|
|
device=device,
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
raise NotImplementedError("Unsupported solver.")
|
|
|
|
|
|
return sample_scheduler, timesteps
|
|
|
|