| import os | |
| import pathlib | |
| import rwkv_world_tokenizer | |
| from typing import List, Tuple, Callable | |
| def add_tokenizer_argument(parser) -> None: | |
| parser.add_argument( | |
| 'tokenizer', | |
| help='Tokenizer to use; supported tokenizers: auto (guess from n_vocab), 20B, world', | |
| nargs='?', | |
| type=str, | |
| default='auto' | |
| ) | |
| def get_tokenizer(tokenizer_name: str, n_vocab: int) -> Tuple[ | |
| Callable[[List[int]], str], | |
| Callable[[str], List[int]] | |
| ]: | |
| if tokenizer_name == 'auto': | |
| if n_vocab == 50277: | |
| tokenizer_name = '20B' | |
| elif n_vocab == 65536: | |
| tokenizer_name = 'world' | |
| else: | |
| raise ValueError(f'Can not guess the tokenizer from n_vocab value of {n_vocab}') | |
| parent: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent | |
| if tokenizer_name == 'world': | |
| print('Loading World v20230424 tokenizer') | |
| return rwkv_world_tokenizer.get_world_tokenizer_v20230424() | |
| elif tokenizer_name == '20B': | |
| print('Loading 20B tokenizer') | |
| import tokenizers | |
| tokenizer: tokenizers.Tokenizer = tokenizers.Tokenizer.from_file(str(parent / '20B_tokenizer.json')) | |
| return tokenizer.decode, lambda x: tokenizer.encode(x).ids | |
| else: | |
| raise ValueError(f'Unknown tokenizer {tokenizer_name}') | |