File size: 8,360 Bytes
c99a3a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
"""文本基础规范化器,用于减轻针对水印的简单攻击。

这个实现不太可能是所有可能的Unicode标准中的所有漏洞的完整列表,
它代表了我们在撰写时的最佳努力。

这些规范化器可以作为独立的规范化器使用。它们可以被制作成符合HF分词器标准的规范化器,
但这将需要涉及tokenizers.NormalizedString的有限Rust接口。
"""

from collections import defaultdict
from functools import cache

import re
import unicodedata
import homoglyphs as hg


def normalization_strategy_lookup(strategy_name: str) -> object:
    if strategy_name == "unicode":
        return UnicodeSanitizer()
    elif strategy_name == "homoglyphs":
        return HomoglyphCanonizer()
    elif strategy_name == "truecase":
        return TrueCaser()


class HomoglyphCanonizer:
    """尝试检测同形字攻击并找到一致的标准形式。

    这个函数是在ISO分类级别上进行的。也可以在语言级别上进行(参见注释掉的代码)。
    """


    def __init__(self):
        self.homoglyphs = None

    def __call__(self, homoglyphed_str: str) -> str:
        # find canon:
        target_category, all_categories = self._categorize_text(homoglyphed_str)
        homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
        return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)

    def _categorize_text(self, text: str) -> dict:
        iso_categories = defaultdict(int)
        # self.iso_languages = defaultdict(int)

        for char in text:
            iso_categories[hg.Categories.detect(char)] += 1
            # for lang in hg.Languages.detect(char):
            #     self.iso_languages[lang] += 1
        target_category = max(iso_categories, key=iso_categories.get)
        all_categories = tuple(iso_categories)
        return target_category, all_categories

    @cache
    def _select_canon_category_and_load(
        self, target_category: str, all_categories: tuple[str]
    ) -> dict:
        homoglyph_table = hg.Homoglyphs(
            categories=(target_category, "COMMON")
        )  # 从文件中加载到此处的字母表

        source_alphabet = hg.Categories.get_alphabet(all_categories)
        restricted_table = homoglyph_table.get_restricted_table(
            source_alphabet, homoglyph_table.alphabet
        )  # 从文件中加载到此处的表
        return restricted_table

    def _sanitize_text(
        self, target_category: str, homoglyph_table: dict, homoglyphed_str: str
    ) -> str:
        sanitized_text = ""
        for char in homoglyphed_str:
            # langs = hg.Languages.detect(char)
            cat = hg.Categories.detect(char)
            if target_category in cat or "COMMON" in cat or len(cat) == 0:
                sanitized_text += char
            else:
                sanitized_text += list(homoglyph_table[char])[0]
        return sanitized_text


class UnicodeSanitizer:

    def __init__(self, ruleset="whitespaces"):
        if ruleset == "whitespaces":
            """Documentation:
            \u00A0: Non-breaking space
            \u1680: Ogham space mark
            \u180E: Mongolian vowel separator
            \u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
            \u200C\u200D: Zero-width non-joiner and zero-width joiner
            \u200E,\u200F: Left-to-right-mark, Right-to-left-mark
            \u2060: Word joiner
            \u2063: Invisible separator
            \u202F: Narrow non-breaking space
            \u205F: Medium mathematical space
            \u3000: Ideographic space
            \uFEFF: Zero-width non-breaking space
            \uFFA0: Halfwidth hangul filler
            \uFFF9\uFFFA\uFFFB: Interlinear annotation characters
            \uFE00-\uFE0F: Variation selectors
            \u202A-\u202F: Embedding characters
            \u3164: Korean hangul filler.

            """

            self.pattern = re.compile(
                r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
                r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
                r"\u202E\u202F]"
            )
        elif ruleset == "IDN.blacklist":
            """Documentation:
            [\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
                        set that are included in the IDN blacklist.
            \uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
                        These characters are not allowed in domain names.
            \uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
                        set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
                        and the second part is in the range U+DC00 to U+DFFF.
            \uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
                        to U+DFFF, and is optional.
            [\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
            """

            self.pattern = re.compile(
                r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
                r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
            )
        else:
            """Documentation:
            This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
            """
            self.pattern = re.compile(r"[^\x00-\x7F]+")

    def __call__(self, text: str) -> str:
        text = unicodedata.normalize("NFC", text)  # canon forms
        text = self.pattern.sub(" ", text)  # pattern match
        text = re.sub(" +", " ", text)  # collapse whitespaces
        text = "".join(
            c for c in text if unicodedata.category(c) != "Cc"
        )  # 删除所有剩余的不可打印字符
        return text


class TrueCaser:
    """真大小写还原,是一种将文本还原为其原始大小写形式的大小写规范化处理。

    这可以防御那些像 spOngBoB 那样随机大小写的攻击。

    这里使用了简单的词性标注器。
    """


    uppercase_pos = ["PROPN"]  # 应使用大写字母命名POS

    def __init__(self, backend="spacy"):
        if backend == "spacy":
            import spacy

            self.nlp = spacy.load("en_core_web_sm")
            self.normalize_fn = self._spacy_truecasing
        else:
            from nltk import pos_tag, word_tokenize  # noqa
            import nltk

            nltk.download("punkt")
            nltk.download("averaged_perceptron_tagger")
            nltk.download("universal_tagset")
            self.normalize_fn = self._nltk_truecasing

    def __call__(self, random_capitalized_string: str) -> str:
        truecased_str = self.normalize_fn(random_capitalized_string)
        return truecased_str

    def _spacy_truecasing(self, random_capitalized_string: str):
        doc = self.nlp(random_capitalized_string.lower())
        POS = self.uppercase_pos
        truecased_str = "".join(
            [
                w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws
                for w in doc
            ]
        )
        return truecased_str

    def _nltk_truecasing(self, random_capitalized_string: str):
        from nltk import pos_tag, word_tokenize
        import nltk

        nltk.download("punkt")
        nltk.download("averaged_perceptron_tagger")
        nltk.download("universal_tagset")
        POS = ["NNP", "NNPS"]

        tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
        truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
        return truecased_str