Spaces:
Runtime error
Runtime error
improved docstring extraction
Browse files- utils/__init__.py +2 -2
- utils/generation.py +2 -6
- utils/tree_utils.py +27 -5
utils/__init__.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree)
|
| 2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
| 3 |
from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
|
| 4 |
|
| 5 |
-
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree"]
|
| 6 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
| 7 |
gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
|
| 8 |
|
|
|
|
| 1 |
+
from .tree_utils import (parse_functions, get_docstrings, grab_before_comments, line_chr2char, replace_function, get_root, node_str_idx, give_tree, full_func_head, has_docstrings)
|
| 2 |
from .html_utils import (make_iframe, make_script, construct_embed)
|
| 3 |
from .generation import (combine_generation_kwargs, stream_generation, construct_model_context)
|
| 4 |
|
| 5 |
+
tree_funcs = ["parse_functions", "get_docstrings", "grab_before_comments", "line_chr2char", "replace_function", "get_root", "node_str_idx", "give_tree", "full_func_head", "has_docstrings"]
|
| 6 |
html_funcs = ["make_iframe", "make_script", "construct_embed"]
|
| 7 |
gen_funcs = ["combine_generation_kwargs", "stream_generation", "construct_model_context"]
|
| 8 |
|
utils/generation.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from transformers import TextIteratorStreamer
|
| 2 |
from threading import Thread
|
| 3 |
-
from .tree_utils import
|
| 4 |
|
| 5 |
def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, repetition_penalty=1.2):
|
| 6 |
"""
|
|
@@ -47,11 +47,7 @@ def construct_model_context(func_node, prompt="") -> str:
|
|
| 47 |
"""
|
| 48 |
Constructs the model context from a function node.
|
| 49 |
"""
|
| 50 |
-
model_context = func_node
|
| 51 |
-
docstring = get_docstrings(func_node) #might be empty?
|
| 52 |
-
if docstring:
|
| 53 |
-
model_context = model_context + "\n" + docstring
|
| 54 |
-
model_context = grab_before_comments(func_node) + model_context #prepend comments
|
| 55 |
if prompt != "":
|
| 56 |
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
| 57 |
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
|
|
|
| 1 |
from transformers import TextIteratorStreamer
|
| 2 |
from threading import Thread
|
| 3 |
+
from .tree_utils import full_func_head, grab_before_comments
|
| 4 |
|
| 5 |
def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, repetition_penalty=1.2):
|
| 6 |
"""
|
|
|
|
| 47 |
"""
|
| 48 |
Constructs the model context from a function node.
|
| 49 |
"""
|
| 50 |
+
model_context = grab_before_comments(func_node) + full_func_head(func_node) # (identifier + docstrings)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
if prompt != "":
|
| 52 |
model_context = "//Title: " + prompt + "\n" + model_context #prepend user prompt/title
|
| 53 |
model_context = "//Language: Shadertoy GLSL fragment shader\n" + model_context #prepend system prompt, language hint
|
utils/tree_utils.py
CHANGED
|
@@ -56,13 +56,28 @@ def get_docstrings(func_node):
|
|
| 56 |
returns the docstring of a function node
|
| 57 |
"""
|
| 58 |
docstring = ""
|
| 59 |
-
for node in func_node.
|
| 60 |
-
if node.type == "comment"
|
| 61 |
-
docstring += node.text.decode()
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
return docstring
|
| 65 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
def grab_before_comments(func_node):
|
| 68 |
"""
|
|
@@ -80,6 +95,13 @@ def grab_before_comments(func_node):
|
|
| 80 |
return precomment
|
| 81 |
return precomment
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def line_chr2char(text, line_idx, chr_idx):
|
| 84 |
"""
|
| 85 |
returns the character index at the given line and character index.
|
|
|
|
| 56 |
returns the docstring of a function node
|
| 57 |
"""
|
| 58 |
docstring = ""
|
| 59 |
+
for node in func_node.children:
|
| 60 |
+
if node.type == "comment": #comment in like the declarator
|
| 61 |
+
docstring += node.text.decode()
|
| 62 |
+
elif node.type == "compound_statement": #body below here
|
| 63 |
+
for body_node in node.children:
|
| 64 |
+
if body_node.type == "comment" or body_node.type == "{":
|
| 65 |
+
docstring += " " * body_node.start_point[1] #add in indentation
|
| 66 |
+
docstring += body_node.text.decode() + "\n"
|
| 67 |
+
else:
|
| 68 |
+
return docstring
|
| 69 |
return docstring
|
| 70 |
|
| 71 |
+
def full_func_head(func_node):
|
| 72 |
+
"""
|
| 73 |
+
returns function head including docstrings before any real body code
|
| 74 |
+
"""
|
| 75 |
+
cursor = func_node.child_by_field_name("body").walk()
|
| 76 |
+
cursor.goto_first_child()
|
| 77 |
+
while cursor.node.type == "comment" or cursor.node.type == "{":
|
| 78 |
+
cursor.goto_next_sibling()
|
| 79 |
+
end = cursor.node.start_point
|
| 80 |
+
return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])
|
| 81 |
|
| 82 |
def grab_before_comments(func_node):
|
| 83 |
"""
|
|
|
|
| 95 |
return precomment
|
| 96 |
return precomment
|
| 97 |
|
| 98 |
+
def has_docstrings(func_node):
|
| 99 |
+
"""
|
| 100 |
+
returns whether a function node has a docstring
|
| 101 |
+
"""
|
| 102 |
+
return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node) != ""
|
| 103 |
+
|
| 104 |
+
|
| 105 |
def line_chr2char(text, line_idx, chr_idx):
|
| 106 |
"""
|
| 107 |
returns the character index at the given line and character index.
|