|
|
|
import unittest |
|
from torch.utils.data.sampler import SequentialSampler |
|
|
|
from detectron2.data.samplers import GroupedBatchSampler |
|
|
|
|
|
class TestGroupedBatchSampler(unittest.TestCase): |
|
def test_missing_group_id(self): |
|
sampler = SequentialSampler(list(range(100))) |
|
group_ids = [1] * 100 |
|
samples = GroupedBatchSampler(sampler, group_ids, 2) |
|
|
|
for mini_batch in samples: |
|
self.assertEqual(len(mini_batch), 2) |
|
|
|
def test_groups(self): |
|
sampler = SequentialSampler(list(range(100))) |
|
group_ids = [1, 0] * 50 |
|
samples = GroupedBatchSampler(sampler, group_ids, 2) |
|
|
|
for mini_batch in samples: |
|
self.assertEqual((mini_batch[0] + mini_batch[1]) % 2, 0) |
|
|