@@ -458,7 +458,7 @@ def test_new_empty_tensor(self):
458458
459459
460460class DeformConvTester (OpTester , unittest .TestCase ):
461- def expected_fn (self , x , weight , offset , bias , stride = 1 , padding = 0 , dilation = 1 ):
461+ def expected_fn (self , x , weight , offset , mask , bias , stride = 1 , padding = 0 , dilation = 1 ):
462462 stride_h , stride_w = _pair (stride )
463463 pad_h , pad_w = _pair (padding )
464464 dil_h , dil_w = _pair (dilation )
@@ -489,12 +489,17 @@ def expected_fn(self, x, weight, offset, bias, stride=1, padding=0, dilation=1):
489489 c_in = weight_grp * in_c_per_weight_grp + c
490490
491491 offset_grp = c_in // in_c_per_offset_grp
492- offset_idx = 2 * (offset_grp * (weight_h * weight_w ) + di * weight_w + dj )
492+ mask_idx = offset_grp * (weight_h * weight_w ) + di * weight_w + dj
493+ offset_idx = 2 * mask_idx
493494
494495 pi = stride_h * i - pad_h + dil_h * di + offset [b , offset_idx , i , j ]
495496 pj = stride_w * j - pad_w + dil_w * dj + offset [b , offset_idx + 1 , i , j ]
496497
497- out [b , c_out , i , j ] += (weight [c_out , c , di , dj ] *
498+ mask_value = 1.0
499+ if mask is not None :
500+ mask_value = mask [b , mask_idx , i , j ]
501+
502+ out [b , c_out , i , j ] += (mask_value * weight [c_out , c , di , dj ] *
498503 bilinear_interpolate (x [b , c_in , :, :], pi , pj ))
499504 out += bias .view (1 , n_out_channels , 1 , 1 )
500505 return out
@@ -523,6 +528,9 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
523528 offset = torch .randn (batch_sz , n_offset_grps * 2 * weight_h * weight_w , out_h , out_w ,
524529 device = device , dtype = dtype , requires_grad = True )
525530
531+ mask = torch .randn (batch_sz , n_offset_grps * weight_h * weight_w , out_h , out_w ,
532+ device = device , dtype = dtype , requires_grad = True )
533+
526534 weight = torch .randn (n_out_channels , n_in_channels // n_weight_grps , weight_h , weight_w ,
527535 device = device , dtype = dtype , requires_grad = True )
528536
@@ -531,31 +539,39 @@ def get_fn_args(self, device, contiguous, batch_sz, dtype):
531539 if not contiguous :
532540 x = x .permute (0 , 1 , 3 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 )
533541 offset = offset .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
542+ mask = mask .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
534543 weight = weight .permute (3 , 2 , 0 , 1 ).contiguous ().permute (2 , 3 , 1 , 0 )
535544
536- return x , weight , offset , bias , stride , pad , dilation
545+ return x , weight , offset , mask , bias , stride , pad , dilation
537546
538547 def _test_forward (self , device , contiguous , dtype = None ):
539548 dtype = self .dtype if dtype is None else dtype
540549 for batch_sz in [0 , 33 ]:
541550 self ._test_forward_with_batchsize (device , contiguous , batch_sz , dtype )
542551
543552 def _test_forward_with_batchsize (self , device , contiguous , batch_sz , dtype ):
544- x , _ , offset , _ , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , dtype )
553+ x , _ , offset , mask , _ , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , dtype )
545554 in_channels = 6
546555 out_channels = 2
547556 kernel_size = (3 , 2 )
548557 groups = 2
558+ tol = 1e-3 if dtype is torch .half else 1e-5
549559
550560 layer = ops .DeformConv2d (in_channels , out_channels , kernel_size , stride = stride , padding = padding ,
551561 dilation = dilation , groups = groups ).to (device = x .device , dtype = dtype )
552- res = layer (x , offset )
562+ res = layer (x , offset , mask )
553563
554564 weight = layer .weight .data
555565 bias = layer .bias .data
556- expected = self .expected_fn (x , weight , offset , bias , stride = stride , padding = padding , dilation = dilation )
566+ expected = self .expected_fn (x , weight , offset , mask , bias , stride = stride , padding = padding , dilation = dilation )
567+
568+ self .assertTrue (torch .allclose (res .to (expected .dtype ), expected , rtol = tol , atol = tol ),
569+ '\n res:\n {}\n expected:\n {}' .format (res , expected ))
570+
571+ # no modulation test
572+ res = layer (x , offset )
573+ expected = self .expected_fn (x , weight , offset , None , bias , stride = stride , padding = padding , dilation = dilation )
557574
558- tol = 1e-3 if dtype is torch .half else 1e-5
559575 self .assertTrue (torch .allclose (res .to (expected .dtype ), expected , rtol = tol , atol = tol ),
560576 '\n res:\n {}\n expected:\n {}' .format (res , expected ))
561577
@@ -564,24 +580,45 @@ def _test_forward_with_batchsize(self, device, contiguous, batch_sz, dtype):
564580 wrong_offset = torch .rand_like (offset [:, :2 ])
565581 res = layer (x , wrong_offset )
566582
583+ with self .assertRaises (RuntimeError ):
584+ wrong_mask = torch .rand_like (mask [:, :2 ])
585+ res = layer (x , offset , wrong_mask )
586+
567587 def _test_backward (self , device , contiguous ):
568588 for batch_sz in [0 , 33 ]:
569589 self ._test_backward_with_batchsize (device , contiguous , batch_sz )
570590
571591 def _test_backward_with_batchsize (self , device , contiguous , batch_sz ):
572- x , weight , offset , bias , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , self .dtype )
592+ x , weight , offset , mask , bias , stride , padding , dilation = self .get_fn_args (device , contiguous , batch_sz , self .dtype )
593+
594+ def func (x_ , offset_ , mask_ , weight_ , bias_ ):
595+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride ,
596+ padding = padding , dilation = dilation , mask = mask_ )
573597
574- def func (x_ , offset_ , weight_ , bias_ ):
575- return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride , padding = padding , dilation = dilation )
598+ gradcheck (func , (x , offset , mask , weight , bias ), nondet_tol = 1e-5 )
599+
600+ def func_no_mask (x_ , offset_ , weight_ , bias_ ):
601+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride ,
602+ padding = padding , dilation = dilation , mask = None )
603+
604+ gradcheck (func_no_mask , (x , offset , weight , bias ), nondet_tol = 1e-5 )
605+
606+ @torch .jit .script
607+ def script_func (x_ , offset_ , mask_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
608+ # type:(Tensor, Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
609+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ ,
610+ padding = pad_ , dilation = dilation_ , mask = mask_ )
576611
577- gradcheck (func , (x , offset , weight , bias ), nondet_tol = 1e-5 )
612+ gradcheck (lambda z , off , msk , wei , bi : script_func (z , off , msk , wei , bi , stride , padding , dilation ),
613+ (x , offset , mask , weight , bias ), nondet_tol = 1e-5 )
578614
579615 @torch .jit .script
580- def script_func (x_ , offset_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
581- # type: (Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
582- return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ , padding = pad_ , dilation = dilation_ )
616+ def script_func_no_mask (x_ , offset_ , weight_ , bias_ , stride_ , pad_ , dilation_ ):
617+ # type:(Tensor, Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int])->Tensor
618+ return ops .deform_conv2d (x_ , offset_ , weight_ , bias_ , stride = stride_ ,
619+ padding = pad_ , dilation = dilation_ , mask = None )
583620
584- gradcheck (lambda z , off , wei , bi : script_func (z , off , wei , bi , stride , padding , dilation ),
621+ gradcheck (lambda z , off , wei , bi : script_func_no_mask (z , off , wei , bi , stride , padding , dilation ),
585622 (x , offset , weight , bias ), nondet_tol = 1e-5 )
586623
587624 # Test from https://github.com/pytorch/vision/issues/2598
@@ -593,17 +630,19 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
593630 init_weight = torch .randn (9 , 9 , 3 , 3 , requires_grad = True )
594631 img = torch .randn (8 , 9 , 1000 , 110 )
595632 offset = torch .rand (8 , 2 * 3 * 3 , 1000 , 110 )
633+ mask = torch .rand (8 , 3 * 3 , 1000 , 110 )
596634
597635 if not contiguous :
598636 img = img .permute (0 , 1 , 3 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 )
599637 offset = offset .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
638+ mask = mask .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
600639 weight = init_weight .permute (3 , 2 , 0 , 1 ).contiguous ().permute (2 , 3 , 1 , 0 )
601640 else :
602641 weight = init_weight
603642
604643 for d in ["cpu" , "cuda" ]:
605644
606- out = ops .deform_conv2d (img .to (d ), offset .to (d ), weight .to (d ), padding = 1 )
645+ out = ops .deform_conv2d (img .to (d ), offset .to (d ), weight .to (d ), padding = 1 , mask = mask . to ( d ) )
607646 out .mean ().backward ()
608647 if true_cpu_grads is None :
609648 true_cpu_grads = init_weight .grad
0 commit comments