| from typing import List | |
| from tokenizers import NormalizedString, PreTokenizedString | |
| from tokenizers.pre_tokenizers import PreTokenizer | |
| from transformers import PreTrainedTokenizerFast | |
| try: | |
| from clang import cindex | |
| except ModuleNotFoundError as e: | |
| raise ModuleNotFoundError( | |
| "VulBERTa Clang tokenizer requires `libclang`. Please install it via `pip install libclang`.", | |
| ) from e | |
| class ClangPreTokenizer: | |
| cidx = cindex.Index.create() | |
| def clang_split( | |
| self, | |
| i: int, | |
| normalized_string: NormalizedString, | |
| ) -> List[NormalizedString]: | |
| tok = [] | |
| tu = self.cidx.parse( | |
| "tmp.c", | |
| args=[""], | |
| unsaved_files=[("tmp.c", str(normalized_string.original))], | |
| options=0, | |
| ) | |
| for t in tu.get_tokens(extent=tu.cursor.extent): | |
| spelling = t.spelling.strip() | |
| if spelling == "": | |
| continue | |
| tok.append(NormalizedString(spelling)) | |
| return tok | |
| def pre_tokenize(self, pretok: PreTokenizedString): | |
| pretok.split(self.clang_split) | |
| class VulBERTaTokenizer(PreTrainedTokenizerFast): | |
| def __init__( | |
| self, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| *args, | |
| **kwargs, | |
| ) | |
| self._tokenizer.pre_tokenizer = PreTokenizer.custom(ClangPreTokenizer()) | |