Spaces:
Sleeping
Sleeping
| import os | |
| model_path = 'final_models' | |
| def prepare_models(): | |
| pfns4bo_dir = os.path.dirname(__file__) | |
| model_names = ['hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt', | |
| 'model_sampled_warp_simple_mlp_for_hpob_46.pt', | |
| 'model_hebo_morebudget_9_unused_features_3.pt',] | |
| for name in model_names: | |
| weights_path = os.path.join(pfns4bo_dir, model_path, name) | |
| compressed_weights_path = os.path.join(pfns4bo_dir, model_path, name + '.gz') | |
| if not os.path.exists(weights_path): | |
| if not os.path.exists(compressed_weights_path): | |
| print("Downloading", os.path.abspath(compressed_weights_path)) | |
| import requests | |
| url = f'https://github.com/automl/PFNs4BO/raw/main/pfns4bo/final_models/{name + ".gz"}' | |
| r = requests.get(url, allow_redirects=True) | |
| os.makedirs(os.path.dirname(compressed_weights_path), exist_ok=True) | |
| with open(compressed_weights_path, 'wb') as f: | |
| f.write(r.content) | |
| if os.path.exists(compressed_weights_path): | |
| print("Unzipping", name) | |
| os.system(f"gzip -dk {compressed_weights_path}") | |
| else: | |
| print("Failed to find", compressed_weights_path) | |
| print("Make sure you have an internet connection to download the model automatically..") | |
| if os.path.exists(weights_path): | |
| print("Successfully located model at", weights_path) | |
| model_dict = { | |
| 'hebo_plus_userprior_model': os.path.join(os.path.dirname(__file__),model_path, | |
| 'hebo_morebudget_9_unused_features_3_userpriorperdim2_8.pt'), | |
| 'hebo_plus_model': os.path.join(os.path.dirname(__file__),model_path, | |
| 'model_hebo_morebudget_9_unused_features_3.pt'), | |
| 'bnn_model': os.path.join(os.path.dirname(__file__),model_path,'model_sampled_warp_simple_mlp_for_hpob_46.pt') | |
| } | |
| def __getattr__(name): | |
| if name in model_dict: | |
| if not os.path.exists(model_dict[name]): | |
| print("Can't find", os.path.abspath(model_dict[name]), "thus unzipping/downloading models now.") | |
| print("This might take a while..") | |
| prepare_models() | |
| return model_dict[name] | |
| raise AttributeError(f"module '{__name__}' has no attribute '{name}'") | |