11from __future__ import division
2+ import math
3+ from typing import Tuple
4+ import unittest
5+
26import numpy as np
7+
38import torch
49from torch .autograd import gradcheck
5-
10+ from torch .nn .modules .utils import _pair
11+ from torch import Tensor
612from torchvision import ops
713
8- from itertools import product
9- import unittest
10-
1114
12- class RoIOpTester (object ):
15+ class OpTester (object ):
1316 @classmethod
1417 def setUpClass (cls ):
1518 cls .dtype = torch .float64
@@ -42,6 +45,14 @@ def test_backward_cuda_contiguous(self):
4245 def test_backward_cuda_non_contiguous (self ):
4346 self ._test_backward (device = torch .device ('cuda' ), contiguous = False )
4447
48+ def _test_forward (self , device , contiguous ):
49+ pass
50+
51+ def _test_backward (self , device , contiguous ):
52+ pass
53+
54+
55+ class RoIOpTester (OpTester ):
4556 def _test_forward (self , device , contiguous ):
4657 pool_size = 5
4758 # n_channels % (pool_size ** 2) == 0 required for PS opeartions.
@@ -79,7 +90,6 @@ def func(z):
7990
8091 self .assertTrue (gradcheck (func , (x ,)))
8192 self .assertTrue (gradcheck (script_func , (x ,)))
82- return
8393
8494 def fn (* args , ** kwargs ):
8595 pass
@@ -98,7 +108,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
98108 def get_script_fn (self , rois , pool_size ):
99109 @torch .jit .script
100110 def script_fn (input , rois , pool_size ):
101- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
111+ # type: (Tensor, Tensor, int) -> Tensor
102112 return ops .roi_pool (input , rois , pool_size , 1.0 )[0 ]
103113 return lambda x : script_fn (x , rois , pool_size )
104114
@@ -137,7 +147,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
137147 def get_script_fn (self , rois , pool_size ):
138148 @torch .jit .script
139149 def script_fn (input , rois , pool_size ):
140- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
150+ # type: (Tensor, Tensor, int) -> Tensor
141151 return ops .ps_roi_pool (input , rois , pool_size , 1.0 )[0 ]
142152 return lambda x : script_fn (x , rois , pool_size )
143153
@@ -174,29 +184,35 @@ def get_slice(k, block):
174184 return y
175185
176186
177- def bilinear_interpolate (data , height , width , y , x ):
178- if y < - 1.0 or y > height or x < - 1.0 or x > width :
179- return 0.
187+ def bilinear_interpolate (data , y , x , snap_border = False ):
188+ height , width = data .shape
180189
181- y = min (max (0 , y ), height - 1 )
182- x = min (max (0 , x ), width - 1 )
190+ if snap_border :
191+ if - 1 < y <= 0 :
192+ y = 0
193+ elif height - 1 <= y < height :
194+ y = height - 1
183195
184- y_low = int (y )
185- y_high = min (y_low + 1 , height - 1 )
196+ if - 1 < x <= 0 :
197+ x = 0
198+ elif width - 1 <= x < width :
199+ x = width - 1
186200
187- x_low = int (x )
188- x_high = min (x_low + 1 , width - 1 )
201+ y_low = int (math .floor (y ))
202+ x_low = int (math .floor (x ))
203+ y_high = y_low + 1
204+ x_high = x_low + 1
189205
190206 wy_h = y - y_low
191- wy_l = 1 - wy_h
192-
193207 wx_h = x - x_low
208+ wy_l = 1 - wy_h
194209 wx_l = 1 - wx_h
195210
196211 val = 0
197- for wx , x in zip ((wx_l , wx_h ), (x_low , x_high )):
198- for wy , y in zip ((wy_l , wy_h ), (y_low , y_high )):
199- val += wx * wy * data [y * width + x ]
212+ for wx , xp in zip ((wx_l , wx_h ), (x_low , x_high )):
213+ for wy , yp in zip ((wy_l , wy_h ), (y_low , y_high )):
214+ if 0 <= yp < height and 0 <= xp < width :
215+ val += wx * wy * data [yp , xp ]
200216 return val
201217
202218
@@ -208,7 +224,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
208224 def get_script_fn (self , rois , pool_size ):
209225 @torch .jit .script
210226 def script_fn (input , rois , pool_size ):
211- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
227+ # type: (Tensor, Tensor, int) -> Tensor
212228 return ops .roi_align (input , rois , pool_size , 1.0 )[0 ]
213229 return lambda x : script_fn (x , rois , pool_size )
214230
@@ -242,12 +258,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
242258 y = start_h + (iy + 0.5 ) * bin_h / grid_h
243259 for ix in range (0 , grid_w ):
244260 x = start_w + (ix + 0.5 ) * bin_w / grid_w
245- val += bilinear_interpolate (
246- in_data [batch_idx , channel , :, :].flatten (),
247- in_data .size (- 2 ),
248- in_data .size (- 1 ),
249- y , x
250- )
261+ val += bilinear_interpolate (in_data [batch_idx , channel , :, :], y , x , snap_border = True )
251262 val /= grid_h * grid_w
252263
253264 out_data [r , channel , i , j ] = val
@@ -262,7 +273,7 @@ def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwar
262273 def get_script_fn (self , rois , pool_size ):
263274 @torch .jit .script
264275 def script_fn (input , rois , pool_size ):
265- # type: (torch. Tensor, torch. Tensor, int) -> torch. Tensor
276+ # type: (Tensor, Tensor, int) -> Tensor
266277 return ops .ps_roi_align (input , rois , pool_size , 1.0 )[0 ]
267278 return lambda x : script_fn (x , rois , pool_size )
268279
@@ -298,12 +309,7 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, device, spatial_scale=1,
298309 y = start_h + (iy + 0.5 ) * bin_h / grid_h
299310 for ix in range (0 , grid_w ):
300311 x = start_w + (ix + 0.5 ) * bin_w / grid_w
301- val += bilinear_interpolate (
302- in_data [batch_idx , c_in , :, :].flatten (),
303- in_data .size (- 2 ),
304- in_data .size (- 1 ),
305- y , x
306- )
312+ val += bilinear_interpolate (in_data [batch_idx , c_in , :, :], y , x , snap_border = True )
307313 val /= grid_h * grid_w
308314
309315 out_data [r , c_out , i , j ] = val
@@ -367,5 +373,106 @@ def test_nms_cuda(self):
367373 self .assertTrue (torch .allclose (r_cpu , r_cuda .cpu ()), err_msg .format (iou ))
368374
369375
376+ class DeformConvTester (OpTester , unittest .TestCase ):
377+ def expected_fn (self , x , offsets , weights , * args , stride = 1 , pad = 0 , dilation = 1 ):
378+ stride_h , stride_w = _pair (stride )
379+ pad_h , pad_w = _pair (pad )
380+ dil_h , dil_w = _pair (dilation )
381+ weights_h , weights_w = weights .shape [- 2 :]
382+
383+ n_batches , n_in_channels , in_h , in_w = x .shape
384+ n_out_channels = weights .shape [0 ]
385+
386+ out_h = (in_h + 2 * pad_h - (dil_h * (weights_h - 1 ) + 1 )) // stride_h + 1
387+ out_w = (in_w + 2 * pad_w - (dil_w * (weights_w - 1 ) + 1 )) // stride_w + 1
388+
389+ n_offset_grps = offsets .shape [1 ] // (2 * weights_h * weights_w )
390+ in_c_per_offset_grp = n_in_channels // n_offset_grps
391+
392+ n_weight_grps = n_in_channels // weights .shape [1 ]
393+ in_c_per_weight_grp = weights .shape [1 ]
394+ out_c_per_weight_grp = n_out_channels // n_weight_grps
395+
396+ out = torch .zeros (n_batches , n_out_channels , out_h , out_w , device = x .device , dtype = x .dtype )
397+ for b in range (n_batches ):
398+ for c_out in range (n_out_channels ):
399+ for i in range (out_h ):
400+ for j in range (out_w ):
401+ for di in range (weights_h ):
402+ for dj in range (weights_w ):
403+ for c in range (in_c_per_weight_grp ):
404+ weight_grp = c_out // out_c_per_weight_grp
405+ c_in = weight_grp * in_c_per_weight_grp + c
406+
407+ offset_grp = c_in // in_c_per_offset_grp
408+ offset_idx = 2 * (offset_grp * (weights_h * weights_w ) + di * weights_w + dj )
409+
410+ pi = stride_h * i - pad_h + dil_h * di + offsets [b , offset_idx , i , j ]
411+ pj = stride_w * j - pad_w + dil_w * dj + offsets [b , offset_idx + 1 , i , j ]
412+
413+ out [b , c_out , i , j ] += (weights [c_out , c , di , dj ] *
414+ bilinear_interpolate (x [b , c_in , :, :], pi , pj ))
415+ return out
416+
417+ def get_fn_args (self , device , contiguous ):
418+ batch_sz = 1
419+ n_in_channels = 6
420+ n_out_channels = 2
421+ n_weight_grps = 2
422+ n_offset_grps = 3
423+
424+ stride = (2 , 1 )
425+ pad = (1 , 0 )
426+ dilation = (2 , 1 )
427+
428+ stride_h , stride_w = stride
429+ pad_h , pad_w = pad
430+ dil_h , dil_w = dilation
431+ weight_h , weight_w = (3 , 2 )
432+ in_h , in_w = (5 , 4 )
433+
434+ out_h = (in_h + 2 * pad_h - (dil_h * (weight_h - 1 ) + 1 )) // stride_h + 1
435+ out_w = (in_w + 2 * pad_w - (dil_w * (weight_w - 1 ) + 1 )) // stride_w + 1
436+
437+ x = torch .rand (batch_sz , n_in_channels , in_h , in_w , device = device , dtype = self .dtype , requires_grad = True )
438+
439+ offset = torch .randn (batch_sz , n_offset_grps * 2 * weight_h * weight_w , out_h , out_w ,
440+ device = device , dtype = self .dtype , requires_grad = True )
441+
442+ weight = torch .randn (n_out_channels , n_in_channels // n_weight_grps , weight_h , weight_w ,
443+ device = device , dtype = self .dtype , requires_grad = True )
444+
445+ if not contiguous :
446+ x = x .permute (0 , 1 , 3 , 2 ).contiguous ().permute (0 , 1 , 3 , 2 )
447+ offset = offset .permute (1 , 3 , 0 , 2 ).contiguous ().permute (2 , 0 , 3 , 1 )
448+ weight = weight .permute (3 , 2 , 0 , 1 ).contiguous ().permute (2 , 3 , 1 , 0 )
449+
450+ return x , offset , weight , stride , pad , dilation
451+
452+ def _test_forward (self , device , contiguous ):
453+ x , offset , weight , stride , pad , dilation = self .get_fn_args (device , contiguous )
454+
455+ res = ops .DeformConv (stride = stride , pad = pad , dilation = dilation )(x , offset , weight )
456+ expected = self .expected_fn (x , offset , weight , stride = stride , pad = pad , dilation = dilation )
457+
458+ self .assertTrue (torch .allclose (res , expected ), '\n res:\n {}\n expected:\n {}' .format (x , res , expected ))
459+
460+ def _test_backward (self , device , contiguous ):
461+ x , offset , weight , stride , pad , dilation = self .get_fn_args (device , contiguous )
462+
463+ def func (x_ , offset_ , weight_ ):
464+ return ops .deform_conv (x_ , offset_ , weight_ , stride = stride , pad = pad , dilation = dilation )
465+
466+ gradcheck (func , (x , offset , weight ), nondet_tol = 1e-5 )
467+
468+ @torch .jit .script
469+ def script_func (x_ , offset_ , weight_ , stride_ , pad_ , dilation_ ):
470+ # type: (Tensor, Tensor, Tensor, Tuple[int, int], Tuple[int, int], Tuple[int, int]) -> Tensor
471+ return ops .deform_conv (x_ , offset_ , weight_ , stride = stride_ , pad = pad_ , dilation = dilation_ )
472+
473+ gradcheck (lambda z , off , wei : script_func (z , off , wei , stride , pad , dilation ),
474+ (x , offset , weight ), nondet_tol = 1e-5 )
475+
476+
370477if __name__ == '__main__' :
371478 unittest .main ()
0 commit comments