Spaces:
Running
on
Zero
Running
on
Zero
| import json | |
| import logging | |
| import tempfile | |
| import uuid | |
| from typing import Optional, Union, Dict, List, Any | |
| import pyarrow as pa | |
| import pyarrow.parquet as pq | |
| from huggingface_hub import CommitScheduler | |
| from huggingface_hub.hf_api import HfApi | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s') | |
| logger = logging.getLogger(__name__) | |
| def load_scheduler(): | |
| return ParquetScheduler( | |
| repo_id="hannahcyberey/Refusal-Steering-Logs", every=10, | |
| private=True, | |
| squash_history=False, | |
| schema={ | |
| "session_id": {"_type": "Value", "dtype": "string"}, | |
| "prompt": {"_type": "Value", "dtype": "string"}, | |
| "steering": {"_type": "Value", "dtype": "bool"}, | |
| "coeff": {"_type": "Value", "dtype": "float64"}, | |
| "top_p": {"_type": "Value", "dtype": "float64"}, | |
| "temperature": {"_type": "Value", "dtype": "float64"}, | |
| "output": {"_type": "Value", "dtype": "string"}, | |
| "upvote": {"_type": "Value", "dtype": "bool"}, | |
| "timestamp": {"_type": "Value", "dtype": "string"}, | |
| } | |
| ) | |
| class ParquetScheduler(CommitScheduler): | |
| """ | |
| Reference: https://huggingface.co/spaces/Wauplin/space_to_dataset_saver | |
| Usage: | |
| Configure the scheduler with a repo id. Once started, you can add data to be uploaded to the Hub. | |
| 1 `.append` call will result in 1 row in your final dataset. | |
| List of possible dtypes: | |
| https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Value. | |
| ```py | |
| # Start scheduler | |
| >>> scheduler = ParquetScheduler( | |
| ... repo_id="my-parquet-dataset", | |
| ... schema={ | |
| ... "prompt": {"_type": "Value", "dtype": "string"}, | |
| ... "negative_prompt": {"_type": "Value", "dtype": "string"}, | |
| ... "guidance_scale": {"_type": "Value", "dtype": "int64"}, | |
| ... "image": {"_type": "Image"}, | |
| ... }, | |
| ... ) | |
| # Append some data to be uploaded | |
| >>> scheduler.append({...}) | |
| """ | |
| def __init__( | |
| self, | |
| *, | |
| repo_id: str, | |
| schema: Dict[str, Dict[str, str]], | |
| every: Union[int, float] = 5, # Number of minutes between each commits | |
| path_in_repo: Optional[str] = "data", | |
| repo_type: Optional[str] = "dataset", | |
| revision: Optional[str] = None, | |
| private: bool = False, | |
| token: Optional[str] = None, | |
| allow_patterns: Union[List[str], str, None] = None, | |
| ignore_patterns: Union[List[str], str, None] = None, | |
| squash_history: Optional[bool] = False, | |
| hf_api: Optional[HfApi] = None, | |
| ) -> None: | |
| super().__init__( | |
| repo_id=repo_id, | |
| folder_path="dummy", # not used by the scheduler | |
| every=every, | |
| path_in_repo=path_in_repo, | |
| repo_type=repo_type, | |
| revision=revision, | |
| private=private, | |
| token=token, | |
| allow_patterns=allow_patterns, | |
| ignore_patterns=ignore_patterns, | |
| squash_history=squash_history, | |
| hf_api=hf_api, | |
| ) | |
| self._rows: List[Dict[str, Any]] = [] | |
| self._schema = schema | |
| def append(self, row: Dict[str, Any]) -> None: | |
| """Add a new item to be uploaded.""" | |
| with self.lock: | |
| self._rows.append(row) | |
| def push_to_hub(self): | |
| # Check for new rows to push | |
| with self.lock: | |
| rows = self._rows | |
| self._rows = [] | |
| if not rows: | |
| return | |
| logger.info("Got %d item(s) to commit.", len(rows)) | |
| # Complete rows if needed | |
| for row in rows: | |
| for feature in self._schema: | |
| if feature not in row: | |
| row[feature] = None | |
| # Export items to Arrow format | |
| table = pa.Table.from_pylist(rows) | |
| # Add metadata (used by datasets library) | |
| table = table.replace_schema_metadata( | |
| {"huggingface": json.dumps({"info": {"features": self._schema}})} | |
| ) | |
| # Write to parquet file | |
| archive_file = tempfile.NamedTemporaryFile() | |
| pq.write_table(table, archive_file.name) | |
| # Upload | |
| self.api.upload_file( | |
| repo_id=self.repo_id, | |
| repo_type=self.repo_type, | |
| revision=self.revision, | |
| path_in_repo=f"{uuid.uuid4()}.parquet", | |
| path_or_fileobj=archive_file.name, | |
| ) | |
| logging.info("Commit completed.") | |
| # Cleanup | |
| archive_file.close() | |