Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Utility to export a training checkpoint to a lightweight release checkpoint. | |
| """ | |
| from pathlib import Path | |
| import typing as tp | |
| from omegaconf import OmegaConf | |
| import torch | |
| from audiocraft import __version__ | |
| def export_encodec(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): | |
| """Export only the best state from the given EnCodec checkpoint. This | |
| should be used if you trained your own EnCodec model. | |
| """ | |
| pkg = torch.load(checkpoint_path, 'cpu') | |
| new_pkg = { | |
| 'best_state': pkg['best_state']['model'], | |
| 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), | |
| 'version': __version__, | |
| 'exported': True, | |
| } | |
| Path(out_file).parent.mkdir(exist_ok=True, parents=True) | |
| torch.save(new_pkg, out_file) | |
| return out_file | |
| def export_pretrained_compression_model(pretrained_encodec: str, out_file: tp.Union[Path, str]): | |
| """Export a compression model (potentially EnCodec) from a pretrained model. | |
| This is required for packaging the audio tokenizer along a MusicGen or AudioGen model. | |
| Do not include the //pretrained/ prefix. For instance if you trained a model | |
| with `facebook/encodec_32khz`, just put that as a name. Same for `dac_44khz`. | |
| In that case, this will not actually include a copy of the model, simply the reference | |
| to the model used. | |
| """ | |
| if Path(pretrained_encodec).exists(): | |
| pkg = torch.load(pretrained_encodec) | |
| assert 'best_state' in pkg | |
| assert 'xp.cfg' in pkg | |
| assert 'version' in pkg | |
| assert 'exported' in pkg | |
| else: | |
| pkg = { | |
| 'pretrained': pretrained_encodec, | |
| 'exported': True, | |
| 'version': __version__, | |
| } | |
| Path(out_file).parent.mkdir(exist_ok=True, parents=True) | |
| torch.save(pkg, out_file) | |
| def export_lm(checkpoint_path: tp.Union[Path, str], out_file: tp.Union[Path, str]): | |
| """Export only the best state from the given MusicGen or AudioGen checkpoint. | |
| """ | |
| pkg = torch.load(checkpoint_path, 'cpu') | |
| if pkg['fsdp_best_state']: | |
| best_state = pkg['fsdp_best_state']['model'] | |
| else: | |
| assert pkg['best_state'] | |
| best_state = pkg['best_state']['model'] | |
| new_pkg = { | |
| 'best_state': best_state, | |
| 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']), | |
| 'version': __version__, | |
| 'exported': True, | |
| } | |
| Path(out_file).parent.mkdir(exist_ok=True, parents=True) | |
| torch.save(new_pkg, out_file) | |
| return out_file | |

