Upload tokenization_hy.py with huggingface_hub
Browse files- tokenization_hy.py +296 -0
    	
        tokenization_hy.py
    ADDED
    
    | @@ -0,0 +1,296 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import base64
         | 
| 2 | 
            +
            import logging
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import unicodedata
         | 
| 5 | 
            +
            from typing import Collection, Dict, List, Set, Tuple, Union
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import tiktoken
         | 
| 8 | 
            +
            from transformers import PreTrainedTokenizer, AddedToken
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            VOCAB_FILES_NAMES = {"vocab_file": "hy.tiktoken"}
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
         | 
| 16 | 
            +
            # PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
         | 
| 17 | 
            +
            ENDOFTEXT = "<|endoftext|>"
         | 
| 18 | 
            +
            STARTOFTEXT = "<|startoftext|>"
         | 
| 19 | 
            +
            BOSTOKEN = "<|bos|>"
         | 
| 20 | 
            +
            EOSTOKEN = "<|eos|>"
         | 
| 21 | 
            +
            PADTOKEN = "<|pad|>"
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            # as the default behavior is changed to allow special tokens in
         | 
| 24 | 
            +
            # regular texts, the surface forms of special tokens need to be
         | 
| 25 | 
            +
            # as different as possible to minimize the impact
         | 
| 26 | 
            +
            EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
         | 
| 27 | 
            +
            # changed to use actual index to avoid misconfiguration with vocabulary expansion
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            SPECIAL_START_ID = 127957
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
         | 
| 33 | 
            +
                # with open(tiktoken_bpe_file, "rb") as f:
         | 
| 34 | 
            +
                #     contents = f.read()
         | 
| 35 | 
            +
                dic = {}
         | 
| 36 | 
            +
                rank = 0
         | 
| 37 | 
            +
                for line in open(tiktoken_bpe_file, "rb"):
         | 
| 38 | 
            +
                    if line:
         | 
| 39 | 
            +
                        token, _ = line.split()
         | 
| 40 | 
            +
                        if base64.b64decode(token) in dic:
         | 
| 41 | 
            +
                            continue
         | 
| 42 | 
            +
                        dic[base64.b64decode(token)] = int(rank)
         | 
| 43 | 
            +
                        rank += 1
         | 
| 44 | 
            +
                global SPECIAL_START_ID
         | 
| 45 | 
            +
                SPECIAL_START_ID=rank
         | 
| 46 | 
            +
                return dic
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            # print(SPECIAL_START_ID)
         | 
| 49 | 
            +
             | 
| 50 | 
            +
            SPECIAL_TOKENS = tuple(
         | 
| 51 | 
            +
                enumerate(
         | 
| 52 | 
            +
                    (
         | 
| 53 | 
            +
                        (
         | 
| 54 | 
            +
                            ENDOFTEXT,
         | 
| 55 | 
            +
                            STARTOFTEXT,
         | 
| 56 | 
            +
                            BOSTOKEN,
         | 
| 57 | 
            +
                            EOSTOKEN,
         | 
| 58 | 
            +
                            PADTOKEN,
         | 
| 59 | 
            +
                        )
         | 
| 60 | 
            +
                        + EXTRAS
         | 
| 61 | 
            +
                    ),
         | 
| 62 | 
            +
                    start=SPECIAL_START_ID,
         | 
| 63 | 
            +
                )
         | 
| 64 | 
            +
            )
         | 
| 65 | 
            +
            # NOTE: Unused Token ID starts from 127962
         | 
| 66 | 
            +
            SPECIAL_TOKENS_SET = set(t for i, t in SPECIAL_TOKENS)
         | 
| 67 | 
            +
             | 
| 68 | 
            +
            class HYTokenizer(PreTrainedTokenizer):
         | 
| 69 | 
            +
                """hunyuan tokenizer."""
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                vocab_files_names = VOCAB_FILES_NAMES
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def __init__(
         | 
| 74 | 
            +
                    self,
         | 
| 75 | 
            +
                    vocab_file,
         | 
| 76 | 
            +
                    errors="replace",
         | 
| 77 | 
            +
                    extra_vocab_file=None,
         | 
| 78 | 
            +
                    **kwargs,
         | 
| 79 | 
            +
                ):
         | 
| 80 | 
            +
                    super().__init__(**kwargs)
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # how to handle errors in decoding UTF-8 byte sequences
         | 
| 83 | 
            +
                    # use ignore if you are in streaming inference
         | 
| 84 | 
            +
                    self.errors = errors  
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)  # type: Dict[bytes, int]
         | 
| 87 | 
            +
                    self.special_tokens = {
         | 
| 88 | 
            +
                        token: index
         | 
| 89 | 
            +
                        for index, token in SPECIAL_TOKENS
         | 
| 90 | 
            +
                    }
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    # try load extra vocab from file
         | 
| 93 | 
            +
                    if extra_vocab_file is not None:
         | 
| 94 | 
            +
                        used_ids = set(self.mergeable_ranks.values()) | set(self.special_tokens.values())
         | 
| 95 | 
            +
                        extra_mergeable_ranks = _load_tiktoken_bpe(extra_vocab_file)
         | 
| 96 | 
            +
                        for token, index in extra_mergeable_ranks.items():
         | 
| 97 | 
            +
                            if token in self.mergeable_ranks:
         | 
| 98 | 
            +
                                logger.info(f"extra token {token} exists, skipping")
         | 
| 99 | 
            +
                                continue
         | 
| 100 | 
            +
                            if index in used_ids:
         | 
| 101 | 
            +
                                logger.info(f'the index {index} for extra token {token} exists, skipping')
         | 
| 102 | 
            +
                                continue
         | 
| 103 | 
            +
                            self.mergeable_ranks[token] = index
         | 
| 104 | 
            +
                        # the index may be sparse after this, but don't worry tiktoken.Encoding will handle this
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    enc = tiktoken.Encoding(
         | 
| 107 | 
            +
                        "HunYuan",
         | 
| 108 | 
            +
                        pat_str=PAT_STR,
         | 
| 109 | 
            +
                        mergeable_ranks=self.mergeable_ranks,
         | 
| 110 | 
            +
                        special_tokens=self.special_tokens,
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
                    assert (
         | 
| 113 | 
            +
                        len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
         | 
| 114 | 
            +
                    ), f"{len(self.mergeable_ranks)} + {len(self.special_tokens)} != {enc.n_vocab} in encoding"
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    self.decoder = {
         | 
| 117 | 
            +
                        v: k for k, v in self.mergeable_ranks.items()
         | 
| 118 | 
            +
                    }  # type: dict[int, bytes|str]
         | 
| 119 | 
            +
                    self.decoder.update({v: k for k, v in self.special_tokens.items()})
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    self.tokenizer = enc  # type: tiktoken.Encoding
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.eod_id = self.tokenizer.eot_token
         | 
| 124 | 
            +
                    self.bod_id = self.special_tokens[STARTOFTEXT]
         | 
| 125 | 
            +
                    self.bos_id = self.special_tokens[BOSTOKEN]
         | 
| 126 | 
            +
                    self.eos_id = self.special_tokens[EOSTOKEN]
         | 
| 127 | 
            +
                    self.pad_id = self.special_tokens[PADTOKEN]
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                def __getstate__(self):
         | 
| 130 | 
            +
                    # for pickle lovers
         | 
| 131 | 
            +
                    state = self.__dict__.copy()
         | 
| 132 | 
            +
                    del state["tokenizer"]
         | 
| 133 | 
            +
                    return state
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def __setstate__(self, state):
         | 
| 136 | 
            +
                    # tokenizer is not python native; don't pass it; rebuild it
         | 
| 137 | 
            +
                    self.__dict__.update(state)
         | 
| 138 | 
            +
                    enc = tiktoken.Encoding(
         | 
| 139 | 
            +
                        "HunYuan",
         | 
| 140 | 
            +
                        pat_str=PAT_STR,
         | 
| 141 | 
            +
                        mergeable_ranks=self.mergeable_ranks,
         | 
| 142 | 
            +
                        special_tokens=self.special_tokens,
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                    self.tokenizer = enc
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def __len__(self) -> int:
         | 
| 147 | 
            +
                    return self.tokenizer.n_vocab
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def get_vocab(self) -> Dict[bytes, int]:
         | 
| 150 | 
            +
                    return self.mergeable_ranks
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def convert_tokens_to_ids(
         | 
| 153 | 
            +
                    self, tokens: Union[bytes, str, List[Union[bytes, str]]]
         | 
| 154 | 
            +
                ) -> List[int]:
         | 
| 155 | 
            +
                    ids = []
         | 
| 156 | 
            +
                    if isinstance(tokens, (str, bytes)):
         | 
| 157 | 
            +
                        if tokens in self.special_tokens:
         | 
| 158 | 
            +
                            return self.special_tokens[tokens]
         | 
| 159 | 
            +
                        else:
         | 
| 160 | 
            +
                            return self.mergeable_ranks.get(tokens)
         | 
| 161 | 
            +
                    for token in tokens:
         | 
| 162 | 
            +
                        if token in self.special_tokens:
         | 
| 163 | 
            +
                            ids.append(self.special_tokens[token])
         | 
| 164 | 
            +
                        else:
         | 
| 165 | 
            +
                            ids.append(self.mergeable_ranks.get(token))
         | 
| 166 | 
            +
                    return ids
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                def _add_tokens(
         | 
| 169 | 
            +
                    self,
         | 
| 170 | 
            +
                    new_tokens: Union[List[str], List[AddedToken]],
         | 
| 171 | 
            +
                    special_tokens: bool = False,
         | 
| 172 | 
            +
                ) -> int:
         | 
| 173 | 
            +
                    if not special_tokens and new_tokens:
         | 
| 174 | 
            +
                        raise ValueError("Adding regular tokens is not supported")
         | 
| 175 | 
            +
                    for token in new_tokens:
         | 
| 176 | 
            +
                        surface_form = token.content if isinstance(token, AddedToken) else token
         | 
| 177 | 
            +
                        if surface_form not in SPECIAL_TOKENS_SET:
         | 
| 178 | 
            +
                            raise ValueError("Adding unknown special tokens is not supported")
         | 
| 179 | 
            +
                    return 0
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
         | 
| 182 | 
            +
                    """
         | 
| 183 | 
            +
                    Save only the vocabulary of the tokenizer (vocabulary).
         | 
| 184 | 
            +
                    Returns:
         | 
| 185 | 
            +
                        `Tuple(str)`: Paths to the files saved.
         | 
| 186 | 
            +
                    """
         | 
| 187 | 
            +
                    file_path = os.path.join(save_directory, "hunyuan.tiktoken")
         | 
| 188 | 
            +
                    with open(file_path, "w", encoding="utf8") as w:
         | 
| 189 | 
            +
                        for k, v in self.mergeable_ranks.items():
         | 
| 190 | 
            +
                            line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
         | 
| 191 | 
            +
                            w.write(line)
         | 
| 192 | 
            +
                    return (file_path,)
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def tokenize(
         | 
| 195 | 
            +
                    self,
         | 
| 196 | 
            +
                    text: str,
         | 
| 197 | 
            +
                    allowed_special: Union[Set, str] = "all",
         | 
| 198 | 
            +
                    disallowed_special: Union[Collection, str] = (),
         | 
| 199 | 
            +
                    **kwargs,
         | 
| 200 | 
            +
                ) -> List[Union[bytes, str]]:
         | 
| 201 | 
            +
                    """
         | 
| 202 | 
            +
                    Converts a string in a sequence of tokens.
         | 
| 203 | 
            +
                    Args:
         | 
| 204 | 
            +
                        text (`str`):
         | 
| 205 | 
            +
                            The sequence to be encoded.
         | 
| 206 | 
            +
                        allowed_special (`Literal["all"]` or `set`):
         | 
| 207 | 
            +
                            The surface forms of the tokens to be encoded as special tokens in regular texts.
         | 
| 208 | 
            +
                            Default to "all".
         | 
| 209 | 
            +
                        disallowed_special (`Literal["all"]` or `Collection`):
         | 
| 210 | 
            +
                            The surface forms of the tokens that should not be in regular texts and trigger errors.
         | 
| 211 | 
            +
                            Default to an empty tuple.
         | 
| 212 | 
            +
                        kwargs (additional keyword arguments, *optional*):
         | 
| 213 | 
            +
                            Will be passed to the underlying model specific encode method.
         | 
| 214 | 
            +
                    Returns:
         | 
| 215 | 
            +
                        `List[bytes|str]`: The list of tokens.
         | 
| 216 | 
            +
                    """
         | 
| 217 | 
            +
                    tokens = []
         | 
| 218 | 
            +
                    text = unicodedata.normalize("NFC", text)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # this implementation takes a detour: text -> token id -> token surface forms
         | 
| 221 | 
            +
                    for t in self.tokenizer.encode(
         | 
| 222 | 
            +
                        text, allowed_special=allowed_special, disallowed_special=disallowed_special
         | 
| 223 | 
            +
                    ):
         | 
| 224 | 
            +
                        tokens.append(self.decoder[t])
         | 
| 225 | 
            +
                    return tokens
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
         | 
| 228 | 
            +
                    """
         | 
| 229 | 
            +
                    Converts a sequence of tokens in a single string.
         | 
| 230 | 
            +
                    """
         | 
| 231 | 
            +
                    text = ""
         | 
| 232 | 
            +
                    temp = b""
         | 
| 233 | 
            +
                    for t in tokens:
         | 
| 234 | 
            +
                        if isinstance(t, str):
         | 
| 235 | 
            +
                            if temp:
         | 
| 236 | 
            +
                                text += temp.decode("utf-8", errors=self.errors)
         | 
| 237 | 
            +
                                temp = b""
         | 
| 238 | 
            +
                            text += t
         | 
| 239 | 
            +
                        elif isinstance(t, bytes):
         | 
| 240 | 
            +
                            temp += t
         | 
| 241 | 
            +
                        else:
         | 
| 242 | 
            +
                            raise TypeError("token should only be of type types or str")
         | 
| 243 | 
            +
                    if temp:
         | 
| 244 | 
            +
                        text += temp.decode("utf-8", errors=self.errors)
         | 
| 245 | 
            +
                    return text
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                @property
         | 
| 248 | 
            +
                def vocab_size(self):
         | 
| 249 | 
            +
                    return self.tokenizer.n_vocab
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
         | 
| 252 | 
            +
                    """Converts an id to a token, special tokens included"""
         | 
| 253 | 
            +
                    if index in self.decoder:
         | 
| 254 | 
            +
                        return self.decoder[index]
         | 
| 255 | 
            +
                    raise ValueError("unknown ids")
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
         | 
| 258 | 
            +
                    """Converts a token to an id using the vocab, special tokens included"""
         | 
| 259 | 
            +
                    if token in self.special_tokens:
         | 
| 260 | 
            +
                        return self.special_tokens[token]
         | 
| 261 | 
            +
                    if token in self.mergeable_ranks:
         | 
| 262 | 
            +
                        return self.mergeable_ranks[token]
         | 
| 263 | 
            +
                    raise ValueError("unknown token")
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                def _tokenize(self, text: str, **kwargs):
         | 
| 266 | 
            +
                    """
         | 
| 267 | 
            +
                    Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
         | 
| 268 | 
            +
                    vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
         | 
| 269 | 
            +
                    Do NOT take care of added tokens.
         | 
| 270 | 
            +
                    """
         | 
| 271 | 
            +
                    raise NotImplementedError
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                def _decode(
         | 
| 274 | 
            +
                    self,
         | 
| 275 | 
            +
                    token_ids: Union[int, List[int]],
         | 
| 276 | 
            +
                    skip_special_tokens: bool = False,
         | 
| 277 | 
            +
                    errors: str = None,
         | 
| 278 | 
            +
                    **kwargs,
         | 
| 279 | 
            +
                ) -> str:
         | 
| 280 | 
            +
                    if isinstance(token_ids, int):
         | 
| 281 | 
            +
                        token_ids = [token_ids]
         | 
| 282 | 
            +
                    if skip_special_tokens:
         | 
| 283 | 
            +
                        token_ids = [i for i in token_ids if i < self.eod_id]
         | 
| 284 | 
            +
                    return self.tokenizer.decode(token_ids, errors=errors or self.errors)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
            # tests
         | 
| 287 | 
            +
            if __name__ == "__main__":
         | 
| 288 | 
            +
                tokenizer = HYTokenizer.from_pretrained('/hy')
         | 
| 289 | 
            +
                text = '你好,世界'
         | 
| 290 | 
            +
                tokens = tokenizer.tokenize(text)
         | 
| 291 | 
            +
                print(tokens)
         | 
| 292 | 
            +
                ids = tokenizer.convert_tokens_to_ids(tokens)
         | 
| 293 | 
            +
                print(ids)
         | 
| 294 | 
            +
                text2 = tokenizer.convert_tokens_to_string(tokens)
         | 
| 295 | 
            +
                print(text2)
         | 
| 296 | 
            +
                ids2 = tokenizer.convert_tokens_to_ids(tokens)
         | 

