Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from safetensors.torch import load_file as safetensors_load_file | |
| from safetensors.torch import save_file as safetensors_save_file | |
| import sys, argparse | |
| def load_ckpt(ckpt_filepath): | |
| print(f"Loading model from {ckpt_filepath}") | |
| if ckpt_filepath.endswith(".safetensors"): | |
| state_dict = safetensors_load_file(ckpt_filepath, device="cpu") | |
| ckpt = None | |
| else: | |
| ckpt = torch.load(ckpt_filepath, map_location="cpu") | |
| state_dict = ckpt["state_dict"] | |
| return ckpt, state_dict | |
| def save_ckpt(ckpt, ckpt_state_dict, ckpt_filepath): | |
| if ckpt_filepath.endswith(".safetensors"): | |
| safetensors_save_file(ckpt_state_dict, ckpt_filepath) | |
| else: | |
| if ckpt is not None: | |
| torch.save(ckpt, ckpt_filepath) | |
| else: | |
| torch.save(ckpt_state_dict, ckpt_filepath) | |
| print(f"Saved to {ckpt_filepath}") | |
| def load_two_models(base_ckpt_filepath, te_ckpt_filepath): | |
| base_ckpt, base_state_dict = load_ckpt(base_ckpt_filepath) | |
| _, te_state_dict = load_ckpt(te_ckpt_filepath) | |
| # Other fields in sd_ckpt are also needed when saving the checkpoint. | |
| # So return the whole sd_ckpt. | |
| return base_ckpt, base_state_dict, te_state_dict | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--base_ckpt", type=str, required=True, help="Path to the base checkpoint") | |
| parser.add_argument("--te_ckpt", type=str, required=True, help="Path to the checkpoint providing text encoder") | |
| parser.add_argument("--out_ckpt", type=str, required=True, help="Path to the output checkpoint") | |
| args = parser.parse_args() | |
| base_ckpt, base_state_dict, te_state_dict = load_two_models(args.base_ckpt, args.te_ckpt) | |
| # base_state_dict = sd_ckpt["state_dict"] | |
| repl_count = 0 | |
| for k in base_state_dict: | |
| if k.startswith("cond_stage_model."): | |
| if k not in te_state_dict: | |
| print(f"!!!! '{k}' not in TE checkpoint") | |
| continue | |
| if base_state_dict[k].shape != te_state_dict[k].shape: | |
| print(f"!!!! '{k}' shape mismatch: {base_state_dict[k].shape} vs {te_state_dict[k].shape} !!!!") | |
| continue | |
| print(k) | |
| base_state_dict[k] = te_state_dict[k] | |
| repl_count += 1 | |
| if repl_count > 0: | |
| print(f"{repl_count} parameters replaced") | |
| save_ckpt(base_ckpt, base_state_dict, args.out_ckpt) | |
| else: | |
| print("ERROR: No parameter replaced") | |