Spaces:
Runtime error
Runtime error
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import wandb | |
| class PretrainedFromWandbMixin: | |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |
| """ | |
| Initializes from a wandb artifact or delegates loading to the superclass. | |
| """ | |
| with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies | |
| if ":" in pretrained_model_name_or_path and not os.path.isdir( | |
| pretrained_model_name_or_path | |
| ): | |
| # wandb artifact | |
| if wandb.run is not None: | |
| artifact = wandb.run.use_artifact(pretrained_model_name_or_path) | |
| else: | |
| artifact = wandb.Api().artifact(pretrained_model_name_or_path) | |
| pretrained_model_name_or_path = artifact.download(tmp_dir) | |
| return super(PretrainedFromWandbMixin, cls).from_pretrained( | |
| pretrained_model_name_or_path, *model_args, **kwargs | |
| ) | |