Spaces:
Paused
Paused
| # Copyright 2025 Bytedance Ltd. and/or its affiliates. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import os | |
| import xml.etree.ElementTree as ET | |
| import subprocess | |
| import logging | |
| import pyarrow.fs as pf | |
| import torch.distributed as dist | |
| logger = logging.getLogger(__name__) | |
| def get_parquet_data_paths(data_dir_list, num_sampled_data_paths, rank=0, world_size=1): | |
| num_data_dirs = len(data_dir_list) | |
| if world_size > 1: | |
| chunk_size = (num_data_dirs + world_size - 1) // world_size | |
| start_idx = rank * chunk_size | |
| end_idx = min(start_idx + chunk_size, num_data_dirs) | |
| local_data_dir_list = data_dir_list[start_idx:end_idx] | |
| local_num_sampled_data_paths = num_sampled_data_paths[start_idx:end_idx] | |
| else: | |
| local_data_dir_list = data_dir_list | |
| local_num_sampled_data_paths = num_sampled_data_paths | |
| local_data_paths = [] | |
| for data_dir, num_data_path in zip(local_data_dir_list, local_num_sampled_data_paths): | |
| if data_dir.startswith("hdfs://"): | |
| files = hdfs_ls_cmd(data_dir) | |
| data_paths_per_dir = [ | |
| file for file in files if file.endswith(".parquet") | |
| ] | |
| else: | |
| files = os.listdir(data_dir) | |
| data_paths_per_dir = [ | |
| os.path.join(data_dir, name) | |
| for name in files | |
| if name.endswith(".parquet") | |
| ] | |
| repeat = num_data_path // len(data_paths_per_dir) | |
| data_paths_per_dir = data_paths_per_dir * (repeat + 1) | |
| local_data_paths.extend(data_paths_per_dir[:num_data_path]) | |
| if world_size > 1: | |
| gather_list = [None] * world_size | |
| dist.all_gather_object(gather_list, local_data_paths) | |
| combined_chunks = [] | |
| for chunk_list in gather_list: | |
| if chunk_list is not None: | |
| combined_chunks.extend(chunk_list) | |
| else: | |
| combined_chunks = local_data_paths | |
| return combined_chunks | |
| # NOTE: cumtomize this function for your cluster | |
| def get_hdfs_host(): | |
| return "hdfs://xxx" | |
| # NOTE: cumtomize this function for your cluster | |
| def get_hdfs_block_size(): | |
| return 134217728 | |
| # NOTE: cumtomize this function for your cluster | |
| def get_hdfs_extra_conf(): | |
| return None | |
| def init_arrow_pf_fs(parquet_file_path): | |
| if parquet_file_path.startswith("hdfs://"): | |
| fs = pf.HadoopFileSystem( | |
| host=get_hdfs_host(), | |
| port=0, | |
| buffer_size=get_hdfs_block_size(), | |
| extra_conf=get_hdfs_extra_conf(), | |
| ) | |
| else: | |
| fs = pf.LocalFileSystem() | |
| return fs | |
| def hdfs_ls_cmd(dir): | |
| result = subprocess.run(["hdfs", "dfs", "ls", dir], capture_output=True, text=True).stdout | |
| return ['hdfs://' + i.split('hdfs://')[-1].strip() for i in result.split('\n') if 'hdfs://' in i] | |