Spaces:
Runtime error
Runtime error
feat: add bucket reference to artifact
Browse files- tools/train/train.py +9 -4
tools/train/train.py
CHANGED
|
@@ -135,11 +135,12 @@ class ModelArguments:
|
|
| 135 |
artifact = wandb.run.use_artifact(state_artifact)
|
| 136 |
else:
|
| 137 |
artifact = wandb.Api().artifact(state_artifact)
|
| 138 |
-
artifact_dir = artifact.download(tmp_dir)
|
| 139 |
if artifact.metadata.get("bucket_path"):
|
|
|
|
| 140 |
self.restore_state = artifact.metadata["bucket_path"]
|
| 141 |
else:
|
| 142 |
-
|
|
|
|
| 143 |
|
| 144 |
if self.restore_state.startswith("gs://"):
|
| 145 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
|
@@ -1130,7 +1131,9 @@ def main():
|
|
| 1130 |
type="DalleBart_model",
|
| 1131 |
metadata=metadata,
|
| 1132 |
)
|
| 1133 |
-
if
|
|
|
|
|
|
|
| 1134 |
for filename in [
|
| 1135 |
"config.json",
|
| 1136 |
"flax_model.msgpack",
|
|
@@ -1153,7 +1156,9 @@ def main():
|
|
| 1153 |
type="DalleBart_state",
|
| 1154 |
metadata=metadata,
|
| 1155 |
)
|
| 1156 |
-
if
|
|
|
|
|
|
|
| 1157 |
artifact_state.add_file(
|
| 1158 |
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
| 1159 |
)
|
|
|
|
| 135 |
artifact = wandb.run.use_artifact(state_artifact)
|
| 136 |
else:
|
| 137 |
artifact = wandb.Api().artifact(state_artifact)
|
|
|
|
| 138 |
if artifact.metadata.get("bucket_path"):
|
| 139 |
+
# we will read directly file contents
|
| 140 |
self.restore_state = artifact.metadata["bucket_path"]
|
| 141 |
else:
|
| 142 |
+
artifact_dir = artifact.download(tmp_dir)
|
| 143 |
+
self.restore_state = str(Path(artifact_dir) / "opt_state.msgpack")
|
| 144 |
|
| 145 |
if self.restore_state.startswith("gs://"):
|
| 146 |
bucket_path = Path(self.restore_state[5:]) / "opt_state.msgpack"
|
|
|
|
| 1131 |
type="DalleBart_model",
|
| 1132 |
metadata=metadata,
|
| 1133 |
)
|
| 1134 |
+
if use_bucket:
|
| 1135 |
+
artifact.add_reference(metadata["bucket_path"])
|
| 1136 |
+
else:
|
| 1137 |
for filename in [
|
| 1138 |
"config.json",
|
| 1139 |
"flax_model.msgpack",
|
|
|
|
| 1156 |
type="DalleBart_state",
|
| 1157 |
metadata=metadata,
|
| 1158 |
)
|
| 1159 |
+
if use_bucket:
|
| 1160 |
+
artifact_state.add_reference(metadata["bucket_path"])
|
| 1161 |
+
else:
|
| 1162 |
artifact_state.add_file(
|
| 1163 |
f"{Path(training_args.output_dir) / 'opt_state.msgpack'}"
|
| 1164 |
)
|