Spaces:
Build error
Build error
| import torch | |
| import numpy as np | |
| from safetensors import safe_open | |
| def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32): | |
| """ | |
| See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 | |
| Args: | |
| timesteps (`torch.Tensor`): | |
| generate embedding vectors at these timesteps | |
| embedding_dim (`int`, *optional*, defaults to 512): | |
| dimension of the embeddings to generate | |
| dtype: | |
| data type of the generated embeddings | |
| Returns: | |
| `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` | |
| """ | |
| assert len(w.shape) == 1 | |
| w = w * 1000.0 | |
| half_dim = embedding_dim // 2 | |
| emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) | |
| emb = w.to(dtype)[:, None] * emb[None, :] | |
| emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) | |
| if embedding_dim % 2 == 1: # zero pad | |
| emb = torch.nn.functional.pad(emb, (0, 1)) | |
| assert emb.shape == (w.shape[0], embedding_dim) | |
| return emb | |
| def append_dims(x, target_dims): | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
| return x[(...,) + (None,) * dims_to_append] | |
| # From LCMScheduler.get_scalings_for_boundary_condition_discrete | |
| def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): | |
| c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2) | |
| c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5 | |
| return c_skip, c_out | |
| # Compare LCMScheduler.step, Step 4 | |
| def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas): | |
| if prediction_type == "epsilon": | |
| sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
| alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
| pred_x_0 = (sample - sigmas * model_output) / alphas | |
| elif prediction_type == "v_prediction": | |
| sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
| alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
| pred_x_0 = alphas * sample - sigmas * model_output | |
| else: | |
| raise ValueError( | |
| f"Prediction type {prediction_type} currently not supported.") | |
| return pred_x_0 | |
| def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas): | |
| if prediction_type == "epsilon": | |
| sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) | |
| alphas = extract_into_tensor(alphas, timesteps, sample.shape) | |
| sample = sample * alphas / sigmas | |
| else: | |
| raise ValueError( | |
| f"Prediction type {prediction_type} currently not supported.") | |
| return sample | |
| def extract_into_tensor(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| class DDIMSolver: | |
| def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): | |
| # DDIM sampling parameters | |
| step_ratio = timesteps // ddim_timesteps | |
| self.ddim_timesteps = ( | |
| np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 | |
| # self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1 | |
| self.ddim_timesteps_prev = np.asarray( | |
| [0] + self.ddim_timesteps[:-1].tolist() | |
| ) | |
| self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] | |
| self.ddim_alpha_cumprods_prev = np.asarray( | |
| [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() | |
| ) | |
| self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] | |
| self.ddim_alpha_cumprods_prev = np.asarray( | |
| [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() | |
| ) | |
| # convert to torch tensors | |
| self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() | |
| self.ddim_timesteps_prev = torch.from_numpy( | |
| self.ddim_timesteps_prev).long() | |
| self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) | |
| self.ddim_alpha_cumprods_prev = torch.from_numpy( | |
| self.ddim_alpha_cumprods_prev) | |
| def to(self, device): | |
| self.ddim_timesteps = self.ddim_timesteps.to(device) | |
| self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device) | |
| self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) | |
| self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to( | |
| device) | |
| return self | |
| def ddim_step(self, pred_x0, pred_noise, timestep_index): | |
| alpha_cumprod_prev = extract_into_tensor( | |
| self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) | |
| dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise | |
| x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt | |
| return x_prev | |
| def update_ema(target_params, source_params, rate=0.99): | |
| """ | |
| Update target parameters to be closer to those of source parameters using | |
| an exponential moving average. | |
| :param target_params: the target parameter sequence. | |
| :param source_params: the source parameter sequence. | |
| :param rate: the EMA rate (closer to 1 means slower). | |
| """ | |
| for targ, src in zip(target_params, source_params): | |
| targ.detach().mul_(rate).add_(src, alpha=1 - rate) | |
| def convert_lcm_lora(unet, path, alpha=1.0): | |
| if path.endswith(("ckpt",)): | |
| state_dict = torch.load(path, map_location="cpu") | |
| else: | |
| state_dict = {} | |
| with safe_open(path, framework="pt", device="cpu") as f: | |
| for key in f.keys(): | |
| state_dict[key] = f.get_tensor(key) | |
| num_alpha = 0 | |
| for key in state_dict.keys(): | |
| if "alpha" in key: | |
| num_alpha += 1 | |
| lora_keys = [k for k in state_dict.keys( | |
| ) if k.endswith("lora_down.weight")] | |
| updated_state_dict = {} | |
| for key in lora_keys: | |
| lora_name = key.split(".")[0] | |
| if lora_name.startswith("lora_unet_"): | |
| diffusers_name = key.replace("lora_unet_", "").replace("_", ".") | |
| if "input.blocks" in diffusers_name: | |
| diffusers_name = diffusers_name.replace( | |
| "input.blocks", "down_blocks") | |
| else: | |
| diffusers_name = diffusers_name.replace( | |
| "down.blocks", "down_blocks") | |
| if "middle.block" in diffusers_name: | |
| diffusers_name = diffusers_name.replace( | |
| "middle.block", "mid_block") | |
| else: | |
| diffusers_name = diffusers_name.replace( | |
| "mid.block", "mid_block") | |
| if "output.blocks" in diffusers_name: | |
| diffusers_name = diffusers_name.replace( | |
| "output.blocks", "up_blocks") | |
| else: | |
| diffusers_name = diffusers_name.replace( | |
| "up.blocks", "up_blocks") | |
| diffusers_name = diffusers_name.replace( | |
| "transformer.blocks", "transformer_blocks") | |
| diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") | |
| diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") | |
| diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") | |
| diffusers_name = diffusers_name.replace( | |
| "to.out.0.lora", "to_out_lora") | |
| diffusers_name = diffusers_name.replace("proj.in", "proj_in") | |
| diffusers_name = diffusers_name.replace("proj.out", "proj_out") | |
| diffusers_name = diffusers_name.replace( | |
| "time.emb.proj", "time_emb_proj") | |
| diffusers_name = diffusers_name.replace( | |
| "conv.shortcut", "conv_shortcut") | |
| updated_state_dict[diffusers_name] = state_dict[key] | |
| up_diffusers_name = diffusers_name.replace(".down.", ".up.") | |
| up_key = key.replace("lora_down.weight", "lora_up.weight") | |
| updated_state_dict[up_diffusers_name] = state_dict[up_key] | |
| state_dict = updated_state_dict | |
| num_lora = 0 | |
| for key in state_dict: | |
| if "up." in key: | |
| continue | |
| up_key = key.replace(".down.", ".up.") | |
| model_key = key.replace("processor.", "").replace("_lora", "").replace( | |
| "down.", "").replace("up.", "").replace(".lora", "") | |
| model_key = model_key.replace("to_out.", "to_out.0.") | |
| layer_infos = model_key.split(".")[:-1] | |
| curr_layer = unet | |
| while len(layer_infos) > 0: | |
| temp_name = layer_infos.pop(0) | |
| curr_layer = curr_layer.__getattr__(temp_name) | |
| weight_down = state_dict[key].to( | |
| curr_layer.weight.data.device, curr_layer.weight.data.dtype) | |
| weight_up = state_dict[up_key].to( | |
| curr_layer.weight.data.device, curr_layer.weight.data.dtype) | |
| if weight_up.ndim == 2: | |
| curr_layer.weight.data += 1/8 * alpha * \ | |
| torch.mm(weight_up, weight_down) | |
| else: | |
| assert weight_up.ndim == 4 | |
| curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten( | |
| start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape) | |
| num_lora += 1 | |
| return unet | |