| import gc | |
| from typing import Optional, Iterator, Callable | |
| import torch | |
| from datasets import load_dataset | |
| from litgpt.tokenizer import Tokenizer | |
| from transformers import AutoTokenizer | |
| def batch_text_iterator(kind: str, | |
| path: str, | |
| name: Optional[str]=None, | |
| data_dir: Optional[str]=None, | |
| data_files: Optional[str]=None, | |
| keep_in_memory: bool=False, | |
| revision: Optional[str]=None, | |
| split: str='train', | |
| num_proc: Optional[int]=None, | |
| format: Optional[Callable|str]=None) -> Iterator[str]: | |
| assert isinstance(format, str) or callable(format), f'{path=} {format=}' | |
| assert kind == 'base' | |
| dataset = load_dataset(path=path, | |
| name=name, | |
| data_dir=data_dir, | |
| data_files=data_files, | |
| keep_in_memory=keep_in_memory, | |
| revision=revision, | |
| split=split, | |
| trust_remote_code=True, | |
| num_proc=num_proc) | |
| if callable(format): | |
| for row in dataset: | |
| text = format(row) | |
| yield text | |
| else: | |
| for row in dataset: | |
| text = format.format(**row) | |
| yield text | |
| del dataset | |
| gc.collect() | |
| def batch_chat_iterator(kind: str, | |
| path: str, | |
| name: Optional[str]=None, | |
| data_dir: Optional[str]=None, | |
| data_files: Optional[str]=None, | |
| keep_in_memory: bool=False, | |
| revision: Optional[str]=None, | |
| split: str='train', | |
| num_proc: Optional[int]=None, | |
| field: Optional[str]=None, | |
| transform: Optional[Callable]=None) -> Iterator[list[dict[str, str]]]: | |
| assert kind == 'instruct' | |
| dataset = load_dataset(path=path, | |
| name=name, | |
| data_dir=data_dir, | |
| data_files=data_files, | |
| keep_in_memory=keep_in_memory, | |
| revision=revision, | |
| split=split, | |
| trust_remote_code=True, | |
| num_proc=num_proc) | |
| if callable(transform): | |
| for row in dataset: | |
| if field: | |
| messages = transform(row[field]) | |
| else: | |
| messages = transform(row) | |
| yield messages | |
| else: | |
| for row in dataset: | |
| if field: | |
| messages = row[field] | |
| else: | |
| raise ValueError(field) | |
| yield messages | |
| del dataset | |
| gc.collect() | |
| def tokenize_text_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: | |
| for text in batch_text_iterator(**dataset_config): | |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) | |
| yield text_ids | |
| def tokenize_chat_fn(dataset_config: dict, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: | |
| for messages in batch_chat_iterator(**dataset_config): | |
| text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) | |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) | |
| yield text_ids | |
| def tokenize_fn(dataset_config: dict, min_len: int, max_len: int, hf_tokenizer: AutoTokenizer, tokenizer: Tokenizer) -> Iterator[torch.Tensor]: | |
| if dataset_config['kind'] == 'base': | |
| for text in batch_text_iterator(**dataset_config): | |
| try: | |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=True) | |
| except Exception as e: | |
| print(f'Skip base raw: {e=} {type(text)=} {text=}') | |
| continue | |
| if min_len <= len(text_ids) <= max_len: | |
| yield text_ids | |
| elif dataset_config['kind'] == 'instruct': | |
| for messages in batch_chat_iterator(**dataset_config): | |
| try: | |
| text: str = hf_tokenizer.apply_chat_template(messages, tokenize=False) | |
| text_ids: torch.Tensor = tokenizer.encode(text, bos=False, eos=False) | |
| except Exception as e: | |
| print(f'Skip instruct row: {e=} {type(messages)=} {messages=}') | |
| continue | |
| if min_len <= len(text_ids) <= max_len: | |
| yield text_ids | |
| else: | |
| raise ValueError(dataset_config['kind']) | |