Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import difflib | |
| import re | |
| from utils import verify_diff, apply_diff_from_output | |
| commit_message_per_brush = { | |
| "Annotate Type": "annotate type to the variables.", | |
| "Reformat" : "Reformat the code using pep8", | |
| "Add Docstrings" : "Add docstrings to all the functions", | |
| "Add Comments" : "Add inline comments to all the functions", | |
| } | |
| def load_model_and_tokenizer(model_name:str="CarperAI/diff-codegen-350M-v2"): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| return tokenizer, model | |
| def make_prompt(code:str,task): | |
| filename = "input.py" | |
| prompt = f"<NME>main.py<BEF>{code}<MSG>{commit_message_per_brush[task]}." | |
| return prompt | |
| def generate_diff(code:str): | |
| input_ids = tokenizer.encode(code, return_tensors='pt') | |
| outputs = model.generate(input_ids, max_length=64,temperature=0.8,top_p=0.85) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def postprocess_output(generated_output:str): | |
| return verify_diff(generated_output) | |
| st.title("Code Brush") | |
| st.write("A tool to brush up your code") | |
| tokenizer,model = load_model_and_tokenizer() | |
| with st.form("my_form"): | |
| text = st.text_area("Enter your code here", height=150, value="def greet(input_name):\n return f'Hello, {input_name}'" ) | |
| brush_type = st.selectbox("Brush Type", ["Annotate Type", "Reformat", "Add Docstrings", "Add Comments"]) | |
| submit_button = st.form_submit_button("Submit") | |
| if submit_button: | |
| st.write("## Diff:") | |
| generate_diff = generate_diff(make_prompt(text,brush_type)) | |
| after_file = apply_diff_from_output(generate_diff) | |
| generate_diff_processed = postprocess_output(generate_diff) | |
| st.write(after_file) | |
| st.write(generate_diff_processed) | |
| #st.text_area(generate_diff_processed) | |
| #st.text_area(generate_diff, height=150, value=generate_diff) | |