Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
| import unittest | |
| from torch.utils.data.sampler import SequentialSampler | |
| from detectron2.data.samplers import GroupedBatchSampler | |
| class TestGroupedBatchSampler(unittest.TestCase): | |
| def test_missing_group_id(self): | |
| sampler = SequentialSampler(list(range(100))) | |
| group_ids = [1] * 100 | |
| samples = GroupedBatchSampler(sampler, group_ids, 2) | |
| for mini_batch in samples: | |
| self.assertEqual(len(mini_batch), 2) | |
| def test_groups(self): | |
| sampler = SequentialSampler(list(range(100))) | |
| group_ids = [1, 0] * 50 | |
| samples = GroupedBatchSampler(sampler, group_ids, 2) | |
| for mini_batch in samples: | |
| self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0) | |