Spaces:
Configuration error
Configuration error
| import os | |
| import gc | |
| import copy | |
| import time | |
| import torch | |
| import warnings | |
| import transformers | |
| import numpy as np | |
| from typing import Dict, Optional, Sequence | |
| from omnilmm import conversation as conversation_lib | |
| IGNORE_INDEX = -100 | |
| DEFAULT_IMAGE_TOKEN = "<image>" | |
| DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" | |
| DEFAULT_IM_START_TOKEN = "<im_start>" | |
| DEFAULT_IM_END_TOKEN = "<im_end>" | |
| def _tokenize_fn(strings: Sequence[str], | |
| tokenizer: transformers.PreTrainedTokenizer) -> Dict: | |
| """Tokenize a list of strings.""" | |
| tokenized_list = [ | |
| tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding="longest", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| ) for text in strings | |
| ] | |
| input_ids = labels = [ | |
| tokenized.input_ids[0] for tokenized in tokenized_list | |
| ] | |
| input_ids_lens = labels_lens = [ | |
| tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() | |
| for tokenized in tokenized_list | |
| ] | |
| return dict( | |
| input_ids=input_ids, | |
| labels=labels, | |
| input_ids_lens=input_ids_lens, | |
| labels_lens=labels_lens, | |
| ) | |
| def omni_preprocess(sources, | |
| tokenizer: transformers.PreTrainedTokenizer, | |
| generation=False): | |
| system_content = 'You are an artificial intelligence assistant, which gives helpful, detailed, and polite answers to the human\'s questions.' | |
| ignore_index = -100 | |
| response_template = '\n<|assistant|>\n' | |
| instruction_template = '\n<|user|>\n' | |
| response_token_ids = tokenizer.encode( | |
| response_template, add_special_tokens=False) | |
| instruction_token_ids = tokenizer.encode( | |
| instruction_template, add_special_tokens=False) | |
| batch_input_ids = [] | |
| batch_labels = [] | |
| for i in range(len(sources)): | |
| new_source = [] | |
| prev_role = 'unexpect' | |
| for conv_turn in sources[i]: | |
| role = conv_turn['from'] if 'from' in conv_turn else conv_turn['role'] | |
| content = conv_turn['value'] if 'value' in conv_turn else conv_turn['content'] | |
| role = 'user' if role == 'human' else role | |
| role = 'assistant' if role == 'gpt' else role | |
| assert role in ['user', 'assistant'] | |
| assert role != prev_role, f'role={role}, prev_role={prev_role}' | |
| prev_role = role | |
| new_turn = { | |
| 'role': role, | |
| 'content': content | |
| } | |
| new_source.append(new_turn) | |
| if new_source[0]['role'] != 'system': | |
| new_source.insert(0, {'role': 'system', 'content': system_content}) | |
| # TODO: this automatically add '\n' to the end | |
| res_text = tokenizer.apply_chat_template( | |
| new_source, tokenize=False, add_generation_prompt=generation) | |
| if not generation: | |
| res_text = res_text.strip() | |
| conversations_tokenized = _tokenize_fn([res_text], tokenizer) | |
| res_input_ids = conversations_tokenized["input_ids"][0] | |
| # since labels and input_ids are reference towards the same object | |
| res_labels = copy.deepcopy(conversations_tokenized["labels"][0]) | |
| response_token_ids_idxs = [] | |
| human_token_ids_idxs = [] | |
| for assistant_idx in np.where(res_labels == response_token_ids[0])[0]: | |
| # find the indexes of the start of a response. | |
| if (response_token_ids == res_labels[assistant_idx: assistant_idx + len( | |
| response_token_ids)].tolist() | |
| ): | |
| response_token_ids_idxs.append( | |
| assistant_idx + len(response_token_ids)) | |
| if len(response_token_ids_idxs) == 0: | |
| warnings.warn( | |
| f"Could not find response key `{response_template}` in the " | |
| f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' | |
| f'Raw text is @===>{res_text}<===@' | |
| f'Raw source is @===>{new_source}<===@' | |
| f"This instance will be ignored in loss calculation. " | |
| f"Note, if this happens often, consider increasing the `max_seq_length`." | |
| ) | |
| res_labels[:] = ignore_index | |
| human_token_ids = instruction_token_ids | |
| for human_idx in np.where(res_labels == human_token_ids[0])[0]: | |
| # find the indexes of the start of a human answer. | |
| if human_token_ids == res_labels[human_idx: human_idx + len(human_token_ids)].tolist(): | |
| human_token_ids_idxs.append(human_idx) | |
| if len(human_token_ids_idxs) == 0: | |
| warnings.warn( | |
| f"Could not find instruction key `{instruction_template}` in the " | |
| f'following instance: @===>{tokenizer.decode(res_input_ids)}<===@ ' | |
| f'Raw text is @===>{res_text}<===@' | |
| f'Raw source is @===>{new_source}<===@' | |
| f"This instance will be ignored in loss calculation. " | |
| f"Note, if this happens often, consider increasing the `max_seq_length`." | |
| ) | |
| res_labels[:] = ignore_index | |
| for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)): | |
| # Make pytorch loss function ignore all non response tokens | |
| if idx != 0: | |
| res_labels[start:end] = ignore_index | |
| else: | |
| res_labels[:end] = ignore_index | |
| if len(response_token_ids_idxs) < len(human_token_ids_idxs): | |
| res_labels[human_token_ids_idxs[-1]:] = ignore_index | |
| batch_input_ids.append(res_input_ids) | |
| batch_labels.append(res_labels) | |
| return dict(input_ids=batch_input_ids, labels=batch_labels) | |