Spaces:
Runtime error
Runtime error
| from typing import Dict, Optional, Sequence, List | |
| import copy | |
| import transformers | |
| import torch | |
| from tinychart.data.process import register_preprocess | |
| from tinychart.mm_utils import tokenizer_image_token | |
| from tinychart import conversation as conversation_lib | |
| from tinychart.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \ | |
| DEFAULT_IM_END_TOKEN | |
| def preprocess_default( | |
| sources: Sequence[str], | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| has_image: bool = False | |
| ) -> Dict: | |
| conversations = [] | |
| for source in sources: | |
| header = f"{conversation_lib.default_conversation.system}\n\n" | |
| conversation = _add_speaker_and_signal(header, source) | |
| conversations.append(conversation) | |
| # tokenize conversations | |
| def get_tokenize_len(prompts): | |
| return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] | |
| if has_image: | |
| input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] | |
| else: | |
| conversations_tokenized = _tokenize_fn(conversations, tokenizer) | |
| input_ids = conversations_tokenized["input_ids"] | |
| targets = copy.deepcopy(input_ids) | |
| for target, source in zip(targets, sources): | |
| if has_image: | |
| tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) | |
| else: | |
| tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source], tokenizer)["input_ids_lens"] | |
| speakers = [sentence["from"] for sentence in source] | |
| _mask_targets(target, tokenized_lens, speakers) | |
| return dict(input_ids=input_ids, labels=targets) | |
| def _tokenize_fn(strings: Sequence[str], | |
| tokenizer: transformers.PreTrainedTokenizer) -> Dict: | |
| """Tokenize a list of strings.""" | |
| tokenized_list = [ | |
| tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="longest", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ) for text in strings | |
| ] | |
| input_ids = labels = [ | |
| tokenized.input_ids[0] for tokenized in tokenized_list | |
| ] | |
| input_ids_lens = labels_lens = [ | |
| tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() | |
| for tokenized in tokenized_list | |
| ] | |
| return dict( | |
| input_ids=input_ids, | |
| labels=labels, | |
| input_ids_lens=input_ids_lens, | |
| labels_lens=labels_lens, | |
| ) | |
| def _add_speaker_and_signal(header, source, get_conversation=True): | |
| """Add speaker and start/end signal on each round.""" | |
| BEGIN_SIGNAL = "### " | |
| END_SIGNAL = "\n" | |
| conversation = header | |
| for sentence in source: | |
| from_str = sentence["from"] | |
| if from_str.lower() == "human": | |
| from_str = conversation_lib.default_conversation.roles[0] | |
| elif from_str.lower() == "gpt": | |
| from_str = conversation_lib.default_conversation.roles[1] | |
| else: | |
| from_str = 'unknown' | |
| sentence["value"] = (BEGIN_SIGNAL + from_str + ": " + | |
| sentence["value"] + END_SIGNAL) | |
| if get_conversation: | |
| conversation += sentence["value"] | |
| conversation += BEGIN_SIGNAL | |
| return conversation | |
| def _mask_targets(target, tokenized_lens, speakers): | |
| # cur_idx = 0 | |
| cur_idx = tokenized_lens[0] | |
| tokenized_lens = tokenized_lens[1:] | |
| target[:cur_idx] = IGNORE_INDEX | |
| for tokenized_len, speaker in zip(tokenized_lens, speakers): | |
| if speaker == "human": | |
| target[cur_idx + 2:cur_idx + tokenized_len] = IGNORE_INDEX | |
| cur_idx += tokenized_len | |