Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2021 The IDEA Authors. All rights reserved. | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import jsonlines | |
| import torch | |
| import pytorch_lightning as pl | |
| from transformers import AutoTokenizer, BertTokenizer | |
| from train_func import CustomDataset, CustomDataModule, CustomModel | |
| import argparse | |
| import os | |
| import gpustat | |
| if __name__ == '__main__': | |
| my_parser = argparse.ArgumentParser() | |
| my_parser.add_argument( | |
| "--model_path", default="./weights/Erlangshen-MegatronBert-1.3B-Similarity", type=str, required=False) | |
| my_parser.add_argument( | |
| "--model_name", default="IDEA-CCNL/Erlangshen-MegatronBert-1.3B-Similarity", type=str, required=False) | |
| my_parser.add_argument("--max_seq_length", default=64, type=int, required=False) | |
| my_parser.add_argument("--batch_size", default=32, type=int, required=False) | |
| my_parser.add_argument("--val_batch_size", default=64, type=int, required=False) | |
| my_parser.add_argument("--num_epochs", default=10, type=int, required=False) | |
| my_parser.add_argument("--learning_rate", default=4e-5, type=float, required=False) | |
| my_parser.add_argument("--warmup_proportion", default=0.2, type=int, required=False) | |
| my_parser.add_argument("--warmup_step", default=2, type=int, required=False) | |
| my_parser.add_argument("--num_labels", default=3, type=int, required=False) | |
| my_parser.add_argument("--cate_performance", default=False, type=bool, required=False) | |
| my_parser.add_argument("--use_original_pooler", default=True, type=bool, required=False) | |
| my_parser.add_argument("--model_output_path", default='./pl_model', type=str, required=False) | |
| my_parser.add_argument("--mode", type=str, choices=['Train', 'Test'], required=True) | |
| my_parser.add_argument("--predict_model_path", default='./pl_model/', type=str, required=False) | |
| my_parser.add_argument("--test_output_path", default='./submissions', type=str, required=False) | |
| my_parser.add_argument("--optimizer", default='AdamW', type=str, required=False) # ['Adam', 'AdamW'] | |
| # ['StepLR', 'CosineWarmup', 'CosineAnnealingLR'] | |
| my_parser.add_argument("--scheduler", default='CosineWarmup', type=str, required=False) | |
| my_parser.add_argument("--loss_function", default='LSCE_correction', type=str, | |
| required=False) # ['CE', 'Focal', 'LSCE_correction'] | |
| args = my_parser.parse_args() | |
| print(args) | |
| gpustat.print_gpustat() | |
| if 'Erlangshen' in args.model_name: | |
| tokenizer = BertTokenizer.from_pretrained(args.model_name, cache_dir=args.model_path) | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.model_path) | |
| seed = 1919 | |
| pl.seed_everything(seed) | |
| dm = CustomDataModule( | |
| args=args, | |
| tokenizer=tokenizer, | |
| ) | |
| metric_index = 2 | |
| checkpoint = pl.callbacks.ModelCheckpoint( | |
| save_top_k=1, | |
| verbose=True, | |
| monitor=['val_loss', 'val_acc', 'val_f1'][metric_index], | |
| mode=['min', 'max', 'max'][metric_index] | |
| ) | |
| lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="step") | |
| callbacks = [checkpoint, lr_monitor] | |
| logger = pl.loggers.TensorBoardLogger(save_dir=os.getcwd(), | |
| name='lightning_logs/' + args.model_name.split('/')[-1]), | |
| trainer = pl.Trainer( | |
| progress_bar_refresh_rate=50, | |
| logger=logger, | |
| gpus=-1 if torch.cuda.is_available() else None, | |
| amp_backend='native', | |
| amp_level='O2', | |
| precision=16, | |
| callbacks=callbacks, | |
| gradient_clip_val=1.0, | |
| max_epochs=args.num_epochs, | |
| # accelerator='ddp', | |
| # plugins='ddp_sharded', | |
| ) | |
| if args.mode == 'Train': | |
| print('Only Train') | |
| model = CustomModel( | |
| args=args, | |
| ) | |
| trainer.fit(model, dm) | |
| # Predict test, save results to json | |
| if args.mode == 'Test': | |
| print('Only Test') | |
| test_loader = torch.utils.data.DataLoader( | |
| CustomDataset('test.json', tokenizer, args.max_seq_length, 'test'), | |
| batch_size=args.val_batch_size, | |
| num_workers=4, | |
| shuffle=False, | |
| pin_memory=True, | |
| drop_last=False | |
| ) | |
| model = CustomModel(args=args).load_from_checkpoint(args.predict_model_path, args=args) | |
| predict_results = trainer.predict(model, test_loader, return_predictions=True) | |
| path = os.path.join( | |
| args.test_output_path, | |
| args.model_name.split('/')[-1].replace('-', '_')) | |
| file_path = os.path.join(path, 'qbqtc_predict.json') | |
| if not os.path.exists(path): | |
| os.makedirs(path) | |
| if os.path.exists(file_path): | |
| print('Json文件已存在, 将用本次结果替换') | |
| with jsonlines.open(file_path, 'w') as jsonf: | |
| for predict_res in predict_results: | |
| for i, p in zip(predict_res['id'], predict_res['logits']): | |
| jsonf.write({"id": i, "label": str(p)}) | |
| print('Json saved:', file_path) | |