Spaces:
Runtime error
Runtime error
| import tree_sitter | |
| from tree_sitter import Language, Parser | |
| Language.build_library("./build/my-languages.so", ['./tree-sitter-glsl']) | |
| GLSL_LANGUAGE = Language('./build/my-languages.so', 'glsl') | |
| parser = Parser() | |
| parser.set_language(GLSL_LANGUAGE) | |
| def replace_function(old_func_node, new_func_node): | |
| """ | |
| replaces the old function node with the new function node | |
| """ | |
| tree = give_tree(old_func_node) | |
| old_func_start, old_func_end = node_str_idx(old_func_node) | |
| # new_func_start, new_func_end = node_str_idx(new_func_node) | |
| new_code = tree.text[:old_func_start] + new_func_node.text + tree.text[old_func_end:] | |
| return new_code | |
| def get_root(node): | |
| """ | |
| returns the root node the tree of the given node (recursively) | |
| """ | |
| if node.parent is None: | |
| return node | |
| else: | |
| return get_root(node.parent) | |
| def node_str_idx(node): | |
| """ | |
| returns the character index of start and end of a node | |
| """ | |
| whole_text = get_root(node).text.decode() | |
| # start_idx = line_chr2char(whole_text, node.start_point[0], node.start_point[1]) | |
| # end_idx = line_chr2char(whole_text, node.end_point[0], node.end_point[1]) | |
| start_idx = node.start_byte #actual numbers? | |
| end_idx = node.end_byte | |
| return start_idx, end_idx | |
| def give_tree(func_node): | |
| """ | |
| return the tree where this function node is in | |
| """ | |
| return parser.parse(func_node.parent.text) #really no better way? | |
| def parse_functions(in_code): | |
| """ | |
| returns all functions in the code as their actual nodes. | |
| includes any comment made directly after the function definition or diretly after #copilot trigger | |
| """ | |
| tree = parser.parse(bytes(in_code, "utf8")) | |
| funcs = [n for n in tree.root_node.children if n.type == "function_definition"] | |
| return funcs | |
| def get_docstrings(func_node): | |
| """ | |
| returns the docstring of a function node | |
| """ | |
| docstring = "" | |
| for node in func_node.children: | |
| if node.type == "comment": #comment in like the declarator | |
| docstring += node.text.decode() | |
| elif node.type == "compound_statement": #body below here | |
| for body_node in node.children: | |
| if body_node.type == "comment" or body_node.type == "{": | |
| docstring += " " * body_node.start_point[1] #add in indentation | |
| docstring += body_node.text.decode() + "\n" | |
| else: | |
| return docstring | |
| return docstring | |
| def full_func_head(func_node) -> str: | |
| """ | |
| returns function head including docstrings before any real body code | |
| """ | |
| cursor = func_node.child_by_field_name("body").walk() | |
| cursor.goto_first_child() | |
| while cursor.node.type == "comment" or cursor.node.type == "{": | |
| last_char = cursor.node.end_byte | |
| cursor.goto_next_sibling() | |
| end = cursor.node.start_point | |
| # return "\n".join(func_node.text.decode().split("\n")[:(end[0]-func_node.start_point[0])])[:-(last_char)-1] | |
| return func_node.text.decode()[:(last_char - func_node.start_byte)] | |
| def grab_before_comments(func_node): | |
| """ | |
| returns the comments that happen just before a function node | |
| """ | |
| precomment = "" | |
| last_comment_line = 0 | |
| for node in func_node.parent.children: #could you optimize where to iterated from? directon? | |
| if node.start_point[0] != last_comment_line + 1: | |
| precomment = "" | |
| if node.type == "comment": | |
| precomment += node.text.decode() + "\n" | |
| last_comment_line = node.start_point[0] | |
| elif node == func_node: | |
| return precomment | |
| return precomment | |
| def has_docstrings(func_node): | |
| """ | |
| returns whether a function node has a docstring | |
| """ | |
| return get_docstrings(func_node).strip() != "{" or grab_before_comments(func_node) != "" | |
| def line_chr2char(text, line_idx, chr_idx): | |
| """ | |
| ## just use strat_byte and end_byte instead! | |
| returns the character index at the given line and character index. | |
| """ | |
| lines = text.split("\n") | |
| char_idx = 0 | |
| for i in range(line_idx): | |
| try: | |
| char_idx += len(lines[i]) + 1 | |
| except IndexError as e: | |
| raise IndexError(f"{i=} of {line_idx=} does not exist in {text=}") from e | |
| char_idx += chr_idx | |
| return char_idx | |