Spaces:
Running
on
L40S
Running
on
L40S
| from collections.abc import Sequence | |
| import random | |
| from typing import Optional | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import transformers | |
| # If the watewrmark is not detected, consider the use case. Could be because of | |
| # the nature of the task (e.g., fatcual responses are lower entropy) or it could | |
| # be another | |
| _MODEL_IDENTIFIER = 'hf-internal-testing/tiny-random-gpt2' | |
| _PROMPTS: tuple[str] = ( | |
| 'prompt 1', | |
| 'prompt 2', | |
| 'prompt 3', | |
| ) | |
| _CORRECT_ANSWERS: dict[str, bool] = {} | |
| _TORCH_DEVICE = ( | |
| torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") | |
| ) | |
| _WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig( | |
| ngram_len=5, | |
| keys=[ | |
| 654, | |
| 400, | |
| 836, | |
| 123, | |
| 340, | |
| 443, | |
| 597, | |
| 160, | |
| 57, | |
| 29, | |
| 590, | |
| 639, | |
| 13, | |
| 715, | |
| 468, | |
| 990, | |
| 966, | |
| 226, | |
| 324, | |
| 585, | |
| 118, | |
| 504, | |
| 421, | |
| 521, | |
| 129, | |
| 669, | |
| 732, | |
| 225, | |
| 90, | |
| 960, | |
| ], | |
| sampling_table_size=2**16, | |
| sampling_table_seed=0, | |
| context_history_size=1024, | |
| ) | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER) | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER) | |
| model.to(_TORCH_DEVICE) | |
| def generate_outputs( | |
| prompts: Sequence[str], | |
| watermarking_config: Optional[ | |
| transformers.generation.SynthIDTextWatermarkingConfig | |
| ] = None, | |
| ) -> Sequence[str]: | |
| tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE) | |
| output_sequences = model.generate( | |
| **tokenized_prompts, | |
| watermarking_config=watermarking_config, | |
| do_sample=True, | |
| max_length=500, | |
| top_k=40, | |
| ) | |
| return tokenizer.batch_decode(output_sequences) | |
| with gr.Blocks() as demo: | |
| prompt_inputs = [ | |
| gr.Textbox(value=prompt, lines=4, label='Prompt') | |
| for prompt in _PROMPTS | |
| ] | |
| generate_btn = gr.Button('Generate') | |
| with gr.Column(visible=False) as generations_col: | |
| generations_grp = gr.CheckboxGroup( | |
| label='All generations, in random order', | |
| info='Select the generations you think are watermarked!', | |
| ) | |
| reveal_btn = gr.Button('Reveal', visible=False) | |
| with gr.Column(visible=False) as detections_col: | |
| revealed_grp = gr.CheckboxGroup( | |
| label='Ground truth for all generations', | |
| info=( | |
| 'Watermarked generations are checked, and your selection are ' | |
| 'marked as correct or incorrect in the text.' | |
| ), | |
| ) | |
| detect_btn = gr.Button('Detect', visible=False) | |
| def generate(*prompts): | |
| standard = generate_outputs(prompts=prompts) | |
| watermarked = generate_outputs( | |
| prompts=prompts, | |
| watermarking_config=_WATERMARK_CONFIG, | |
| ) | |
| responses = standard + watermarked | |
| random.shuffle(responses) | |
| _CORRECT_ANSWERS.update({ | |
| response: response in watermarked | |
| for response in responses | |
| }) | |
| # Load model | |
| return { | |
| generate_btn: gr.Button(visible=False), | |
| generations_col: gr.Column(visible=True), | |
| generations_grp: gr.CheckboxGroup( | |
| responses, | |
| ), | |
| reveal_btn: gr.Button(visible=True), | |
| } | |
| generate_btn.click( | |
| generate, | |
| inputs=prompt_inputs, | |
| outputs=[generate_btn, generations_col, generations_grp, reveal_btn] | |
| ) | |
| def reveal(user_selections: list[str]): | |
| choices: list[str] = [] | |
| value: list[str] = [] | |
| for response, is_watermarked in _CORRECT_ANSWERS.items(): | |
| if is_watermarked and response in user_selections: | |
| choice = f'Correct! {response}' | |
| elif not is_watermarked and response not in user_selections: | |
| choice = f'Correct! {response}' | |
| else: | |
| choice = f'Incorrect. {response}' | |
| choices.append(choice) | |
| if is_watermarked: | |
| value.append(choice) | |
| return { | |
| reveal_btn: gr.Button(visible=False), | |
| detections_col: gr.Column(visible=True), | |
| revealed_grp: gr.CheckboxGroup(choices=choices, value=value), | |
| detect_btn: gr.Button(visible=True), | |
| } | |
| reveal_btn.click( | |
| reveal, | |
| inputs=generations_grp, | |
| outputs=[ | |
| reveal_btn, | |
| detections_col, | |
| revealed_grp, | |
| detect_btn | |
| ], | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |