Spaces:
Runtime error
Runtime error
| import random | |
| from datetime import datetime | |
| import os | |
| from collections import OrderedDict | |
| from typing import TYPE_CHECKING, Union | |
| import torch | |
| import yaml | |
| from jobs.process.BaseProcess import BaseProcess | |
| if TYPE_CHECKING: | |
| from jobs import TrainJob, BaseJob, ExtensionJob | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| class BaseTrainProcess(BaseProcess): | |
| def __init__( | |
| self, | |
| process_id: int, | |
| job, | |
| config: OrderedDict | |
| ): | |
| super().__init__(process_id, job, config) | |
| self.process_id: int | |
| self.config: OrderedDict | |
| self.writer: 'SummaryWriter' | |
| self.job: Union['TrainJob', 'BaseJob', 'ExtensionJob'] | |
| self.progress_bar: 'tqdm' = None | |
| self.training_seed = self.get_conf('training_seed', self.job.training_seed if hasattr(self.job, 'training_seed') else None) | |
| # if training seed is set, use it | |
| if self.training_seed is not None: | |
| torch.manual_seed(self.training_seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(self.training_seed) | |
| random.seed(self.training_seed) | |
| self.progress_bar = None | |
| self.writer = None | |
| self.training_folder = self.get_conf('training_folder', | |
| self.job.training_folder if hasattr(self.job, 'training_folder') else None) | |
| self.save_root = os.path.join(self.training_folder, self.name) | |
| self.step = 0 | |
| self.first_step = 0 | |
| self.log_dir = self.get_conf('log_dir', self.job.log_dir if hasattr(self.job, 'log_dir') else None) | |
| self.setup_tensorboard() | |
| self.save_training_config() | |
| def run(self): | |
| super().run() | |
| # implement in child class | |
| # be sure to call super().run() first | |
| pass | |
| # def print(self, message, **kwargs): | |
| def print(self, *args): | |
| if self.progress_bar is not None: | |
| self.progress_bar.write(' '.join(map(str, args))) | |
| self.progress_bar.update() | |
| else: | |
| print(*args) | |
| def setup_tensorboard(self): | |
| if self.log_dir: | |
| from torch.utils.tensorboard import SummaryWriter | |
| now = datetime.now() | |
| time_str = now.strftime('%Y%m%d-%H%M%S') | |
| summary_name = f"{self.name}_{time_str}" | |
| summary_dir = os.path.join(self.log_dir, summary_name) | |
| self.writer = SummaryWriter(summary_dir) | |
| def save_training_config(self): | |
| os.makedirs(self.save_root, exist_ok=True) | |
| save_dif = os.path.join(self.save_root, f'config.yaml') | |
| with open(save_dif, 'w') as f: | |
| yaml.dump(self.job.raw_config, f) | |