Spaces:
Build error
Build error
| from abc import abstractmethod | |
| from pprint import pformat | |
| from time import sleep | |
| from typing import List, Tuple, Optional, Union, Generator | |
| from datasets import ( | |
| Dataset, | |
| DatasetDict, | |
| DatasetInfo, | |
| concatenate_datasets, | |
| load_dataset, | |
| ) | |
| # Defualt values for retrying dataset download | |
| DEFAULT_NUMBER_OF_RETRIES_ALLOWED = 5 | |
| DEFAULT_WAIT_SECONDS_BEFORE_RETRY = 5 | |
| # Default value for creating missing val/test splits | |
| TEST_OR_VAL_SPLIT_RATIO = 0.1 | |
| class SummInstance: | |
| """ | |
| Basic instance for summarization tasks | |
| """ | |
| def __init__( | |
| self, source: Union[List[str], str], summary: str, query: Optional[str] = None | |
| ): | |
| """ | |
| Create a summarization instance | |
| :rtype: object | |
| :param source: either `List[str]` or `str`, depending on the dataset itself, string joining may needed to fit | |
| into specific models. For example, for the same document, it could be simply `str` or `List[str]` for | |
| a list of sentences in the same document | |
| :param summary: a string summary that serves as ground truth | |
| :param query: Optional, applies when a string query is present | |
| """ | |
| self.source = source | |
| self.summary = summary | |
| self.query = query | |
| def __repr__(self): | |
| instance_dict = {"source": self.source, "summary": self.summary} | |
| if self.query: | |
| instance_dict["query"] = self.query | |
| return str(instance_dict) | |
| def __str__(self): | |
| instance_dict = {"source": self.source, "summary": self.summary} | |
| if self.query: | |
| instance_dict["query"] = self.query | |
| return pformat(instance_dict, indent=1) | |
| class SummDataset: | |
| """ | |
| Dataset class for summarization, which takes into account of the following tasks: | |
| * Single document summarization | |
| * Multi-document/Dialogue summarization | |
| * Query-based summarization | |
| """ | |
| def __init__( | |
| self, dataset_args: Optional[Tuple[str]] = None, splitseed: Optional[int] = None | |
| ): | |
| """Create dataset information from the huggingface Dataset class | |
| :rtype: object | |
| :param dataset_args: a tuple containing arguments to passed on to the 'load_dataset_safe' method. | |
| Only required for datasets loaded from the Huggingface library. | |
| The arguments for each dataset are different and comprise of a string or multiple strings | |
| :param splitseed: a number to instantiate the random generator used to generate val/test splits | |
| for the datasets without them | |
| """ | |
| # Load dataset from huggingface, use default huggingface arguments | |
| if self.huggingface_dataset: | |
| dataset = self._load_dataset_safe(*dataset_args) | |
| # Load non-huggingface dataset, use custom dataset builder | |
| else: | |
| dataset = self._load_dataset_safe(path=self.builder_script_path) | |
| info_set = self._get_dataset_info(dataset) | |
| # Ensure any dataset with a val or dev or validation split is standardised to validation split | |
| if "val" in dataset: | |
| dataset["validation"] = dataset["val"] | |
| dataset.remove("val") | |
| elif "dev" in dataset: | |
| dataset["validation"] = dataset["dev"] | |
| dataset.remove("dev") | |
| # If no splits other other than training, generate them | |
| assert ( | |
| "train" in dataset or "validation" in dataset or "test" in dataset | |
| ), "At least one of train/validation test needs to be not empty!" | |
| if not ("validation" in dataset or "test" in dataset): | |
| dataset = self._generate_missing_val_test_splits(dataset, splitseed) | |
| self.description = info_set.description | |
| self.citation = info_set.citation | |
| self.homepage = info_set.homepage | |
| # Extract the dataset entries from folders and load into dataset | |
| self._train_set = self._process_data(dataset["train"]) | |
| self._validation_set = self._process_data( | |
| dataset["validation"] | |
| ) # Some datasets have a validation split | |
| self._test_set = self._process_data(dataset["test"]) | |
| def train_set(self) -> Union[Generator[SummInstance, None, None], List]: | |
| if self._train_set is not None: | |
| return self._train_set | |
| else: | |
| print( | |
| f"{self.dataset_name} does not contain a train set, empty list returned" | |
| ) | |
| return list() | |
| def validation_set(self) -> Union[Generator[SummInstance, None, None], List]: | |
| if self._validation_set is not None: | |
| return self._validation_set | |
| else: | |
| print( | |
| f"{self.dataset_name} does not contain a validation set, empty list returned" | |
| ) | |
| return list() | |
| def test_set(self) -> Union[Generator[SummInstance, None, None], List]: | |
| if self._test_set is not None: | |
| return self._test_set | |
| else: | |
| print( | |
| f"{self.dataset_name} does not contain a test set, empty list returned" | |
| ) | |
| return list() | |
| def _load_dataset_safe(self, *args, **kwargs) -> Dataset: | |
| """ | |
| This method creates a wrapper around the huggingface 'load_dataset()' function for a more robust download function, | |
| the original 'load_dataset()' function occassionally fails when it cannot reach a server especially after multiple requests. | |
| This method tackles this problem by attempting the download multiple times with a wait time before each retry | |
| The wrapper method passes all arguments and keyword arguments to the 'load_dataset' function with no alteration. | |
| :rtype: Dataset | |
| :param args: non-keyword arguments to passed on to the 'load_dataset' function | |
| :param kwargs: keyword arguments to passed on to the 'load_dataset' function | |
| """ | |
| tries = DEFAULT_NUMBER_OF_RETRIES_ALLOWED | |
| wait_time = DEFAULT_WAIT_SECONDS_BEFORE_RETRY | |
| for i in range(tries): | |
| try: | |
| dataset = load_dataset(*args, **kwargs) | |
| except ConnectionError: | |
| if i < tries - 1: # i is zero indexed | |
| sleep(wait_time) | |
| continue | |
| else: | |
| raise RuntimeError( | |
| "Wait for a minute and attempt downloading the dataset again. \ | |
| The server hosting the dataset occassionally times out." | |
| ) | |
| break | |
| return dataset | |
| def _get_dataset_info(self, data_dict: DatasetDict) -> DatasetInfo: | |
| """ | |
| Get the information set from the dataset | |
| The information set contains: dataset name, description, version, citation and licence | |
| :param data_dict: DatasetDict | |
| :rtype: DatasetInfo | |
| """ | |
| return data_dict["train"].info | |
| def _process_data(self, dataset: Dataset) -> Generator[SummInstance, None, None]: | |
| """ | |
| Abstract class method to process the data contained within each dataset. | |
| Each dataset class processes it's own information differently due to the diversity in domains | |
| This method processes the data contained in the dataset | |
| and puts each data instance into a SummInstance object, | |
| the SummInstance has the following properties [source, summary, query[optional]] | |
| :param dataset: a train/validation/test dataset | |
| :rtype: a generator yielding SummInstance objects | |
| """ | |
| return | |
| def _generate_missing_val_test_splits( | |
| self, dataset_dict: DatasetDict, seed: int | |
| ) -> DatasetDict: | |
| """ | |
| Creating the train, val and test splits from a dataset | |
| the generated sets are 'train: ~.80', 'validation: ~.10', and 'test: ~10' in size | |
| the splits are randomized for each object unless a seed is provided for the random generator | |
| :param dataset: Arrow Dataset with containing, usually the train set | |
| :param seed: seed for the random generator to shuffle the dataset | |
| :rtype: Arrow DatasetDict containing the three splits | |
| """ | |
| # Return dataset if no train set available for splitting | |
| if "train" not in dataset_dict: | |
| if "validation" not in dataset_dict: | |
| dataset_dict["validation"] = None | |
| if "test" not in dataset_dict: | |
| dataset_dict["test"] = None | |
| return dataset_dict | |
| # Create a 'test' split from 'train' if no 'test' set is available | |
| if "test" not in dataset_dict: | |
| dataset_traintest_split = dataset_dict["train"].train_test_split( | |
| test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed | |
| ) | |
| dataset_dict["train"] = dataset_traintest_split["train"] | |
| dataset_dict["test"] = dataset_traintest_split["test"] | |
| # Create a 'validation' split from the remaining 'train' set if no 'validation' set is available | |
| if "validation" not in dataset_dict: | |
| dataset_trainval_split = dataset_dict["train"].train_test_split( | |
| test_size=TEST_OR_VAL_SPLIT_RATIO, seed=seed | |
| ) | |
| dataset_dict["train"] = dataset_trainval_split["train"] | |
| dataset_dict["validation"] = dataset_trainval_split["test"] | |
| return dataset_dict | |
| def _concatenate_dataset_dicts( | |
| self, dataset_dicts: List[DatasetDict] | |
| ) -> DatasetDict: | |
| """ | |
| Concatenate two dataset dicts with similar splits and columns tinto one | |
| :param dataset_dicts: A list of DatasetDicts | |
| :rtype: DatasetDict containing the combined data | |
| """ | |
| # Ensure all dataset dicts have the same splits | |
| setsofsplits = set(tuple(dataset_dict.keys()) for dataset_dict in dataset_dicts) | |
| if len(setsofsplits) > 1: | |
| raise ValueError("Splits must match for all datasets") | |
| # Concatenate all datasets into one according to the splits | |
| temp_dict = {} | |
| for split in setsofsplits.pop(): | |
| split_set = [dataset_dict[split] for dataset_dict in dataset_dicts] | |
| temp_dict[split] = concatenate_datasets(split_set) | |
| return DatasetDict(temp_dict) | |
| def generate_basic_description(cls) -> str: | |
| """ | |
| Automatically generate the basic description string based on the attributes | |
| :rtype: string containing the description | |
| :param cls: class object | |
| """ | |
| basic_description = ( | |
| f": {cls.dataset_name} is a " | |
| f"{'query-based ' if cls.is_query_based else ''}" | |
| f"{'dialogue ' if cls.is_dialogue_based else ''}" | |
| f"{'multi-document' if cls.is_multi_document else 'single-document'} " | |
| f"summarization dataset." | |
| ) | |
| return basic_description | |
| def show_description(self): | |
| """ | |
| Print the description of the dataset. | |
| """ | |
| print(self.dataset_name, ":\n", self.description) | |