| """ | |
| BERT-Attack: | |
| ============================================================ | |
| (BERT-Attack: Adversarial Attack Against BERT Using BERT) | |
| .. warning:: | |
| This attack is super slow | |
| (see https://github.com/QData/TextAttack/issues/586) | |
| Consider using smaller values for "max_candidates". | |
| """ | |
| from textattack import Attack | |
| from textattack.constraints.overlap import MaxWordsPerturbed | |
| from textattack.constraints.pre_transformation import ( | |
| RepeatModification, | |
| StopwordModification, | |
| ) | |
| from textattack.constraints.semantics.sentence_encoders import UniversalSentenceEncoder | |
| from textattack.goal_functions import UntargetedClassification | |
| from textattack.search_methods import GreedyWordSwapWIR | |
| from textattack.transformations import WordSwapMaskedLM | |
| from .attack_recipe import AttackRecipe | |
| class BERTAttackLi2020(AttackRecipe): | |
| """Li, L.., Ma, R., Guo, Q., Xiangyang, X., Xipeng, Q. (2020). | |
| BERT-ATTACK: Adversarial Attack Against BERT Using BERT | |
| https://arxiv.org/abs/2004.09984 | |
| This is "attack mode" 1 from the paper, BAE-R, word replacement. | |
| """ | |
| def build(model_wrapper): | |
| # [from correspondence with the author] | |
| # Candidate size K is set to 48 for all data-sets. | |
| transformation = WordSwapMaskedLM(method="bert-attack", max_candidates=48) | |
| # | |
| # Don't modify the same word twice or stopwords. | |
| # | |
| constraints = [RepeatModification(), StopwordModification()] | |
| # "We only take ε percent of the most important words since we tend to keep | |
| # perturbations minimum." | |
| # | |
| # [from correspondence with the author] | |
| # "Word percentage allowed to change is set to 0.4 for most data-sets, this | |
| # parameter is trivial since most attacks only need a few changes. This | |
| # epsilon is only used to avoid too much queries on those very hard samples." | |
| constraints.append(MaxWordsPerturbed(max_percent=0.4)) | |
| # "As used in TextFooler (Jin et al., 2019), we also use Universal Sentence | |
| # Encoder (Cer et al., 2018) to measure the semantic consistency between the | |
| # adversarial sample and the original sequence. To balance between semantic | |
| # preservation and attack success rate, we set up a threshold of semantic | |
| # similarity score to filter the less similar examples." | |
| # | |
| # [from correspondence with author] | |
| # "Over the full texts, after generating all the adversarial samples, we filter | |
| # out low USE score samples. Thus the success rate is lower but the USE score | |
| # can be higher. (actually USE score is not a golden metric, so we simply | |
| # measure the USE score over the final texts for a comparison with TextFooler). | |
| # For datasets like IMDB, we set a higher threshold between 0.4-0.7; for | |
| # datasets like MNLI, we set threshold between 0-0.2." | |
| # | |
| # Since the threshold in the real world can't be determined from the training | |
| # data, the TextAttack implementation uses a fixed threshold - determined to | |
| # be 0.2 to be most fair. | |
| use_constraint = UniversalSentenceEncoder( | |
| threshold=0.2, | |
| metric="cosine", | |
| compare_against_original=True, | |
| window_size=None, | |
| ) | |
| constraints.append(use_constraint) | |
| # | |
| # Goal is untargeted classification. | |
| # | |
| goal_function = UntargetedClassification(model_wrapper) | |
| # | |
| # "We first select the words in the sequence which have a high significance | |
| # influence on the final output logit. Let S = [w0, ··· , wi ··· ] denote | |
| # the input sentence, and oy(S) denote the logit output by the target model | |
| # for correct label y, the importance score Iwi is defined as | |
| # Iwi = oy(S) − oy(S\wi), where S\wi = [w0, ··· , wi−1, [MASK], wi+1, ···] | |
| # is the sentence after replacing wi with [MASK]. Then we rank all the words | |
| # according to the ranking score Iwi in descending order to create word list | |
| # L." | |
| search_method = GreedyWordSwapWIR(wir_method="unk") | |
| return Attack(goal_function, constraints, transformation, search_method) | |