Spaces:
Runtime error
Runtime error
feat: load from bucket
Browse files- src/dalle_mini/model/utils.py +26 -3
- tools/train/train.py +22 -9
src/dalle_mini/model/utils.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import os
|
| 2 |
import tempfile
|
|
|
|
| 3 |
|
| 4 |
import wandb
|
| 5 |
|
|
@@ -8,11 +9,13 @@ class PretrainedFromWandbMixin:
|
|
| 8 |
@classmethod
|
| 9 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 10 |
"""
|
| 11 |
-
Initializes from a wandb artifact, or delegates loading to the superclass.
|
| 12 |
"""
|
| 13 |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
| 14 |
-
if
|
| 15 |
-
pretrained_model_name_or_path
|
|
|
|
|
|
|
| 16 |
):
|
| 17 |
# wandb artifact
|
| 18 |
if wandb.run is not None:
|
|
@@ -20,7 +23,27 @@ class PretrainedFromWandbMixin:
|
|
| 20 |
else:
|
| 21 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 22 |
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
| 25 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 26 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import tempfile
|
| 3 |
+
from pathlib import Path
|
| 4 |
|
| 5 |
import wandb
|
| 6 |
|
|
|
|
| 9 |
@classmethod
|
| 10 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 11 |
"""
|
| 12 |
+
Initializes from a wandb artifact, google bucket path or delegates loading to the superclass.
|
| 13 |
"""
|
| 14 |
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
| 15 |
+
if (
|
| 16 |
+
":" in pretrained_model_name_or_path
|
| 17 |
+
and not os.path.isdir(pretrained_model_name_or_path)
|
| 18 |
+
and not pretrained_model_name_or_path.startswith("gs")
|
| 19 |
):
|
| 20 |
# wandb artifact
|
| 21 |
if wandb.run is not None:
|
|
|
|
| 23 |
else:
|
| 24 |
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
| 25 |
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
| 26 |
+
if artifact.metadata.get("bucket_path"):
|
| 27 |
+
pretrained_model_name_or_path = artifact.metadata["bucket_path"]
|
| 28 |
+
|
| 29 |
+
if pretrained_model_name_or_path.startswith("gs://"):
|
| 30 |
+
copy_blobs(pretrained_model_name_or_path, tmp_dir)
|
| 31 |
+
pretrained_model_name_or_path = tmp_dir
|
| 32 |
|
| 33 |
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
| 34 |
pretrained_model_name_or_path, *model_args, **kwargs
|
| 35 |
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def copy_blobs(source_path, dest_path):
|
| 39 |
+
assert source_path.startswith("gs://")
|
| 40 |
+
from google.cloud import storage
|
| 41 |
+
|
| 42 |
+
bucket_path = Path(source_path[5:])
|
| 43 |
+
bucket, dir_path = str(bucket_path).split("/", 1)
|
| 44 |
+
client = storage.Client()
|
| 45 |
+
bucket = client.bucket(bucket)
|
| 46 |
+
blobs = client.list_blobs(bucket, prefix=f"{dir_path}/")
|
| 47 |
+
for blob in blobs:
|
| 48 |
+
dest_name = str(Path(dest_path) / Path(blob.name).name)
|
| 49 |
+
blob.download_to_filename(dest_name)
|
tools/train/train.py
CHANGED
|
@@ -135,8 +135,21 @@ class ModelArguments:
|
|
| 135 |
else:
|
| 136 |
artifact = wandb.Api().artifact(state_artifact)
|
| 137 |
artifact_dir = artifact.download(tmp_dir)
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
|
| 142 |
@dataclass
|
|
@@ -788,9 +801,7 @@ def main():
|
|
| 788 |
|
| 789 |
else:
|
| 790 |
# load opt_state
|
| 791 |
-
|
| 792 |
-
opt_state = from_bytes(opt_state_shape, opt_state_file.read())
|
| 793 |
-
opt_state_file.close()
|
| 794 |
|
| 795 |
# restore other attributes
|
| 796 |
attr_state = {
|
|
@@ -1060,7 +1071,7 @@ def main():
|
|
| 1060 |
client = storage.Client()
|
| 1061 |
bucket = client.bucket(bucket)
|
| 1062 |
for filename in Path(output_dir).glob("*"):
|
| 1063 |
-
blob_name = str(Path(dir_path) / filename.name)
|
| 1064 |
blob = bucket.blob(blob_name)
|
| 1065 |
blob.upload_from_filename(str(filename))
|
| 1066 |
tmp_dir.cleanup()
|
|
@@ -1068,7 +1079,7 @@ def main():
|
|
| 1068 |
# save state
|
| 1069 |
opt_state = jax.device_get(state.opt_state)
|
| 1070 |
if use_bucket:
|
| 1071 |
-
blob_name = str(Path(dir_path) / "opt_state.msgpack")
|
| 1072 |
blob = bucket.blob(blob_name)
|
| 1073 |
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
|
| 1074 |
else:
|
|
@@ -1088,10 +1099,10 @@ def main():
|
|
| 1088 |
metadata["num_params"] = num_params
|
| 1089 |
if eval_metrics is not None:
|
| 1090 |
metadata["eval"] = eval_metrics
|
| 1091 |
-
if use_bucket:
|
| 1092 |
-
metadata["bucket_path"] = bucket_path
|
| 1093 |
|
| 1094 |
# create model artifact
|
|
|
|
|
|
|
| 1095 |
artifact = wandb.Artifact(
|
| 1096 |
name=f"model-{wandb.run.id}",
|
| 1097 |
type="DalleBart_model",
|
|
@@ -1113,6 +1124,8 @@ def main():
|
|
| 1113 |
wandb.run.log_artifact(artifact)
|
| 1114 |
|
| 1115 |
# create state artifact
|
|
|
|
|
|
|
| 1116 |
artifact_state = wandb.Artifact(
|
| 1117 |
name=f"state-{wandb.run.id}",
|
| 1118 |
type="DalleBart_state",
|
|
|
|
| 135 |
else:
|
| 136 |
artifact = wandb.Api().artifact(state_artifact)
|
| 137 |
artifact_dir = artifact.download(tmp_dir)
|
| 138 |
+
if artifact.metadata.get("bucket_path"):
|
| 139 |
+
self.restore_state = artifact.metadata["bucket_path"]
|
| 140 |
+
else:
|
| 141 |
+
self.restore_state = Path(artifact_dir) / "opt_state.msgpack"
|
| 142 |
+
|
| 143 |
+
if self.restore_state.startswith("gs://"):
|
| 144 |
+
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
| 145 |
+
bucket, blob_name = str(bucket_path).split("/", 1)
|
| 146 |
+
client = storage.Client()
|
| 147 |
+
bucket = client.bucket(bucket)
|
| 148 |
+
blob = bucket.blob(blob_name)
|
| 149 |
+
return blob.download_as_bytes()
|
| 150 |
+
|
| 151 |
+
with Path(self.restore_state).open("rb") as f:
|
| 152 |
+
return f.read()
|
| 153 |
|
| 154 |
|
| 155 |
@dataclass
|
|
|
|
| 801 |
|
| 802 |
else:
|
| 803 |
# load opt_state
|
| 804 |
+
opt_state = from_bytes(opt_state_shape, model_args.get_opt_state())
|
|
|
|
|
|
|
| 805 |
|
| 806 |
# restore other attributes
|
| 807 |
attr_state = {
|
|
|
|
| 1071 |
client = storage.Client()
|
| 1072 |
bucket = client.bucket(bucket)
|
| 1073 |
for filename in Path(output_dir).glob("*"):
|
| 1074 |
+
blob_name = str(Path(dir_path) / "model" / filename.name)
|
| 1075 |
blob = bucket.blob(blob_name)
|
| 1076 |
blob.upload_from_filename(str(filename))
|
| 1077 |
tmp_dir.cleanup()
|
|
|
|
| 1079 |
# save state
|
| 1080 |
opt_state = jax.device_get(state.opt_state)
|
| 1081 |
if use_bucket:
|
| 1082 |
+
blob_name = str(Path(dir_path) / "state" / "opt_state.msgpack")
|
| 1083 |
blob = bucket.blob(blob_name)
|
| 1084 |
blob.upload_from_file(io.BytesIO(to_bytes(opt_state)))
|
| 1085 |
else:
|
|
|
|
| 1099 |
metadata["num_params"] = num_params
|
| 1100 |
if eval_metrics is not None:
|
| 1101 |
metadata["eval"] = eval_metrics
|
|
|
|
|
|
|
| 1102 |
|
| 1103 |
# create model artifact
|
| 1104 |
+
if use_bucket:
|
| 1105 |
+
metadata["bucket_path"] = f"gs://{bucket_path}/model"
|
| 1106 |
artifact = wandb.Artifact(
|
| 1107 |
name=f"model-{wandb.run.id}",
|
| 1108 |
type="DalleBart_model",
|
|
|
|
| 1124 |
wandb.run.log_artifact(artifact)
|
| 1125 |
|
| 1126 |
# create state artifact
|
| 1127 |
+
if use_bucket:
|
| 1128 |
+
metadata["bucket_path"] = f"gs://{bucket_path}/state"
|
| 1129 |
artifact_state = wandb.Artifact(
|
| 1130 |
name=f"state-{wandb.run.id}",
|
| 1131 |
type="DalleBart_state",
|