@@ -101,8 +101,13 @@ def forward(self, img1, img2, flow, valid_flow_mask):
101101
102102class RandomErasing (T .RandomErasing ):
103103 # This only erases img2, and with an extra max_erase param
104+ # This max_erase is needed because in the RAFT training ref does:
105+ # 0 erasing with .5 proba
106+ # 1 erase with .25 proba
107+ # 2 erase with .25 proba
108+ # and there's no accurate way to achieve this otherwise.
104109 def __init__ (self , p = 0.5 , scale = (0.02 , 0.33 ), ratio = (0.3 , 3.3 ), value = 0 , inplace = False , max_erase = 1 ):
105- super ().__init__ ()
110+ super ().__init__ (p = p , scale = scale , ratio = ratio , value = value , inplace = inplace )
106111 self .max_erase = max_erase
107112 assert self .max_erase > 0
108113
@@ -171,12 +176,12 @@ def forward(self, img1, img2, flow, valid_flow_mask):
171176 # It shouldn't matter much
172177 min_scale = max ((self .crop_size [0 ] + 8 ) / h , (self .crop_size [1 ] + 8 ) / w )
173178
174- scale = 2 ** torch .FloatTensor ( 1 ).uniform_ (self .min_scale , self .max_scale ).item ()
179+ scale = 2 ** torch .empty ( 1 , dtype = torch . float32 ).uniform_ (self .min_scale , self .max_scale ).item ()
175180 scale_x = scale
176181 scale_y = scale
177182 if torch .rand (1 ) < self .stretch_prob :
178- scale_x *= 2 ** torch .FloatTensor ( 1 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
179- scale_y *= 2 ** torch .FloatTensor ( 1 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
183+ scale_x *= 2 ** torch .empty ( 1 , dtype = torch . float32 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
184+ scale_y *= 2 ** torch .empty ( 1 , dtype = torch . float32 ).uniform_ (- self .max_stretch , self .max_stretch ).item ()
180185
181186 scale_x = max (scale_x , min_scale )
182187 scale_y = max (scale_y , min_scale )
@@ -245,8 +250,9 @@ def _resize_sparse_flow(self, flow, valid_flow_mask, scale_x=1.0, scale_y=1.0):
245250 return flow_new , valid_new
246251
247252
248- class Compose :
253+ class Compose ( torch . nn . Module ) :
249254 def __init__ (self , transforms ):
255+ super ().__init__ ()
250256 self .transforms = transforms
251257
252258 def forward (self , img1 , img2 , flow , valid_flow_mask ):
0 commit comments