Skip to content

Commit 30e1f70

Browse files
committed
Fix use_mask logic
1 parent 3f69c76 commit 30e1f70

File tree

6 files changed

+74
-67
lines changed

6 files changed

+74
-67
lines changed

torchvision/csrc/DeformConv.h

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ at::Tensor DeformConv2d_forward(
1919
const std::pair<int, int>& padding,
2020
const std::pair<int, int>& dilation,
2121
const int groups,
22-
const int offset_groups) {
22+
const int offset_groups,
23+
const bool use_mask) {
2324
if (input.is_cuda()) {
2425
#if defined(WITH_CUDA) || defined(WITH_HIP)
2526
return DeformConv2d_forward_cuda(
@@ -32,7 +33,8 @@ at::Tensor DeformConv2d_forward(
3233
padding,
3334
dilation,
3435
groups,
35-
offset_groups);
36+
offset_groups,
37+
use_mask);
3638
#else
3739
AT_ERROR("Not compiled with GPU support");
3840
#endif
@@ -47,7 +49,8 @@ at::Tensor DeformConv2d_forward(
4749
padding,
4850
dilation,
4951
groups,
50-
offset_groups);
52+
offset_groups,
53+
use_mask);
5154
}
5255

5356
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -62,7 +65,8 @@ DeformConv2d_backward(
6265
const std::pair<int, int>& padding,
6366
const std::pair<int, int>& dilation,
6467
const int groups,
65-
const int offset_groups) {
68+
const int offset_groups,
69+
const bool use_mask) {
6670
if (grad.is_cuda()) {
6771
#if defined(WITH_CUDA) || defined(WITH_HIP)
6872
return DeformConv2d_backward_cuda(
@@ -76,7 +80,8 @@ DeformConv2d_backward(
7680
padding,
7781
dilation,
7882
groups,
79-
offset_groups);
83+
offset_groups,
84+
use_mask);
8085
#else
8186
AT_ERROR("Not compiled with GPU support");
8287
#endif
@@ -92,7 +97,8 @@ DeformConv2d_backward(
9297
padding,
9398
dilation,
9499
groups,
95-
offset_groups);
100+
offset_groups,
101+
use_mask);
96102
}
97103

98104
class DeformConv2dFunction
@@ -112,7 +118,8 @@ class DeformConv2dFunction
112118
int64_t dilation_h,
113119
int64_t dilation_w,
114120
int64_t groups,
115-
int64_t offset_groups) {
121+
int64_t offset_groups,
122+
bool use_mask) {
116123
auto output = DeformConv2d_forward(
117124
input,
118125
weight,
@@ -123,7 +130,8 @@ class DeformConv2dFunction
123130
{pad_h, pad_w},
124131
{dilation_h, dilation_w},
125132
groups,
126-
offset_groups);
133+
offset_groups,
134+
use_mask);
127135

128136
ctx->save_for_backward({input, weight, offset, mask, bias});
129137
ctx->saved_data["stride_h"] = stride_h;
@@ -134,6 +142,7 @@ class DeformConv2dFunction
134142
ctx->saved_data["dilation_w"] = dilation_w;
135143
ctx->saved_data["groups"] = groups;
136144
ctx->saved_data["offset_groups"] = offset_groups;
145+
ctx->saved_data["use_mask"] = use_mask;
137146

138147
return {
139148
output,
@@ -158,6 +167,7 @@ class DeformConv2dFunction
158167
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
159168
auto groups = ctx->saved_data["groups"].toInt();
160169
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
170+
auto use_mask = ctx->saved_data["use_mask"].toBool();
161171

162172
auto grads = DeformConv2d_backward(
163173
grad_output[0],
@@ -170,7 +180,8 @@ class DeformConv2dFunction
170180
{pad_h, pad_w},
171181
{dilation_h, dilation_w},
172182
groups,
173-
offset_groups);
183+
offset_groups,
184+
use_mask);
174185
auto grad_input = std::get<0>(grads);
175186
auto grad_weight = std::get<1>(grads);
176187
auto grad_offset = std::get<2>(grads);
@@ -191,6 +202,7 @@ class DeformConv2dFunction
191202
torch::autograd::Variable(),
192203
torch::autograd::Variable(),
193204
torch::autograd::Variable(),
205+
torch::autograd::Variable(),
194206
};
195207
}
196208
};
@@ -208,7 +220,8 @@ at::Tensor deform_conv2d(
208220
int64_t dilation_h,
209221
int64_t dilation_w,
210222
int64_t groups,
211-
int64_t offset_groups) {
223+
int64_t offset_groups,
224+
bool use_mask) {
212225
auto result = DeformConv2dFunction::apply(
213226
input,
214227
weight,
@@ -222,6 +235,7 @@ at::Tensor deform_conv2d(
222235
dilation_h,
223236
dilation_w,
224237
groups,
225-
offset_groups);
238+
offset_groups,
239+
use_mask);
226240
return result[0];
227241
}

torchvision/csrc/cpu/DeformConv_cpu.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ static void deformable_im2col(
211211
int out_w,
212212
int parallel_imgs,
213213
int deformable_group,
214+
bool use_mask,
214215
at::Tensor data_col) {
215-
bool use_mask = data_mask.numel() != 0;
216216
int num_kernels = n_in_channels * out_h * out_w * parallel_imgs;
217217

218218
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
@@ -261,14 +261,13 @@ at::Tensor DeformConv2d_forward_cpu(
261261
std::pair<int, int> pad,
262262
std::pair<int, int> dilation,
263263
int n_weight_grps,
264-
int n_offset_grps) {
264+
int n_offset_grps,
265+
bool use_mask) {
265266
at::Tensor input = input_param;
266267
at::Tensor offset = offset_param;
267268
at::Tensor mask = mask_param;
268269
at::Tensor weight = weight_param;
269270

270-
bool use_mask = mask.numel() != 0;
271-
272271
TORCH_CHECK(input.ndimension() == 4);
273272
TORCH_CHECK(offset.ndimension() == 4);
274273
TORCH_CHECK(!use_mask || mask.ndimension() == 4);
@@ -442,6 +441,7 @@ at::Tensor DeformConv2d_forward_cpu(
442441
out_w,
443442
n_parallel_imgs,
444443
n_offset_grps,
444+
use_mask,
445445
columns);
446446

447447
columns = columns.view(
@@ -561,9 +561,8 @@ static void compute_grad_input(
561561
const int dilation_w,
562562
const int parallel_imgs,
563563
const int n_offset_grps,
564+
const bool use_mask,
564565
at::Tensor grad_im) {
565-
bool use_mask = mask.numel() != 0;
566-
567566
int out_h =
568567
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
569568
int out_w =
@@ -762,10 +761,9 @@ static void compute_grad_offset_and_mask(
762761
const int dilation_w,
763762
const int parallel_imgs,
764763
const int n_offset_grps,
764+
const bool use_mask,
765765
at::Tensor grad_offset,
766766
at::Tensor grad_mask) {
767-
bool use_mask = mask.numel() != 0;
768-
769767
int out_h =
770768
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
771769
int out_w =
@@ -815,9 +813,8 @@ deform_conv2d_backward_input_cpu(
815813
std::pair<int, int> dilation,
816814
int n_weight_grps,
817815
int n_offset_grps,
818-
int n_parallel_imgs) {
819-
bool use_mask = mask.numel() != 0;
820-
816+
int n_parallel_imgs,
817+
bool use_mask) {
821818
int batch_sz = input.size(0);
822819
int n_in_channels = input.size(1);
823820
int in_h = input.size(2);
@@ -927,6 +924,7 @@ deform_conv2d_backward_input_cpu(
927924
dil_w,
928925
n_parallel_imgs,
929926
n_offset_grps,
927+
use_mask,
930928
grad_offset[elt],
931929
grad_mask[elt]);
932930

@@ -947,6 +945,7 @@ deform_conv2d_backward_input_cpu(
947945
dil_w,
948946
n_parallel_imgs,
949947
n_offset_grps,
948+
use_mask,
950949
grad_input[elt]);
951950
}
952951

@@ -973,9 +972,8 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
973972
std::pair<int, int> dilation,
974973
int n_weight_grps,
975974
int n_offset_grps,
976-
int n_parallel_imgs) {
977-
bool use_mask = mask.numel() != 0;
978-
975+
int n_parallel_imgs,
976+
bool use_mask) {
979977
int batch_sz = input.size(0);
980978
int n_in_channels = input.size(1);
981979
int in_h = input.size(2);
@@ -1063,6 +1061,7 @@ static at::Tensor deform_conv2d_backward_parameters_cpu(
10631061
out_w,
10641062
n_parallel_imgs,
10651063
n_offset_grps,
1064+
use_mask,
10661065
columns);
10671066

10681067
for (int g = 0; g < n_weight_grps; g++) {
@@ -1094,7 +1093,8 @@ DeformConv2d_backward_cpu(
10941093
std::pair<int, int> pad,
10951094
std::pair<int, int> dilation,
10961095
int n_weight_grps,
1097-
int n_offset_grps) {
1096+
int n_offset_grps,
1097+
bool use_mask) {
10981098
const int batch_sz = input.size(0);
10991099
const int n_parallel_imgs =
11001100
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
@@ -1110,7 +1110,8 @@ DeformConv2d_backward_cpu(
11101110
dilation,
11111111
n_weight_grps,
11121112
n_offset_grps,
1113-
n_parallel_imgs);
1113+
n_parallel_imgs,
1114+
use_mask);
11141115

11151116
auto grad_input = std::get<0>(grad_input_and_offset_and_mask);
11161117
auto grad_offset = std::get<1>(grad_input_and_offset_and_mask);
@@ -1127,7 +1128,8 @@ DeformConv2d_backward_cpu(
11271128
dilation,
11281129
n_weight_grps,
11291130
n_offset_grps,
1130-
n_parallel_imgs);
1131+
n_parallel_imgs,
1132+
use_mask);
11311133

11321134
auto grad_bias = at::ones_like(bias) * grad_out.sum({0, 2, 3});
11331135

torchvision/csrc/cpu/vision_cpu.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ VISION_API at::Tensor DeformConv2d_forward_cpu(
107107
std::pair<int, int> pad,
108108
std::pair<int, int> dilation,
109109
int groups,
110-
int deformable_groups);
110+
int deformable_groups,
111+
bool use_mask);
111112

112113
VISION_API std::
113114
tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@@ -122,4 +123,5 @@ VISION_API std::
122123
std::pair<int, int> pad,
123124
std::pair<int, int> dilation,
124125
int groups,
125-
int deformable_groups);
126+
int deformable_groups,
127+
bool use_mask);

0 commit comments

Comments
 (0)