Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Antoine Chaffin
		
	commited on
		
		
					Commit 
							
							·
						
						ed02397
	
1
								Parent(s):
							
							31f8227
								
Initial commit
Browse files- app.py +80 -0
 - requirements.txt +4 -0
 - watermark.py +291 -0
 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,80 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import argparse
         
     | 
| 3 | 
         
            +
            import os
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from watermark import Watermarker
         
     | 
| 7 | 
         
            +
            import time
         
     | 
| 8 | 
         
            +
            import gradio as gr
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            parser = argparse.ArgumentParser(description='Generative Text Watermarking demo')
         
     | 
| 13 | 
         
            +
            parser.add_argument('--model', '-m', type=str, default="facebook/opt-350m", help='Language model')
         
     | 
| 14 | 
         
            +
            # parser.add_argument('--model', '-m', type=str, default="meta-llama/Llama-2-7b-chat-hf", help='Language model')
         
     | 
| 15 | 
         
            +
            parser.add_argument('--key', '-k', type=int, default=42,
         
     | 
| 16 | 
         
            +
                                help='The seed of the pseudo random number generator')
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            args = parser.parse_args()
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            USERS = ['Alice', 'Bob', 'Charlie', 'Dan']
         
     | 
| 21 | 
         
            +
            EMBED_METHODS = [ 'aaronson', 'kirchenbauer', 'sampling', 'greedy' ]
         
     | 
| 22 | 
         
            +
            DETECT_METHODS = [ 'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson', 'kirchenbauer']
         
     | 
| 23 | 
         
            +
            PAYLOAD_BITS = 2
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
            def embed(user, max_length, window_size, method, prompt):
         
     | 
| 26 | 
         
            +
                uid = USERS.index(user)
         
     | 
| 27 | 
         
            +
                
         
     | 
| 28 | 
         
            +
                watermarker = Watermarker(modelname=args.model,
         
     | 
| 29 | 
         
            +
                                          window_size=window_size, payload_bits=PAYLOAD_BITS)
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
                watermarked_texts = watermarker.embed(key=args.key, messages=[ uid ],
         
     | 
| 32 | 
         
            +
                                                      max_length=max_length, method=method, prompt=prompt)
         
     | 
| 33 | 
         
            +
                print("watermarked_texts: ", watermarked_texts)
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                return watermarked_texts[0]
         
     | 
| 36 | 
         
            +
                
         
     | 
| 37 | 
         
            +
            def detect(attacked_text, window_size, method, prompt):
         
     | 
| 38 | 
         
            +
                watermarker = Watermarker(modelname=args.model,
         
     | 
| 39 | 
         
            +
                                          window_size=window_size, payload_bits=PAYLOAD_BITS)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                pvalues, messages = watermarker.detect([ attacked_text ], key=args.key, method=method, prompts=[prompt])
         
     | 
| 42 | 
         
            +
                print("messages: ", messages)
         
     | 
| 43 | 
         
            +
                print("p-values: ", pvalues)
         
     | 
| 44 | 
         
            +
                user = USERS[messages[0]]
         
     | 
| 45 | 
         
            +
                pf = pvalues[0]
         
     | 
| 46 | 
         
            +
                label = 'The user detected is {:s} with pvalue of {:.3e}'.format(user, pf)
         
     | 
| 47 | 
         
            +
             
         
     | 
| 48 | 
         
            +
                return label
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            with gr.Blocks() as demo:
         
     | 
| 53 | 
         
            +
                gr.Markdown("""# LLM generation watermarking
         
     | 
| 54 | 
         
            +
                This spaces let you to try different watermarking scheme for LLM generation.\n
         
     | 
| 55 | 
         
            +
                It leverages the upgrades introduced in the paper, reducing the gap between empirical and theoretical false positive detection rate and give the ability to embed a message (of n bits). Here we use this capacity to embed the identity of the user generating the text, but it could also be used to identify different version of a model or just convey a secret message.\n
         
     | 
| 56 | 
         
            +
                Simply select an user name, set the maximum text length, the watermarking window size and the prompt. Aaronson and Kirchenbauer watermarking scheme are proposed, along traditional sampling and greedy search without watermarking.\n
         
     | 
| 57 | 
         
            +
                Once the text is generated, you can eventually apply some attacks to it (e.g, remove words), select the associated detection method and run the detection. Please note that the detection is non-blind, and require the original prompt to be known and so left untouched.\n
         
     | 
| 58 | 
         
            +
                For Aaronson, the original detection function, along the Neyman-Pearson and Simplified Score version are available.""")
         
     | 
| 59 | 
         
            +
                with gr.Row():
         
     | 
| 60 | 
         
            +
                    user = gr.Dropdown(choices=USERS, value=USERS[0], label="User")
         
     | 
| 61 | 
         
            +
                    text_length = gr.Number(minimum=1, maximum=512, value=256, step=1, precision=0, label="Max text length")
         
     | 
| 62 | 
         
            +
                    window_size = gr.Number(minimum=0, maximum=10, value=0, step=1, precision=0, label="Watermarking window size")
         
     | 
| 63 | 
         
            +
                    embed_method = gr.Dropdown(choices=EMBED_METHODS, value=EMBED_METHODS[0], label="Sampling method")
         
     | 
| 64 | 
         
            +
                    prompt = gr.Textbox(label="prompt")
         
     | 
| 65 | 
         
            +
                with gr.Row():
         
     | 
| 66 | 
         
            +
                    btn1 = gr.Button("Embed")
         
     | 
| 67 | 
         
            +
                with gr.Row():
         
     | 
| 68 | 
         
            +
                    watermarked_text = gr.Textbox(label="Generated text")
         
     | 
| 69 | 
         
            +
                    detect_method = gr.Dropdown(choices=DETECT_METHODS, value=DETECT_METHODS[0], label="Detection method")
         
     | 
| 70 | 
         
            +
                with gr.Row():
         
     | 
| 71 | 
         
            +
                    btn2 = gr.Button("Detect")
         
     | 
| 72 | 
         
            +
                with gr.Row():
         
     | 
| 73 | 
         
            +
                    detection_label = gr.Label(label="Detection result")
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                btn1.click(fn=embed, inputs=[user, text_length, window_size, embed_method, prompt], outputs=[watermarked_text], api_name="watermark")
         
     | 
| 76 | 
         
            +
                btn2.click(fn=detect, inputs=[watermarked_text, window_size, detect_method, prompt], outputs=[detection_label], api_name="detect")
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                demo.launch()
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
             
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,4 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            torch
         
     | 
| 2 | 
         
            +
            transformers
         
     | 
| 3 | 
         
            +
            scipy
         
     | 
| 4 | 
         
            +
            numpy
         
     | 
    	
        watermark.py
    ADDED
    
    | 
         @@ -0,0 +1,291 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import transformers
         
     | 
| 2 | 
         
            +
            from transformers import AutoTokenizer
         
     | 
| 3 | 
         
            +
            from transformers import (
         
     | 
| 4 | 
         
            +
                AutoTokenizer,
         
     | 
| 5 | 
         
            +
                AutoModelForCausalLM,
         
     | 
| 6 | 
         
            +
            )
         
     | 
| 7 | 
         
            +
            from transformers import pipeline, set_seed, LogitsProcessor
         
     | 
| 8 | 
         
            +
            from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
         
     | 
| 9 | 
         
            +
            import torch
         
     | 
| 10 | 
         
            +
            from scipy.special import gamma, gammainc, gammaincc, betainc
         
     | 
| 11 | 
         
            +
            from scipy.optimize import fminbound
         
     | 
| 12 | 
         
            +
            import numpy as np
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            import os
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            hf_token = os.getenv('HF_TOKEN')
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def hash_tokens(input_ids: torch.LongTensor, key: int):
         
     | 
| 22 | 
         
            +
                seed = key
         
     | 
| 23 | 
         
            +
                salt = 35317
         
     | 
| 24 | 
         
            +
                for i in input_ids:
         
     | 
| 25 | 
         
            +
                    seed = (seed * salt + i.item()) % (2 ** 64 - 1)
         
     | 
| 26 | 
         
            +
                return seed
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class WatermarkingLogitsProcessor(LogitsProcessor):
         
     | 
| 29 | 
         
            +
                def __init__(self, n, key, messages, window_size, *args, **kwargs):
         
     | 
| 30 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 31 | 
         
            +
                    self.batch_size = len(messages)
         
     | 
| 32 | 
         
            +
                    self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                    self.n = n
         
     | 
| 35 | 
         
            +
                    self.key = key
         
     | 
| 36 | 
         
            +
                    self.window_size = window_size
         
     | 
| 37 | 
         
            +
                    if not self.window_size:
         
     | 
| 38 | 
         
            +
                        for b in range(self.batch_size):
         
     | 
| 39 | 
         
            +
                            self.generators[b].manual_seed(self.key)
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                    self.messages = messages
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
         
     | 
| 44 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 45 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:       
         
     | 
| 48 | 
         
            +
                    # get random uniform variables
         
     | 
| 49 | 
         
            +
                    B, V = scores.shape
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                    r = torch.zeros_like(scores)
         
     | 
| 52 | 
         
            +
                    for b in range(B):
         
     | 
| 53 | 
         
            +
                        if self.window_size:
         
     | 
| 54 | 
         
            +
                            window = input_ids[b, -self.window_size:]
         
     | 
| 55 | 
         
            +
                            seed = hash_tokens(window, self.key)
         
     | 
| 56 | 
         
            +
                            self.generators[b].manual_seed(seed)
         
     | 
| 57 | 
         
            +
                        r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
         
     | 
| 58 | 
         
            +
                    # generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
         
     | 
| 59 | 
         
            +
                    r = r[:,:V]
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    # modify law as r^(1/p)
         
     | 
| 62 | 
         
            +
                    # Since we want to return logits (logits processor takes and outputs logits),
         
     | 
| 63 | 
         
            +
                    # we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
         
     | 
| 64 | 
         
            +
                    return r / scores.exp()
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
         
     | 
| 67 | 
         
            +
                def __init__(self, *args,
         
     | 
| 68 | 
         
            +
                             gamma = 0.5,
         
     | 
| 69 | 
         
            +
                             delta = 4.0,
         
     | 
| 70 | 
         
            +
                             **kwargs):
         
     | 
| 71 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 72 | 
         
            +
                    self.gamma = gamma
         
     | 
| 73 | 
         
            +
                    self.delta = delta
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
         
     | 
| 76 | 
         
            +
                    B, V = scores.shape
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    for b in range(B):
         
     | 
| 79 | 
         
            +
                        if self.window_size:
         
     | 
| 80 | 
         
            +
                            window = input_ids[b, -self.window_size:]
         
     | 
| 81 | 
         
            +
                            seed = hash_tokens(window, self.key)
         
     | 
| 82 | 
         
            +
                            self.generators[b].manual_seed(seed)
         
     | 
| 83 | 
         
            +
                        vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
         
     | 
| 84 | 
         
            +
                        greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
         
     | 
| 85 | 
         
            +
                        bias = torch.zeros(self.n).to(scores.device)
         
     | 
| 86 | 
         
            +
                        bias[greenlist] = self.delta
         
     | 
| 87 | 
         
            +
                        bias = bias.roll(-self.messages[b])[:V]
         
     | 
| 88 | 
         
            +
                        scores[b] += bias # add bias to greenlist words
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                    return scores
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            class Watermarker(object):
         
     | 
| 93 | 
         
            +
                def __init__(self, modelname="facebook/opt-350m", window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
         
     | 
| 94 | 
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(modelname, use_auth_token=hf_token)
         
     | 
| 95 | 
         
            +
                    self.model = AutoModelForCausalLM.from_pretrained(modelname, use_auth_token=hf_token).to(device)
         
     | 
| 96 | 
         
            +
                    self.model.eval()
         
     | 
| 97 | 
         
            +
                    self.window_size = window_size
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                    # preprocessing wrappers
         
     | 
| 100 | 
         
            +
                    self.logits_processor = logits_processor or []
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    self.payload_bits = payload_bits
         
     | 
| 103 | 
         
            +
                    self.V = max(2**payload_bits, self.model.config.vocab_size)
         
     | 
| 104 | 
         
            +
                    self.generator = torch.Generator(device=device)
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    B = len(messages) # batch size
         
     | 
| 110 | 
         
            +
                    length = max_length
         
     | 
| 111 | 
         
            +
                     
         
     | 
| 112 | 
         
            +
                    # compute capacity
         
     | 
| 113 | 
         
            +
                    if self.payload_bits:
         
     | 
| 114 | 
         
            +
                        assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    # tokenize prompt
         
     | 
| 117 | 
         
            +
                    inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                    if method == 'aaronson':
         
     | 
| 120 | 
         
            +
                        # generate with greedy search
         
     | 
| 121 | 
         
            +
                        generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
         
     | 
| 122 | 
         
            +
                                                            logits_processor = self.logits_processor + [
         
     | 
| 123 | 
         
            +
                                                                WatermarkingAaronsonLogitsProcessor(n=self.V,
         
     | 
| 124 | 
         
            +
                                                                                                    key=key,
         
     | 
| 125 | 
         
            +
                                                                                                    messages=messages,
         
     | 
| 126 | 
         
            +
                                                                                                    window_size = self.window_size)])
         
     | 
| 127 | 
         
            +
                    elif method == 'kirchenbauer':
         
     | 
| 128 | 
         
            +
                        # use sampling
         
     | 
| 129 | 
         
            +
                        generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
         
     | 
| 130 | 
         
            +
                                                            logits_processor = self.logits_processor + [
         
     | 
| 131 | 
         
            +
                                                                WatermarkingKirchenbauerLogitsProcessor(n=self.V,
         
     | 
| 132 | 
         
            +
                                                                                                        key=key,
         
     | 
| 133 | 
         
            +
                                                                                                        messages=messages,
         
     | 
| 134 | 
         
            +
                                                                                                        window_size = self.window_size)])
         
     | 
| 135 | 
         
            +
                    elif method == 'greedy':
         
     | 
| 136 | 
         
            +
                        # generate with greedy search
         
     | 
| 137 | 
         
            +
                        generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
         
     | 
| 138 | 
         
            +
                                                            logits_processor = self.logits_processor)
         
     | 
| 139 | 
         
            +
                    elif method == 'sampling':
         
     | 
| 140 | 
         
            +
                        # generate with greedy search
         
     | 
| 141 | 
         
            +
                        generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
         
     | 
| 142 | 
         
            +
                                                            logits_processor = self.logits_processor)
         
     | 
| 143 | 
         
            +
                    else:
         
     | 
| 144 | 
         
            +
                       raise Exception('Unknown method %s' % method)
         
     | 
| 145 | 
         
            +
                    decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                    return decoded_texts
         
     | 
| 148 | 
         
            +
                
         
     | 
| 149 | 
         
            +
                def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
         
     | 
| 150 | 
         
            +
                    if(prompts==None):
         
     | 
| 151 | 
         
            +
                        prompts = [""] * len(attacked_texts)
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
                    generator = self.generator
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    #print("attacked_texts = ", attacked_texts)
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    cdfs = []
         
     | 
| 158 | 
         
            +
                    ms = []
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
                    MAX = 2**self.payload_bits
         
     | 
| 161 | 
         
            +
                    
         
     | 
| 162 | 
         
            +
                    # tokenize input
         
     | 
| 163 | 
         
            +
                    inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
         
     | 
| 164 | 
         
            +
                                    
         
     | 
| 165 | 
         
            +
                    input_ids = inputs["input_ids"].to(self.model.device)
         
     | 
| 166 | 
         
            +
                    attention_masks = inputs["attention_mask"].to(self.model.device)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    B,T = input_ids.shape
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    if method == 'aaronson_neyman_pearson':
         
     | 
| 171 | 
         
            +
                        # compute logits
         
     | 
| 172 | 
         
            +
                        outputs = self.model.forward(input_ids, return_dict=True)
         
     | 
| 173 | 
         
            +
                        logits = outputs['logits']
         
     | 
| 174 | 
         
            +
                        # TODO
         
     | 
| 175 | 
         
            +
                        # reapply logits processors to get same distribution
         
     | 
| 176 | 
         
            +
                        #for i in range(T):
         
     | 
| 177 | 
         
            +
                        #    for processor in self.logits_processor:
         
     | 
| 178 | 
         
            +
                        #        logits[:,i] = processor(input_ids[:, :i], logits[:, i])
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                        probs = logits.softmax(dim=-1)
         
     | 
| 181 | 
         
            +
                        ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                    seq_len = input_ids.shape[1]
         
     | 
| 185 | 
         
            +
                    length = seq_len
         
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
                    V = self.V
         
     | 
| 188 | 
         
            +
                            
         
     | 
| 189 | 
         
            +
                    Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    # keep a history of contexts we have already seen,
         
     | 
| 193 | 
         
            +
                    # to exclude them from score aggregation and allow
         
     | 
| 194 | 
         
            +
                    # correct p-value computation under H0
         
     | 
| 195 | 
         
            +
                    history = [set() for _ in range(B)]
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
         
     | 
| 198 | 
         
            +
                    prompts_length = torch.sum(attention_masks_prompts, dim=1)
         
     | 
| 199 | 
         
            +
                    for b in range(B):
         
     | 
| 200 | 
         
            +
                        attention_masks[b, :prompts_length[b]] = 0
         
     | 
| 201 | 
         
            +
                        if not self.window_size:
         
     | 
| 202 | 
         
            +
                            generator.manual_seed(key)
         
     | 
| 203 | 
         
            +
                        # We can go from seq_len - prompt_len, need to change +1 to + prompt_len
         
     | 
| 204 | 
         
            +
                        for i in range(seq_len-1):
         
     | 
| 205 | 
         
            +
                        
         
     | 
| 206 | 
         
            +
                            if self.window_size:
         
     | 
| 207 | 
         
            +
                                window = input_ids[b, max(0, i-self.window_size+1):i+1]
         
     | 
| 208 | 
         
            +
                                #print("window = ", window)
         
     | 
| 209 | 
         
            +
                                seed = hash_tokens(window, key)
         
     | 
| 210 | 
         
            +
                                if seed not in history[b]:
         
     | 
| 211 | 
         
            +
                                    generator.manual_seed(seed)
         
     | 
| 212 | 
         
            +
                                    history[b].add(seed)
         
     | 
| 213 | 
         
            +
                                else:
         
     | 
| 214 | 
         
            +
                                    # ignore the token
         
     | 
| 215 | 
         
            +
                                    attention_masks[b, i+1] = 0
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                            if not attention_masks[b,i+1]:
         
     | 
| 218 | 
         
            +
                                continue
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                            token = int(input_ids[b,i+1])
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                            if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
         
     | 
| 223 | 
         
            +
                                R = torch.rand(V, generator = generator, device = generator.device)
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                            if method == 'aaronson':
         
     | 
| 226 | 
         
            +
                                r = -(1-R).log()
         
     | 
| 227 | 
         
            +
                            elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
         
     | 
| 228 | 
         
            +
                                r = -R.log()
         
     | 
| 229 | 
         
            +
                            elif method == 'kirchenbauer':
         
     | 
| 230 | 
         
            +
                                r = torch.zeros(V, device=device)
         
     | 
| 231 | 
         
            +
                                vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
         
     | 
| 232 | 
         
            +
                                greenlist = vocab_permutation[:int(gamma * V)]
         
     | 
| 233 | 
         
            +
                                r[greenlist] = 1
         
     | 
| 234 | 
         
            +
                            else:
         
     | 
| 235 | 
         
            +
                                raise Exception('Unknown method %s' % method)
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                            if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
         
     | 
| 238 | 
         
            +
                                # independent of probs
         
     | 
| 239 | 
         
            +
                                Z[b] += r.roll(-token)
         
     | 
| 240 | 
         
            +
                            elif method == 'aaronson_neyman_pearson':
         
     | 
| 241 | 
         
            +
                                # Neyman-Pearson
         
     | 
| 242 | 
         
            +
                                Z[b] += r.roll(-token) * (1/ps[b,i] - 1)
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    for b in range(B):
         
     | 
| 245 | 
         
            +
                        if method in {'aaronson', 'kirchenbauer'}:
         
     | 
| 246 | 
         
            +
                            m = torch.argmax(Z[b,:MAX])
         
     | 
| 247 | 
         
            +
                        elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
         
     | 
| 248 | 
         
            +
                            m = torch.argmin(Z[b,:MAX])
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                        i = int(m)
         
     | 
| 251 | 
         
            +
                        S = Z[b, i].item()
         
     | 
| 252 | 
         
            +
                        m = i
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                        # actual sequence length
         
     | 
| 255 | 
         
            +
                        k = torch.sum(attention_masks[b]).item() - 1
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                        if method == 'aaronson':
         
     | 
| 258 | 
         
            +
                            cdf = gammaincc(k, S)
         
     | 
| 259 | 
         
            +
                        elif method == 'aaronson_simplified':
         
     | 
| 260 | 
         
            +
                            cdf = gammainc(k, S)
         
     | 
| 261 | 
         
            +
                        elif method == 'aaronson_neyman_pearson':
         
     | 
| 262 | 
         
            +
                            # Chernoff bound
         
     | 
| 263 | 
         
            +
                            ratio = ps[b,:k] / (1 - ps[b,:k])
         
     | 
| 264 | 
         
            +
                            E = (1/ratio).sum()
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                            if S > E:
         
     | 
| 267 | 
         
            +
                                cdf = 1.0
         
     | 
| 268 | 
         
            +
                            else:
         
     | 
| 269 | 
         
            +
                                # to compute p-value we must solve for c*:
         
     | 
| 270 | 
         
            +
                                # (1/(c* + ps/(1-ps))).sum() = S
         
     | 
| 271 | 
         
            +
                                func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
         
     | 
| 272 | 
         
            +
                                c1 = (k / S - torch.min(ratio)).item()
         
     | 
| 273 | 
         
            +
                                print("max = ", c1)
         
     | 
| 274 | 
         
            +
                                c = fminbound(func, 0, c1)
         
     | 
| 275 | 
         
            +
                                print("solved c = ", c)
         
     | 
| 276 | 
         
            +
                                print("solved s = ", ((1/(c + ratio)).sum()).item())
         
     | 
| 277 | 
         
            +
                                # upper bound
         
     | 
| 278 | 
         
            +
                                cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
         
     | 
| 279 | 
         
            +
                        elif method == 'kirchenbauer':
         
     | 
| 280 | 
         
            +
                            cdf = betainc(S, k - S + 1, gamma)
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                        if cdf > min(1 / MAX, 1e-5):
         
     | 
| 283 | 
         
            +
                            cdf = 1 - (1 - cdf)**MAX # true value
         
     | 
| 284 | 
         
            +
                        else:
         
     | 
| 285 | 
         
            +
                            cdf = cdf * MAX # numerically stable upper bound
         
     | 
| 286 | 
         
            +
                        cdfs.append(float(cdf))
         
     | 
| 287 | 
         
            +
                        ms.append(m)
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    return cdfs, ms
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
             
     |