Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.utils.data._utils.collate import default_collate | |
| from dataclasses import dataclass | |
| from typing import Dict, List | |
| from .base import ( | |
| _CONFIG_MODEL_TYPE, | |
| _CONFIG_TOKENIZER_TYPE) | |
| from fengshen.models.roformer import RoFormerForSequenceClassification | |
| from fengshen.models.longformer import LongformerForSequenceClassification | |
| from fengshen.models.zen1 import ZenForSequenceClassification | |
| from transformers import ( | |
| BertConfig, | |
| AutoModelForSequenceClassification, | |
| AutoTokenizer, | |
| ) | |
| from transformers.models.auto.tokenization_auto import get_tokenizer_config | |
| from transformers.pipelines.base import PipelineException, GenericTensor | |
| from transformers import TextClassificationPipeline as HuggingfacePipe | |
| import pytorch_lightning as pl | |
| from fengshen.data.universal_datamodule import UniversalDataModule | |
| from fengshen.utils.universal_checkpoint import UniversalCheckpoint | |
| from fengshen.models.model_utils import add_module_args | |
| import torchmetrics | |
| _model_dict = { | |
| 'fengshen-roformer': RoFormerForSequenceClassification, | |
| # 'fengshen-megatron_t5': T5EncoderModel, TODO 实现T5EncoderForSequenceClassification | |
| 'fengshen-longformer': LongformerForSequenceClassification, | |
| 'fengshen-zen1': ZenForSequenceClassification, | |
| 'huggingface-auto': AutoModelForSequenceClassification, | |
| } | |
| _tokenizer_dict = {} | |
| _ATTR_PREPARE_INPUT = '_prepare_inputs_for_sequence_classification' | |
| class _taskModel(pl.LightningModule): | |
| def add_model_specific_args(parent_args): | |
| _ = parent_args.add_argument_group('text classification task model') | |
| return parent_args | |
| def __init__(self, args, model): | |
| super().__init__() | |
| self.model = model | |
| self.acc_metrics = torchmetrics.Accuracy() | |
| self.save_hyperparameters(args) | |
| def setup(self, stage) -> None: | |
| if stage == 'fit': | |
| train_loader = self.trainer._data_connector._train_dataloader_source.dataloader() | |
| # Calculate total steps | |
| if self.trainer.max_epochs > 0: | |
| world_size = self.trainer.world_size | |
| tb_size = self.hparams.train_batchsize * max(1, world_size) | |
| ab_size = self.trainer.accumulate_grad_batches | |
| self.total_steps = (len(train_loader.dataset) * | |
| self.trainer.max_epochs // tb_size) // ab_size | |
| else: | |
| self.total_steps = self.trainer.max_steps // self.trainer.accumulate_grad_batches | |
| print('Total steps: {}' .format(self.total_steps)) | |
| def training_step(self, batch, batch_idx): | |
| outputs = self.model(**batch) | |
| loss, _ = outputs[0], outputs[1] | |
| self.log('train_loss', loss) | |
| return loss | |
| def comput_metrix(self, logits, labels): | |
| y_pred = torch.argmax(logits, dim=-1) | |
| y_pred = y_pred.view(size=(-1,)) | |
| y_true = labels.view(size=(-1,)).long() | |
| acc = self.acc_metrics(y_pred.long(), y_true.long()) | |
| return acc | |
| def validation_step(self, batch, batch_idx): | |
| outputs = self.model(**batch) | |
| loss, logits = outputs[0], outputs[1] | |
| acc = self.comput_metrix(logits, batch['labels']) | |
| self.log('val_loss', loss) | |
| self.log('val_acc', acc) | |
| def predict_step(self, batch, batch_idx): | |
| output = self.model(**batch) | |
| return output.logits | |
| def configure_optimizers(self): | |
| from fengshen.models.model_utils import configure_optimizers | |
| return configure_optimizers(self) | |
| class _Collator: | |
| tokenizer = None | |
| texta_name = 'sentence' | |
| textb_name = 'sentence2' | |
| label_name = 'label' | |
| max_length = 512 | |
| model_type = 'huggingface-auto' | |
| def __call__(self, samples): | |
| sample_list = [] | |
| for item in samples: | |
| if self.textb_name in item and item[self.textb_name] != '': | |
| if self.model_type != 'fengshen-roformer': | |
| encode_dict = self.tokenizer.encode_plus( | |
| [item[self.texta_name], item[self.textb_name]], | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation='longest_first') | |
| else: | |
| encode_dict = self.tokenizer.encode_plus( | |
| [item[self.texta_name]+'[SEP]'+item[self.textb_name]], | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation='longest_first') | |
| else: | |
| encode_dict = self.tokenizer.encode_plus( | |
| item[self.texta_name], | |
| max_length=self.max_length, | |
| padding='max_length', | |
| truncation='longest_first') | |
| sample = {} | |
| for k, v in encode_dict.items(): | |
| sample[k] = torch.tensor(v) | |
| if self.label_name in item: | |
| sample['labels'] = torch.tensor(item[self.label_name]).long() | |
| sample_list.append(sample) | |
| return default_collate(sample_list) | |
| class TextClassificationPipeline(HuggingfacePipe): | |
| def add_pipeline_specific_args(parent_args): | |
| parser = parent_args.add_argument_group('SequenceClassificationPipeline') | |
| parser.add_argument('--texta_name', default='sentence', type=str) | |
| parser.add_argument('--textb_name', default='sentence2', type=str) | |
| parser.add_argument('--label_name', default='label', type=str) | |
| parser.add_argument('--max_length', default=512, type=int) | |
| parser.add_argument('--device', default=-1, type=int) | |
| parser = _taskModel.add_model_specific_args(parent_args) | |
| parser = UniversalDataModule.add_data_specific_args(parent_args) | |
| parser = UniversalCheckpoint.add_argparse_args(parent_args) | |
| parser = pl.Trainer.add_argparse_args(parent_args) | |
| parser = add_module_args(parent_args) | |
| return parent_args | |
| def __init__(self, | |
| model: str = None, | |
| args=None, | |
| **kwargs): | |
| self.args = args | |
| self.model_name = model | |
| self.model_type = 'huggingface-auto' | |
| # 用BertConfig做兼容,我只需要读里面的fengshen_model_type,所以这里用啥Config都可以 | |
| config = BertConfig.from_pretrained(model) | |
| if hasattr(config, _CONFIG_MODEL_TYPE): | |
| self.model_type = config.fengshen_model_type | |
| if self.model_type not in _model_dict: | |
| raise PipelineException(self.model_name, ' not in model type dict') | |
| # 加载模型,并且使用模型的config | |
| self.model = _model_dict[self.model_type].from_pretrained(model) | |
| self.config = self.model.config | |
| # 加载分词 | |
| tokenizer_config = get_tokenizer_config(model, **kwargs) | |
| self.tokenizer = None | |
| if hasattr(tokenizer_config, _CONFIG_TOKENIZER_TYPE): | |
| if tokenizer_config._CONFIG_TOKENIZER_TYPE in _tokenizer_dict: | |
| self.tokenizer = _tokenizer_dict[tokenizer_config._CONFIG_TOKENIZER_TYPE].from_pretrained( | |
| model) | |
| if self.tokenizer is None: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model) | |
| # 加载数据处理模块 | |
| c = _Collator() | |
| c.tokenizer = self.tokenizer | |
| c.model_type = self.model_type | |
| if args is not None: | |
| c.texta_name = self.args.texta_name | |
| c.textb_name = self.args.textb_name | |
| c.label_name = self.args.label_name | |
| c.max_length = self.args.max_length | |
| self.collator = c | |
| device = -1 if args is None else args.device | |
| super().__init__(model=self.model, | |
| tokenizer=self.tokenizer, | |
| framework='pt', | |
| device=device, | |
| **kwargs) | |
| def train(self, | |
| datasets: Dict): | |
| """ | |
| Args: | |
| datasets is a dict like | |
| { | |
| test: Dataset() | |
| validation: Dataset() | |
| train: Dataset() | |
| } | |
| """ | |
| checkpoint_callback = UniversalCheckpoint(self.args) | |
| trainer = pl.Trainer.from_argparse_args(self.args, | |
| callbacks=[checkpoint_callback] | |
| ) | |
| data_model = UniversalDataModule( | |
| datasets=datasets, | |
| tokenizer=self.tokenizer, | |
| collate_fn=self.collator, | |
| args=self.args) | |
| model = _taskModel(self.args, self.model) | |
| trainer.fit(model, data_model) | |
| return | |
| def preprocess(self, inputs, **tokenizer_kwargs) -> Dict[str, GenericTensor]: | |
| # 如果模型有自定义的接口,用模型的口 | |
| if hasattr(self.model, _ATTR_PREPARE_INPUT): | |
| return getattr(self.model, _ATTR_PREPARE_INPUT)(inputs, self.tokenizer, **tokenizer_kwargs) | |
| samples = [] | |
| if isinstance(inputs, str): | |
| samples.append({self.collator.texta_name: inputs}) | |
| else: | |
| # 在__call__里面已经保证了input的类型,所以这里直接else就行 | |
| for i in inputs: | |
| samples.append({self.collator.texta_name}) | |
| return self.collator(samples) | |
| Pipeline = TextClassificationPipeline | |