|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import random |
|
|
import threading |
|
|
from abc import ABC |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
from itertools import chain |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
import pyarrow as pa |
|
|
import pyarrow.parquet as pq |
|
|
from omegaconf import DictConfig |
|
|
|
|
|
from common.distributed import get_global_rank, get_world_size |
|
|
from common.fs import copy, exists, listdir, mkdir, remove |
|
|
from common.partition import partition_by_groups |
|
|
from common.persistence.utils import get_local_path |
|
|
from data.common.parquet_sampler import ( |
|
|
IdentityParquetSampler, |
|
|
ParquetSampler, |
|
|
create_parquet_sampler, |
|
|
) |
|
|
from data.common.utils import filter_parquets, get_parquet_metadata |
|
|
|
|
|
|
|
|
|
|
|
def save_and_copy( |
|
|
pa_table, |
|
|
local_path: str, |
|
|
target_path: str, |
|
|
row_group_size: int, |
|
|
executor: ThreadPoolExecutor, |
|
|
do_async: bool = False, |
|
|
futures: List[Tuple[threading.Thread, str]] = [], |
|
|
): |
|
|
|
|
|
def _make_on_complete(local_path): |
|
|
def _on_complete(future): |
|
|
target_path = future.result() |
|
|
remove(local_path) |
|
|
|
|
|
print(f"Target path saved: {target_path}") |
|
|
|
|
|
return _on_complete |
|
|
|
|
|
|
|
|
def _fn(pa_table, local_path, target_path, row_group_size): |
|
|
pq.write_table( |
|
|
pa_table, |
|
|
local_path, |
|
|
row_group_size=row_group_size, |
|
|
) |
|
|
mkdir(os.path.dirname(target_path)) |
|
|
copy(local_path, target_path) |
|
|
return target_path |
|
|
|
|
|
|
|
|
future = executor.submit(_fn, pa_table, local_path, target_path, row_group_size) |
|
|
future.add_done_callback(_make_on_complete(local_path)) |
|
|
futures.append(future) |
|
|
|
|
|
|
|
|
if not do_async: |
|
|
for future in as_completed(futures): |
|
|
try: |
|
|
future.result() |
|
|
except Exception as exc: |
|
|
print(f"Generated an exception: {exc}") |
|
|
executor.shutdown(wait=True) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FileListOutput: |
|
|
existing_files: List[str] |
|
|
source_files: List[Any] |
|
|
target_files: List[str] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PersistedParquet: |
|
|
path: str |
|
|
|
|
|
|
|
|
def save( |
|
|
self, |
|
|
row_group_size: int, |
|
|
executor: ThreadPoolExecutor, |
|
|
pa_table: Optional[pa.Table] = None, |
|
|
data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, |
|
|
is_last_file=False, |
|
|
futures: List[threading.Thread] = [], |
|
|
): |
|
|
assert (pa_table is None) != (data_dict is None) |
|
|
local_path = get_local_path(self.path) |
|
|
if not pa_table: |
|
|
schema_dict = self.generate_schema_from_dict(data_dict) |
|
|
pa_table = pa.Table.from_pydict(data_dict, schema=schema_dict) |
|
|
save_and_copy( |
|
|
pa_table, |
|
|
local_path=local_path, |
|
|
target_path=self.path, |
|
|
row_group_size=row_group_size, |
|
|
executor=executor, |
|
|
do_async=not is_last_file, |
|
|
futures=futures, |
|
|
) |
|
|
|
|
|
|
|
|
def generate_schema_from_dict( |
|
|
self, |
|
|
data_dict: Dict[str, List[Union[str, bytes]]], |
|
|
): |
|
|
schema_dict = {} |
|
|
for key, value in data_dict.items(): |
|
|
if isinstance(value[0], str): |
|
|
schema_dict[key] = pa.string() |
|
|
elif isinstance(value[0], bytes): |
|
|
schema_dict[key] = pa.binary() |
|
|
else: |
|
|
raise ValueError(f"Unsupported data type for key '{key}': {type(value)}") |
|
|
return pa.schema(schema_dict) |
|
|
|
|
|
|
|
|
|
|
|
class ParquetManager(ABC): |
|
|
""" |
|
|
Base class for the DumpingManager and RepackingManager. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
task: Optional[DictConfig] = None, |
|
|
target_dir: str = ".", |
|
|
): |
|
|
self.task = task |
|
|
self.target_dir = target_dir.rstrip("/") |
|
|
self.executor = ThreadPoolExecutor(max_workers=4) |
|
|
self.futures = [] |
|
|
|
|
|
|
|
|
def get_parquet_files( |
|
|
self, |
|
|
source_path: str, |
|
|
parquet_sampler: ParquetSampler = IdentityParquetSampler(), |
|
|
path_mode: str = "dir", |
|
|
): |
|
|
|
|
|
|
|
|
def _flatten(paths): |
|
|
if isinstance(paths, list): |
|
|
if any(isinstance(i, list) for i in paths): |
|
|
return list(chain(*paths)) |
|
|
else: |
|
|
return paths |
|
|
else: |
|
|
return [paths] |
|
|
|
|
|
file_paths = _flatten(source_path) |
|
|
if path_mode == "dir": |
|
|
file_paths = map(listdir, file_paths) |
|
|
if isinstance(parquet_sampler.size, float): |
|
|
file_paths = map(filter_parquets, file_paths) |
|
|
file_paths = map(parquet_sampler, file_paths) |
|
|
file_paths = list(chain(*file_paths)) |
|
|
else: |
|
|
file_paths = chain(*file_paths) |
|
|
file_paths = parquet_sampler(filter_parquets(file_paths)) |
|
|
|
|
|
return file_paths |
|
|
|
|
|
|
|
|
def save_parquet( |
|
|
self, |
|
|
*, |
|
|
file_name: str, |
|
|
row_group_size: int, |
|
|
pa_table: Optional[pa.Table] = None, |
|
|
data_dict: Optional[Dict[str, List[Union[str, bytes]]]] = None, |
|
|
override: bool = True, |
|
|
is_last_file: bool = False, |
|
|
): |
|
|
|
|
|
persist = self._get_parquet(file_name) |
|
|
if override or not exists(persist.path): |
|
|
persist.save( |
|
|
pa_table=pa_table, |
|
|
data_dict=data_dict, |
|
|
executor=self.executor, |
|
|
row_group_size=row_group_size, |
|
|
is_last_file=is_last_file, |
|
|
futures=self.futures, |
|
|
) |
|
|
|
|
|
|
|
|
def _get_parquet(self, file_name: str) -> PersistedParquet: |
|
|
return PersistedParquet(file_name) |
|
|
|
|
|
|
|
|
|
|
|
class DumpingManager(ParquetManager): |
|
|
""" |
|
|
Dumping manager handles parquet saving and resuming. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
task: DictConfig, |
|
|
target_dir: str, |
|
|
): |
|
|
super().__init__(task=task, target_dir=target_dir) |
|
|
|
|
|
|
|
|
def generate_saving_path(self, file_path: str, rsplit: int): |
|
|
part_list = file_path.rsplit("/", rsplit) |
|
|
result_folder = "/".join( |
|
|
[self.target_dir] + [f"epoch_{self.task.epoch}"] + part_list[-rsplit:-1] |
|
|
) |
|
|
result_file = "/".join([result_folder, part_list[-1]]) |
|
|
return result_folder, result_file |
|
|
|
|
|
|
|
|
def configure_task_path(self, source_path: str, rsplit: int, path_mode: str = "dir"): |
|
|
|
|
|
file_paths = self.get_parquet_files( |
|
|
source_path=source_path, |
|
|
path_mode=path_mode, |
|
|
) |
|
|
|
|
|
|
|
|
random.Random(0).shuffle(file_paths) |
|
|
|
|
|
|
|
|
full_source_files = partition_by_groups(file_paths, self.task.total_count)[self.task.index] |
|
|
full_source_files = partition_by_groups(full_source_files, get_world_size())[ |
|
|
get_global_rank() |
|
|
] |
|
|
|
|
|
if not full_source_files: |
|
|
return FileListOutput([], [], []) |
|
|
|
|
|
generate_saving_path = partial(self.generate_saving_path, rsplit=rsplit) |
|
|
full_paths = map(generate_saving_path, full_source_files) |
|
|
full_target_folders, full_target_files = map(list, zip(*full_paths)) |
|
|
full_target_folders = set(full_target_folders) |
|
|
|
|
|
existing_file_paths = map( |
|
|
lambda folder: listdir(folder) if exists(folder) else [], full_target_folders |
|
|
) |
|
|
existing_file_paths = chain(*existing_file_paths) |
|
|
self.existing_files = list( |
|
|
filter( |
|
|
lambda path: path.endswith(".parquet") and path in full_target_files, |
|
|
existing_file_paths, |
|
|
) |
|
|
) |
|
|
|
|
|
filtered_pairs = list( |
|
|
filter( |
|
|
lambda pair: pair[1] not in self.existing_files, |
|
|
zip(full_source_files, full_target_files), |
|
|
) |
|
|
) |
|
|
if filtered_pairs: |
|
|
filtered_source_files, filtered_target_files = map(list, zip(*filtered_pairs)) |
|
|
else: |
|
|
filtered_source_files, filtered_target_files = [], [] |
|
|
|
|
|
|
|
|
skip_exists = self.task.skip_exists |
|
|
self.source_files = filtered_source_files if skip_exists else full_source_files |
|
|
self.target_files = filtered_target_files if skip_exists else full_target_files |
|
|
|
|
|
return FileListOutput(self.existing_files, self.source_files, self.target_files) |
|
|
|
|
|
|
|
|
class RepackingManager(ParquetManager): |
|
|
""" |
|
|
Repacking manager handles parquet spliting and saving. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
task: DictConfig, |
|
|
target_dir: str, |
|
|
repackaging: DictConfig, |
|
|
): |
|
|
super().__init__(task=task, target_dir=target_dir) |
|
|
self.repackaging = repackaging |
|
|
|
|
|
|
|
|
def configure_task_path( |
|
|
self, |
|
|
source_path: str, |
|
|
parquet_sampler: Optional[DictConfig] = None, |
|
|
path_mode: str = "dir", |
|
|
): |
|
|
|
|
|
parquet_sampler = create_parquet_sampler(config=parquet_sampler) |
|
|
file_paths = self.get_parquet_files( |
|
|
source_path=source_path, |
|
|
parquet_sampler=parquet_sampler, |
|
|
path_mode=path_mode, |
|
|
) |
|
|
|
|
|
random.Random(0).shuffle(file_paths) |
|
|
target_dir = self.target_dir |
|
|
size = abs(parquet_sampler.size) |
|
|
|
|
|
if self.task: |
|
|
|
|
|
file_paths = partition_by_groups(file_paths, self.task.total_count)[self.task.index] |
|
|
target_dir = os.path.join(target_dir, f"{self.task.total_count}_{self.task.index}") |
|
|
|
|
|
if size > 1: |
|
|
size = len( |
|
|
partition_by_groups(range(size), self.task.total_count)[self.task.index] |
|
|
) |
|
|
|
|
|
|
|
|
metadatas = get_parquet_metadata(file_paths, self.repackaging.num_processes) |
|
|
|
|
|
|
|
|
target_items = [ |
|
|
(file_path, row) |
|
|
for file_path, metadata in zip(file_paths, metadatas) |
|
|
for row in range(metadata.num_rows) |
|
|
] |
|
|
|
|
|
|
|
|
random.Random(0).shuffle(target_items) |
|
|
|
|
|
if size > 1: |
|
|
target_items = target_items[:size] |
|
|
|
|
|
|
|
|
items_per_file = partition_by_groups(target_items, self.repackaging.num_files) |
|
|
|
|
|
|
|
|
target_files = [ |
|
|
os.path.join(target_dir, f"{str(i).zfill(5)}.parquet") |
|
|
for i in range(self.repackaging.num_files) |
|
|
] |
|
|
|
|
|
existing_file_paths = listdir(target_dir) if exists(target_dir) else [] |
|
|
self.existing_files = list( |
|
|
filter( |
|
|
lambda path: path.endswith(".parquet"), |
|
|
existing_file_paths, |
|
|
) |
|
|
) |
|
|
self.source_files = items_per_file |
|
|
self.target_files = target_files |
|
|
|
|
|
return FileListOutput(self.existing_files, self.source_files, self.target_files) |
|
|
|