@@ -25,9 +25,135 @@ PyMODINIT_FUNC PyInit__custom_ops(void) {
2525#endif
2626#endif
2727
28+ using torch::Tensor;
29+ using torch::autograd::AutogradContext;
30+ using torch::autograd::Variable;
31+ using torch::autograd::variable_list;
32+
33+ class ROIAlignFunction : public torch ::autograd::Function<ROIAlignFunction> {
34+ public:
35+ static variable_list forward (
36+ AutogradContext* ctx,
37+ Variable input,
38+ Variable rois,
39+ const double spatial_scale,
40+ const int64_t pooled_height,
41+ const int64_t pooled_width,
42+ const int64_t sampling_ratio) {
43+ ctx->saved_data [" spatial_scale" ] = spatial_scale;
44+ ctx->saved_data [" pooled_height" ] = pooled_height;
45+ ctx->saved_data [" pooled_width" ] = pooled_width;
46+ ctx->saved_data [" sampling_ratio" ] = sampling_ratio;
47+ ctx->saved_data [" input_shape" ] = input.sizes ();
48+ ctx->save_for_backward ({rois});
49+ auto result = ROIAlign_forward (
50+ input,
51+ rois,
52+ spatial_scale,
53+ pooled_height,
54+ pooled_width,
55+ sampling_ratio);
56+ return {result};
57+ }
58+
59+ static variable_list backward (
60+ AutogradContext* ctx,
61+ variable_list grad_output) {
62+ // Use data saved in forward
63+ auto saved = ctx->get_saved_variables ();
64+ auto rois = saved[0 ];
65+ auto input_shape = ctx->saved_data [" input_shape" ].toIntList ();
66+ auto grad_in = ROIAlign_backward (
67+ grad_output[0 ],
68+ rois,
69+ ctx->saved_data [" spatial_scale" ].toDouble (),
70+ ctx->saved_data [" pooled_height" ].toInt (),
71+ ctx->saved_data [" pooled_width" ].toInt (),
72+ input_shape[0 ],
73+ input_shape[1 ],
74+ input_shape[2 ],
75+ input_shape[3 ],
76+ ctx->saved_data [" sampling_ratio" ].toInt ());
77+ return {
78+ grad_in, Variable (), Variable (), Variable (), Variable (), Variable ()};
79+ }
80+ };
81+
82+ Tensor roi_align (
83+ const Tensor& input,
84+ const Tensor& rois,
85+ const double spatial_scale,
86+ const int64_t pooled_height,
87+ const int64_t pooled_width,
88+ const int64_t sampling_ratio) {
89+ return ROIAlignFunction::apply (
90+ input,
91+ rois,
92+ spatial_scale,
93+ pooled_height,
94+ pooled_width,
95+ sampling_ratio)[0 ];
96+ }
97+
98+ class ROIPoolFunction : public torch ::autograd::Function<ROIPoolFunction> {
99+ public:
100+ static variable_list forward (
101+ AutogradContext* ctx,
102+ Variable input,
103+ Variable rois,
104+ const double spatial_scale,
105+ const int64_t pooled_height,
106+ const int64_t pooled_width) {
107+ ctx->saved_data [" spatial_scale" ] = spatial_scale;
108+ ctx->saved_data [" pooled_height" ] = pooled_height;
109+ ctx->saved_data [" pooled_width" ] = pooled_width;
110+ ctx->saved_data [" input_shape" ] = input.sizes ();
111+ auto result = ROIPool_forward (
112+ input, rois, spatial_scale, pooled_height, pooled_width);
113+ auto output = std::get<0 >(result);
114+ auto argmax = std::get<1 >(result);
115+ ctx->save_for_backward ({rois, argmax});
116+ ctx->mark_non_differentiable ({argmax});
117+ return {output, argmax};
118+ }
119+
120+ static variable_list backward (
121+ AutogradContext* ctx,
122+ variable_list grad_output) {
123+ // Use data saved in forward
124+ auto saved = ctx->get_saved_variables ();
125+ auto rois = saved[0 ];
126+ auto argmax = saved[1 ];
127+ auto input_shape = ctx->saved_data [" input_shape" ].toIntList ();
128+ auto grad_in = ROIPool_backward (
129+ grad_output[0 ],
130+ rois,
131+ argmax,
132+ ctx->saved_data [" spatial_scale" ].toDouble (),
133+ ctx->saved_data [" pooled_height" ].toInt (),
134+ ctx->saved_data [" pooled_width" ].toInt (),
135+ input_shape[0 ],
136+ input_shape[1 ],
137+ input_shape[2 ],
138+ input_shape[3 ]);
139+ return {grad_in, Variable (), Variable (), Variable (), Variable ()};
140+ }
141+ };
142+
143+ std::tuple<Tensor, Tensor> roi_pool (
144+ const Tensor& input,
145+ const Tensor& rois,
146+ const double spatial_scale,
147+ const int64_t pooled_height,
148+ const int64_t pooled_width) {
149+ auto result = ROIPoolFunction::apply (
150+ input, rois, spatial_scale, pooled_height, pooled_width);
151+ return std::tuple<Tensor, Tensor>(result[0 ], result[1 ]);
152+ }
153+
28154static auto registry =
29155 torch::RegisterOperators ()
30156 .op(" torchvision::nms" , &nms)
31157 .op(" torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> Tensor" ,
32- &ROIAlign_forward )
33- .op(" torchvision::roi_pool" , &ROIPool_forward );
158+ &roi_align )
159+ .op(" torchvision::roi_pool" , &roi_pool );
0 commit comments