Spaces:
Runtime error
Runtime error
feat: load data first
Browse files- tools/train/train.py +3 -3
tools/train/train.py
CHANGED
|
@@ -375,9 +375,6 @@ def main():
|
|
| 375 |
datasets.utils.logging.set_verbosity_error()
|
| 376 |
transformers.utils.logging.set_verbosity_error()
|
| 377 |
|
| 378 |
-
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 379 |
-
assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
|
| 380 |
-
|
| 381 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 382 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 383 |
|
|
@@ -388,6 +385,9 @@ def main():
|
|
| 388 |
do_eval=training_args.do_eval,
|
| 389 |
)
|
| 390 |
|
|
|
|
|
|
|
|
|
|
| 391 |
# Set up wandb run
|
| 392 |
if jax.process_index() == 0:
|
| 393 |
wandb.init(
|
|
|
|
| 375 |
datasets.utils.logging.set_verbosity_error()
|
| 376 |
transformers.utils.logging.set_verbosity_error()
|
| 377 |
|
|
|
|
|
|
|
|
|
|
| 378 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
| 379 |
logger.info(f"Training/evaluation parameters {training_args}")
|
| 380 |
|
|
|
|
| 385 |
do_eval=training_args.do_eval,
|
| 386 |
)
|
| 387 |
|
| 388 |
+
logger.info(f"Local TPUs: {jax.local_device_count()}")
|
| 389 |
+
assert jax.local_device_count() == 8, "TPUs in use, please check running processes"
|
| 390 |
+
|
| 391 |
# Set up wandb run
|
| 392 |
if jax.process_index() == 0:
|
| 393 |
wandb.init(
|