Spaces:
Runtime error
Runtime error
| """ | |
| Perplexity Metric: | |
| ------------------------------------------------------- | |
| Class for calculating perplexity from AttackResults | |
| """ | |
| import torch | |
| from textattack.attack_results import FailedAttackResult, SkippedAttackResult | |
| from textattack.metrics import Metric | |
| import textattack.shared.utils | |
| class Perplexity(Metric): | |
| def __init__(self, model_name="gpt2"): | |
| self.all_metrics = {} | |
| self.original_candidates = [] | |
| self.successful_candidates = [] | |
| if model_name == "gpt2": | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| self.ppl_model = GPT2LMHeadModel.from_pretrained("gpt2") | |
| self.ppl_model.to(textattack.shared.utils.device) | |
| self.ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| self.ppl_model.eval() | |
| self.max_length = self.ppl_model.config.n_positions | |
| else: | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer | |
| self.ppl_model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| self.ppl_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.ppl_model.to(textattack.shared.utils.device) | |
| self.ppl_model.eval() | |
| self.max_length = self.ppl_model.config.max_position_embeddings | |
| self.stride = 512 | |
| def calculate(self, results): | |
| """Calculates average Perplexity on all successfull attacks using a | |
| pre-trained small GPT-2 model. | |
| Args: | |
| results (``AttackResult`` objects): | |
| Attack results for each instance in dataset | |
| Example:: | |
| >> import textattack | |
| >> import transformers | |
| >> model = transformers.AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| >> tokenizer = transformers.AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") | |
| >> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
| >> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) | |
| >> dataset = textattack.datasets.HuggingFaceDataset("glue", "sst2", split="train") | |
| >> attack_args = textattack.AttackArgs( | |
| num_examples=1, | |
| log_to_csv="log.csv", | |
| checkpoint_interval=5, | |
| checkpoint_dir="checkpoints", | |
| disable_stdout=True | |
| ) | |
| >> attacker = textattack.Attacker(attack, dataset, attack_args) | |
| >> results = attacker.attack_dataset() | |
| >> ppl = textattack.metrics.quality_metrics.Perplexity().calculate(results) | |
| """ | |
| self.results = results | |
| self.original_candidates_ppl = [] | |
| self.successful_candidates_ppl = [] | |
| for i, result in enumerate(self.results): | |
| if isinstance(result, FailedAttackResult): | |
| continue | |
| elif isinstance(result, SkippedAttackResult): | |
| continue | |
| else: | |
| self.original_candidates.append( | |
| result.original_result.attacked_text.text.lower() | |
| ) | |
| self.successful_candidates.append( | |
| result.perturbed_result.attacked_text.text.lower() | |
| ) | |
| ppl_orig = self.calc_ppl(self.original_candidates) | |
| ppl_attack = self.calc_ppl(self.successful_candidates) | |
| self.all_metrics["avg_original_perplexity"] = round(ppl_orig, 2) | |
| self.all_metrics["avg_attack_perplexity"] = round(ppl_attack, 2) | |
| return self.all_metrics | |
| def calc_ppl(self, texts): | |
| with torch.no_grad(): | |
| text = " ".join(texts) | |
| eval_loss = [] | |
| input_ids = torch.tensor( | |
| self.ppl_tokenizer.encode(text, add_special_tokens=True) | |
| ).unsqueeze(0) | |
| # Strided perplexity calculation from huggingface.co/transformers/perplexity.html | |
| for i in range(0, input_ids.size(1), self.stride): | |
| begin_loc = max(i + self.stride - self.max_length, 0) | |
| end_loc = min(i + self.stride, input_ids.size(1)) | |
| trg_len = end_loc - i | |
| input_ids_t = input_ids[:, begin_loc:end_loc].to( | |
| textattack.shared.utils.device | |
| ) | |
| target_ids = input_ids_t.clone() | |
| target_ids[:, :-trg_len] = -100 | |
| outputs = self.ppl_model(input_ids_t, labels=target_ids) | |
| log_likelihood = outputs[0] * trg_len | |
| eval_loss.append(log_likelihood) | |
| return torch.exp(torch.stack(eval_loss).sum() / end_loc).item() | |