@@ -742,7 +742,7 @@ struct HelperInterpBase {
742742 }
743743
744744 template <typename scalar_t , typename aa_filter_fn_t >
745- static inline void _compute_weights_aa (
745+ static inline scalar_t _compute_weights_aa (
746746 const int64_t i, const int64_t input_size, const scalar_t scale, const scalar_t support,
747747 scalar_t * wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn,
748748 int64_t & xmin, int64_t & xsize, bool antialias, double align_corners_delta
@@ -764,14 +764,19 @@ struct HelperInterpBase {
764764 wt_ptr[j] = w;
765765 total_w += w;
766766 }
767- for (j = 0 ; j < xsize; j++) {
768- if (total_w != 0.0 ) {
767+
768+ scalar_t wt_max = 0.0 ;
769+ if (total_w != 0.0 ) {
770+ for (j = 0 ; j < xsize; j++) {
769771 wt_ptr[j] /= total_w;
772+ wt_max = std::max (wt_max, wt_ptr[j]);
770773 }
771774 }
775+
772776 for (; j < max_interp_size; j++) {
773777 wt_ptr[j] = static_cast <scalar_t >(0.0 );
774778 }
779+ return wt_max;
775780 }
776781
777782 // Note [ Support for antialias=False as a subcase of antilias=True ]
@@ -785,7 +790,7 @@ struct HelperInterpBase {
785790 // indices, but this can be optimized further when aa=False since we know
786791 // their actual dimensions.
787792 template <typename scalar_t , typename aa_filter_fn_t , int weight_index_stride=sizeof (scalar_t )>
788- static inline std::tuple<std::vector<Tensor>, int > _compute_indices_weights_aa (
793+ static inline std::tuple<std::vector<Tensor>, int , scalar_t > _compute_indices_weights_aa (
789794 int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
790795 int64_t reshape_dim, scalar_t scale,
791796 int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, double align_corners_delta
@@ -834,10 +839,10 @@ struct HelperInterpBase {
834839 scalar_t * wt_ptr = output[3 ].data_ptr <scalar_t >();
835840 int64_t * wt_idx_ptr = output[4 ].data_ptr <int64_t >();
836841
837- int64_t xmin, xmax;
838-
842+ scalar_t wt_max = 0.0 ;
839843 for (const auto i : c10::irange (output_size)) {
840- HelperInterpBase::_compute_weights_aa (
844+ int64_t xmin, xmax;
845+ auto wt_max_i = HelperInterpBase::_compute_weights_aa (
841846 i,
842847 input_size,
843848 scale,
@@ -850,12 +855,14 @@ struct HelperInterpBase {
850855 antialias,
851856 align_corners_delta);
852857
858+ wt_max = std::max (wt_max, wt_max_i);
859+
853860 idx_ptr_xmin[i] = xmin * stride;
854861 idx_ptr_size[i] = xmax;
855862 idx_ptr_stride[i] = stride;
856863 wt_idx_ptr[i] = i * max_interp_size * weight_index_stride;
857864 }
858- return {output, max_interp_size};
865+ return {output, max_interp_size, wt_max };
859866 }
860867
861868 /*
@@ -911,25 +918,17 @@ struct HelperInterpBase {
911918
912919 std::vector<Tensor> indices_weights;
913920 auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0 ;
914- std::tie (indices_weights, interp_size) = HelperInterpBase::_compute_indices_weights_aa<double , aa_filter_fn_t , sizeof (int16_t )>(
921+ double wt_max;
922+ std::tie (indices_weights, interp_size, wt_max) = HelperInterpBase::_compute_indices_weights_aa<double , aa_filter_fn_t , sizeof (int16_t )>(
915923 input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners_delta);
916924
917925 // Rescale float weights to int16 and compute weights precision
918926 auto weights_f64 = indices_weights[3 ];
919927 double * data_f64 = weights_f64.data_ptr <double >();
920- int64_t weights_f64_size = output_size * interp_size;
921- // can't use weights_f64.max() here as tensor is restrided
922- double w_max = data_f64[0 ];
923- for (const auto i : c10::irange (weights_f64_size)) {
924- double v = data_f64[i];
925- if (w_max < v) {
926- w_max = v;
927- }
928- }
929928
930929 unsigned int weights_precision = 0 ;
931- for (weights_precision = 0 ; weights_precision < 22 ; weights_precision += 1 ) {
932- int next_value = (int ) (0.5 + w_max * (1 << (weights_precision + 1 )));
930+ for (weights_precision = 0 ; weights_precision < 22 ; ++weights_precision ) {
931+ int next_value = (int ) (0.5 + wt_max * (1 << (weights_precision + 1 )));
933932 if (next_value >= (1 << 15 ))
934933 break ;
935934 }
@@ -939,8 +938,7 @@ struct HelperInterpBase {
939938 auto aligned_interp_size = interp_size;
940939
941940 if (align_i32) {
942- // We should respect int32 alignment as
943- // we will load data as int32 with AVX2
941+ // We should respect int32 alignment as we will load int16 data as int32
944942 // See ImagingResampleHorizontalConvolution8u4x, mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]);
945943 // compute aligned_interp_size = nearest pair value to interp_size
946944 while (aligned_interp_size % sizeof (int32_t ) != 0 ) {
@@ -952,20 +950,13 @@ struct HelperInterpBase {
952950
953951 for (const auto j : c10::irange (output_size)) {
954952 for (const auto k : c10::irange (interp_size)) {
955- double v = data_f64[j * interp_size + k];
956- if (v < 0 ) {
957- data_i16[j * aligned_interp_size + k] = (int ) (-0.5 + v * (1 << weights_precision));
958- } else {
959- data_i16[j * aligned_interp_size + k] = (int ) (0.5 + v * (1 << weights_precision));
960- }
953+ double v = data_f64[j * interp_size + k] * (1 << weights_precision);
954+ data_i16[j * aligned_interp_size + k] = (v < 0 ) ? (int ) (-0.5 + v) : (int ) (0.5 + v);
961955 }
962956 }
963957
964958 return {indices_weights, aligned_interp_size, weights_precision};
965959 }
966-
967-
968-
969960};
970961
971962struct HelperInterpNearest : public HelperInterpBase {
@@ -1175,8 +1166,9 @@ struct HelperInterpLinear : public HelperInterpBase {
11751166
11761167 auto interp_size = HelperInterpLinear::interp_size;
11771168 int unused;
1169+ scalar_t unused_2;
11781170
1179- std::tie (indices_weights, unused) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t >(
1171+ std::tie (indices_weights, unused, unused_2 ) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t >(
11801172 input_size,
11811173 output_size,
11821174 stride,
@@ -1307,8 +1299,9 @@ struct HelperInterpCubic : public HelperInterpBase {
13071299
13081300 auto interp_size = HelperInterpCubic::interp_size;
13091301 int unused;
1302+ scalar_t unused_2;
13101303
1311- std::tie (indices_weights, unused) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t >(
1304+ std::tie (indices_weights, unused, unused_2 ) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t >(
13121305 input_size,
13131306 output_size,
13141307 stride,
0 commit comments