Spaces:
Runtime error
Runtime error
| from transformers import Trainer | |
| import torch.nn.functional as F | |
| from typing import Optional | |
| import os | |
| import torch | |
| from transformers.utils import WEIGHTS_NAME | |
| import json | |
| class VideoBaseTrainer(Trainer): | |
| def _save(self, output_dir: Optional[str] = None, state_dict=None): | |
| output_dir = output_dir if output_dir is not None else self.args.output_dir | |
| os.makedirs(output_dir, exist_ok=True) | |
| if state_dict is None: | |
| state_dict = self.model.state_dict() | |
| # get model config | |
| model_config = self.model.config.to_dict() | |
| # add more information | |
| model_config['model'] = self.model.__class__.__name__ | |
| with open(os.path.join(output_dir, "config.json"), "w") as file: | |
| json.dump(self.model.config.to_dict(), file) | |
| torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | |
| torch.save(self.args, os.path.join(output_dir, "training_args.bin")) | |