Spaces:
Sleeping
Sleeping
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from transformers_interpret import SequenceClassificationExplainer | |
| import torch | |
| import pandas as pd | |
| class EmotionDetection: | |
| """ | |
| Emotion Detection on text data. | |
| Attributes: | |
| tokenizer: An instance of Hugging Face Tokenizer | |
| model: An instance of Hugging Face Model | |
| explainer: An instance of SequenceClassificationExplainer from Transformers interpret | |
| """ | |
| def __init__(self): | |
| hub_location = 'cardiffnlp/twitter-roberta-base-emotion' | |
| self.tokenizer = AutoTokenizer.from_pretrained(hub_location) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(hub_location) | |
| self.explainer = SequenceClassificationExplainer(self.model, self.tokenizer) | |
| def justify(self, text): | |
| """ | |
| Get html annotation for displaying emotion justification over text. | |
| Parameters: | |
| text (str): The user input string to emotion justification | |
| Returns: | |
| html (hmtl): html object for plotting emotion prediction justification | |
| """ | |
| word_attributions = self.explainer(text) | |
| html = self.explainer.visualize("example.html") | |
| return html | |
| def classify(self, text): | |
| """ | |
| Recognize Emotion in text. | |
| Parameters: | |
| text (str): The user input string to perform emotion classification on | |
| Returns: | |
| predictions (str): The predicted probabilities for emotion classes | |
| """ | |
| tokens = self.tokenizer.encode_plus(text, add_special_tokens=False, return_tensors='pt') | |
| outputs = self.model(**tokens) | |
| probs = torch.nn.functional.softmax(outputs[0], dim=-1) | |
| probs = probs.mean(dim=0).detach().numpy() | |
| labels = list(self.model.config.id2label.values()) | |
| preds = pd.Series(probs, index=labels, name='Predicted Probability') | |
| return preds | |
| def run(self, text): | |
| """ | |
| Classify and Justify Emotion in text. | |
| Parameters: | |
| text (str): The user input string to perform emotion classification on | |
| Returns: | |
| predictions (str): The predicted probabilities for emotion classes | |
| html (hmtl): html object for plotting emotion prediction justification | |
| """ | |
| preds = self.classify(text) | |
| html = self.justify(text) | |
| return preds, html |