@@ -511,7 +511,6 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs):
511511 atol = kwargs ['atol' ]
512512 del kwargs ['atol' ]
513513 kwargs_compare ['atol' ] = atol
514- print (kwargs )
515514
516515 if 'rtol' in kwargs :
517516 rtol = kwargs ['rtol' ]
@@ -520,13 +519,16 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs):
520519
521520 # Single then transform then batch
522521
523- expected = functional (tensor , * args , ** kwargs )
522+ torch .random .manual_seed (42 )
523+ expected = functional (tensor .clone (), * args , ** kwargs )
524524 expected = expected .unsqueeze (0 ).unsqueeze (0 )
525525
526526 # 1-Batch then transform
527527
528528 tensors = tensor .unsqueeze (0 ).unsqueeze (0 )
529- computed = functional (tensors , * args , ** kwargs )
529+
530+ torch .random .manual_seed (42 )
531+ computed = functional (tensors .clone (), * args , ** kwargs )
530532
531533 self ._compare_estimate (computed , expected , ** kwargs_compare )
532534
@@ -555,19 +557,30 @@ def _test_batch(self, functional, tensor, *args, **kwargs):
555557 ind = [3 ] + [1 ] * (int (expected .dim ()) - 1 )
556558 expected = expected .repeat (* ind )
557559
558- computed = functional (tensors , * args , ** kwargs )
560+ torch .random .manual_seed (42 )
561+ computed = functional (tensors .clone (), * args , ** kwargs )
559562
560563 self ._compare_estimate (computed , expected , ** kwargs_compare )
561564
562565 def test_batch_mask_along_axis_iid (self ):
563566
567+ mask_param = 2
568+ mask_value = 30.
569+ axis = 2
570+
571+ tensor = torch .rand (2 , 5 , 5 )
572+
573+ self ._test_batch (F .mask_along_axis_iid , tensor , mask_param = mask_param , mask_value = mask_value , axis = axis , atol = 1e-1 , rtol = 1e-1 )
574+
575+ def test_batch_mask_along_axis (self ):
576+
564577 tensor = torch .rand (2 , 5 , 5 )
565578
566579 mask_param = 2
567580 mask_value = 30.
568581 axis = 2
569582
570- self ._test_batch_shape (F .mask_along_axis_iid , tensor , mask_param = mask_param , mask_value = mask_value , axis = axis )
583+ self ._test_batch (F .mask_along_axis , tensor , mask_param = mask_param , mask_value = mask_value , axis = axis )
571584
572585 def test_torchscript_create_fb_matrix (self ):
573586
0 commit comments