Spaces:
Runtime error
Runtime error
| import os, json, random | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| from huggingface_hub import login, hf_hub_download | |
| import pyreft | |
| import pyvene as pv | |
| from threading import Thread | |
| from typing import Iterator | |
| import torch.nn.functional as F | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| login(token=HF_TOKEN) | |
| MAX_MAX_NEW_TOKENS = 2048 | |
| DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory | |
| MAX_INPUT_TOKEN_LENGTH = 4096 | |
| css = """ | |
| #alert-message textarea { | |
| background-color: #e8f4ff; | |
| border: 1px solid #cce5ff; | |
| color: #084298; | |
| font-size: 1.1em; | |
| padding: 12px; | |
| border-radius: 4px; | |
| font-weight: 500; | |
| } | |
| .concept-help { | |
| font-size: 0.9em; | |
| color: #666; | |
| margin-top: 4px; | |
| font-style: italic; | |
| } | |
| """ | |
| def load_jsonl(jsonl_path): | |
| jsonl_data = [] | |
| with open(jsonl_path, 'r') as f: | |
| for line in f: | |
| data = json.loads(line) | |
| jsonl_data.append(data) | |
| return jsonl_data | |
| class Steer(pv.SourcelessIntervention): | |
| """Steer model via activation addition""" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs, keep_last_dim=True) | |
| self.proj = torch.nn.Linear( | |
| self.embed_dim, kwargs["latent_dim"], bias=False) | |
| self.subspace_generator = kwargs["subspace_generator"] | |
| def steer(self, base, source=None, subspaces=None): | |
| if subspaces["steer"]["subspace_gen_inputs"] is not None: | |
| # we call our subspace generator to generate the subspace on-the-fly. | |
| raw_steering_vec = self.subspace_generator( | |
| subspaces["steer"]["subspace_gen_inputs"]["input_ids"], | |
| subspaces["steer"]["subspace_gen_inputs"]["attention_mask"], | |
| )[0] | |
| steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \ | |
| raw_steering_vec.unsqueeze(dim=0) | |
| return base + steering_vec | |
| else: | |
| steering_vec = torch.tensor(subspaces["steer"]["mag"]) * \ | |
| self.proj.weight[subspaces["steer"]["idx"]].unsqueeze(dim=0) | |
| return base + steering_vec | |
| def forward(self, base, source=None, subspaces=None): | |
| if subspaces == None: | |
| return base | |
| if subspaces["detect"] is not None: | |
| if subspaces["detect"]["subspace_gen_inputs"] is not None: | |
| # we call our subspace generator to generate the subspace on-the-fly. | |
| raw_detection_vec = self.subspace_generator( | |
| subspaces["detect"]["subspace_gen_inputs"]["input_ids"], | |
| subspaces["detect"]["subspace_gen_inputs"]["attention_mask"], | |
| )[0].unsqueeze(dim=-1) | |
| else: | |
| raw_detection_vec = self.proj.weight[subspaces["detect"]["idx"]].unsqueeze(dim=-1) | |
| print(base.shape) | |
| print(raw_detection_vec.shape) | |
| detection_latent = torch.matmul(base, raw_detection_vec.to(base.dtype)).squeeze(dim=-1) # (batch_size, seq, 1) -> (batch_size, seq) | |
| max_latent = torch.max(detection_latent, dim=-1).values[0] # (batch_size, seq) -> (batch_size) | |
| print("max_latent", max_latent) | |
| if max_latent > torch.tensor(subspaces["detect"]["mag"]): | |
| print("Detected!") | |
| return self.steer(base, source, subspaces) | |
| else: | |
| return base | |
| else: | |
| return self.steer(base, source, subspaces) | |
| class RegressionWrapper(torch.nn.Module): | |
| def __init__(self, base_model, hidden_size, output_dim): | |
| super().__init__() | |
| self.base_model = base_model | |
| self.regression_head = torch.nn.Linear(hidden_size, output_dim) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.base_model.model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| return_dict=True | |
| ) | |
| last_hiddens = outputs.hidden_states[-1] | |
| last_token_representations = last_hiddens[:, -1] | |
| preds = self.regression_head(last_token_representations) | |
| preds = F.normalize(preds, p=2, dim=-1) | |
| return preds | |
| # Check GPU | |
| if not torch.cuda.is_available(): | |
| print("Warning: Running on CPU, may be slow.") | |
| # Load model & dictionary | |
| model_id = "google/gemma-2-2b-it" | |
| pv_model = None | |
| tokenizer = None | |
| concept_list = [] | |
| concept_id_map = {} | |
| if torch.cuda.is_available(): | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, device_map="cuda", torch_dtype=torch.bfloat16 | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| # Download dictionary | |
| weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt") | |
| meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl") | |
| params = torch.load(weight_path).cuda() | |
| md = load_jsonl(meta_path) | |
| concept_list = [item["concept"] for item in md] | |
| concept_id_map = {} | |
| # the reason to reindex is because there is one concept that is missing. | |
| concept_reindex = 0 | |
| for item in md: | |
| concept_id_map[item["concept"]] = concept_reindex | |
| concept_reindex += 1 | |
| # load subspace generator. | |
| base_tokenizer = AutoTokenizer.from_pretrained( | |
| f"google/gemma-2-2b", model_max_length=512) | |
| config = AutoConfig.from_pretrained("google/gemma-2-2b") | |
| base_model = AutoModelForCausalLM.from_config(config) | |
| subspace_generator_weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res-generator", filename="l20/weight.pt") | |
| hidden_size = base_model.config.hidden_size | |
| subspace_generator = RegressionWrapper( | |
| base_model, hidden_size, hidden_size).bfloat16().to("cuda") | |
| subspace_generator.load_state_dict(torch.load(subspace_generator_weight_path)) | |
| print(f"Loading model from saved file {subspace_generator_weight_path}") | |
| _ = subspace_generator.eval() | |
| steer = Steer( | |
| embed_dim=params.shape[0], latent_dim=params.shape[1], | |
| subspace_generator=subspace_generator) | |
| steer.proj.weight.data = params.float() | |
| pv_model = pv.IntervenableModel({ | |
| "component": f"model.layers[20].output", | |
| "intervention": steer}, model=model) | |
| terminators = [tokenizer.eos_token_id] if tokenizer else [] | |
| def generate( | |
| message: str, | |
| chat_history: list[tuple[str, str]], | |
| detection_list: list[dict], | |
| steering_list: list[dict], | |
| max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS, | |
| ) -> Iterator[str]: | |
| # limit to last 4 turns | |
| start_idx = max(0, len(chat_history) - 4) | |
| recent_history = chat_history[start_idx:] | |
| # build list of messages | |
| messages = [] | |
| for rh in recent_history: | |
| messages.append({"role": rh["role"], "content": rh["content"]}) | |
| messages.append({"role": "user", "content": message}) | |
| input_ids = torch.tensor([tokenizer.apply_chat_template( | |
| messages, tokenize=True, add_generation_prompt=True)]).cuda() | |
| # trim if needed | |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
| yield "[Truncated prior text]\n" | |
| streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) | |
| print("detection_list: ", detection_list) | |
| print("steering_list: ", steering_list) | |
| generate_kwargs = { | |
| "base": {"input_ids": input_ids}, | |
| "unit_locations": None, | |
| "max_new_tokens": max_new_tokens, | |
| "intervene_on_prompt": True, | |
| "subspaces": [ | |
| { | |
| "detect": { | |
| "idx": int(detection_list[0]["idx"]), | |
| "mag": detection_list[0]["internal_mag"]*50, | |
| "subspace_gen_inputs": base_tokenizer(detection_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \ | |
| if detection_list[0]["subspace_gen_text"] is not None else None | |
| } if detection_list else None, | |
| "steer": { | |
| "idx": int(steering_list[0]["idx"]), | |
| "mag": steering_list[0]["internal_mag"]*50, | |
| "subspace_gen_inputs": base_tokenizer(steering_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \ | |
| if steering_list[0]["subspace_gen_text"] is not None else None | |
| } | |
| } | |
| ] if steering_list else None, # if steering is not provided, we do not steer. | |
| "streamer": streamer, | |
| "do_sample": True | |
| } | |
| t = Thread(target=pv_model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| partial_text = [] | |
| for token_str in streamer: | |
| partial_text.append(token_str) | |
| yield "".join(partial_text) | |
| def filter_concepts(search_text: str): | |
| if not search_text.strip(): | |
| return concept_list[:500] | |
| filtered = [c for c in concept_list if search_text.lower() in c.lower()] | |
| return filtered[:500] | |
| def add_concept_to_list(selected_concept, user_slider_val, current_list): | |
| if not selected_concept: | |
| return current_list | |
| selected_concept_text = None | |
| if selected_concept.startswith("[New] "): | |
| selected_concept_text = selected_concept[6:] | |
| idx = 0 | |
| else: | |
| idx = concept_id_map[selected_concept] | |
| internal_mag = user_slider_val | |
| new_entry = { | |
| "text": selected_concept, | |
| "idx": idx, | |
| "display_mag": user_slider_val, | |
| "internal_mag": internal_mag, | |
| "subspace_gen_text": selected_concept_text | |
| } | |
| # Add to the beginning of the list | |
| current_list = [new_entry] | |
| return current_list | |
| def update_dropdown_choices(search_text, is_detection=False): | |
| filtered = filter_concepts(search_text) | |
| if not filtered or len(filtered) == 0: | |
| alert_message = ( | |
| "Good news! Based on the topic you provided, we will automatically generate a detector for you!" | |
| ) if is_detection else ( | |
| "Good news! Based on the topic you provided, we will automatically generate a steering vector. Try it out by starting a chat!" | |
| ) | |
| return gr.update( | |
| choices=[], | |
| value=None, | |
| interactive=True | |
| ), gr.Textbox( | |
| label="No matching topics found", | |
| value=alert_message, | |
| lines=3, | |
| interactive=False, | |
| visible=True, | |
| elem_id="alert-message" | |
| ) | |
| return gr.update( | |
| choices=filtered, | |
| value=filtered[0], | |
| interactive=True, | |
| visible=True | |
| ), gr.Textbox(visible=False) | |
| with gr.Blocks(css=css, fill_height=True) as demo: | |
| selected_detection = gr.State([]) | |
| selected_subspaces = gr.State([]) | |
| with gr.Row(min_height=500, equal_height=True): | |
| # Left side: chat area | |
| with gr.Column(scale=7): | |
| gr.Markdown("""# Conditionally Steer AI Responses Based on Topics""") | |
| gr.Markdown("""This is an experimental chatbot that you can steer using topics you care about: | |
| Step 1: Choose a topic (e.g., "Google") to detect | |
| Step 2: Choose a topic (e.g., "ethics") you want the model to discuss when the previous topic comes up | |
| We intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20.""") | |
| chat_interface = gr.ChatInterface( | |
| fn=generate, | |
| chatbot=gr.Chatbot(), | |
| textbox=gr.Textbox(placeholder="List some search engines with their pros and cons", container=True, scale=7, submit_btn=True), | |
| additional_inputs=[selected_detection, selected_subspaces], | |
| ) | |
| # Right side: concept detection and steering | |
| with gr.Column(scale=3): | |
| gr.Markdown("""#### Step 1: Choose a topic the model needs to recognize.""") | |
| with gr.Group(): | |
| detect_search = gr.Textbox( | |
| label="Search for topics to detect", | |
| placeholder="Try: 'Google'", | |
| lines=1, | |
| ) | |
| detect_msg = gr.TextArea(visible=False) | |
| detect_dropdown = gr.Dropdown( | |
| label="Choose a topic to detect (Click to see more!)", | |
| interactive=True, | |
| allow_custom_value=False, | |
| ) | |
| detect_threshold = gr.Slider( | |
| label="Detection sensitivity", | |
| minimum=0, | |
| maximum=1, | |
| step=0.1, | |
| value=0.5, | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("""#### Step 2: Choose another topic the model needs to discuss when it detects the topic above.""") | |
| with gr.Group(): | |
| search_box = gr.Textbox( | |
| label="Search topics to steer", | |
| placeholder="Try: 'ethics'", | |
| lines=1, | |
| ) | |
| msg = gr.TextArea(visible=False) | |
| concept_dropdown = gr.Dropdown( | |
| label="Choose a topic to steer the model (Click to see more!)", | |
| interactive=True, | |
| allow_custom_value=False, | |
| ) | |
| concept_magnitude = gr.Slider( | |
| label="Steering intensity", | |
| minimum=-5, | |
| maximum=5, | |
| step=0.1, | |
| value=3.5, | |
| ) | |
| # Wire up events for detection | |
| detect_search.input( | |
| lambda x: update_dropdown_choices(x, is_detection=True), | |
| [detect_search], | |
| [detect_dropdown, detect_msg] | |
| ).then( | |
| add_concept_to_list, | |
| [detect_dropdown, detect_threshold, selected_detection], | |
| [selected_detection] | |
| ) | |
| detect_dropdown.select( | |
| add_concept_to_list, | |
| [detect_dropdown, detect_threshold, selected_detection], | |
| [selected_detection] | |
| ) | |
| detect_threshold.input( | |
| add_concept_to_list, | |
| [detect_dropdown, detect_threshold, selected_detection], | |
| [selected_detection] | |
| ) | |
| # Wire up events for steering | |
| search_box.input( | |
| lambda x: update_dropdown_choices(x, is_detection=False), | |
| [search_box], | |
| [concept_dropdown, msg] | |
| ).then( | |
| add_concept_to_list, | |
| [concept_dropdown, concept_magnitude, selected_subspaces], | |
| [selected_subspaces] | |
| ) | |
| concept_dropdown.select( | |
| add_concept_to_list, | |
| [concept_dropdown, concept_magnitude, selected_subspaces], | |
| [selected_subspaces] | |
| ) | |
| concept_magnitude.input( | |
| add_concept_to_list, | |
| [concept_dropdown, concept_magnitude, selected_subspaces], | |
| [selected_subspaces] | |
| ) | |
| demo.launch(share=True, height=1000) |