LLMwatermark / normalizers.py
jianuo
first upload
c99a3a6
"""文本基础规范化器,用于减轻针对水印的简单攻击。
这个实现不太可能是所有可能的Unicode标准中的所有漏洞的完整列表,
它代表了我们在撰写时的最佳努力。
这些规范化器可以作为独立的规范化器使用。它们可以被制作成符合HF分词器标准的规范化器,
但这将需要涉及tokenizers.NormalizedString的有限Rust接口。
"""
from collections import defaultdict
from functools import cache
import re
import unicodedata
import homoglyphs as hg
def normalization_strategy_lookup(strategy_name: str) -> object:
if strategy_name == "unicode":
return UnicodeSanitizer()
elif strategy_name == "homoglyphs":
return HomoglyphCanonizer()
elif strategy_name == "truecase":
return TrueCaser()
class HomoglyphCanonizer:
"""尝试检测同形字攻击并找到一致的标准形式。
这个函数是在ISO分类级别上进行的。也可以在语言级别上进行(参见注释掉的代码)。
"""
def __init__(self):
self.homoglyphs = None
def __call__(self, homoglyphed_str: str) -> str:
# find canon:
target_category, all_categories = self._categorize_text(homoglyphed_str)
homoglyph_table = self._select_canon_category_and_load(target_category, all_categories)
return self._sanitize_text(target_category, homoglyph_table, homoglyphed_str)
def _categorize_text(self, text: str) -> dict:
iso_categories = defaultdict(int)
# self.iso_languages = defaultdict(int)
for char in text:
iso_categories[hg.Categories.detect(char)] += 1
# for lang in hg.Languages.detect(char):
# self.iso_languages[lang] += 1
target_category = max(iso_categories, key=iso_categories.get)
all_categories = tuple(iso_categories)
return target_category, all_categories
@cache
def _select_canon_category_and_load(
self, target_category: str, all_categories: tuple[str]
) -> dict:
homoglyph_table = hg.Homoglyphs(
categories=(target_category, "COMMON")
) # 从文件中加载到此处的字母表
source_alphabet = hg.Categories.get_alphabet(all_categories)
restricted_table = homoglyph_table.get_restricted_table(
source_alphabet, homoglyph_table.alphabet
) # 从文件中加载到此处的表
return restricted_table
def _sanitize_text(
self, target_category: str, homoglyph_table: dict, homoglyphed_str: str
) -> str:
sanitized_text = ""
for char in homoglyphed_str:
# langs = hg.Languages.detect(char)
cat = hg.Categories.detect(char)
if target_category in cat or "COMMON" in cat or len(cat) == 0:
sanitized_text += char
else:
sanitized_text += list(homoglyph_table[char])[0]
return sanitized_text
class UnicodeSanitizer:
def __init__(self, ruleset="whitespaces"):
if ruleset == "whitespaces":
"""Documentation:
\u00A0: Non-breaking space
\u1680: Ogham space mark
\u180E: Mongolian vowel separator
\u2000-\u200B: Various space characters, including en space, em space, thin space, hair space, zero-width space, and zero-width non-joiner
\u200C\u200D: Zero-width non-joiner and zero-width joiner
\u200E,\u200F: Left-to-right-mark, Right-to-left-mark
\u2060: Word joiner
\u2063: Invisible separator
\u202F: Narrow non-breaking space
\u205F: Medium mathematical space
\u3000: Ideographic space
\uFEFF: Zero-width non-breaking space
\uFFA0: Halfwidth hangul filler
\uFFF9\uFFFA\uFFFB: Interlinear annotation characters
\uFE00-\uFE0F: Variation selectors
\u202A-\u202F: Embedding characters
\u3164: Korean hangul filler.
"""
self.pattern = re.compile(
r"[\u00A0\u1680\u180E\u2000-\u200B\u200C\u200D\u200E\u200F\u2060\u2063\u202F\u205F\u3000\uFEFF\uFFA0\uFFF9\uFFFA\uFFFB"
r"\uFE00\uFE01\uFE02\uFE03\uFE04\uFE05\uFE06\uFE07\uFE08\uFE09\uFE0A\uFE0B\uFE0C\uFE0D\uFE0E\uFE0F\u3164\u202A\u202B\u202C\u202D"
r"\u202E\u202F]"
)
elif ruleset == "IDN.blacklist":
"""Documentation:
[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF]: Matches any whitespace characters in the Unicode character
set that are included in the IDN blacklist.
\uFFF9-\uFFFB: Matches characters that are not defined in Unicode but are used as language tags in various legacy encodings.
These characters are not allowed in domain names.
\uD800-\uDB7F: Matches the first part of a surrogate pair. Surrogate pairs are used to represent characters in the Unicode character
set that cannot be represented by a single 16-bit value. The first part of a surrogate pair is in the range U+D800 to U+DBFF,
and the second part is in the range U+DC00 to U+DFFF.
\uDB80-\uDBFF][\uDC00-\uDFFF]?: Matches the second part of a surrogate pair. The second part of a surrogate pair is in the range U+DC00
to U+DFFF, and is optional.
[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]: Matches certain invalid UTF-16 sequences which should not appear in IDNs.
"""
self.pattern = re.compile(
r"[\u00A0\u1680\u180E\u2000-\u200B\u202F\u205F\u2060\u2063\uFEFF\uFFF9-\uFFFB\uD800-\uDB7F\uDB80-\uDBFF]"
r"[\uDC00-\uDFFF]?|[\uDB40\uDC20-\uDB40\uDC7F][\uDC00-\uDFFF]"
)
else:
"""Documentation:
This is a simple restriction to "no-unicode", using only ascii characters. Control characters are included.
"""
self.pattern = re.compile(r"[^\x00-\x7F]+")
def __call__(self, text: str) -> str:
text = unicodedata.normalize("NFC", text) # canon forms
text = self.pattern.sub(" ", text) # pattern match
text = re.sub(" +", " ", text) # collapse whitespaces
text = "".join(
c for c in text if unicodedata.category(c) != "Cc"
) # 删除所有剩余的不可打印字符
return text
class TrueCaser:
"""真大小写还原,是一种将文本还原为其原始大小写形式的大小写规范化处理。
这可以防御那些像 spOngBoB 那样随机大小写的攻击。
这里使用了简单的词性标注器。
"""
uppercase_pos = ["PROPN"] # 应使用大写字母命名POS
def __init__(self, backend="spacy"):
if backend == "spacy":
import spacy
self.nlp = spacy.load("en_core_web_sm")
self.normalize_fn = self._spacy_truecasing
else:
from nltk import pos_tag, word_tokenize # noqa
import nltk
nltk.download("punkt")
nltk.download("averaged_perceptron_tagger")
nltk.download("universal_tagset")
self.normalize_fn = self._nltk_truecasing
def __call__(self, random_capitalized_string: str) -> str:
truecased_str = self.normalize_fn(random_capitalized_string)
return truecased_str
def _spacy_truecasing(self, random_capitalized_string: str):
doc = self.nlp(random_capitalized_string.lower())
POS = self.uppercase_pos
truecased_str = "".join(
[
w.text_with_ws.capitalize() if w.pos_ in POS or w.is_sent_start else w.text_with_ws
for w in doc
]
)
return truecased_str
def _nltk_truecasing(self, random_capitalized_string: str):
from nltk import pos_tag, word_tokenize
import nltk
nltk.download("punkt")
nltk.download("averaged_perceptron_tagger")
nltk.download("universal_tagset")
POS = ["NNP", "NNPS"]
tagged_text = pos_tag(word_tokenize(random_capitalized_string.lower()))
truecased_str = " ".join([w.capitalize() if p in POS else w for (w, p) in tagged_text])
return truecased_str