Spaces:
Runtime error
Runtime error
| from utils.utils import instantiate_from_config | |
| import torch | |
| import copy | |
| from omegaconf import OmegaConf | |
| import logging | |
| main_logger = logging.getLogger("main_logger") | |
| def expand_conv_kernel(pretrained_dict): | |
| """expand 2d conv parameters from 4D -> 5D""" | |
| for k, v in pretrained_dict.items(): | |
| if v.dim() == 4 and not k.startswith("first_stage_model"): | |
| v = v.unsqueeze(2) | |
| pretrained_dict[k] = v | |
| return pretrained_dict | |
| def print_state_dict(state_dict): | |
| print("====== Dumping State Dict ======") | |
| for k, v in state_dict.items(): | |
| print(k, v.shape) | |
| def load_from_pretrainedSD_checkpoint( | |
| model, | |
| pretained_ckpt, | |
| expand_to_3d=True, | |
| adapt_keyname=False, | |
| echo_empty_params=False, | |
| ): | |
| sd_state_dict = torch.load(pretained_ckpt, map_location="cpu") | |
| if "state_dict" in list(sd_state_dict.keys()): | |
| sd_state_dict = sd_state_dict["state_dict"] | |
| model_state_dict = model.state_dict() | |
| # delete ema_weights just for <precise param counting> | |
| for k in list(sd_state_dict.keys()): | |
| if k.startswith("model_ema"): | |
| del sd_state_dict[k] | |
| main_logger.info( | |
| f"Num of model params of Source:{len(sd_state_dict.keys())} VS. Target:{len(model_state_dict.keys())}" | |
| ) | |
| # print_state_dict(model_state_dict) | |
| # print_state_dict(sd_state_dict) | |
| if adapt_keyname: | |
| # adapting to standard 2d network: modify the key name because of the add of temporal-attention | |
| mapping_dict = { | |
| "middle_block.2": "middle_block.3", | |
| "output_blocks.5.2": "output_blocks.5.3", | |
| "output_blocks.8.2": "output_blocks.8.3", | |
| } | |
| cnt = 0 | |
| for k in list(sd_state_dict.keys()): | |
| for src_word, dst_word in mapping_dict.items(): | |
| if src_word in k: | |
| new_key = k.replace(src_word, dst_word) | |
| sd_state_dict[new_key] = sd_state_dict[k] | |
| del sd_state_dict[k] | |
| cnt += 1 | |
| main_logger.info(f"[renamed {cnt} Source keys to match Target model]") | |
| pretrained_dict = { | |
| k: v for k, v in sd_state_dict.items() if k in model_state_dict | |
| } # drop extra keys | |
| empty_paras = [ | |
| k for k, v in model_state_dict.items() if k not in pretrained_dict | |
| ] # log no pretrained keys | |
| assert len(empty_paras) + len(pretrained_dict.keys()) == len( | |
| model_state_dict.keys() | |
| ) | |
| if expand_to_3d: | |
| # adapting to 2d inflated network | |
| pretrained_dict = expand_conv_kernel(pretrained_dict) | |
| # overwrite entries in the existing state dict | |
| model_state_dict.update(pretrained_dict) | |
| # load the new state dict | |
| try: | |
| model.load_state_dict(model_state_dict) | |
| except: | |
| skipped = [] | |
| model_dict_ori = model.state_dict() | |
| for n, p in model_state_dict.items(): | |
| if p.shape != model_dict_ori[n].shape: | |
| # skip by using original empty paras | |
| model_state_dict[n] = model_dict_ori[n] | |
| main_logger.info( | |
| f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_state_dict[n].shape} in current model" | |
| ) | |
| skipped.append(n) | |
| main_logger.info( | |
| f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!" | |
| ) | |
| model.load_state_dict(model_state_dict) | |
| empty_paras += skipped | |
| # only count Unet part of depth estimation model | |
| unet_empty_paras = [ | |
| name for name in empty_paras if name.startswith("model.diffusion_model") | |
| ] | |
| main_logger.info( | |
| f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)} [Unet:{len(unet_empty_paras)}]" | |
| ) | |
| if echo_empty_params: | |
| print("Printing empty parameters:") | |
| for k in empty_paras: | |
| print(k) | |
| return model, empty_paras | |
| # Below: written by Yingqing -------------------------------------------------------- | |
| def load_model_from_config(config, ckpt, verbose=False): | |
| pl_sd = torch.load(ckpt, map_location="cpu") | |
| sd = pl_sd["state_dict"] | |
| model = instantiate_from_config(config.model) | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| main_logger.info("missing keys:") | |
| main_logger.info(m) | |
| if len(u) > 0 and verbose: | |
| main_logger.info("unexpected keys:") | |
| main_logger.info(u) | |
| model.eval() | |
| return model | |
| def init_and_load_ldm_model(config_path, ckpt_path, device=None): | |
| assert config_path.endswith(".yaml"), f"config_path = {config_path}" | |
| assert ckpt_path.endswith(".ckpt"), f"ckpt_path = {ckpt_path}" | |
| config = OmegaConf.load(config_path) | |
| model = load_model_from_config(config, ckpt_path) | |
| if device is not None: | |
| model = model.to(device) | |
| return model | |
| def load_img_model_to_video_model( | |
| model, | |
| device=None, | |
| expand_to_3d=True, | |
| adapt_keyname=False, | |
| config_path="configs/latent-diffusion/txt2img-1p4B-eval.yaml", | |
| ckpt_path="models/ldm/text2img-large/model.ckpt", | |
| ): | |
| pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device) | |
| model, empty_paras = load_partial_weights( | |
| model, | |
| pretrained_ldm.state_dict(), | |
| expand_to_3d=expand_to_3d, | |
| adapt_keyname=adapt_keyname, | |
| ) | |
| return model, empty_paras | |
| def load_partial_weights( | |
| model, pretrained_dict, expand_to_3d=True, adapt_keyname=False | |
| ): | |
| model2 = copy.deepcopy(model) | |
| model_dict = model.state_dict() | |
| model_dict_ori = copy.deepcopy(model_dict) | |
| main_logger.info(f"[Load pretrained LDM weights]") | |
| main_logger.info( | |
| f"Num of parameters of source model:{len(pretrained_dict.keys())} VS. target model:{len(model_dict.keys())}" | |
| ) | |
| if adapt_keyname: | |
| # adapting to menghan's standard 2d network: modify the key name because of the add of temporal-attention | |
| mapping_dict = { | |
| "middle_block.2": "middle_block.3", | |
| "output_blocks.5.2": "output_blocks.5.3", | |
| "output_blocks.8.2": "output_blocks.8.3", | |
| } | |
| cnt = 0 | |
| newpretrained_dict = copy.deepcopy(pretrained_dict) | |
| for k, v in newpretrained_dict.items(): | |
| for src_word, dst_word in mapping_dict.items(): | |
| if src_word in k: | |
| new_key = k.replace(src_word, dst_word) | |
| pretrained_dict[new_key] = v | |
| pretrained_dict.pop(k) | |
| cnt += 1 | |
| main_logger.info(f"--renamed {cnt} source keys to match target model.") | |
| pretrained_dict = { | |
| k: v for k, v in pretrained_dict.items() if k in model_dict | |
| } # drop extra keys | |
| empty_paras = [ | |
| k for k, v in model_dict.items() if k not in pretrained_dict | |
| ] # log no pretrained keys | |
| main_logger.info( | |
| f"Pretrained parameters: {len(pretrained_dict.keys())} | Empty parameters: {len(empty_paras)}" | |
| ) | |
| # disable info | |
| # main_logger.info(f'Empty parameters: {empty_paras} ') | |
| assert len(empty_paras) + len(pretrained_dict.keys()) == len(model_dict.keys()) | |
| if expand_to_3d: | |
| # adapting to yingqing's 2d inflation network | |
| pretrained_dict = expand_conv_kernel(pretrained_dict) | |
| # overwrite entries in the existing state dict | |
| model_dict.update(pretrained_dict) | |
| # load the new state dict | |
| try: | |
| model2.load_state_dict(model_dict) | |
| except: | |
| # if parameter size mismatch, skip them | |
| skipped = [] | |
| for n, p in model_dict.items(): | |
| if p.shape != model_dict_ori[n].shape: | |
| # skip by using original empty paras | |
| model_dict[n] = model_dict_ori[n] | |
| main_logger.info( | |
| f"Skip para: {n}, size={pretrained_dict[n].shape} in pretrained, {model_dict[n].shape} in current model" | |
| ) | |
| skipped.append(n) | |
| main_logger.info( | |
| f"[INFO] Skip {len(skipped)} parameters becasuse of size mismatch!" | |
| ) | |
| model2.load_state_dict(model_dict) | |
| empty_paras += skipped | |
| main_logger.info(f"Empty parameters: {len(empty_paras)} ") | |
| main_logger.info(f"Finished.") | |
| return model2, empty_paras | |
| def load_autoencoder(model, config_path=None, ckpt_path=None, device=None): | |
| if config_path is None: | |
| config_path = "configs/latent-diffusion/txt2img-1p4B-eval.yaml" | |
| if ckpt_path is None: | |
| ckpt_path = "models/ldm/text2img-large/model.ckpt" | |
| pretrained_ldm = init_and_load_ldm_model(config_path, ckpt_path, device) | |
| autoencoder_dict = {} | |
| for n, p in pretrained_ldm.state_dict().items(): | |
| if n.startswith("first_stage_model"): | |
| autoencoder_dict[n] = p | |
| model_dict = model.state_dict() | |
| model_dict.update(autoencoder_dict) | |
| main_logger.info(f"Load [{len(autoencoder_dict)}] autoencoder parameters!") | |
| model.load_state_dict(model_dict) | |
| return model | |