Spaces:
Build error
Build error
| import unittest | |
| from model.base_model import SummModel | |
| from model import SUPPORTED_SUMM_MODELS | |
| from pipeline import assemble_model_pipeline | |
| from evaluation.base_metric import SummMetric | |
| from evaluation import SUPPORTED_EVALUATION_METRICS | |
| from dataset.st_dataset import SummInstance, SummDataset | |
| from dataset import SUPPORTED_SUMM_DATASETS | |
| from dataset.dataset_loaders import ScisummnetDataset, ArxivDataset | |
| from helpers import print_with_color, retrieve_random_test_instances | |
| import random | |
| import time | |
| from typing import List, Union, Tuple | |
| import sys | |
| import re | |
| class IntegrationTests(unittest.TestCase): | |
| def get_prediction( | |
| self, model: SummModel, dataset: SummDataset, test_instances: List[SummInstance] | |
| ) -> Tuple[Union[List[str], List[List[str]]], Union[List[str], List[List[str]]]]: | |
| """ | |
| Get summary prediction given model and dataset instances. | |
| :param SummModel `model`: Model for summarization task. | |
| :param SummDataset `dataset`: Dataset for summarization task. | |
| :param List[SummInstance] `test_instances`: Instances from `dataset` to summarize. | |
| :returns Tuple containing summary list of summary predictions and targets corresponding to each instance in `test_instances`. | |
| """ | |
| src = ( | |
| [ins.source[0] for ins in test_instances] | |
| if isinstance(dataset, ScisummnetDataset) | |
| else [ins.source for ins in test_instances] | |
| ) | |
| tgt = [ins.summary for ins in test_instances] | |
| query = ( | |
| [ins.query for ins in test_instances] if dataset.is_query_based else None | |
| ) | |
| prediction = model.summarize(src, query) | |
| return prediction, tgt | |
| def get_eval_dict(self, metric: SummMetric, prediction: List[str], tgt: List[str]): | |
| """ | |
| Run evaluation metric on summary prediction. | |
| :param SummMetric `metric`: Evaluation metric. | |
| :param List[str] `prediction`: Summary prediction instances. | |
| :param List[str] `tgt`: Target prediction instances from dataset. | |
| """ | |
| score_dict = metric.evaluate(prediction, tgt) | |
| return score_dict | |
| def test_all(self): | |
| """ | |
| Runs integration test on all compatible dataset + model + evaluation metric pipelines supported by SummerTime. | |
| """ | |
| print_with_color("\nInitializing all evaluation metrics...", "35") | |
| evaluation_metrics = [] | |
| for eval_cls in SUPPORTED_EVALUATION_METRICS: | |
| # # TODO: Temporarily skipping Rouge/RougeWE metrics to avoid local bug. | |
| # if eval_cls in [Rouge, RougeWe]: | |
| # continue | |
| print(eval_cls) | |
| evaluation_metrics.append(eval_cls()) | |
| print_with_color("\n\nBeginning integration tests...", "35") | |
| for dataset_cls in SUPPORTED_SUMM_DATASETS: | |
| # TODO: Temporarily skipping Arxiv (size/time) | |
| if dataset_cls in [ArxivDataset]: | |
| continue | |
| dataset = dataset_cls() | |
| if dataset.train_set is not None: | |
| dataset_instances = list(dataset.train_set) | |
| print( | |
| f"\n{dataset.dataset_name} has a training set of {len(dataset_instances)} examples" | |
| ) | |
| print_with_color( | |
| f"Initializing all matching model pipelines for {dataset.dataset_name} dataset...", | |
| "35", | |
| ) | |
| # matching_model_instances = assemble_model_pipeline(dataset_cls, list(filter(lambda m: m != PegasusModel, SUPPORTED_SUMM_MODELS))) | |
| matching_model_instances = assemble_model_pipeline( | |
| dataset_cls, SUPPORTED_SUMM_MODELS | |
| ) | |
| for model, model_name in matching_model_instances: | |
| test_instances = retrieve_random_test_instances( | |
| dataset_instances=dataset_instances, num_instances=1 | |
| ) | |
| print_with_color( | |
| f"{'#' * 20} Testing: {dataset.dataset_name} dataset, {model_name} model {'#' * 20}", | |
| "35", | |
| ) | |
| prediction, tgt = self.get_prediction( | |
| model, dataset, test_instances | |
| ) | |
| print(f"Prediction: {prediction}\nTarget: {tgt}\n") | |
| for metric in evaluation_metrics: | |
| print_with_color(f"{metric.metric_name} metric", "35") | |
| score_dict = self.get_eval_dict(metric, prediction, tgt) | |
| print(score_dict) | |
| print_with_color( | |
| f"{'#' * 20} Test for {dataset.dataset_name} dataset, {model_name} model COMPLETE {'#' * 20}\n\n", | |
| "32", | |
| ) | |
| if __name__ == "__main__": | |
| if len(sys.argv) > 2 or ( | |
| len(sys.argv) == 2 and not re.match("^\\d+$", sys.argv[1]) | |
| ): | |
| print("Usage: python tests/integration_test.py [seed]", file=sys.stderr) | |
| sys.exit(1) | |
| seed = int(time.time()) if len(sys.argv) == 1 else int(sys.argv.pop()) | |
| random.seed(seed) | |
| print_with_color(f"(to reproduce) random seeded with {seed}\n", "32") | |
| unittest.main() | |