Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from safetensors.torch import load_file | |
| from transformers import AutoTokenizer | |
| from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel | |
| def main(args): | |
| # checkpoint from https://huggingface.co/Shitao/OmniGen-v1 | |
| if not os.path.exists(args.origin_ckpt_path): | |
| print("Model not found, downloading...") | |
| cache_folder = os.getenv("HF_HUB_CACHE") | |
| args.origin_ckpt_path = snapshot_download( | |
| repo_id=args.origin_ckpt_path, | |
| cache_dir=cache_folder, | |
| ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"], | |
| ) | |
| print(f"Downloaded model to {args.origin_ckpt_path}") | |
| ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors") | |
| ckpt = load_file(ckpt, device="cpu") | |
| mapping_dict = { | |
| "pos_embed": "patch_embedding.pos_embed", | |
| "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", | |
| "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", | |
| "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", | |
| "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", | |
| "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", | |
| "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", | |
| "final_layer.linear.weight": "proj_out.weight", | |
| "final_layer.linear.bias": "proj_out.bias", | |
| "time_token.mlp.0.weight": "time_token.linear_1.weight", | |
| "time_token.mlp.0.bias": "time_token.linear_1.bias", | |
| "time_token.mlp.2.weight": "time_token.linear_2.weight", | |
| "time_token.mlp.2.bias": "time_token.linear_2.bias", | |
| "t_embedder.mlp.0.weight": "t_embedder.linear_1.weight", | |
| "t_embedder.mlp.0.bias": "t_embedder.linear_1.bias", | |
| "t_embedder.mlp.2.weight": "t_embedder.linear_2.weight", | |
| "t_embedder.mlp.2.bias": "t_embedder.linear_2.bias", | |
| "llm.embed_tokens.weight": "embed_tokens.weight", | |
| } | |
| converted_state_dict = {} | |
| for k, v in ckpt.items(): | |
| if k in mapping_dict: | |
| converted_state_dict[mapping_dict[k]] = v | |
| elif "qkv" in k: | |
| to_q, to_k, to_v = v.chunk(3) | |
| converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q | |
| converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k | |
| converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v | |
| elif "o_proj" in k: | |
| converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v | |
| else: | |
| converted_state_dict[k[4:]] = v | |
| transformer = OmniGenTransformer2DModel( | |
| rope_scaling={ | |
| "long_factor": [ | |
| 1.0299999713897705, | |
| 1.0499999523162842, | |
| 1.0499999523162842, | |
| 1.0799999237060547, | |
| 1.2299998998641968, | |
| 1.2299998998641968, | |
| 1.2999999523162842, | |
| 1.4499999284744263, | |
| 1.5999999046325684, | |
| 1.6499998569488525, | |
| 1.8999998569488525, | |
| 2.859999895095825, | |
| 3.68999981880188, | |
| 5.419999599456787, | |
| 5.489999771118164, | |
| 5.489999771118164, | |
| 9.09000015258789, | |
| 11.579999923706055, | |
| 15.65999984741211, | |
| 15.769999504089355, | |
| 15.789999961853027, | |
| 18.360000610351562, | |
| 21.989999771118164, | |
| 23.079999923706055, | |
| 30.009998321533203, | |
| 32.35000228881836, | |
| 32.590003967285156, | |
| 35.56000518798828, | |
| 39.95000457763672, | |
| 53.840003967285156, | |
| 56.20000457763672, | |
| 57.95000457763672, | |
| 59.29000473022461, | |
| 59.77000427246094, | |
| 59.920005798339844, | |
| 61.190006256103516, | |
| 61.96000671386719, | |
| 62.50000762939453, | |
| 63.3700065612793, | |
| 63.48000717163086, | |
| 63.48000717163086, | |
| 63.66000747680664, | |
| 63.850006103515625, | |
| 64.08000946044922, | |
| 64.760009765625, | |
| 64.80001068115234, | |
| 64.81001281738281, | |
| 64.81001281738281, | |
| ], | |
| "short_factor": [ | |
| 1.05, | |
| 1.05, | |
| 1.05, | |
| 1.1, | |
| 1.1, | |
| 1.1, | |
| 1.2500000000000002, | |
| 1.2500000000000002, | |
| 1.4000000000000004, | |
| 1.4500000000000004, | |
| 1.5500000000000005, | |
| 1.8500000000000008, | |
| 1.9000000000000008, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.000000000000001, | |
| 2.1000000000000005, | |
| 2.1000000000000005, | |
| 2.2, | |
| 2.3499999999999996, | |
| 2.3499999999999996, | |
| 2.3499999999999996, | |
| 2.3499999999999996, | |
| 2.3999999999999995, | |
| 2.3999999999999995, | |
| 2.6499999999999986, | |
| 2.6999999999999984, | |
| 2.8999999999999977, | |
| 2.9499999999999975, | |
| 3.049999999999997, | |
| 3.049999999999997, | |
| 3.049999999999997, | |
| ], | |
| "type": "su", | |
| }, | |
| patch_size=2, | |
| in_channels=4, | |
| pos_embed_max_size=192, | |
| ) | |
| transformer.load_state_dict(converted_state_dict, strict=True) | |
| transformer.to(torch.bfloat16) | |
| num_model_params = sum(p.numel() for p in transformer.parameters()) | |
| print(f"Total number of transformer parameters: {num_model_params}") | |
| scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1) | |
| vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32) | |
| tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path) | |
| pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler) | |
| pipeline.save_pretrained(args.dump_path) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--origin_ckpt_path", | |
| default="Shitao/OmniGen-v1", | |
| type=str, | |
| required=False, | |
| help="Path to the checkpoint to convert.", | |
| ) | |
| parser.add_argument( | |
| "--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline." | |
| ) | |
| args = parser.parse_args() | |
| main(args) | |