77#endif
88
99at::Tensor DeformConv2d_forward (
10- const Tensor& input,
11- const Tensor& weight,
12- const Tensor& offset,
13- const Tensor& bias,
10+ const at:: Tensor& input,
11+ const at:: Tensor& weight,
12+ const at:: Tensor& offset,
13+ const at:: Tensor& bias,
1414 const std::pair<int , int >& stride,
1515 const std::pair<int , int >& padding,
1616 const std::pair<int , int >& dilation,
17- const int groups, const int offset_groups) {
17+ const int groups,
18+ const int offset_groups) {
1819 if (input.type ().is_cuda ()) {
1920#ifdef WITH_CUDA
20- return DeformConv2d_forward_cuda (input.contiguous (), weight.contiguous (), offset.contiguous (),
21- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
21+ return DeformConv2d_forward_cuda (
22+ input.contiguous (),
23+ weight.contiguous (),
24+ offset.contiguous (),
25+ bias.contiguous (),
26+ stride,
27+ padding,
28+ dilation,
29+ groups,
30+ offset_groups);
2231#else
2332 AT_ERROR (" Not compiled with GPU support" );
2433#endif
2534 }
26- return DeformConv2d_forward_cpu (input.contiguous (), weight.contiguous (), offset.contiguous (),
27- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
35+ return DeformConv2d_forward_cpu (
36+ input.contiguous (),
37+ weight.contiguous (),
38+ offset.contiguous (),
39+ bias.contiguous (),
40+ stride,
41+ padding,
42+ dilation,
43+ groups,
44+ offset_groups);
2845}
2946
3047std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> DeformConv2d_backward (
3148 const at::Tensor& grad,
32- const Tensor& input,
33- const Tensor& weight,
34- const Tensor& offset,
35- const Tensor& bias,
49+ const at:: Tensor& input,
50+ const at:: Tensor& weight,
51+ const at:: Tensor& offset,
52+ const at:: Tensor& bias,
3653 const std::pair<int , int >& stride,
3754 const std::pair<int , int >& padding,
3855 const std::pair<int , int >& dilation,
3956 const int groups,
4057 const int offset_groups) {
4158 if (grad.type ().is_cuda ()) {
4259#ifdef WITH_CUDA
43- return DeformConv2d_backward_cuda (grad.contiguous (), input.contiguous (), weight.contiguous (), offset.contiguous (),
44- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
60+ return DeformConv2d_backward_cuda (
61+ grad.contiguous (),
62+ input.contiguous (),
63+ weight.contiguous (),
64+ offset.contiguous (),
65+ bias.contiguous (),
66+ stride,
67+ padding,
68+ dilation,
69+ groups,
70+ offset_groups);
4571#else
4672 AT_ERROR (" Not compiled with GPU support" );
4773#endif
4874 }
49- return DeformConv2d_backward_cpu (grad.contiguous (), input.contiguous (), weight.contiguous (), offset.contiguous (),
50- bias.contiguous (), stride, padding, dilation, groups, offset_groups);
75+ return DeformConv2d_backward_cpu (
76+ grad.contiguous (),
77+ input.contiguous (),
78+ weight.contiguous (),
79+ offset.contiguous (),
80+ bias.contiguous (),
81+ stride,
82+ padding,
83+ dilation,
84+ groups,
85+ offset_groups);
5186}
5287
5388using namespace at ;
@@ -56,25 +91,33 @@ using torch::autograd::AutogradContext;
5691using torch::autograd::Variable;
5792using torch::autograd::variable_list;
5893
59- class DeformConv2dFunction : public torch ::autograd::Function<DeformConv2dFunction> {
94+ class DeformConv2dFunction
95+ : public torch::autograd::Function<DeformConv2dFunction> {
6096 public:
6197 static variable_list forward (
6298 AutogradContext* ctx,
6399 Variable input,
64100 Variable weight,
65101 Variable offset,
66102 Variable bias,
67- int64_t stride_h, int64_t stride_w,
68- int64_t pad_h, int64_t pad_w,
69- int64_t dilation_h, int64_t dilation_w,
103+ int64_t stride_h,
104+ int64_t stride_w,
105+ int64_t pad_h,
106+ int64_t pad_w,
107+ int64_t dilation_h,
108+ int64_t dilation_w,
70109 int64_t groups,
71110 int64_t offset_groups) {
72111 auto output = DeformConv2d_forward (
73- input, weight, offset, bias,
112+ input,
113+ weight,
114+ offset,
115+ bias,
74116 {stride_h, stride_w},
75117 {pad_h, pad_w},
76118 {dilation_h, dilation_w},
77- groups, offset_groups);
119+ groups,
120+ offset_groups);
78121
79122 ctx->save_for_backward ({input, weight, offset, bias});
80123 ctx->saved_data [" stride_h" ] = stride_h;
@@ -86,7 +129,9 @@ class DeformConv2dFunction : public torch::autograd::Function<DeformConv2dFuncti
86129 ctx->saved_data [" groups" ] = groups;
87130 ctx->saved_data [" offset_groups" ] = offset_groups;
88131
89- return {output,};
132+ return {
133+ output,
134+ };
90135 }
91136
92137 static variable_list backward (
@@ -107,34 +152,64 @@ class DeformConv2dFunction : public torch::autograd::Function<DeformConv2dFuncti
107152 auto groups = ctx->saved_data [" groups" ].toInt ();
108153 auto offset_groups = ctx->saved_data [" offset_groups" ].toInt ();
109154
110- auto grads = DeformConv2d_backward (grad_output[0 ],
111- input, weight, offset, bias,
155+ auto grads = DeformConv2d_backward (
156+ grad_output[0 ],
157+ input,
158+ weight,
159+ offset,
160+ bias,
112161 {stride_h, stride_w},
113162 {pad_h, pad_w},
114163 {dilation_h, dilation_w},
115- groups, offset_groups);
164+ groups,
165+ offset_groups);
116166 auto grad_input = std::get<0 >(grads);
117167 auto grad_weight = std::get<1 >(grads);
118168 auto grad_offset = std::get<2 >(grads);
119169 auto grad_bias = std::get<3 >(grads);
120170
121- return {grad_input, grad_weight, grad_offset,
122- grad_bias, Variable (), Variable (),
123- Variable (), Variable (), Variable (),
124- Variable (), Variable (), Variable (),};
171+ return {
172+ grad_input,
173+ grad_weight,
174+ grad_offset,
175+ grad_bias,
176+ Variable (),
177+ Variable (),
178+ Variable (),
179+ Variable (),
180+ Variable (),
181+ Variable (),
182+ Variable (),
183+ Variable (),
184+ };
125185 }
126186};
127187
128- Tensor deform_conv2d (
129- const Tensor& input,
130- const Tensor& weight,
131- const Tensor& offset,
132- const Tensor& bias,
133- int64_t stride_h, int64_t stride_w,
134- int64_t pad_h, int64_t pad_w,
135- int64_t dilation_h, int64_t dilation_w,
136- int64_t groups, int64_t offset_groups) {
137- auto result = DeformConv2dFunction::apply (input, weight, offset, bias, stride_h, stride_w, pad_h, pad_w,
138- dilation_h, dilation_w, groups, offset_groups);
188+ at::Tensor deform_conv2d (
189+ const at::Tensor& input,
190+ const at::Tensor& weight,
191+ const at::Tensor& offset,
192+ const at::Tensor& bias,
193+ int64_t stride_h,
194+ int64_t stride_w,
195+ int64_t pad_h,
196+ int64_t pad_w,
197+ int64_t dilation_h,
198+ int64_t dilation_w,
199+ int64_t groups,
200+ int64_t offset_groups) {
201+ auto result = DeformConv2dFunction::apply (
202+ input,
203+ weight,
204+ offset,
205+ bias,
206+ stride_h,
207+ stride_w,
208+ pad_h,
209+ pad_w,
210+ dilation_h,
211+ dilation_w,
212+ groups,
213+ offset_groups);
139214 return result[0 ];
140215}
0 commit comments