Spaces:
Running
Running
| from typing import Dict, List, Optional, Type, Union | |
| from ..datasets.ab_dataset import ABDataset | |
| # from benchmark.data.visualize import visualize_classes_in_object_detection | |
| # from benchmark.scenario.val_domain_shift import get_val_domain_shift_transform | |
| from ..dataset import get_dataset | |
| import copy | |
| from torchvision.transforms import Compose | |
| from .merge_alias import merge_the_same_meaning_classes | |
| from ..datasets.registery import static_dataset_registery | |
| # some legacy aliases of variables: | |
| # ignore_classes == discarded classes | |
| # private_classes == unknown classes in partial / open-set / universal DA | |
| def _merge_the_same_meaning_classes(classes_info_of_all_datasets): | |
| final_classes_of_all_datasets, rename_map = merge_the_same_meaning_classes(classes_info_of_all_datasets) | |
| return final_classes_of_all_datasets, rename_map | |
| def _find_ignore_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode): | |
| thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode] | |
| from functools import reduce | |
| a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set()) | |
| if set(a_classes) == set(b_classes): | |
| # a is equal to b, normal | |
| # 1. no ignore classes; 2. match class idx | |
| a_ignore_classes, b_ignore_classes = [], [] | |
| elif set(a_classes) > set(b_classes): | |
| # a contains b, partial | |
| a_ignore_classes, b_ignore_classes = [], [] | |
| if thres == 3 or thres == 1: # ignore extra classes in a | |
| a_ignore_classes = set(a_classes) - set(b_classes) | |
| elif set(a_classes) < set(b_classes): | |
| # a is contained by b, open set | |
| a_ignore_classes, b_ignore_classes = [], [] | |
| if thres == 3 or thres == 2: # ignore extra classes in b | |
| b_ignore_classes = set(b_classes) - set(a_classes) | |
| elif len(set(a_classes) & set(b_classes)) > 0: | |
| a_ignore_classes, b_ignore_classes = [], [] | |
| if thres == 3: | |
| a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes)) | |
| b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes)) | |
| elif thres == 2: | |
| b_ignore_classes = set(b_classes) - (set(a_classes) & set(b_classes)) | |
| elif thres == 1: | |
| a_ignore_classes = set(a_classes) - (set(a_classes) & set(b_classes)) | |
| else: | |
| return None # a has no intersection with b, none | |
| as_ignore_classes = [list(set(a_classes) & set(a_ignore_classes)) for a_classes in as_classes] | |
| return as_ignore_classes, list(b_ignore_classes) | |
| def _find_private_classes_when_sources_as_to_target_b(as_classes: List[List[str]], b_classes: List[str], da_mode): | |
| thres = {'da': 3, 'partial_da': 2, 'open_set_da': 1, 'universal_da': 0}[da_mode] | |
| from functools import reduce | |
| a_classes = reduce(lambda res, cur: res | set(cur), as_classes, set()) | |
| if set(a_classes) == set(b_classes): | |
| # a is equal to b, normal | |
| # 1. no ignore classes; 2. match class idx | |
| a_private_classes, b_private_classes = [], [] | |
| elif set(a_classes) > set(b_classes): | |
| # a contains b, partial | |
| a_private_classes, b_private_classes = [], [] | |
| # if thres == 2 or thres == 0: # ignore extra classes in a | |
| # a_private_classes = set(a_classes) - set(b_classes) | |
| # if thres == 0: # ignore extra classes in a | |
| # a_private_classes = set(a_classes) - set(b_classes) | |
| elif set(a_classes) < set(b_classes): | |
| # a is contained by b, open set | |
| a_private_classes, b_private_classes = [], [] | |
| if thres == 1 or thres == 0: # ignore extra classes in b | |
| b_private_classes = set(b_classes) - set(a_classes) | |
| elif len(set(a_classes) & set(b_classes)) > 0: | |
| a_private_classes, b_private_classes = [], [] | |
| if thres == 0: | |
| # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes)) | |
| b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes)) | |
| elif thres == 1: | |
| b_private_classes = set(b_classes) - (set(a_classes) & set(b_classes)) | |
| elif thres == 2: | |
| # a_private_classes = set(a_classes) - (set(a_classes) & set(b_classes)) | |
| pass | |
| else: | |
| return None # a has no intersection with b, none | |
| return list(b_private_classes) | |
| class _ABDatasetMetaInfo: | |
| def __init__(self, name, classes, task_type, object_type, class_aliases, shift_type): | |
| self.name = name | |
| self.classes = classes | |
| self.class_aliases = class_aliases | |
| self.shift_type = shift_type | |
| self.task_type = task_type | |
| self.object_type = object_type | |
| def _get_dist_shift_type_when_source_a_to_target_b(a: _ABDatasetMetaInfo, b: _ABDatasetMetaInfo): | |
| if b.shift_type is None: | |
| return 'Dataset Shifts' | |
| if a.name in b.shift_type.keys(): | |
| return b.shift_type[a.name] | |
| mid_dataset_name = list(b.shift_type.keys())[0] | |
| mid_dataset_meta_info = _ABDatasetMetaInfo(mid_dataset_name, *static_dataset_registery[mid_dataset_name][1:]) | |
| return _get_dist_shift_type_when_source_a_to_target_b(a, mid_dataset_meta_info) + ' + ' + list(b.shift_type.values())[0] | |
| def _handle_all_datasets_v2(source_datasets: List[_ABDatasetMetaInfo], target_datasets: List[_ABDatasetMetaInfo], da_mode): | |
| # 1. merge the same meaning classes | |
| classes_info_of_all_datasets = { | |
| d.name: (d.classes, d.class_aliases) | |
| for d in source_datasets + target_datasets | |
| } | |
| final_classes_of_all_datasets, rename_map = _merge_the_same_meaning_classes(classes_info_of_all_datasets) | |
| all_datasets_classes = copy.deepcopy(final_classes_of_all_datasets) | |
| # print(all_datasets_known_classes) | |
| # 2. find ignored classes according to DA mode | |
| # source_datasets_ignore_classes, target_datasets_ignore_classes = {d.name: [] for d in source_datasets}, \ | |
| # {d.name: [] for d in target_datasets} | |
| # source_datasets_private_classes, target_datasets_private_classes = {d.name: [] for d in source_datasets}, \ | |
| # {d.name: [] for d in target_datasets} | |
| target_source_relationship_map = {td.name: {} for td in target_datasets} | |
| # source_target_relationship_map = {sd.name: [] for sd in source_datasets} | |
| # 1. construct target_source_relationship_map | |
| for sd in source_datasets:#sd和td使列表中每一个元素(类)的实例 | |
| for td in target_datasets: | |
| sc = all_datasets_classes[sd.name] | |
| tc = all_datasets_classes[td.name] | |
| if len(set(sc) & set(tc)) == 0:#只保留有相似类别的源域和目标域 | |
| continue | |
| target_source_relationship_map[td.name][sd.name] = _get_dist_shift_type_when_source_a_to_target_b(sd, td) | |
| # print(target_source_relationship_map) | |
| # exit() | |
| source_datasets_ignore_classes = {} | |
| for td_name, v1 in target_source_relationship_map.items(): | |
| for sd_name, v2 in v1.items(): | |
| source_datasets_ignore_classes[sd_name + '|' + td_name] = [] | |
| target_datasets_ignore_classes = {d.name: [] for d in target_datasets} | |
| target_datasets_private_classes = {d.name: [] for d in target_datasets} | |
| # 保证对于每个目标域上的DA都符合给定的label shift | |
| # 所以不同目标域就算对应同一个源域,该源域也可能不相同 | |
| for td_name, v1 in target_source_relationship_map.items(): | |
| sd_names = list(v1.keys()) | |
| sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names] | |
| td_classes = all_datasets_classes[td_name] | |
| ss_ignore_classes, t_ignore_classes = _find_ignore_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode)#根据DA方式不同产生ignore_classes | |
| t_private_classes = _find_private_classes_when_sources_as_to_target_b(sds_classes, td_classes, da_mode) | |
| for sd_name, s_ignore_classes in zip(sd_names, ss_ignore_classes): | |
| source_datasets_ignore_classes[sd_name + '|' + td_name] = s_ignore_classes | |
| target_datasets_ignore_classes[td_name] = t_ignore_classes | |
| target_datasets_private_classes[td_name] = t_private_classes | |
| source_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in source_datasets_ignore_classes.items()} | |
| target_datasets_ignore_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_ignore_classes.items()} | |
| target_datasets_private_classes = {k: sorted(set(v), key=v.index) for k, v in target_datasets_private_classes.items()} | |
| # for k, v in source_datasets_ignore_classes.items(): | |
| # print(k, len(v)) | |
| # print() | |
| # for k, v in target_datasets_ignore_classes.items(): | |
| # print(k, len(v)) | |
| # print() | |
| # for k, v in target_datasets_private_classes.items(): | |
| # print(k, len(v)) | |
| # print() | |
| # print(source_datasets_private_classes, target_datasets_private_classes) | |
| # 3. reparse classes idx | |
| # 3.1. agg all used classes | |
| # all_used_classes = [] | |
| # all_datasets_private_class_idx_map = {} | |
| # source_datasets_classes_idx_map = {} | |
| # for td_name, v1 in target_source_relationship_map.items(): | |
| # for sd_name, v2 in v1.items(): | |
| # source_datasets_classes_idx_map[sd_name + '|' + td_name] = [] | |
| # target_datasets_classes_idx_map = {} | |
| global_idx = 0 | |
| all_used_classes_idx_map = {} | |
| # all_datasets_known_classes = {d: [] for d in final_classes_of_all_datasets.keys()} | |
| for dataset_name, classes in all_datasets_classes.items(): | |
| if dataset_name not in target_datasets_ignore_classes.keys(): | |
| ignore_classes = [0] * 100000 | |
| for sn, sic in source_datasets_ignore_classes.items(): | |
| if sn.startswith(dataset_name): | |
| if len(sic) < len(ignore_classes): | |
| ignore_classes = sic | |
| else: | |
| ignore_classes = target_datasets_ignore_classes[dataset_name] | |
| private_classes = [] \ | |
| if dataset_name not in target_datasets_ignore_classes.keys() else target_datasets_private_classes[dataset_name] | |
| for c in classes: | |
| if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c not in private_classes: | |
| all_used_classes_idx_map[c] = global_idx | |
| global_idx += 1 | |
| # print(all_used_classes_idx_map) | |
| # dataset_private_class_idx_offset = 0 | |
| target_private_class_idx = global_idx | |
| target_datasets_private_class_idx = {d: None for d in target_datasets_private_classes.keys()} | |
| for dataset_name, classes in final_classes_of_all_datasets.items(): | |
| if dataset_name not in target_datasets_private_classes.keys(): | |
| continue | |
| # ignore_classes = target_datasets_ignore_classes[dataset_name] | |
| private_classes = target_datasets_private_classes[dataset_name] | |
| # private_classes = [] \ | |
| # if dataset_name in source_datasets_private_classes.keys() else target_datasets_private_classes[dataset_name] | |
| # for c in classes: | |
| # if c not in ignore_classes and c not in all_used_classes_idx_map.keys() and c in private_classes: | |
| # all_used_classes_idx_map[c] = global_idx + dataset_private_class_idx_offset | |
| if len(private_classes) > 0: | |
| # all_datasets_private_class_idx[dataset_name] = global_idx + dataset_private_class_idx_offset | |
| # dataset_private_class_idx_offset += 1 | |
| # if dataset_name in source_datasets_private_classes.keys(): | |
| # if source_private_class_idx is None: | |
| # source_private_class_idx = global_idx if target_private_class_idx is None else target_private_class_idx + 1 | |
| # all_datasets_private_class_idx[dataset_name] = source_private_class_idx | |
| # else: | |
| # if target_private_class_idx is None: | |
| # target_private_class_idx = global_idx if source_private_class_idx is None else source_private_class_idx + 1 | |
| # all_datasets_private_class_idx[dataset_name] = target_private_class_idx | |
| target_datasets_private_class_idx[dataset_name] = target_private_class_idx | |
| target_private_class_idx += 1 | |
| # all_used_classes = sorted(set(all_used_classes), key=all_used_classes.index) | |
| # all_used_classes_idx_map = {c: i for i, c in enumerate(all_used_classes)} | |
| # print('rename_map', rename_map) | |
| # 3.2 raw_class -> rename_map[raw_classes] -> all_used_classes_idx_map | |
| all_datasets_e2e_idx_map = {} | |
| all_datasets_e2e_class_to_idx_map = {} | |
| for td_name, v1 in target_source_relationship_map.items(): | |
| sd_names = list(v1.keys()) | |
| sds_classes = [all_datasets_classes[sd_name] for sd_name in sd_names] | |
| td_classes = all_datasets_classes[td_name] | |
| for sd_name, sd_classes in zip(sd_names, sds_classes): | |
| cur_e2e_idx_map = {} | |
| cur_e2e_class_to_idx_map = {} | |
| for raw_ci, raw_c in enumerate(sd_classes): | |
| renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c] | |
| ignore_classes = source_datasets_ignore_classes[sd_name + '|' + td_name] | |
| if renamed_c in ignore_classes: | |
| continue | |
| idx = all_used_classes_idx_map[renamed_c] | |
| cur_e2e_idx_map[raw_ci] = idx | |
| cur_e2e_class_to_idx_map[raw_c] = idx | |
| all_datasets_e2e_idx_map[sd_name + '|' + td_name] = cur_e2e_idx_map | |
| all_datasets_e2e_class_to_idx_map[sd_name + '|' + td_name] = cur_e2e_class_to_idx_map | |
| cur_e2e_idx_map = {} | |
| cur_e2e_class_to_idx_map = {} | |
| for raw_ci, raw_c in enumerate(td_classes): | |
| renamed_c = raw_c if raw_c not in rename_map[dataset_name] else rename_map[dataset_name][raw_c] | |
| ignore_classes = target_datasets_ignore_classes[td_name] | |
| if renamed_c in ignore_classes: | |
| continue | |
| if renamed_c in target_datasets_private_classes[td_name]: | |
| idx = target_datasets_private_class_idx[td_name] | |
| else: | |
| idx = all_used_classes_idx_map[renamed_c] | |
| cur_e2e_idx_map[raw_ci] = idx | |
| cur_e2e_class_to_idx_map[raw_c] = idx | |
| all_datasets_e2e_idx_map[td_name] = cur_e2e_idx_map | |
| all_datasets_e2e_class_to_idx_map[td_name] = cur_e2e_class_to_idx_map | |
| all_datasets_ignore_classes = {**source_datasets_ignore_classes, **target_datasets_ignore_classes} | |
| # all_datasets_private_classes = {**source_datasets_private_classes, **target_datasets_private_classes} | |
| classes_idx_set = [] | |
| for d, m in all_datasets_e2e_class_to_idx_map.items(): | |
| classes_idx_set += list(m.values()) | |
| classes_idx_set = set(classes_idx_set) | |
| num_classes = len(classes_idx_set) | |
| return all_datasets_ignore_classes, target_datasets_private_classes, \ | |
| all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ | |
| target_source_relationship_map, rename_map, num_classes | |
| def _build_scenario_info_v2( | |
| source_datasets_name: List[str], | |
| target_datasets_order: List[str], | |
| da_mode: str | |
| ): | |
| assert da_mode in ['close_set', 'partial', 'open_set', 'universal'] | |
| da_mode = {'close_set': 'da', 'partial': 'partial_da', 'open_set': 'open_set_da', 'universal': 'universal_da'}[da_mode] | |
| source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name]#获知对应的名字和对应属性,要添加数据集时,直接register就行 | |
| target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))] | |
| all_datasets_ignore_classes, target_datasets_private_classes, \ | |
| all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ | |
| target_source_relationship_map, rename_map, num_classes \ | |
| = _handle_all_datasets_v2(source_datasets_meta_info, target_datasets_meta_info, da_mode) | |
| return all_datasets_ignore_classes, target_datasets_private_classes, \ | |
| all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ | |
| target_source_relationship_map, rename_map, num_classes | |
| def build_scenario_manually_v2( | |
| source_datasets_name: List[str], | |
| target_datasets_order: List[str], | |
| da_mode: str, | |
| data_dirs: Dict[str, str], | |
| # transforms: Optional[Dict[str, Compose]] = None | |
| ): | |
| configs = copy.deepcopy(locals())#返回当前局部变量 | |
| source_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in source_datasets_name] | |
| target_datasets_meta_info = [_ABDatasetMetaInfo(d, *static_dataset_registery[d][1:]) for d in list(set(target_datasets_order))] | |
| all_datasets_ignore_classes, target_datasets_private_classes, \ | |
| all_datasets_e2e_idx_map, all_datasets_e2e_class_to_idx_map, target_datasets_private_class_idx, \ | |
| target_source_relationship_map, rename_map, num_classes \ | |
| = _build_scenario_info_v2(source_datasets_name, target_datasets_order, da_mode) | |
| # from rich.console import Console | |
| # console = Console(width=10000) | |
| # def print_obj(_o): | |
| # # import pprint | |
| # # s = pprint.pformat(_o, width=140, compact=True) | |
| # console.print(_o) | |
| # console.print('configs:', style='bold red') | |
| # print_obj(configs) | |
| # console.print('renamed classes:', style='bold red') | |
| # print_obj(rename_map) | |
| # console.print('discarded classes:', style='bold red') | |
| # print_obj(all_datasets_ignore_classes) | |
| # console.print('unknown classes:', style='bold red') | |
| # print_obj(target_datasets_private_classes) | |
| # console.print('class to index map:', style='bold red') | |
| # print_obj(all_datasets_e2e_class_to_idx_map) | |
| # console.print('index map:', style='bold red') | |
| # print_obj(all_datasets_e2e_idx_map) | |
| # console = Console() | |
| # # console.print('class distribution:', style='bold red') | |
| # # class_dist = { | |
| # # k: { | |
| # # '#known classes': len(all_datasets_known_classes[k]), | |
| # # '#unknown classes': len(all_datasets_private_classes[k]), | |
| # # '#discarded classes': len(all_datasets_ignore_classes[k]) | |
| # # } for k in all_datasets_ignore_classes.keys() | |
| # # } | |
| # # print_obj(class_dist) | |
| # console.print('corresponding sources of each target:', style='bold red') | |
| # print_obj(target_source_relationship_map) | |
| # return | |
| # res_source_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None), | |
| # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) | |
| # for split in ['train', 'val', 'test']} | |
| # for d in source_datasets_name} | |
| # res_target_datasets_map = {d: {'train': get_num_limited_dataset(get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None), | |
| # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]), | |
| # num_samples_in_each_target_domain), | |
| # 'test': get_dataset(d, data_dirs[d], 'test', getattr(transforms, d, None), | |
| # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) | |
| # } | |
| # for d in list(set(target_datasets_order))} | |
| # res_source_datasets_map = {d: {split: get_dataset(d.split('|')[0], data_dirs[d.split('|')[0]], split, | |
| # getattr(transforms, d.split('|')[0], None), | |
| # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) | |
| # for split in ['train', 'val', 'test']} | |
| # for d in all_datasets_ignore_classes.keys() if d.split('|')[0] in source_datasets_name} | |
| # from functools import reduce | |
| # res_offline_train_source_datasets_map = {} | |
| # res_offline_train_source_datasets_map_names = {} | |
| # for d in source_datasets_name: | |
| # source_dataset_with_max_num_classes = None | |
| # for ed_name, ed in res_source_datasets_map.items(): | |
| # if not ed_name.startswith(d): | |
| # continue | |
| # if source_dataset_with_max_num_classes is None: | |
| # source_dataset_with_max_num_classes = ed | |
| # res_offline_train_source_datasets_map_names[d] = ed_name | |
| # if len(ed['train'].ignore_classes) < len(source_dataset_with_max_num_classes['train'].ignore_classes): | |
| # source_dataset_with_max_num_classes = ed | |
| # res_offline_train_source_datasets_map_names[d] = ed_name | |
| # res_offline_train_source_datasets_map[d] = source_dataset_with_max_num_classes | |
| # res_target_datasets_map = {d: {split: get_dataset(d, data_dirs[d], split, getattr(transforms, d, None), | |
| # all_datasets_ignore_classes[d], all_datasets_e2e_idx_map[d]) | |
| # for split in ['train', 'val', 'test']} | |
| # for d in list(set(target_datasets_order))} | |
| from .scenario import Scenario, DatasetMetaInfo | |
| # test_scenario = Scenario( | |
| # config=configs, | |
| # offline_source_datasets_meta_info={ | |
| # d: DatasetMetaInfo(d, | |
| # {k: v for k, v in all_datasets_e2e_class_to_idx_map[res_offline_train_source_datasets_map_names[d]].items()}, | |
| # None) | |
| # for d in source_datasets_name | |
| # }, | |
| # offline_source_datasets={d: res_offline_train_source_datasets_map[d] for d in source_datasets_name}, | |
| # online_datasets_meta_info=[ | |
| # ( | |
| # {sd + '|' + d: DatasetMetaInfo(d, | |
| # {k: v for k, v in all_datasets_e2e_class_to_idx_map[sd + '|' + d].items()}, | |
| # None) | |
| # for sd in target_source_relationship_map[d].keys()}, | |
| # DatasetMetaInfo(d, | |
| # {k: v for k, v in all_datasets_e2e_class_to_idx_map[d].items() if k not in target_datasets_private_classes[d]}, | |
| # target_datasets_private_class_idx[d]) | |
| # ) | |
| # for d in target_datasets_order | |
| # ], | |
| # online_datasets={**res_source_datasets_map, **res_target_datasets_map}, | |
| # target_domains_order=target_datasets_order, | |
| # target_source_map=target_source_relationship_map, | |
| # num_classes=num_classes | |
| # ) | |
| import os | |
| os.environ['_ZQL_NUMC'] = str(num_classes) | |
| test_scenario = Scenario(config=configs, all_datasets_ignore_classes_map=all_datasets_ignore_classes, | |
| all_datasets_idx_map=all_datasets_e2e_idx_map, | |
| target_domains_order=target_datasets_order, | |
| target_source_map=target_source_relationship_map, | |
| all_datasets_e2e_class_to_idx_map=all_datasets_e2e_class_to_idx_map, | |
| num_classes=num_classes) | |
| return test_scenario | |
| if __name__ == '__main__': | |
| test_scenario = build_scenario_manually_v2(['CIFAR10', 'SVHN'], | |
| ['STL10', 'MNIST', 'STL10', 'USPS', 'MNIST', 'STL10'], | |
| 'close_set') | |
| print(test_scenario.num_classes) | |