Spaces:
Running
Running
| """Basic model. Predicts tags for every token""" | |
| from typing import Dict, Optional, List, Any | |
| import numpy | |
| import torch | |
| import torch.nn.functional as F | |
| from allennlp.data import Vocabulary | |
| from allennlp.models.model import Model | |
| from allennlp.modules import TimeDistributed, TextFieldEmbedder | |
| from allennlp.nn import InitializerApplicator, RegularizerApplicator | |
| from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits | |
| from allennlp.training.metrics import CategoricalAccuracy | |
| from overrides import overrides | |
| from torch.nn.modules.linear import Linear | |
| class Seq2Labels(Model): | |
| """ | |
| This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then | |
| predicts a tag (or couple tags) for each token in the sequence. | |
| Parameters | |
| ---------- | |
| vocab : ``Vocabulary``, required | |
| A Vocabulary, required in order to compute sizes for input/output projections. | |
| text_field_embedder : ``TextFieldEmbedder``, required | |
| Used to embed the ``tokens`` ``TextField`` we get as input to the model. | |
| encoder : ``Seq2SeqEncoder`` | |
| The encoder (with its own internal stacking) that we will use in between embedding tokens | |
| and predicting output tags. | |
| calculate_span_f1 : ``bool``, optional (default=``None``) | |
| Calculate span-level F1 metrics during training. If this is ``True``, then | |
| ``label_encoding`` is required. If ``None`` and | |
| label_encoding is specified, this is set to ``True``. | |
| If ``None`` and label_encoding is not specified, it defaults | |
| to ``False``. | |
| label_encoding : ``str``, optional (default=``None``) | |
| Label encoding to use when calculating span f1. | |
| Valid options are "BIO", "BIOUL", "IOB1", "BMES". | |
| Required if ``calculate_span_f1`` is true. | |
| labels_namespace : ``str``, optional (default=``labels``) | |
| This is needed to compute the SpanBasedF1Measure metric, if desired. | |
| Unless you did something unusual, the default value should be what you want. | |
| verbose_metrics : ``bool``, optional (default = False) | |
| If true, metrics will be returned per label class in addition | |
| to the overall statistics. | |
| initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) | |
| Used to initialize the model parameters. | |
| regularizer : ``RegularizerApplicator``, optional (default=``None``) | |
| If provided, will be used to calculate the regularization penalty during training. | |
| """ | |
| def __init__(self, vocab: Vocabulary, | |
| text_field_embedder: TextFieldEmbedder, | |
| predictor_dropout=0.0, | |
| labels_namespace: str = "labels", | |
| detect_namespace: str = "d_tags", | |
| verbose_metrics: bool = False, | |
| label_smoothing: float = 0.0, | |
| confidence: float = 0.0, | |
| del_confidence: float = 0.0, | |
| initializer: InitializerApplicator = InitializerApplicator(), | |
| regularizer: Optional[RegularizerApplicator] = None) -> None: | |
| super(Seq2Labels, self).__init__(vocab, regularizer) | |
| self.label_namespaces = [labels_namespace, | |
| detect_namespace] | |
| self.text_field_embedder = text_field_embedder | |
| self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace) | |
| self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace) | |
| self.label_smoothing = label_smoothing | |
| self.confidence = confidence | |
| self.del_conf = del_confidence | |
| self.incorr_index = self.vocab.get_token_index("INCORRECT", | |
| namespace=detect_namespace) | |
| self._verbose_metrics = verbose_metrics | |
| self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout)) | |
| self.tag_labels_projection_layer = TimeDistributed( | |
| Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes)) | |
| self.tag_detect_projection_layer = TimeDistributed( | |
| Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes)) | |
| self.metrics = {"accuracy": CategoricalAccuracy()} | |
| initializer(self) | |
| def forward(self, # type: ignore | |
| tokens: Dict[str, torch.LongTensor], | |
| labels: torch.LongTensor = None, | |
| d_tags: torch.LongTensor = None, | |
| metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: | |
| # pylint: disable=arguments-differ | |
| """ | |
| Parameters | |
| ---------- | |
| tokens : Dict[str, torch.LongTensor], required | |
| The output of ``TextField.as_array()``, which should typically be passed directly to a | |
| ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` | |
| tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": | |
| Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used | |
| for the ``TokenIndexers`` when you created the ``TextField`` representing your | |
| sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, | |
| which knows how to combine different word representations into a single vector per | |
| token in your input. | |
| labels : torch.LongTensor, optional (default = None) | |
| A torch tensor representing the sequence of integer gold class labels of shape | |
| ``(batch_size, num_tokens)``. | |
| d_tags : torch.LongTensor, optional (default = None) | |
| A torch tensor representing the sequence of integer gold class labels of shape | |
| ``(batch_size, num_tokens)``. | |
| metadata : ``List[Dict[str, Any]]``, optional, (default = None) | |
| metadata containing the original words in the sentence to be tagged under a 'words' key. | |
| Returns | |
| ------- | |
| An output dictionary consisting of: | |
| logits : torch.FloatTensor | |
| A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing | |
| unnormalised log probabilities of the tag classes. | |
| class_probabilities : torch.FloatTensor | |
| A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing | |
| a distribution of the tag classes per word. | |
| loss : torch.FloatTensor, optional | |
| A scalar loss to be optimised. | |
| """ | |
| encoded_text = self.text_field_embedder(tokens) | |
| batch_size, sequence_length, _ = encoded_text.size() | |
| mask = get_text_field_mask(tokens) | |
| logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text)) | |
| logits_d = self.tag_detect_projection_layer(encoded_text) | |
| class_probabilities_labels = F.softmax(logits_labels, dim=-1).view( | |
| [batch_size, sequence_length, self.num_labels_classes]) | |
| class_probabilities_d = F.softmax(logits_d, dim=-1).view( | |
| [batch_size, sequence_length, self.num_detect_classes]) | |
| error_probs = class_probabilities_d[:, :, self.incorr_index] * mask | |
| incorr_prob = torch.max(error_probs, dim=-1)[0] | |
| probability_change = [self.confidence, self.del_conf] + [0] * (self.num_labels_classes - 2) | |
| class_probabilities_labels += torch.FloatTensor(probability_change).repeat( | |
| (batch_size, sequence_length, 1)).to(class_probabilities_labels.device) | |
| output_dict = {"logits_labels": logits_labels, | |
| "logits_d_tags": logits_d, | |
| "class_probabilities_labels": class_probabilities_labels, | |
| "class_probabilities_d_tags": class_probabilities_d, | |
| "max_error_probability": incorr_prob} | |
| if labels is not None and d_tags is not None: | |
| loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask, | |
| label_smoothing=self.label_smoothing) | |
| loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask) | |
| for metric in self.metrics.values(): | |
| metric(logits_labels, labels, mask.float()) | |
| metric(logits_d, d_tags, mask.float()) | |
| output_dict["loss"] = loss_labels + loss_d | |
| if metadata is not None: | |
| output_dict["words"] = [x["words"] for x in metadata] | |
| return output_dict | |
| def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """ | |
| Does a simple position-wise argmax over each token, converts indices to string labels, and | |
| adds a ``"tags"`` key to the dictionary with the result. | |
| """ | |
| for label_namespace in self.label_namespaces: | |
| all_predictions = output_dict[f'class_probabilities_{label_namespace}'] | |
| all_predictions = all_predictions.cpu().data.numpy() | |
| if all_predictions.ndim == 3: | |
| predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] | |
| else: | |
| predictions_list = [all_predictions] | |
| all_tags = [] | |
| for predictions in predictions_list: | |
| argmax_indices = numpy.argmax(predictions, axis=-1) | |
| tags = [self.vocab.get_token_from_index(x, namespace=label_namespace) | |
| for x in argmax_indices] | |
| all_tags.append(tags) | |
| output_dict[f'{label_namespace}'] = all_tags | |
| return output_dict | |
| def get_metrics(self, reset: bool = False) -> Dict[str, float]: | |
| metrics_to_return = {metric_name: metric.get_metric(reset) for | |
| metric_name, metric in self.metrics.items()} | |
| return metrics_to_return | |