@@ -9,9 +9,32 @@ class DistributedSampler(Sampler):
99 """
1010 Extension of DistributedSampler, as discussed in
1111 https://github.com/pytorch/pytorch/issues/23430
12+
13+ Example:
14+ dataset: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
15+ num_replicas: 4
16+ shuffle: False
17+
18+ when group_size = 1
19+ RANK | shard_dataset
20+ =========================
21+ rank_0 | [0, 4, 8, 12]
22+ rank_1 | [1, 5, 9, 13]
23+ rank_2 | [2, 6, 10, 0]
24+ rank_3 | [3, 7, 11, 1]
25+
26+ when group_size = 2
27+
28+ RANK | shard_dataset
29+ =========================
30+ rank_0 | [0, 1, 8, 9]
31+ rank_1 | [2, 3, 10, 11]
32+ rank_2 | [4, 5, 12, 13]
33+ rank_3 | [6, 7, 0, 1]
34+
1235 """
1336
14- def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = False ):
37+ def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = False , group_size = 1 ):
1538 if num_replicas is None :
1639 if not dist .is_available ():
1740 raise RuntimeError ("Requires distributed package to be available" )
@@ -20,11 +43,20 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=False):
2043 if not dist .is_available ():
2144 raise RuntimeError ("Requires distributed package to be available" )
2245 rank = dist .get_rank ()
46+ assert len (dataset ) % group_size == 0 , (
47+ "dataset length must be a multiplier of group size"
48+ "dataset length: %d, group size: %d" % (len (dataset ), group_size )
49+ )
2350 self .dataset = dataset
51+ self .group_size = group_size
2452 self .num_replicas = num_replicas
2553 self .rank = rank
2654 self .epoch = 0
27- self .num_samples = int (math .ceil (len (self .dataset ) * 1.0 / self .num_replicas ))
55+ dataset_group_length = len (dataset ) // group_size
56+ self .num_group_samples = int (
57+ math .ceil (dataset_group_length * 1.0 / self .num_replicas )
58+ )
59+ self .num_samples = self .num_group_samples * group_size
2860 self .total_size = self .num_samples * self .num_replicas
2961 self .shuffle = shuffle
3062
@@ -41,8 +73,14 @@ def __iter__(self):
4173 indices += indices [:(self .total_size - len (indices ))]
4274 assert len (indices ) == self .total_size
4375
76+ total_group_size = self .total_size // self .group_size
77+ indices = torch .reshape (
78+ torch .LongTensor (indices ), (total_group_size , self .group_size )
79+ )
80+
4481 # subsample
45- indices = indices [self .rank :self .total_size :self .num_replicas ]
82+ indices = indices [self .rank :total_group_size :self .num_replicas , :]
83+ indices = torch .reshape (indices , (- 1 ,)).tolist ()
4684 assert len (indices ) == self .num_samples
4785
4886 if isinstance (self .dataset , Sampler ):
0 commit comments