Spaces:
Runtime error
Runtime error
| import os | |
| from dataset.util import load_dataset | |
| from dataset.vocab import Vocab | |
| if __name__ == '__main__': | |
| import argparse | |
| description = ''' | |
| train.py: | |
| Usage: python train.py --model tfmwtr --start-epoch n --data_path ./data --dataset binhvq | |
| Params: | |
| --start-epoch n | |
| n = 0: training from beginning | |
| n > 0: continue training from the nth epoch | |
| --model | |
| tfmwtr - Transformer with Tokenization Repair | |
| --data_path: default to ./data | |
| --dataset: default to 'binhvq' | |
| ''' | |
| parser = argparse.ArgumentParser(description=description) | |
| parser.add_argument('--model', type=str, default='tfmwtr') | |
| parser.add_argument('--start_epoch', type=int, default=0) | |
| parser.add_argument('--data_path', type=str, default='./data') | |
| parser.add_argument('--dataset', type=str, default='binhvq') | |
| args = parser.parse_args() | |
| dataset_path = os.path.join(args.data_path, f'{args.dataset}') | |
| vocab_path = os.path.join(dataset_path, f'{args.dataset}.vocab.pkl') | |
| vocab = Vocab() | |
| vocab.load_vocab_dict(vocab_path) | |
| checkpoint_dir = os.path.join(args.data_path, f'checkpoints/{args.model}') | |
| incorrect_file = f'{args.dataset}.train.noise' | |
| correct_file = f'{args.dataset}.train' | |
| length_file = f'{args.dataset}.length.train' | |
| valid_incorrect_file = f'{args.dataset}.valid.noise' | |
| valid_correct_file = f'{args.dataset}.valid' | |
| valid_length_file = f'{args.dataset}.length.valid' | |
| valid_data = load_dataset(base_path=dataset_path, corr_file=valid_correct_file, incorr_file=valid_incorrect_file, | |
| length_file = valid_length_file) | |
| from dataset.autocorrect_dataset import SpellCorrectDataset | |
| from models.trainer import Trainer | |
| from models.model import ModelWrapper | |
| valid_dataset = SpellCorrectDataset(dataset=valid_data) | |
| model_wrapper = ModelWrapper(args.model, vocab) | |
| trainer = Trainer(model_wrapper, dataset_path, args.dataset, valid_dataset) | |
| trainer.load_checkpoint(checkpoint_dir, args.dataset, args.start_epoch) | |
| trainer.train() | |