from hmac import new import sys import os import argparse from safetensors.torch import save_file import time import json import torch import torchaudio import numpy as np from omegaconf import OmegaConf from codeclm.models import builders import gc from codeclm.trainer.codec_song_pl import CodecLM_PL from codeclm.models import CodecLM from third_party.demucs.models.pretrained import get_model_from_yaml cfg_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/songgeneration_base/config.yaml" cfg = OmegaConf.load(cfg_path) cfg.mode = 'inference' # audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg) # model = audio_tokenizer.model.model # weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0} # save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors') # print(weights) # seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg) # model = seperate_tokenizer.model.model # weights = {} # for k, v in model.state_dict().items(): # if k.startswith("rvq_bestrq_bgm_emb") or k.startswith("rvq_bestrq_emb") or k.startswith("bestrq"): # weights[k] = v.half() # else: # weights[k] = v # # weights = {k: v.half() for k, v in model.state_dict().items() if isinstance(v, torch.Tensor) and v.numel() > 0} # save_file(weights, '/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-LeVo/ckpt/encoder_fp16.safetensors') # print(weights.keys()) ckpt_path = "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model_32.pt" # audiolm = builders.get_lm_model(cfg) checkpoint = torch.load(ckpt_path, map_location='cpu') audiolm_state_dict = {k: v.half() for k, v in checkpoint.items()} torch.save(audiolm_state_dict, "/apdcephfs_cq11/share_300883980/tanwei/SongGeneration-WX/ckpt/songgeneration_new_small/model.pt")