Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| os.environ["TOKENIZERS_PARALLELISM"] = "true" | |
| import frontmatter | |
| import gradio as gr | |
| import json | |
| import spaces | |
| import torch | |
| from normalize import normalize | |
| from transformers import AutoTokenizer | |
| from modeling_nova import NovaTokenizer, NovaForCausalLM | |
| def fix_assembly_tabs(asm_text): | |
| """ | |
| Fix assembly code formatting by ensuring proper tab placement. | |
| Expected format: address:TABhex_bytesWHITESPACEinstructionWHITESPACEoperands | |
| """ | |
| lines = asm_text.split("\n") | |
| fixed_lines = [] | |
| for line in lines: | |
| line = line.rstrip() # Remove trailing whitespace | |
| if not line.strip(): # Skip empty lines | |
| fixed_lines.append(line) | |
| continue | |
| # Check if this looks like an assembly instruction line | |
| # Pattern: optional_spaces + hex_address + colon + hex_bytes + instruction + operands | |
| asm_pattern = r"^(\s*)([0-9a-f]+):\s*([0-9a-f\s]+?)\s+(\w+)(\s+.*)?$" | |
| match = re.match(asm_pattern, line, re.IGNORECASE) | |
| if match: | |
| indent, address, hex_bytes, instruction, operands = match.groups() | |
| operands = operands or "" | |
| # Clean up hex bytes (remove extra spaces) | |
| hex_bytes = re.sub(r"\s+", " ", hex_bytes.strip()) | |
| # Reconstruct with proper tab formatting | |
| # Format: indent + address + ":" + TAB + hex_bytes + TAB + instruction + operands | |
| fixed_line = f"{indent}{address}:\t{hex_bytes}\t{instruction}{operands}" | |
| fixed_lines.append(fixed_line) | |
| else: | |
| # Not an assembly instruction line, keep as is | |
| fixed_lines.append(line) | |
| return "\n".join(fixed_lines) | |
| print("Downloading model") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "lt-asset/nova-6.7b-bcr", trust_remote_code=True | |
| ) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| nova_tokenizer = NovaTokenizer(tokenizer) | |
| model = NovaForCausalLM.from_pretrained( | |
| "lt-asset/nova-6.7b-bcr", torch_dtype=torch.bfloat16, device_map="auto" | |
| ).eval() | |
| examples = json.load(open("humaneval_decompile_nova_6.7b.json", "r")) | |
| def predict(type, input_asm, _c_source): | |
| if "<func0>:" not in input_asm: | |
| # Needs normalizing | |
| # Add a bogus function header if needed. | |
| first_line = input_asm.split("\n")[0] | |
| if "<" not in first_line or ">" not in first_line: | |
| print("Adding synthetic function header") | |
| input_asm = "<func0>:\n" + input_asm | |
| # Fix tab formatting in assembly code | |
| input_asm = fix_assembly_tabs(input_asm) | |
| # Normalizing | |
| normalized_asm = normalize(input_asm) | |
| print(f"Normalized asm: {normalized_asm}") | |
| else: | |
| normalized_asm = input_asm | |
| prompt_before = f"# This is the assembly code with {type} optimization:\n<func0>:" | |
| asm = normalized_asm.strip() | |
| assert asm.startswith("<func0>:") | |
| asm = asm[len("<func0>:") :] | |
| prompt_after = "\nWhat is the source code?\n" | |
| inputs = prompt_before + asm + prompt_after | |
| print("Inputs:", inputs) | |
| # 0 for non-assembly code characters and 1 for assembly characters, required by nova tokenizer | |
| char_types = "0" * len(prompt_before) + "1" * len(asm) + "0" * len(prompt_after) | |
| tokenizer_output = nova_tokenizer.encode(inputs, "", char_types) | |
| input_ids = torch.LongTensor(tokenizer_output["input_ids"].tolist()).unsqueeze(0) | |
| print("Input IDs:", input_ids.shape) | |
| nova_attention_mask = torch.LongTensor( | |
| tokenizer_output["nova_attention_mask"] | |
| ).unsqueeze(0) | |
| output = model.generate( | |
| inputs=input_ids.cuda(), | |
| max_new_tokens=512, | |
| temperature=0.2, | |
| top_p=0.95, | |
| num_return_sequences=1, | |
| do_sample=True, | |
| nova_attention_mask=nova_attention_mask.cuda(), | |
| no_mask_idx=torch.LongTensor([tokenizer_output["no_mask_idx"]]).cuda(), | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| print("Output 1:", output) | |
| output = tokenizer.decode( | |
| output[0][input_ids.size(1) :], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True, | |
| ) | |
| print("Output 2:", output) | |
| return output | |
| example = """ 0: f3 0f 1e fa endbr64 | |
| 4: 55 push %rbp | |
| 5: 48 89 e5 mov %rsp,%rbp | |
| 8: 89 7d fc mov %edi,-0x4(%rbp) | |
| b: 8b 45 fc mov -0x4(%rbp),%eax | |
| e: 83 c0 2a add $0x2a,%eax | |
| 11: 5d pop %rbp | |
| 12: c3 ret | |
| """ | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Text(label="Optimization Type", value="O0"), | |
| gr.Text(label="Assembly Code (Normalized or not)", value=example), | |
| gr.Text(label="Original C Code"), | |
| ], | |
| outputs=gr.Text(label="Raw Nova Output"), | |
| description=frontmatter.load("README.md").content, | |
| examples=[[ex["type"], ex["normalized_asm"], ex["c_func"]] for ex in examples], | |
| ) | |
| demo.launch(show_error=True) | |