Spaces:
Runtime error
Runtime error
| import transformers | |
| from transformers import AutoTokenizer | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| ) | |
| from transformers import pipeline, set_seed, LogitsProcessor | |
| from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper | |
| import torch | |
| from scipy.special import gamma, gammainc, gammaincc, betainc | |
| from scipy.optimize import fminbound | |
| import numpy as np | |
| import os | |
| hf_token = os.getenv('HF_TOKEN') | |
| device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') | |
| def hash_tokens(input_ids: torch.LongTensor, key: int): | |
| seed = key | |
| salt = 35317 | |
| for i in input_ids: | |
| seed = (seed * salt + i.item()) % (2 ** 64 - 1) | |
| return seed | |
| class WatermarkingLogitsProcessor(LogitsProcessor): | |
| def __init__(self, n, key, messages, window_size, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.batch_size = len(messages) | |
| self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ] | |
| self.n = n | |
| self.key = key | |
| self.window_size = window_size | |
| if not self.window_size: | |
| for b in range(self.batch_size): | |
| self.generators[b].manual_seed(self.key) | |
| self.messages = messages | |
| class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| # get random uniform variables | |
| B, V = scores.shape | |
| r = torch.zeros_like(scores) | |
| for b in range(B): | |
| if self.window_size: | |
| window = input_ids[b, -self.window_size:] | |
| seed = hash_tokens(window, self.key) | |
| self.generators[b].manual_seed(seed) | |
| r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b]) | |
| # generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder | |
| r = r[:,:V] | |
| # modify law as r^(1/p) | |
| # Since we want to return logits (logits processor takes and outputs logits), | |
| # we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p | |
| return r / scores.exp() | |
| class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor): | |
| def __init__(self, *args, | |
| gamma = 0.5, | |
| delta = 4.0, | |
| **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.gamma = gamma | |
| self.delta = delta | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
| B, V = scores.shape | |
| for b in range(B): | |
| if self.window_size: | |
| window = input_ids[b, -self.window_size:] | |
| seed = hash_tokens(window, self.key) | |
| self.generators[b].manual_seed(seed) | |
| vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device) | |
| greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n | |
| bias = torch.zeros(self.n).to(scores.device) | |
| bias[greenlist] = self.delta | |
| bias = bias.roll(-self.messages[b])[:V] | |
| scores[b] += bias # add bias to greenlist words | |
| return scores | |
| class Watermarker(object): | |
| def __init__(self, modelname="facebook/opt-350m", window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs): | |
| self.tokenizer = AutoTokenizer.from_pretrained(modelname, use_auth_token=hf_token) | |
| self.model = AutoModelForCausalLM.from_pretrained(modelname, use_auth_token=hf_token).to(device) | |
| self.model.eval() | |
| self.window_size = window_size | |
| # preprocessing wrappers | |
| self.logits_processor = logits_processor or [] | |
| self.payload_bits = payload_bits | |
| self.V = max(2**payload_bits, self.model.config.vocab_size) | |
| self.generator = torch.Generator(device=device) | |
| def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'): | |
| B = len(messages) # batch size | |
| length = max_length | |
| # compute capacity | |
| if self.payload_bits: | |
| assert min([message >= 0 and message < 2**self.payload_bits for message in messages]) | |
| # tokenize prompt | |
| inputs = self.tokenizer([ prompt ] * B, return_tensors="pt") | |
| if method == 'aaronson': | |
| # generate with greedy search | |
| generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, | |
| logits_processor = self.logits_processor + [ | |
| WatermarkingAaronsonLogitsProcessor(n=self.V, | |
| key=key, | |
| messages=messages, | |
| window_size = self.window_size)]) | |
| elif method == 'kirchenbauer': | |
| # use sampling | |
| generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, | |
| logits_processor = self.logits_processor + [ | |
| WatermarkingKirchenbauerLogitsProcessor(n=self.V, | |
| key=key, | |
| messages=messages, | |
| window_size = self.window_size)]) | |
| elif method == 'greedy': | |
| # generate with greedy search | |
| generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, | |
| logits_processor = self.logits_processor) | |
| elif method == 'sampling': | |
| # generate with greedy search | |
| generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, | |
| logits_processor = self.logits_processor) | |
| else: | |
| raise Exception('Unknown method %s' % method) | |
| decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| return decoded_texts | |
| def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None): | |
| if(prompts==None): | |
| prompts = [""] * len(attacked_texts) | |
| generator = self.generator | |
| #print("attacked_texts = ", attacked_texts) | |
| cdfs = [] | |
| ms = [] | |
| MAX = 2**self.payload_bits | |
| # tokenize input | |
| inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True) | |
| input_ids = inputs["input_ids"].to(self.model.device) | |
| attention_masks = inputs["attention_mask"].to(self.model.device) | |
| B,T = input_ids.shape | |
| if method == 'aaronson_neyman_pearson': | |
| # compute logits | |
| outputs = self.model.forward(input_ids, return_dict=True) | |
| logits = outputs['logits'] | |
| # TODO | |
| # reapply logits processors to get same distribution | |
| #for i in range(T): | |
| # for processor in self.logits_processor: | |
| # logits[:,i] = processor(input_ids[:, :i], logits[:, i]) | |
| probs = logits.softmax(dim=-1) | |
| ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1) | |
| seq_len = input_ids.shape[1] | |
| length = seq_len | |
| V = self.V | |
| Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device) | |
| # keep a history of contexts we have already seen, | |
| # to exclude them from score aggregation and allow | |
| # correct p-value computation under H0 | |
| history = [set() for _ in range(B)] | |
| attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"] | |
| prompts_length = torch.sum(attention_masks_prompts, dim=1) | |
| for b in range(B): | |
| attention_masks[b, :prompts_length[b]] = 0 | |
| if not self.window_size: | |
| generator.manual_seed(key) | |
| # We can go from seq_len - prompt_len, need to change +1 to + prompt_len | |
| for i in range(seq_len-1): | |
| if self.window_size: | |
| window = input_ids[b, max(0, i-self.window_size+1):i+1] | |
| #print("window = ", window) | |
| seed = hash_tokens(window, key) | |
| if seed not in history[b]: | |
| generator.manual_seed(seed) | |
| history[b].add(seed) | |
| else: | |
| # ignore the token | |
| attention_masks[b, i+1] = 0 | |
| if not attention_masks[b,i+1]: | |
| continue | |
| token = int(input_ids[b,i+1]) | |
| if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}: | |
| R = torch.rand(V, generator = generator, device = generator.device) | |
| if method == 'aaronson': | |
| r = -(1-R).log() | |
| elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: | |
| r = -R.log() | |
| elif method == 'kirchenbauer': | |
| r = torch.zeros(V, device=device) | |
| vocab_permutation = torch.randperm(V, generator = generator, device=generator.device) | |
| greenlist = vocab_permutation[:int(gamma * V)] | |
| r[greenlist] = 1 | |
| else: | |
| raise Exception('Unknown method %s' % method) | |
| if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}: | |
| # independent of probs | |
| Z[b] += r.roll(-token) | |
| elif method == 'aaronson_neyman_pearson': | |
| # Neyman-Pearson | |
| Z[b] += r.roll(-token) * (1/ps[b,i] - 1) | |
| for b in range(B): | |
| if method in {'aaronson', 'kirchenbauer'}: | |
| m = torch.argmax(Z[b,:MAX]) | |
| elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: | |
| m = torch.argmin(Z[b,:MAX]) | |
| i = int(m) | |
| S = Z[b, i].item() | |
| m = i | |
| # actual sequence length | |
| k = torch.sum(attention_masks[b]).item() - 1 | |
| if method == 'aaronson': | |
| cdf = gammaincc(k, S) | |
| elif method == 'aaronson_simplified': | |
| cdf = gammainc(k, S) | |
| elif method == 'aaronson_neyman_pearson': | |
| # Chernoff bound | |
| ratio = ps[b,:k] / (1 - ps[b,:k]) | |
| E = (1/ratio).sum() | |
| if S > E: | |
| cdf = 1.0 | |
| else: | |
| # to compute p-value we must solve for c*: | |
| # (1/(c* + ps/(1-ps))).sum() = S | |
| func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item() | |
| c1 = (k / S - torch.min(ratio)).item() | |
| print("max = ", c1) | |
| c = fminbound(func, 0, c1) | |
| print("solved c = ", c) | |
| print("solved s = ", ((1/(c + ratio)).sum()).item()) | |
| # upper bound | |
| cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S) | |
| elif method == 'kirchenbauer': | |
| cdf = betainc(S, k - S + 1, gamma) | |
| if cdf > min(1 / MAX, 1e-5): | |
| cdf = 1 - (1 - cdf)**MAX # true value | |
| else: | |
| cdf = cdf * MAX # numerically stable upper bound | |
| cdfs.append(float(cdf)) | |
| ms.append(m) | |
| return cdfs, ms | |