Spaces:
Runtime error
Runtime error
feat(train): restore opt_state efficiently
Browse files- tools/train/train.py +40 -35
tools/train/train.py
CHANGED
|
@@ -42,7 +42,7 @@ from flax.training.common_utils import onehot, stack_forest
|
|
| 42 |
from jax.experimental import PartitionSpec, maps
|
| 43 |
from jax.experimental.pjit import pjit
|
| 44 |
from tqdm import tqdm
|
| 45 |
-
from transformers import
|
| 46 |
|
| 47 |
import wandb
|
| 48 |
from dalle_mini.data import Dataset
|
|
@@ -375,23 +375,6 @@ class TrainState(train_state.TrainState):
|
|
| 375 |
train_time: float = 0.0 # total time the model trained
|
| 376 |
train_samples: int = 0 # number of samples seen
|
| 377 |
|
| 378 |
-
def restore_state(self, artifact_dir):
|
| 379 |
-
# restore optimizer state
|
| 380 |
-
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 381 |
-
new_opt_state = from_bytes(self.opt_state, f.read())
|
| 382 |
-
|
| 383 |
-
# restore other parameters
|
| 384 |
-
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
| 385 |
-
training_state = json.load(f)
|
| 386 |
-
|
| 387 |
-
# replace state
|
| 388 |
-
return self.replace(
|
| 389 |
-
opt_state=new_opt_state,
|
| 390 |
-
step=training_state["step"],
|
| 391 |
-
train_time=training_state["train_time"],
|
| 392 |
-
train_samples=training_state["train_samples"],
|
| 393 |
-
)
|
| 394 |
-
|
| 395 |
|
| 396 |
class MetricsLogger:
|
| 397 |
def __init__(self, state):
|
|
@@ -528,7 +511,7 @@ def main():
|
|
| 528 |
|
| 529 |
# Load tokenizer
|
| 530 |
if model_args.tokenizer_name is not None:
|
| 531 |
-
tokenizer =
|
| 532 |
model_args.tokenizer_name, use_fast=True
|
| 533 |
)
|
| 534 |
else:
|
|
@@ -648,8 +631,7 @@ def main():
|
|
| 648 |
)
|
| 649 |
|
| 650 |
# get opt_state shape without actual init
|
| 651 |
-
|
| 652 |
-
opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), param_shape)
|
| 653 |
|
| 654 |
# get PartitionSpec for model params
|
| 655 |
param_spec = set_partitions(model.params)
|
|
@@ -692,28 +674,51 @@ def main():
|
|
| 692 |
tx=optimizer,
|
| 693 |
)
|
| 694 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
# create training state
|
| 696 |
-
def init_state(params):
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
return state
|
| 704 |
|
| 705 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 706 |
state = pjit(
|
| 707 |
init_state,
|
| 708 |
-
in_axis_resources=
|
| 709 |
out_axis_resources=state_spec,
|
| 710 |
-
donate_argnums=(0,),
|
| 711 |
-
)(freeze(model.params))
|
| 712 |
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
# we currently ignore partial epoch training: see https://github.com/borisdayma/dalle-mini/issues/105
|
| 716 |
-
state = state.restore_state(artifact_dir)
|
| 717 |
|
| 718 |
# label smoothed cross entropy
|
| 719 |
def loss_fn(logits, labels):
|
|
|
|
| 42 |
from jax.experimental import PartitionSpec, maps
|
| 43 |
from jax.experimental.pjit import pjit
|
| 44 |
from tqdm import tqdm
|
| 45 |
+
from transformers import HfArgumentParser
|
| 46 |
|
| 47 |
import wandb
|
| 48 |
from dalle_mini.data import Dataset
|
|
|
|
| 375 |
train_time: float = 0.0 # total time the model trained
|
| 376 |
train_samples: int = 0 # number of samples seen
|
| 377 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
class MetricsLogger:
|
| 380 |
def __init__(self, state):
|
|
|
|
| 511 |
|
| 512 |
# Load tokenizer
|
| 513 |
if model_args.tokenizer_name is not None:
|
| 514 |
+
tokenizer = DalleBartTokenizer.from_pretrained(
|
| 515 |
model_args.tokenizer_name, use_fast=True
|
| 516 |
)
|
| 517 |
else:
|
|
|
|
| 631 |
)
|
| 632 |
|
| 633 |
# get opt_state shape without actual init
|
| 634 |
+
opt_state_shape = jax.eval_shape(lambda x: optimizer.init(x), model.params)
|
|
|
|
| 635 |
|
| 636 |
# get PartitionSpec for model params
|
| 637 |
param_spec = set_partitions(model.params)
|
|
|
|
| 674 |
tx=optimizer,
|
| 675 |
)
|
| 676 |
|
| 677 |
+
opt_state, attr_state = None, None
|
| 678 |
+
if training_args.resume_from_checkpoint is not None:
|
| 679 |
+
# restore opt_state
|
| 680 |
+
with (Path(artifact_dir) / "opt_state.msgpack").open("rb") as f:
|
| 681 |
+
opt_state = from_bytes(opt_state_shape, f.read())
|
| 682 |
+
# need to freeze dict for pjit
|
| 683 |
+
opt_state = jax.tree_map(
|
| 684 |
+
lambda x: freeze(x) if isinstance(x, dict) else x,
|
| 685 |
+
opt_state,
|
| 686 |
+
is_leaf=lambda x: isinstance(x, (dict, optax.EmptyState)),
|
| 687 |
+
)
|
| 688 |
+
# restore other attributes
|
| 689 |
+
with (Path(artifact_dir) / "training_state.json").open("r") as f:
|
| 690 |
+
attr_state = json.load(f)
|
| 691 |
+
|
| 692 |
# create training state
|
| 693 |
+
def init_state(params, opt_state):
|
| 694 |
+
if training_args.resume_from_checkpoint is None:
|
| 695 |
+
state = TrainState.create(
|
| 696 |
+
apply_fn=model.__call__,
|
| 697 |
+
tx=optimizer,
|
| 698 |
+
params=freeze(params),
|
| 699 |
+
dropout_rng=dropout_rng,
|
| 700 |
+
)
|
| 701 |
+
else:
|
| 702 |
+
state = TrainState(
|
| 703 |
+
apply_fn=model.__call__,
|
| 704 |
+
tx=optimizer,
|
| 705 |
+
params=freeze(params),
|
| 706 |
+
opt_state=opt_state,
|
| 707 |
+
dropout_rng=dropout_rng,
|
| 708 |
+
**attr_state,
|
| 709 |
+
)
|
| 710 |
return state
|
| 711 |
|
| 712 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 713 |
state = pjit(
|
| 714 |
init_state,
|
| 715 |
+
in_axis_resources=(param_spec, opt_state_spec),
|
| 716 |
out_axis_resources=state_spec,
|
| 717 |
+
donate_argnums=(0, 1),
|
| 718 |
+
)(freeze(model.params), opt_state)
|
| 719 |
|
| 720 |
+
# free memory from large parameters
|
| 721 |
+
del model._params, opt_state
|
|
|
|
|
|
|
| 722 |
|
| 723 |
# label smoothed cross entropy
|
| 724 |
def loss_fn(logits, labels):
|