Spaces:
Sleeping
Sleeping
| import fire | |
| CONFIG = { | |
| "preserve_insertion_order": False | |
| } | |
| CMD_SRC_KWARGS = """ | |
| SELECT ('hf://datasets/{src}/' || lo.arguments['splits']['{split}']) AS path, function | |
| FROM ( | |
| SELECT unnest(li.loading_codes) AS lo, li.function[4:] as function | |
| FROM ( | |
| SELECT unnest(libraries) as li | |
| FROM read_json('https://datasets-server.huggingface.co/compatible-libraries?dataset={src}') | |
| ) WHERE li.function[:3] = 'pl.' | |
| ) WHERE lo.config_name='{config}'; | |
| """.strip() | |
| CMD_SRC = """ | |
| CREATE VIEW src AS SELECT * FROM {function}('{path}'); | |
| """.strip() | |
| CMD_DST = """ | |
| COPY ({query}) to 'tmp' (FORMAT PARQUET, ROW_GROUP_SIZE_BYTES '100MB', ROW_GROUPS_PER_FILE 5, PER_THREAD_OUTPUT true); | |
| """.strip() | |
| CMD_SRC_DRY_RUN = CMD_SRC[:-1] + " LIMIT 5;" | |
| CMD_DST_DRY_RUN = "{query};" | |
| def sql(src: str, dst: str, query: str, config: str = "default", split: str = "train", private: bool = False, dry_run: bool = False): | |
| import os | |
| import duckdb | |
| from contextlib import nullcontext | |
| from huggingface_hub import CommitScheduler | |
| class CommitAndCleanScheduler(CommitScheduler): | |
| def push_to_hub(self): | |
| for path in self.folder_path.with_name("tmp").glob(self.allow_patterns): | |
| with path.open("rb") as f: | |
| footer = f.read(4) and f.seek(-4, os.SEEK_END) and f.read(4) | |
| if footer == b"PAR1": | |
| path.rename(self.folder_path / path.name) | |
| super().push_to_hub() | |
| for path in self.last_uploaded: | |
| path.unlink(missing_ok=True) | |
| with nullcontext() if dry_run else CommitAndCleanScheduler(repo_id=dst, repo_type="dataset", folder_path="dst", path_in_repo="data", allow_patterns="*.parquet", every=0.1, private=private): | |
| con = duckdb.connect(":memory:", config=CONFIG) | |
| src_kwargs = con.sql(CMD_SRC_KWARGS.format(src=src, config=config, split=split)).df().to_dict(orient="records") | |
| if not src_kwargs: | |
| raise ValueError(f'Invalid --config "{config}" for dataset "{src}", please select a valid dataset config/subset.') | |
| con.sql((CMD_SRC_DRY_RUN if dry_run else CMD_SRC).format(**src_kwargs[0])) | |
| if dry_run: | |
| print(f"Sample data from '{src}' that would be written to dataset '{dst}':\n") | |
| else: | |
| con.sql("PRAGMA enable_progress_bar;") | |
| result = con.sql((CMD_DST_DRY_RUN if dry_run else CMD_DST).format(query=query.rstrip("\n ;"))) | |
| if dry_run: | |
| print(result.df().to_markdown()) | |
| else: | |
| print("done") | |
| if __name__ == '__main__': | |
| fire.Fire(sql) | |