Spaces:
Runtime error
Runtime error
refactor generation utils
Browse files- app.py +9 -57
- utils/__init__.py +5 -3
- utils/generation.py +58 -0
app.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
| 3 |
import datasets
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
| 6 |
-
from threading import Thread
|
| 7 |
|
| 8 |
from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
|
| 9 |
from utils.html_utils import make_iframe, construct_embed
|
|
|
|
| 10 |
PIPE = None
|
| 11 |
|
| 12 |
intro_text = """
|
|
@@ -99,35 +99,6 @@ def _make_pipeline(model_cp = "Vipitis/santacoder-finetuned-Shadertoys-fine"): #
|
|
| 99 |
print(f"loaded model {model_cp} as a pipline")
|
| 100 |
return pipe
|
| 101 |
|
| 102 |
-
def _run_generation(model_ctx:str, pipe, gen_kwargs:dict):
|
| 103 |
-
"""
|
| 104 |
-
Text generation function
|
| 105 |
-
Args:
|
| 106 |
-
model_ctx (str): The context to start generation from.
|
| 107 |
-
pipe (Pipeline): The pipeline to use for generation.
|
| 108 |
-
gen_kwargs (dict): The generation kwargs.
|
| 109 |
-
Returns:
|
| 110 |
-
str: The generated text. (it iterates over time)
|
| 111 |
-
"""
|
| 112 |
-
# Tokenize the model_context
|
| 113 |
-
model_inputs = pipe.tokenizer(model_ctx, return_tensors="pt")
|
| 114 |
-
|
| 115 |
-
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
| 116 |
-
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|
| 117 |
-
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
|
| 118 |
-
generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
|
| 119 |
-
t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
|
| 120 |
-
t.start()
|
| 121 |
-
|
| 122 |
-
# Pull the generated text from the streamer, and update the model output.
|
| 123 |
-
model_output = ""
|
| 124 |
-
for new_text in streamer:
|
| 125 |
-
# print("step", end="")
|
| 126 |
-
model_output += new_text
|
| 127 |
-
yield model_output
|
| 128 |
-
streamer.on_finalized_text("stream reached the end.")
|
| 129 |
-
return model_output #is this ever reached?
|
| 130 |
-
|
| 131 |
def process_retn(retn):
|
| 132 |
return retn.split(";")[0].strip()
|
| 133 |
|
|
@@ -167,7 +138,7 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
|
|
| 167 |
else:
|
| 168 |
raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")
|
| 169 |
|
| 170 |
-
generation_kwargs =
|
| 171 |
|
| 172 |
retrns = []
|
| 173 |
retrn_start_idx = orig_code.find("return")
|
|
@@ -189,14 +160,6 @@ def alter_return(orig_code, func_idx, temperature, max_new_tokens, top_p, repeti
|
|
| 189 |
return altered_code
|
| 190 |
|
| 191 |
|
| 192 |
-
def _combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
|
| 193 |
-
gen_kwargs = {}
|
| 194 |
-
gen_kwargs["temperature"] = temperature
|
| 195 |
-
gen_kwargs["max_new_tokens"] = max_new_tokens
|
| 196 |
-
gen_kwargs["top_p"] = top_p
|
| 197 |
-
gen_kwargs["repetition_penalty"] = repetition_penalty
|
| 198 |
-
return gen_kwargs
|
| 199 |
-
|
| 200 |
def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2, max_new_tokens=512, top_p=.95, repetition_penalty=1.2, pipeline=PIPE):
|
| 201 |
"""
|
| 202 |
Replaces the body of a function with a generated one.
|
|
@@ -223,27 +186,16 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
|
|
| 223 |
func_node = funcs_list[func_id]
|
| 224 |
print(f"using for generation: {func_node=}")
|
| 225 |
|
| 226 |
-
generation_kwargs =
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
func_start_idx = line_chr2char(old_code, func_node.start_point[0], func_node.start_point[1])
|
| 229 |
-
identifier_str = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
|
| 230 |
body_node = func_node.child_by_field_name("body")
|
| 231 |
body_start_idx, body_end_idx = node_str_idx(body_node)
|
| 232 |
-
model_context = identifier_str # base case
|
| 233 |
-
|
| 234 |
-
docstring = get_docstrings(func_node) #might be empty?
|
| 235 |
-
if docstring:
|
| 236 |
-
model_context = model_context + "\n" + docstring
|
| 237 |
-
model_context = grab_before_comments(func_node) + model_context #prepend comments
|
| 238 |
-
if prompt != "":
|
| 239 |
-
model_context = f"//avialable functions: {','.join([n.child_by_field_name('declarator').text.decode() for n in funcs_list])}\n" + model_context #prepend available functions
|
| 240 |
-
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
| 241 |
-
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
| 242 |
-
print(f"{model_context=}")
|
| 243 |
# generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
|
| 244 |
-
generation =
|
| 245 |
for i in generation:
|
| 246 |
-
print(f"{i=}")
|
| 247 |
yield model_context + i #fix in between, do all the stuff in the end?
|
| 248 |
generation = i[:] #seems to work
|
| 249 |
print(f"{generation=}")
|
|
@@ -253,7 +205,7 @@ def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2,
|
|
| 253 |
first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
|
| 254 |
except IndexError:
|
| 255 |
print("generation wasn't a full function.")
|
| 256 |
-
altered_code = old_code[:
|
| 257 |
return altered_code
|
| 258 |
altered_code = replace_function(func_node, first_gened_func)
|
| 259 |
yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
| 3 |
import datasets
|
| 4 |
import numpy as np
|
| 5 |
import torch
|
|
|
|
| 6 |
|
| 7 |
from utils.tree_utils import parse_functions, get_docstrings, grab_before_comments, line_chr2char, node_str_idx, replace_function
|
| 8 |
from utils.html_utils import make_iframe, construct_embed
|
| 9 |
+
from utils.generation import combine_generation_kwargs, stream_generation, construct_model_context
|
| 10 |
PIPE = None
|
| 11 |
|
| 12 |
intro_text = """
|
|
|
|
| 99 |
print(f"loaded model {model_cp} as a pipline")
|
| 100 |
return pipe
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def process_retn(retn):
|
| 103 |
return retn.split(";")[0].strip()
|
| 104 |
|
|
|
|
| 138 |
else:
|
| 139 |
raise gr.Error(f"func_idx must be int or str, not {type(func_idx)}")
|
| 140 |
|
| 141 |
+
generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
|
| 142 |
|
| 143 |
retrns = []
|
| 144 |
retrn_start_idx = orig_code.find("return")
|
|
|
|
| 160 |
return altered_code
|
| 161 |
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
def alter_body(old_code, func_id, funcs_list: list, prompt="", temperature=0.2, max_new_tokens=512, top_p=.95, repetition_penalty=1.2, pipeline=PIPE):
|
| 164 |
"""
|
| 165 |
Replaces the body of a function with a generated one.
|
|
|
|
| 186 |
func_node = funcs_list[func_id]
|
| 187 |
print(f"using for generation: {func_node=}")
|
| 188 |
|
| 189 |
+
generation_kwargs = combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty)
|
| 190 |
+
model_context = construct_model_context(func_node, prompt=prompt)
|
| 191 |
+
print(f"{model_context=}")
|
| 192 |
|
|
|
|
|
|
|
| 193 |
body_node = func_node.child_by_field_name("body")
|
| 194 |
body_start_idx, body_end_idx = node_str_idx(body_node)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
# generation = pipeline(model_context, return_full_text=False, **generation_kwargs)[0]["generated_text"]
|
| 196 |
+
generation = stream_generation(model_context, pipeline, generation_kwargs)
|
| 197 |
for i in generation:
|
| 198 |
+
# print(f"{i=}")
|
| 199 |
yield model_context + i #fix in between, do all the stuff in the end?
|
| 200 |
generation = i[:] #seems to work
|
| 201 |
print(f"{generation=}")
|
|
|
|
| 205 |
first_gened_func = parse_functions(ctx_with_generation)[0] # truncate generation to a single function?
|
| 206 |
except IndexError:
|
| 207 |
print("generation wasn't a full function.")
|
| 208 |
+
altered_code = old_code[:body_start_idx] + generation + "//the generation didn't complete the function!\n" + old_code[body_end_idx:] #needs a newline to break out of the comment.
|
| 209 |
return altered_code
|
| 210 |
altered_code = replace_function(func_node, first_gened_func)
|
| 211 |
yield altered_code #yield once so it updates? -> works... gg but doesn't seem to do it for the dropdown
|
utils/__init__.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
-
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char)
|
| 2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
|
|
|
| 3 |
|
| 4 |
-
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char"]
|
| 5 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
|
|
|
| 6 |
|
| 7 |
-
__all__ = tree_funcs + html_funcs
|
|
|
|
| 1 |
+
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree)
|
| 2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
| 3 |
+
from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
|
| 4 |
|
| 5 |
+
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree"]
|
| 6 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
| 7 |
+
gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
|
| 8 |
|
| 9 |
+
__all__ = tree_funcs + html_funcs + gen_funcs
|
utils/generation.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import TextIteratorStreamer
|
| 2 |
+
from threading import Thread
|
| 3 |
+
from utils.tree_utils import get_docstrings, grab_before_comments
|
| 4 |
+
|
| 5 |
+
def combine_generation_kwargs(temperature, max_new_tokens, top_p, repetition_penalty):
|
| 6 |
+
"""
|
| 7 |
+
Combines the generation kwargs into a single dict.
|
| 8 |
+
"""
|
| 9 |
+
gen_kwargs = {}
|
| 10 |
+
gen_kwargs["temperature"] = temperature
|
| 11 |
+
gen_kwargs["max_new_tokens"] = max_new_tokens
|
| 12 |
+
gen_kwargs["top_p"] = top_p
|
| 13 |
+
gen_kwargs["repetition_penalty"] = repetition_penalty
|
| 14 |
+
return gen_kwargs
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def stream_generation(prompt:str, pipe, gen_kwargs:dict):
|
| 18 |
+
"""
|
| 19 |
+
Text generation function
|
| 20 |
+
Args:
|
| 21 |
+
prompt (str): The context to start generation from.
|
| 22 |
+
pipe (Pipeline): The pipeline to use for generation.
|
| 23 |
+
gen_kwargs (dict): The generation kwargs.
|
| 24 |
+
Returns:
|
| 25 |
+
str: The generated text. (it iterates over time)
|
| 26 |
+
"""
|
| 27 |
+
# Tokenize the model_context
|
| 28 |
+
model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
| 29 |
+
|
| 30 |
+
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
| 31 |
+
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|
| 32 |
+
streamer = TextIteratorStreamer(pipe.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15.0)
|
| 33 |
+
generate_kwargs = dict(model_inputs, streamer=streamer, **gen_kwargs)
|
| 34 |
+
t = Thread(target=pipe.model.generate, kwargs=generate_kwargs)
|
| 35 |
+
t.start()
|
| 36 |
+
|
| 37 |
+
# Pull the generated text from the streamer, and update the model output.
|
| 38 |
+
model_output = ""
|
| 39 |
+
for new_text in streamer:
|
| 40 |
+
# print("step", end="")
|
| 41 |
+
model_output += new_text
|
| 42 |
+
yield model_output
|
| 43 |
+
streamer.on_finalized_text("stream reached the end.")
|
| 44 |
+
return model_output #is this ever reached?
|
| 45 |
+
|
| 46 |
+
def construct_model_context(func_node, prompt="") -> str:
|
| 47 |
+
"""
|
| 48 |
+
Constructs the model context from a function node.
|
| 49 |
+
"""
|
| 50 |
+
model_context = func_node.child_by_field_name("type").text.decode() + " " + func_node.child_by_field_name("declarator").text.decode() #func_start_idx:body_start_idx?
|
| 51 |
+
docstring = get_docstrings(func_node) #might be empty?
|
| 52 |
+
if docstring:
|
| 53 |
+
model_context = model_context + "\n" + docstring
|
| 54 |
+
model_context = grab_before_comments(func_node) + model_context #prepend comments
|
| 55 |
+
if prompt != "":
|
| 56 |
+
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
| 57 |
+
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
| 58 |
+
return model_context
|