Spaces:
Runtime error
Runtime error
| from typing import Any, Callable, Dict, Optional, List | |
| import torch | |
| import torch.nn as nn | |
| from .gaussian_diffusion import GaussianDiffusion | |
| from .k_diffusion import karras_sample, karras_sample_addition_condition | |
| DEFAULT_KARRAS_STEPS = 64 | |
| DEFAULT_KARRAS_SIGMA_MIN = 1e-3 | |
| DEFAULT_KARRAS_SIGMA_MAX = 160 | |
| DEFAULT_KARRAS_S_CHURN = 0.0 | |
| def uncond_guide_model( | |
| model: Callable[..., torch.Tensor], scale: float | |
| ) -> Callable[..., torch.Tensor]: | |
| def model_fn(x_t, ts, **kwargs): | |
| half = x_t[: len(x_t) // 2] | |
| combined = torch.cat([half, half], dim=0) | |
| model_out = model(combined, ts, **kwargs) | |
| cond_out, uncond_out = torch.chunk(model_out, 2, dim=0) | |
| cond_out = uncond_out + scale * (cond_out - uncond_out) | |
| return torch.cat([cond_out, cond_out], dim=0) | |
| return model_fn | |
| def sample_latents( | |
| *, | |
| batch_size: int, | |
| model: nn.Module, | |
| diffusion: GaussianDiffusion, | |
| model_kwargs: Dict[str, Any], | |
| guidance_scale: float, | |
| clip_denoised: bool, | |
| use_fp16: bool, | |
| use_karras: bool, | |
| karras_steps: int, | |
| sigma_min: float, | |
| sigma_max: float, | |
| s_churn: float, | |
| device: Optional[torch.device] = None, | |
| progress: bool = False, | |
| initial_noise: Optional[torch.Tensor] = None, | |
| ) -> (torch.Tensor, List[torch.Tensor]): | |
| sample_shape = (batch_size, model.d_latent) | |
| if device is None: | |
| device = next(model.parameters()).device | |
| if hasattr(model, "cached_model_kwargs"): | |
| model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) | |
| if guidance_scale != 1.0 and guidance_scale != 0.0: | |
| for k, v in model_kwargs.copy().items(): | |
| # print(k, v.shape) | |
| model_kwargs[k] = torch.cat([v, torch.zeros_like(v)], dim=0) | |
| sample_shape = (batch_size, model.d_latent) | |
| with torch.autocast(device_type=device.type, enabled=use_fp16): | |
| if use_karras: | |
| samples, sample_sequence = karras_sample( | |
| diffusion=diffusion, | |
| model=model, | |
| shape=sample_shape, | |
| steps=karras_steps, | |
| clip_denoised=clip_denoised, | |
| model_kwargs=model_kwargs, | |
| device=device, | |
| sigma_min=sigma_min, | |
| sigma_max=sigma_max, | |
| s_churn=s_churn, | |
| guidance_scale=guidance_scale, | |
| progress=progress, | |
| initial_noise=initial_noise, | |
| ) | |
| else: | |
| internal_batch_size = batch_size | |
| if guidance_scale != 1.0: | |
| model = uncond_guide_model(model, guidance_scale) | |
| internal_batch_size *= 2 | |
| samples = diffusion.p_sample_loop( | |
| model, | |
| shape=(internal_batch_size, *sample_shape[1:]), | |
| model_kwargs=model_kwargs, | |
| device=device, | |
| clip_denoised=clip_denoised, | |
| progress=progress, | |
| ) | |
| return samples | |
| def sample_latents_with_additional_latent( | |
| *, | |
| batch_size: int, | |
| model: nn.Module, | |
| diffusion: GaussianDiffusion, | |
| model_kwargs: Dict[str, Any], | |
| text_guidance_scale: float, | |
| image_guidance_scale: float, | |
| clip_denoised: bool, | |
| use_fp16: bool, | |
| use_karras: bool, | |
| karras_steps: int, | |
| sigma_min: float, | |
| sigma_max: float, | |
| s_churn: float, | |
| device: Optional[torch.device] = None, | |
| progress: bool = False, | |
| condition_latent: Optional[torch.Tensor] = None, | |
| ) -> (torch.Tensor, List[torch.Tensor]): | |
| if device is None: | |
| device = next(model.parameters()).device | |
| if hasattr(model, "cached_model_kwargs"): | |
| model_kwargs = model.cached_model_kwargs(batch_size, model_kwargs) | |
| if (text_guidance_scale != 1.0 and text_guidance_scale != 0.0) or (image_guidance_scale != 1.0 and image_guidance_scale != 0.0): | |
| for k, v in model_kwargs.copy().items(): | |
| # print(k, v.shape) | |
| model_kwargs[k] = torch.cat([v, torch.zeros_like(v), torch.zeros_like(v)], dim=0) | |
| condition_latent = torch.cat([condition_latent, condition_latent, torch.zeros_like(condition_latent)], dim=0) | |
| sample_shape = (batch_size, model.d_latent) | |
| # print("sample_shape", sample_shape) | |
| with torch.autocast(device_type=device.type, enabled=use_fp16): | |
| if use_karras: | |
| samples, samples_squence = karras_sample_addition_condition( | |
| diffusion=diffusion, | |
| model=model, | |
| shape=sample_shape, | |
| steps=karras_steps, | |
| clip_denoised=clip_denoised, | |
| model_kwargs=model_kwargs, | |
| device=device, | |
| sigma_min=sigma_min, | |
| sigma_max=sigma_max, | |
| s_churn=s_churn, | |
| text_guidance_scale=text_guidance_scale, | |
| image_guidance_scale=image_guidance_scale, | |
| progress=progress, | |
| condition_latent=condition_latent, | |
| ) | |
| else: | |
| internal_batch_size = batch_size | |
| if text_guidance_scale != 1.0: | |
| model = uncond_guide_model(model, text_guidance_scale) | |
| internal_batch_size *= 2 | |
| samples = diffusion.p_sample_loop( | |
| model, | |
| shape=(internal_batch_size, *sample_shape[1:]), | |
| model_kwargs=model_kwargs, | |
| device=device, | |
| clip_denoised=clip_denoised, | |
| progress=progress, | |
| ) | |
| return samples |