[#1] the name of the artifact should just be seq2seq
Browse files- main_train.py +1 -1
main_train.py
CHANGED
|
@@ -46,7 +46,7 @@ def main():
|
|
| 46 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
| 47 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
| 48 |
trainer.save_checkpoint(str(ckpt_path))
|
| 49 |
-
artifact = wandb.Artifact(name=
|
| 50 |
artifact.add_file(str(ckpt_path))
|
| 51 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
| 52 |
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
|
|
|
|
| 46 |
if not config['fast_dev_run'] and trainer.current_epoch == config['max_epochs'] - 1:
|
| 47 |
ckpt_path = ROOT_DIR / "model.ckpt"
|
| 48 |
trainer.save_checkpoint(str(ckpt_path))
|
| 49 |
+
artifact = wandb.Artifact(name="seq2seq", type="model", metadata=config)
|
| 50 |
artifact.add_file(str(ckpt_path))
|
| 51 |
run.log_artifact(artifact, aliases=["latest", config['ver']])
|
| 52 |
os.remove(str(ckpt_path)) # make sure you remove it after you are done with uploading it
|