Spaces:
Runtime error
Runtime error
feat: use_artifact if run existing
Browse files
src/dalle_mini/model/configuration.py
CHANGED
|
@@ -18,7 +18,7 @@ import warnings
|
|
| 18 |
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
from transformers.utils import logging
|
| 20 |
|
| 21 |
-
from .
|
| 22 |
|
| 23 |
logger = logging.get_logger(__name__)
|
| 24 |
|
|
|
|
| 18 |
from transformers.configuration_utils import PretrainedConfig
|
| 19 |
from transformers.utils import logging
|
| 20 |
|
| 21 |
+
from .utils import PretrainedFromWandbMixin
|
| 22 |
|
| 23 |
logger = logging.get_logger(__name__)
|
| 24 |
|
src/dalle_mini/model/modeling.py
CHANGED
|
@@ -46,7 +46,7 @@ from transformers.models.bart.modeling_flax_bart import (
|
|
| 46 |
from transformers.utils import logging
|
| 47 |
|
| 48 |
from .configuration import DalleBartConfig
|
| 49 |
-
from .
|
| 50 |
|
| 51 |
logger = logging.get_logger(__name__)
|
| 52 |
|
|
|
|
| 46 |
from transformers.utils import logging
|
| 47 |
|
| 48 |
from .configuration import DalleBartConfig
|
| 49 |
+
from .utils import PretrainedFromWandbMixin
|
| 50 |
|
| 51 |
logger = logging.get_logger(__name__)
|
| 52 |
|
src/dalle_mini/model/tokenizer.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
from transformers import BartTokenizer
|
| 3 |
from transformers.utils import logging
|
| 4 |
|
| 5 |
-
from .
|
| 6 |
|
| 7 |
logger = logging.get_logger(__name__)
|
| 8 |
|
|
|
|
| 2 |
from transformers import BartTokenizer
|
| 3 |
from transformers.utils import logging
|
| 4 |
|
| 5 |
+
from .utils import PretrainedFromWandbMixin
|
| 6 |
|
| 7 |
logger = logging.get_logger(__name__)
|
| 8 |
|
src/dalle_mini/model/{wandb_pretrained.py → utils.py}
RENAMED
|
@@ -13,7 +13,10 @@ class PretrainedFromWandbMixin:
|
|
| 13 |
pretrained_model_name_or_path
|
| 14 |
):
|
| 15 |
# wandb artifact
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
| 17 |
pretrained_model_name_or_path = artifact.download()
|
| 18 |
|
| 19 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
|
|
|
| 13 |
pretrained_model_name_or_path
|
| 14 |
):
|
| 15 |
# wandb artifact
|
| 16 |
+
if wandb.run is not None:
|
| 17 |
+
artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
|
| 18 |
+
else:
|
| 19 |
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 20 |
pretrained_model_name_or_path = artifact.download()
|
| 21 |
|
| 22 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|