Spaces:
Runtime error
Runtime error
feat(train): more custom x-axis
Browse files- tools/train/train.py +28 -35
tools/train/train.py
CHANGED
|
@@ -395,17 +395,16 @@ class TrainState(train_state.TrainState):
|
|
| 395 |
|
| 396 |
|
| 397 |
class MetricsLogger:
|
| 398 |
-
def __init__(self,
|
| 399 |
-
self.step =
|
| 400 |
self.time = time.perf_counter()
|
|
|
|
| 401 |
|
| 402 |
-
def
|
| 403 |
-
"""
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
k.split("_")[-1]: getattr(state, k)
|
| 408 |
-
for k in ["epoch", "train_time", "train_samples"]
|
| 409 |
}
|
| 410 |
# timing metrics
|
| 411 |
new_step = int(state.step)
|
|
@@ -414,19 +413,15 @@ class MetricsLogger:
|
|
| 414 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
| 415 |
self.step = new_step
|
| 416 |
self.time = new_time
|
| 417 |
-
state_dict["time_per_step"] = time_per_step
|
| 418 |
-
return {**metrics, **state_dict}
|
| 419 |
|
| 420 |
-
|
| 421 |
-
def log(metrics, step=None, prefix=None):
|
| 422 |
if jax.process_index() == 0:
|
| 423 |
log_metrics = {
|
| 424 |
f"{prefix}/{k}" if prefix is not None else k: v
|
| 425 |
for k, v in metrics.items()
|
| 426 |
}
|
| 427 |
-
|
| 428 |
-
log_metrics["train/step"] = step
|
| 429 |
-
wandb.log(log_metrics)
|
| 430 |
|
| 431 |
|
| 432 |
def main():
|
|
@@ -878,9 +873,9 @@ def main():
|
|
| 878 |
return state, metrics
|
| 879 |
|
| 880 |
# Define eval fn
|
| 881 |
-
def eval_step(
|
| 882 |
batch, labels = batch.pop("labels")
|
| 883 |
-
logits = model(**batch, params=params, train=False)[0]
|
| 884 |
loss = loss_fn(logits, labels)
|
| 885 |
return loss
|
| 886 |
|
|
@@ -893,7 +888,7 @@ def main():
|
|
| 893 |
)
|
| 894 |
p_eval_step = pjit(
|
| 895 |
eval_step,
|
| 896 |
-
in_axis_resources=(
|
| 897 |
out_axis_resources=None,
|
| 898 |
)
|
| 899 |
|
|
@@ -913,10 +908,14 @@ def main():
|
|
| 913 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 914 |
)
|
| 915 |
|
| 916 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 917 |
if jax.process_index() == 0:
|
| 918 |
# set default x-axis as 'train/step'
|
| 919 |
-
metrics_logger.log({}, step=state.step)
|
| 920 |
wandb.define_metric("*", step_metric="train/step")
|
| 921 |
|
| 922 |
# add interesting config parameters
|
|
@@ -950,7 +949,7 @@ def main():
|
|
| 950 |
# freeze batch to pass safely to JAX transforms
|
| 951 |
batch = freeze(batch)
|
| 952 |
# accumulate losses async
|
| 953 |
-
eval_loss.append(p_eval_step(state
|
| 954 |
|
| 955 |
# get the mean of the loss
|
| 956 |
eval_loss = jnp.stack(eval_loss)
|
|
@@ -958,7 +957,7 @@ def main():
|
|
| 958 |
eval_metrics = {"loss": eval_loss}
|
| 959 |
|
| 960 |
# log metrics
|
| 961 |
-
metrics_logger.log(eval_metrics,
|
| 962 |
|
| 963 |
# Print metrics and update progress bar
|
| 964 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
@@ -1036,16 +1035,12 @@ def main():
|
|
| 1036 |
)
|
| 1037 |
wandb.run.log_artifact(artifact_state)
|
| 1038 |
|
| 1039 |
-
# init variables
|
| 1040 |
-
last_time = time.perf_counter()
|
| 1041 |
-
train_metrics = None
|
| 1042 |
-
step = int(state.step)
|
| 1043 |
-
|
| 1044 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 1045 |
for epoch in epochs:
|
| 1046 |
state.replace(epoch=epoch)
|
| 1047 |
# ======================== Training ================================
|
| 1048 |
-
metrics_logger.
|
|
|
|
| 1049 |
|
| 1050 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 1051 |
train_loader = dataset.dataloader(
|
|
@@ -1086,10 +1081,8 @@ def main():
|
|
| 1086 |
step += 1
|
| 1087 |
|
| 1088 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 1089 |
-
|
| 1090 |
-
|
| 1091 |
-
)
|
| 1092 |
-
metrics_logger.log(all_metrics, step=step, prefix="train")
|
| 1093 |
|
| 1094 |
eval_metrics = None
|
| 1095 |
if step % training_args.eval_steps == 0:
|
|
@@ -1100,8 +1093,8 @@ def main():
|
|
| 1100 |
|
| 1101 |
# log final train metrics
|
| 1102 |
if train_metrics is not None:
|
| 1103 |
-
|
| 1104 |
-
metrics_logger.log(
|
| 1105 |
|
| 1106 |
epochs.write(
|
| 1107 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
|
|
|
| 395 |
|
| 396 |
|
| 397 |
class MetricsLogger:
|
| 398 |
+
def __init__(self, step):
|
| 399 |
+
self.step = step
|
| 400 |
self.time = time.perf_counter()
|
| 401 |
+
self.state_dict = {}
|
| 402 |
|
| 403 |
+
def update_state_metrics(self, state):
|
| 404 |
+
"""Update internal state metrics (logged at each call to be used as x-axis)"""
|
| 405 |
+
self.state_dict = {
|
| 406 |
+
f'train/{k.split("_")[-1]}': getattr(state, k)
|
| 407 |
+
for k in ["step", "epoch", "train_time", "train_samples"]
|
|
|
|
|
|
|
| 408 |
}
|
| 409 |
# timing metrics
|
| 410 |
new_step = int(state.step)
|
|
|
|
| 413 |
time_per_step = (new_time - self.time) / (new_step - self.step)
|
| 414 |
self.step = new_step
|
| 415 |
self.time = new_time
|
| 416 |
+
self.state_dict["train/time_per_step"] = time_per_step
|
|
|
|
| 417 |
|
| 418 |
+
def log(self, metrics, prefix=None):
|
|
|
|
| 419 |
if jax.process_index() == 0:
|
| 420 |
log_metrics = {
|
| 421 |
f"{prefix}/{k}" if prefix is not None else k: v
|
| 422 |
for k, v in metrics.items()
|
| 423 |
}
|
| 424 |
+
wandb.log({**log_metrics, **self.state_dict})
|
|
|
|
|
|
|
| 425 |
|
| 426 |
|
| 427 |
def main():
|
|
|
|
| 873 |
return state, metrics
|
| 874 |
|
| 875 |
# Define eval fn
|
| 876 |
+
def eval_step(state, batch):
|
| 877 |
batch, labels = batch.pop("labels")
|
| 878 |
+
logits = model(**batch, params=state.params, train=False)[0]
|
| 879 |
loss = loss_fn(logits, labels)
|
| 880 |
return loss
|
| 881 |
|
|
|
|
| 888 |
)
|
| 889 |
p_eval_step = pjit(
|
| 890 |
eval_step,
|
| 891 |
+
in_axis_resources=(state_spec, batch_spec),
|
| 892 |
out_axis_resources=None,
|
| 893 |
)
|
| 894 |
|
|
|
|
| 908 |
range(state.epoch, num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0
|
| 909 |
)
|
| 910 |
|
| 911 |
+
# init variables
|
| 912 |
+
last_time = time.perf_counter()
|
| 913 |
+
train_metrics = None
|
| 914 |
+
step = int(state.step)
|
| 915 |
+
metrics_logger = MetricsLogger(step)
|
| 916 |
+
|
| 917 |
if jax.process_index() == 0:
|
| 918 |
# set default x-axis as 'train/step'
|
|
|
|
| 919 |
wandb.define_metric("*", step_metric="train/step")
|
| 920 |
|
| 921 |
# add interesting config parameters
|
|
|
|
| 949 |
# freeze batch to pass safely to JAX transforms
|
| 950 |
batch = freeze(batch)
|
| 951 |
# accumulate losses async
|
| 952 |
+
eval_loss.append(p_eval_step(state, batch))
|
| 953 |
|
| 954 |
# get the mean of the loss
|
| 955 |
eval_loss = jnp.stack(eval_loss)
|
|
|
|
| 957 |
eval_metrics = {"loss": eval_loss}
|
| 958 |
|
| 959 |
# log metrics
|
| 960 |
+
metrics_logger.log(eval_metrics, prefix="eval")
|
| 961 |
|
| 962 |
# Print metrics and update progress bar
|
| 963 |
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
|
|
|
| 1035 |
)
|
| 1036 |
wandb.run.log_artifact(artifact_state)
|
| 1037 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
with maps.mesh(mesh.devices, mesh.axis_names):
|
| 1039 |
for epoch in epochs:
|
| 1040 |
state.replace(epoch=epoch)
|
| 1041 |
# ======================== Training ================================
|
| 1042 |
+
metrics_logger.update_state_metrics(state)
|
| 1043 |
+
metrics_logger.log({})
|
| 1044 |
|
| 1045 |
# Generate an epoch by shuffling sampling indices from the train dataset
|
| 1046 |
train_loader = dataset.dataloader(
|
|
|
|
| 1081 |
step += 1
|
| 1082 |
|
| 1083 |
if step % training_args.logging_steps == 0 and jax.process_index() == 0:
|
| 1084 |
+
metrics_logger.update_state_metrics(state)
|
| 1085 |
+
metrics_logger.log(train_metrics, prefix="train")
|
|
|
|
|
|
|
| 1086 |
|
| 1087 |
eval_metrics = None
|
| 1088 |
if step % training_args.eval_steps == 0:
|
|
|
|
| 1093 |
|
| 1094 |
# log final train metrics
|
| 1095 |
if train_metrics is not None:
|
| 1096 |
+
metrics_logger.update_state_metrics(state)
|
| 1097 |
+
metrics_logger.log(train_metrics, prefix="train")
|
| 1098 |
|
| 1099 |
epochs.write(
|
| 1100 |
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|