| """ | |
| Util function for Model Wrapper | |
| --------------------------------------------------------------------- | |
| """ | |
| import glob | |
| import os | |
| import torch | |
| def load_cached_state_dict(model_folder_path): | |
| # Take the first model matching the pattern *model.bin. | |
| model_path_list = glob.glob(os.path.join(model_folder_path, "*model.bin")) | |
| if not model_path_list: | |
| raise FileNotFoundError( | |
| f"model.bin not found in model folder {model_folder_path}." | |
| ) | |
| model_path = model_path_list[0] | |
| state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
| return state_dict | |