Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| # @Time : 2021/12/2 5:41 p.m. | |
| # @Author : JianingWang | |
| # @File : sampler.py | |
| import numpy as np | |
| from typing import Optional | |
| """ | |
| random sampling for each label | |
| """ | |
| def random_sampling(raw_datasets, num_examples_per_label: Optional[int]=16): | |
| label_list = raw_datasets["label"] # [0, 1, 0, 0, ...] | |
| label_dict = dict() | |
| # denote index of each label | |
| for ei, label in enumerate(label_list): | |
| if label not in label_dict.keys(): | |
| label_dict[label] = list() | |
| label_dict[label].append(ei) | |
| # random sample k examples of each class | |
| few_example_ids = list() | |
| for label, eid_list in label_dict.items(): | |
| idxs = np.random.choice(len(eid_list), size=num_examples_per_label, replace=False) | |
| selected_eids = [eid_list[i] for i in idxs] | |
| few_example_ids.extend(selected_eids) | |
| return few_example_ids | |