Spaces:
Running
Running
| import enum | |
| from functools import reduce | |
| from typing import Dict, List, Tuple | |
| import numpy as np | |
| import copy | |
| from utils.common.log import logger | |
| from ..datasets.ab_dataset import ABDataset | |
| from ..dataloader import FastDataLoader, InfiniteDataLoader, build_dataloader | |
| from data import get_dataset | |
| class DatasetMetaInfo: | |
| def __init__(self, name, | |
| known_classes_name_idx_map, unknown_class_idx): | |
| assert unknown_class_idx not in known_classes_name_idx_map.keys() | |
| self.name = name | |
| self.unknown_class_idx = unknown_class_idx | |
| self.known_classes_name_idx_map = known_classes_name_idx_map | |
| def num_classes(self): | |
| return len(self.known_classes_idx) + 1 | |
| class MergedDataset: | |
| def __init__(self, datasets: List[ABDataset]): | |
| self.datasets = datasets | |
| self.datasets_len = [len(i) for i in self.datasets] | |
| logger.info(f'create MergedDataset: len of datasets {self.datasets_len}') | |
| self.datasets_cum_len = np.cumsum(self.datasets_len) | |
| def __getitem__(self, idx): | |
| for i, cum_len in enumerate(self.datasets_cum_len): | |
| if idx < cum_len: | |
| return self.datasets[i][idx - sum(self.datasets_len[0: i])] | |
| def __len__(self): | |
| return sum(self.datasets_len) | |
| class IndexReturnedDataset: | |
| def __init__(self, dataset: ABDataset): | |
| self.dataset = dataset | |
| def __getitem__(self, idx): | |
| res = self.dataset[idx] | |
| if isinstance(res, (tuple, list)): | |
| return (*res, idx) | |
| else: | |
| return res, idx | |
| def __len__(self): | |
| return len(self.dataset) | |
| # class Scenario: | |
| # def __init__(self, config, | |
| # source_datasets_meta_info: Dict[str, DatasetMetaInfo], target_datasets_meta_info: Dict[str, DatasetMetaInfo], | |
| # target_source_map: Dict[str, Dict[str, str]], | |
| # target_domains_order: List[str], | |
| # source_datasets: Dict[str, Dict[str, ABDataset]], target_datasets: Dict[str, Dict[str, ABDataset]]): | |
| # self.__config = config | |
| # self.__source_datasets_meta_info = source_datasets_meta_info | |
| # self.__target_datasets_meta_info = target_datasets_meta_info | |
| # self.__target_source_map = target_source_map | |
| # self.__target_domains_order = target_domains_order | |
| # self.__source_datasets = source_datasets | |
| # self.__target_datasets = target_datasets | |
| # # 1. basic | |
| # def get_config(self): | |
| # return copy.deepcopy(self.__config) | |
| # def get_task_type(self): | |
| # return list(self.__source_datasets.values())[0]['train'].task_type | |
| # def get_num_classes(self): | |
| # known_classes_idx = [] | |
| # unknown_classes_idx = [] | |
| # for v in self.__source_datasets_meta_info.values(): | |
| # known_classes_idx += list(v.known_classes_name_idx_map.values()) | |
| # unknown_classes_idx += [v.unknown_class_idx] | |
| # for v in self.__target_datasets_meta_info.values(): | |
| # known_classes_idx += list(v.known_classes_name_idx_map.values()) | |
| # unknown_classes_idx += [v.unknown_class_idx] | |
| # unknown_classes_idx = [i for i in unknown_classes_idx if i is not None] | |
| # # print(known_classes_idx, unknown_classes_idx) | |
| # res = len(set(known_classes_idx)), len(set(unknown_classes_idx)), len(set(known_classes_idx + unknown_classes_idx)) | |
| # # print(res) | |
| # assert res[0] + res[1] == res[2] | |
| # return res | |
| # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int, infinite: bool, shuffle_when_finite: bool): | |
| # if infinite: | |
| # dataloader = InfiniteDataLoader( | |
| # dataset, None, batch_size, num_workers=num_workers) | |
| # else: | |
| # dataloader = FastDataLoader( | |
| # dataset, batch_size, num_workers, shuffle=shuffle_when_finite) | |
| # return dataloader | |
| # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]): | |
| # from ..data.datasets.dataset_split import _SplitDataset | |
| # dataset.dataset = _SplitDataset(dataset.dataset, indexes) | |
| # return dataset | |
| # def build_index_returned_dataset(self, dataset: ABDataset): | |
| # return IndexReturnedDataset(dataset) | |
| # # 2. source | |
| # def get_source_datasets_meta_info(self): | |
| # return self.__source_datasets_meta_info | |
| # def get_source_datasets_name(self): | |
| # return list(self.__source_datasets.keys()) | |
| # def get_merged_source_dataset(self, split): | |
| # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()} | |
| # return MergedDataset(list(source_train_datasets.values())) | |
| # def get_source_datasets(self, split): | |
| # source_train_datasets = {n: d[split] for n, d in self.__source_datasets.items()} | |
| # return source_train_datasets | |
| # # 3. target **domain** | |
| # # (do we need such API `get_ith_target_domain()`?) | |
| # def get_target_domains_meta_info(self): | |
| # return self.__source_datasets_meta_info | |
| # def get_target_domains_order(self): | |
| # return self.__target_domains_order | |
| # def get_corr_source_datasets_name_of_target_domain(self, target_domain_name): | |
| # return self.__target_source_map[target_domain_name] | |
| # def get_limited_target_train_dataset(self): | |
| # if len(self.__target_domains_order) > 1: | |
| # raise RuntimeError('this API is only for pass-in scenario in user-defined online DA algorithm') | |
| # return list(self.__target_datasets.values())[0]['train'] | |
| # def get_target_domains_iterator(self, split): | |
| # for target_domain_index, target_domain_name in enumerate(self.__target_domains_order): | |
| # target_dataset = self.__target_datasets[target_domain_name] | |
| # target_domain_meta_info = self.__target_datasets_meta_info[target_domain_name] | |
| # yield target_domain_index, target_domain_name, target_dataset[split], target_domain_meta_info | |
| # # 4. permission management | |
| # def get_sub_scenario(self, source_datasets_name, source_splits, target_domains_order, target_splits): | |
| # def get_split(dataset, splits): | |
| # res = {} | |
| # for s, d in dataset.items(): | |
| # if s in splits: | |
| # res[s] = d | |
| # return res | |
| # return Scenario( | |
| # config=self.__config, | |
| # source_datasets_meta_info={k: v for k, v in self.__source_datasets_meta_info.items() if k in source_datasets_name}, | |
| # target_datasets_meta_info={k: v for k, v in self.__target_datasets_meta_info.items() if k in target_domains_order}, | |
| # target_source_map={k: v for k, v in self.__target_source_map.items() if k in target_domains_order}, | |
| # target_domains_order=target_domains_order, | |
| # source_datasets={k: get_split(v, source_splits) for k, v in self.__source_datasets.items() if k in source_datasets_name}, | |
| # target_datasets={k: get_split(v, target_splits) for k, v in self.__target_datasets.items() if k in target_domains_order} | |
| # ) | |
| # def get_only_source_sub_scenario_for_exp_tracker(self): | |
| # return self.get_sub_scenario(self.get_source_datasets_name(), ['train', 'val', 'test'], [], []) | |
| # def get_only_source_sub_scenario_for_alg(self): | |
| # return self.get_sub_scenario(self.get_source_datasets_name(), ['train'], [], []) | |
| # def get_one_da_sub_scenario_for_alg(self, target_domain_name): | |
| # return self.get_sub_scenario(self.get_corr_source_datasets_name_of_target_domain(target_domain_name), | |
| # ['train', 'val'], [target_domain_name], ['train']) | |
| # class Scenario: | |
| # def __init__(self, config, | |
| # offline_source_datasets_meta_info: Dict[str, DatasetMetaInfo], | |
| # offline_source_datasets: Dict[str, ABDataset], | |
| # online_datasets_meta_info: List[Tuple[Dict[str, DatasetMetaInfo], DatasetMetaInfo]], | |
| # online_datasets: Dict[str, ABDataset], | |
| # target_domains_order: List[str], | |
| # target_source_map: Dict[str, Dict[str, str]], | |
| # num_classes: int): | |
| # self.config = config | |
| # self.offline_source_datasets_meta_info = offline_source_datasets_meta_info | |
| # self.offline_source_datasets = offline_source_datasets | |
| # self.online_datasets_meta_info = online_datasets_meta_info | |
| # self.online_datasets = online_datasets | |
| # self.target_domains_order = target_domains_order | |
| # self.target_source_map = target_source_map | |
| # self.num_classes = num_classes | |
| # def get_offline_source_datasets(self, split): | |
| # return {n: d[split] for n, d in self.offline_source_datasets.items()} | |
| # def get_offline_source_merged_dataset(self, split): | |
| # return MergedDataset([d[split] for d in self.offline_source_datasets.values()]) | |
| # def get_online_current_corresponding_source_datasets(self, domain_index, split): | |
| # cur_target_domain_name = self.target_domains_order[domain_index] | |
| # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys()) | |
| # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name} | |
| # return cur_source_datasets | |
| # def get_online_current_corresponding_merged_source_dataset(self, domain_index, split): | |
| # cur_target_domain_name = self.target_domains_order[domain_index] | |
| # cur_source_datasets_name = list(self.target_source_map[cur_target_domain_name].keys()) | |
| # cur_source_datasets = {n: self.online_datasets[n + '|' + cur_target_domain_name][split] for n in cur_source_datasets_name} | |
| # return MergedDataset([d for d in cur_source_datasets.values()]) | |
| # def get_online_current_target_dataset(self, domain_index, split): | |
| # cur_target_domain_name = self.target_domains_order[domain_index] | |
| # return self.online_datasets[cur_target_domain_name][split] | |
| # def build_dataloader(self, dataset: ABDataset, batch_size: int, num_workers: int, | |
| # infinite: bool, shuffle_when_finite: bool, to_iterator: bool): | |
| # if infinite: | |
| # dataloader = InfiniteDataLoader( | |
| # dataset, None, batch_size, num_workers=num_workers) | |
| # else: | |
| # dataloader = FastDataLoader( | |
| # dataset, batch_size, num_workers, shuffle=shuffle_when_finite) | |
| # if to_iterator: | |
| # dataloader = iter(dataloader) | |
| # return dataloader | |
| # def build_sub_dataset(self, dataset: ABDataset, indexes: List[int]): | |
| # from data.datasets.dataset_split import _SplitDataset | |
| # dataset.dataset = _SplitDataset(dataset.dataset, indexes) | |
| # return dataset | |
| # def build_index_returned_dataset(self, dataset: ABDataset): | |
| # return IndexReturnedDataset(dataset) | |
| # def get_config(self): | |
| # return copy.deepcopy(self.config) | |
| # def get_task_type(self): | |
| # return list(self.online_datasets.values())[0]['train'].task_type | |
| # def get_num_classes(self): | |
| # return self.num_classes | |
| class Scenario: | |
| def __init__(self, config, all_datasets_ignore_classes_map, all_datasets_idx_map, target_domains_order, target_source_map, | |
| all_datasets_e2e_class_to_idx_map, | |
| num_classes): | |
| self.config = config | |
| self.all_datasets_ignore_classes_map = all_datasets_ignore_classes_map | |
| self.all_datasets_idx_map = all_datasets_idx_map | |
| self.target_domains_order = target_domains_order | |
| self.target_source_map = target_source_map | |
| self.all_datasets_e2e_class_to_idx_map = all_datasets_e2e_class_to_idx_map | |
| self.num_classes = num_classes | |
| self.cur_domain_index = 0 | |
| logger.info(f'[scenario build] # classes: {num_classes}') | |
| logger.debug(f'[scenario build] idx map: {all_datasets_idx_map}') | |
| def to_json(self): | |
| return dict( | |
| config=self.config, all_datasets_ignore_classes_map=self.all_datasets_ignore_classes_map, | |
| all_datasets_idx_map=self.all_datasets_idx_map, target_domains_order=self.target_domains_order, | |
| target_source_map=self.target_source_map, | |
| all_datasets_e2e_class_to_idx_map=self.all_datasets_e2e_class_to_idx_map, | |
| num_classes=self.num_classes | |
| ) | |
| def __str__(self): | |
| return f'Scenario({self.to_json()})' | |
| def get_offline_datasets(self, transform=None): | |
| # make source datasets which contains all unioned classes | |
| res_offline_train_source_datasets_map = {} | |
| from .. import get_dataset | |
| data_dirs = self.config['data_dirs'] | |
| source_datasets_name = self.config['source_datasets_name'] | |
| res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split, | |
| transform, | |
| self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d]) | |
| for split in ['train', 'val', 'test']} | |
| for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name} | |
| for source_dataset_name in self.config['source_datasets_name']: | |
| source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k] | |
| # how to merge idx map? | |
| # 35 79 97 | |
| idx_maps = [d['train'].idx_map for d in source_datasets] | |
| ignore_classes_list = [d['train'].ignore_classes for d in source_datasets] | |
| union_idx_map = {} | |
| for idx_map in idx_maps: | |
| for k, v in idx_map.items(): | |
| if k not in union_idx_map: | |
| union_idx_map[k] = v | |
| else: | |
| assert union_idx_map[k] == v | |
| union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0])) | |
| assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes) | |
| logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training') | |
| d = source_dataset_name | |
| res_offline_train_source_datasets_map[d] = {split: get_dataset(d, data_dirs[d], split, | |
| transform, | |
| union_ignore_classes, union_idx_map) | |
| for split in ['train', 'val', 'test']} | |
| return res_offline_train_source_datasets_map | |
| def get_offline_datasets_args(self): | |
| # make source datasets which contains all unioned classes | |
| res_offline_train_source_datasets_map = {} | |
| from .. import get_dataset | |
| data_dirs = self.config['data_dirs'] | |
| source_datasets_name = self.config['source_datasets_name'] | |
| res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split, | |
| None, | |
| self.all_datasets_ignore_classes_map[d], self.all_datasets_idx_map[d]) | |
| for split in ['train', 'val', 'test']} | |
| for d in self.all_datasets_ignore_classes_map.keys() if d.split('|')[0] in source_datasets_name} | |
| for source_dataset_name in self.config['source_datasets_name']: | |
| source_datasets = [v for k, v in res_source_datasets_map.items() if source_dataset_name in k] | |
| # how to merge idx map? | |
| # 35 79 97 | |
| idx_maps = [d['train'].idx_map for d in source_datasets] | |
| ignore_classes_list = [d['train'].ignore_classes for d in source_datasets] | |
| union_idx_map = {} | |
| for idx_map in idx_maps: | |
| for k, v in idx_map.items(): | |
| if k not in union_idx_map: | |
| union_idx_map[k] = v | |
| else: | |
| assert union_idx_map[k] == v | |
| union_ignore_classes = reduce(lambda res, cur: res & set(cur), ignore_classes_list, set(ignore_classes_list[0])) | |
| assert len(union_ignore_classes) + len(union_idx_map) == len(source_datasets[0]['train'].raw_classes) | |
| logger.info(f'[scenario build] {source_dataset_name} has {len(union_idx_map)} classes in offline training') | |
| d = source_dataset_name | |
| res_offline_train_source_datasets_map[d] = {split: dict(d, data_dirs[d], split, | |
| None, | |
| union_ignore_classes, union_idx_map) | |
| for split in ['train', 'val', 'test']} | |
| return res_offline_train_source_datasets_map | |
| # for d in source_datasets_name: | |
| # source_dataset_with_max_num_classes = None | |
| # for ed_name, ed in res_source_datasets_map.items(): | |
| # if not ed_name.startswith(d): | |
| # continue | |
| # if source_dataset_with_max_num_classes is None: | |
| # source_dataset_with_max_num_classes = ed | |
| # res_offline_train_source_datasets_map_names[d] = ed_name | |
| # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes): | |
| # source_dataset_with_max_num_classes = ed | |
| # res_offline_train_source_datasets_map_names[d] = ed_name | |
| # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes | |
| # return res_offline_train_source_datasets_map | |
| def get_online_ith_domain_datasets_args_for_inference(self, domain_index): | |
| target_dataset_name = self.target_domains_order[domain_index] | |
| # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None | |
| if 'MM-CityscapesDet' in self.target_domains_order or 'CityscapesDet' in self.target_domains_order or 'BaiduPersonDet' in self.target_domains_order: | |
| logger.info(f'use val split for inference test (only Det workload)') | |
| split = 'test' | |
| else: | |
| split = 'train' | |
| return dict(dataset_name=target_dataset_name, | |
| root_dir=self.config['data_dirs'][target_dataset_name], | |
| split=split, | |
| transform=None, | |
| ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name], | |
| idx_map=self.all_datasets_idx_map[target_dataset_name]) | |
| def get_online_ith_domain_datasets_args_for_training(self, domain_index): | |
| target_dataset_name = self.target_domains_order[domain_index] | |
| source_datasets_name = list(self.target_source_map[target_dataset_name].keys()) | |
| res = {} | |
| # dataset_name: Any, root_dir: Any, split: Any, transform: Any | None = None, ignore_classes: Any = [], idx_map: Any | None = None | |
| res[target_dataset_name] = {split: dict(dataset_name=target_dataset_name, | |
| root_dir=self.config['data_dirs'][target_dataset_name], | |
| split=split, | |
| transform=None, | |
| ignore_classes=self.all_datasets_ignore_classes_map[target_dataset_name], | |
| idx_map=self.all_datasets_idx_map[target_dataset_name]) for split in ['train', 'val']} | |
| for d in source_datasets_name: | |
| res[d] = {split: dict(dataset_name=d, | |
| root_dir=self.config['data_dirs'][d], | |
| split=split, | |
| transform=None, | |
| ignore_classes=self.all_datasets_ignore_classes_map[d + '|' + target_dataset_name], | |
| idx_map=self.all_datasets_idx_map[d + '|' + target_dataset_name]) for split in ['train', 'val']} | |
| return res | |
| def get_online_cur_domain_datasets_args_for_inference(self): | |
| return self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index) | |
| def get_online_cur_domain_datasets_args_for_training(self): | |
| return self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index) | |
| def get_online_cur_domain_datasets_for_training(self, transform=None): | |
| res = {} | |
| datasets_args = self.get_online_ith_domain_datasets_args_for_training(self.cur_domain_index) | |
| for dataset_name, dataset_args in datasets_args.items(): | |
| res[dataset_name] = {} | |
| for split, args in dataset_args.items(): | |
| if transform is not None: | |
| args['transform'] = transform | |
| dataset = get_dataset(**args) | |
| res[dataset_name][split] = dataset | |
| return res | |
| def get_online_cur_domain_datasets_for_inference(self, transform=None): | |
| datasets_args = self.get_online_ith_domain_datasets_args_for_inference(self.cur_domain_index) | |
| if transform is not None: | |
| datasets_args['transform'] = transform | |
| return get_dataset(**datasets_args) | |
| def get_online_cur_domain_samples_for_training(self, num_samples, transform=None, collate_fn=None): | |
| dataset = self.get_online_cur_domain_datasets_for_training(transform=transform) | |
| dataset = dataset[self.target_domains_order[self.cur_domain_index]]['train'] | |
| return next(iter(build_dataloader(dataset, num_samples, 0, True, None, collate_fn=collate_fn)))[0] | |
| def next_domain(self): | |
| self.cur_domain_index += 1 | |