@@ -15,7 +15,7 @@ class RASampler(torch.utils.data.Sampler):
1515 https://github.com/facebookresearch/deit/blob/main/samplers.py
1616 """
1717
18- def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = True , seed = 0 ):
18+ def __init__ (self , dataset , num_replicas = None , rank = None , shuffle = True , seed = 0 , repetitions = 3 ):
1919 if num_replicas is None :
2020 if not dist .is_available ():
2121 raise RuntimeError ("Requires distributed package to be available!" )
@@ -28,11 +28,12 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, seed=0):
2828 self .num_replicas = num_replicas
2929 self .rank = rank
3030 self .epoch = 0
31- self .num_samples = int (math .ceil (len (self .dataset ) * 3.0 / self .num_replicas ))
31+ self .num_samples = int (math .ceil (len (self .dataset ) * float ( repetitions ) / self .num_replicas ))
3232 self .total_size = self .num_samples * self .num_replicas
3333 self .num_selected_samples = int (math .floor (len (self .dataset ) // 256 * 256 / self .num_replicas ))
3434 self .shuffle = shuffle
3535 self .seed = seed
36+ self .repetitions = repetitions
3637
3738 def __iter__ (self ):
3839 # Deterministically shuffle based on epoch
@@ -44,7 +45,7 @@ def __iter__(self):
4445 indices = list (range (len (self .dataset )))
4546
4647 # Add extra samples to make it evenly divisible
47- indices = [ele for ele in indices for i in range (3 )]
48+ indices = [ele for ele in indices for i in range (self . repetitions )]
4849 indices += indices [: (self .total_size - len (indices ))]
4950 assert len (indices ) == self .total_size
5051
0 commit comments