Spaces:
Runtime error
Runtime error
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| """ AutoResume callback. | |
| A transformer trainer callback for interfacing with ADLR's AutoResume SDK. | |
| Copyright 2024 NVIDIA CORPORATION. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import transformers | |
| from transformers.utils import logging | |
| logger = logging.get_logger("transformers") | |
| def rank_print(*s): | |
| if not torch.distributed.is_initialized(): | |
| rank = 0 | |
| else: | |
| rank = torch.distributed.get_rank() | |
| print(rank, *s) | |
| sys.path.append(os.environ.get("SUBMIT_SCRIPTS", ".")) | |
| try: | |
| logger.info("Importing AutoResume lib...") | |
| from userlib.auto_resume import AutoResume | |
| AutoResume.init() | |
| logger.info("Found AutoResume SDK!") | |
| except: | |
| logger.warn("Did not find AutoResume SDK!") | |
| AutoResume = None | |
| class AutoResumeCallback(transformers.TrainerCallback): | |
| """ | |
| A [`TrainerCallback`] that handles autoresume. | |
| Args: | |
| interval: interval (in number of iterations) between checks as to | |
| whether to suspend. | |
| """ | |
| def __init__(self, interval: int = 50): | |
| self.interval = interval | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if state.global_step % self.interval == 0: | |
| rank_print("AutoResumeHook: Checking whether to suspend...") | |
| # Check whether to suspend the job. | |
| should_preempt = AutoResume is not None and AutoResume.termination_requested() | |
| if should_preempt: | |
| if state.is_local_process_zero: | |
| logger.warn(f"AutoResumeHook: Request resume...") | |
| if AutoResume is not None: | |
| AutoResume.request_resume() | |
| control.should_training_stop = True | |
| control.should_save = True | |