Skip to content

Commit 0798f54

Browse files
committed
Code update for vectorized interpolate cpu uint8
- code style update - use idx_ptr_xmin/idx_ptr_size instead of bounds - compute wt_max inside _compute_indices_weights_aa (no significant overhead) - added comments and explanations - renamed xmin/xmax into ids_min, ids_size ghstack-source-id: 8761e1f Pull Request resolved: pytorch#96847
1 parent 397fb27 commit 0798f54

File tree

2 files changed

+583
-363
lines changed

2 files changed

+583
-363
lines changed

aten/src/ATen/native/cpu/UpSampleKernel.cpp

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ struct HelperInterpBase {
742742
}
743743

744744
template <typename scalar_t, typename aa_filter_fn_t>
745-
static inline void _compute_weights_aa(
745+
static inline scalar_t _compute_weights_aa(
746746
const int64_t i, const int64_t input_size, const scalar_t scale, const scalar_t support,
747747
scalar_t* wt_ptr, const int64_t max_interp_size, aa_filter_fn_t filter_fn,
748748
int64_t& xmin, int64_t& xsize, bool antialias, double align_corners_delta
@@ -764,14 +764,19 @@ struct HelperInterpBase {
764764
wt_ptr[j] = w;
765765
total_w += w;
766766
}
767-
for (j = 0; j < xsize; j++) {
768-
if (total_w != 0.0) {
767+
768+
scalar_t wt_max = 0.0;
769+
if (total_w != 0.0) {
770+
for (j = 0; j < xsize; j++) {
769771
wt_ptr[j] /= total_w;
772+
wt_max = std::max(wt_max, wt_ptr[j]);
770773
}
771774
}
775+
772776
for (; j < max_interp_size; j++) {
773777
wt_ptr[j] = static_cast<scalar_t>(0.0);
774778
}
779+
return wt_max;
775780
}
776781

777782
// Note [ Support for antialias=False as a subcase of antilias=True ]
@@ -785,7 +790,7 @@ struct HelperInterpBase {
785790
// indices, but this can be optimized further when aa=False since we know
786791
// their actual dimensions.
787792
template <typename scalar_t, typename aa_filter_fn_t, int weight_index_stride=sizeof(scalar_t)>
788-
static inline std::tuple<std::vector<Tensor>, int> _compute_indices_weights_aa(
793+
static inline std::tuple<std::vector<Tensor>, int, scalar_t> _compute_indices_weights_aa(
789794
int64_t input_size, int64_t output_size, int64_t stride, int64_t ndims,
790795
int64_t reshape_dim, scalar_t scale,
791796
int interp_size, aa_filter_fn_t aa_filter_fn, bool antialias, double align_corners_delta
@@ -834,10 +839,10 @@ struct HelperInterpBase {
834839
scalar_t* wt_ptr = output[3].data_ptr<scalar_t>();
835840
int64_t* wt_idx_ptr = output[4].data_ptr<int64_t>();
836841

837-
int64_t xmin, xmax;
838-
842+
scalar_t wt_max = 0.0;
839843
for (const auto i : c10::irange(output_size)) {
840-
HelperInterpBase::_compute_weights_aa(
844+
int64_t xmin, xmax;
845+
auto wt_max_i = HelperInterpBase::_compute_weights_aa(
841846
i,
842847
input_size,
843848
scale,
@@ -850,12 +855,14 @@ struct HelperInterpBase {
850855
antialias,
851856
align_corners_delta);
852857

858+
wt_max = std::max(wt_max, wt_max_i);
859+
853860
idx_ptr_xmin[i] = xmin * stride;
854861
idx_ptr_size[i] = xmax;
855862
idx_ptr_stride[i] = stride;
856863
wt_idx_ptr[i] = i * max_interp_size * weight_index_stride;
857864
}
858-
return {output, max_interp_size};
865+
return {output, max_interp_size, wt_max};
859866
}
860867

861868
/*
@@ -911,25 +918,17 @@ struct HelperInterpBase {
911918

912919
std::vector<Tensor> indices_weights;
913920
auto align_corners_delta = (align_corners && !antialias) ? 0.5 : 0.0;
914-
std::tie(indices_weights, interp_size) = HelperInterpBase::_compute_indices_weights_aa<double, aa_filter_fn_t, sizeof(int16_t)>(
921+
double wt_max;
922+
std::tie(indices_weights, interp_size, wt_max) = HelperInterpBase::_compute_indices_weights_aa<double, aa_filter_fn_t, sizeof(int16_t)>(
915923
input_size, output_size, stride, ndims, reshape_dim, scale, interp_size, aa_filter_fn, antialias, align_corners_delta);
916924

917925
// Rescale float weights to int16 and compute weights precision
918926
auto weights_f64 = indices_weights[3];
919927
double * data_f64 = weights_f64.data_ptr<double>();
920-
int64_t weights_f64_size = output_size * interp_size;
921-
// can't use weights_f64.max() here as tensor is restrided
922-
double w_max = data_f64[0];
923-
for (const auto i : c10::irange(weights_f64_size)) {
924-
double v = data_f64[i];
925-
if (w_max < v) {
926-
w_max = v;
927-
}
928-
}
929928

930929
unsigned int weights_precision = 0;
931-
for (weights_precision = 0; weights_precision < 22; weights_precision += 1) {
932-
int next_value = (int) (0.5 + w_max * (1 << (weights_precision + 1)));
930+
for (weights_precision = 0; weights_precision < 22; ++weights_precision) {
931+
int next_value = (int) (0.5 + wt_max * (1 << (weights_precision + 1)));
933932
if (next_value >= (1 << 15))
934933
break;
935934
}
@@ -939,8 +938,7 @@ struct HelperInterpBase {
939938
auto aligned_interp_size = interp_size;
940939

941940
if (align_i32) {
942-
// We should respect int32 alignment as
943-
// we will load data as int32 with AVX2
941+
// We should respect int32 alignment as we will load int16 data as int32
944942
// See ImagingResampleHorizontalConvolution8u4x, mmk0 = _mm256_set1_epi32(*(int32_t*)&k[x]);
945943
// compute aligned_interp_size = nearest pair value to interp_size
946944
while (aligned_interp_size % sizeof(int32_t) != 0) {
@@ -952,20 +950,13 @@ struct HelperInterpBase {
952950

953951
for (const auto j : c10::irange(output_size)) {
954952
for (const auto k : c10::irange(interp_size)) {
955-
double v = data_f64[j * interp_size + k];
956-
if (v < 0) {
957-
data_i16[j * aligned_interp_size + k] = (int) (-0.5 + v * (1 << weights_precision));
958-
} else {
959-
data_i16[j * aligned_interp_size + k] = (int) (0.5 + v * (1 << weights_precision));
960-
}
953+
double v = data_f64[j * interp_size + k] * (1 << weights_precision);
954+
data_i16[j * aligned_interp_size + k] = (v < 0) ? (int) (-0.5 + v) : (int) (0.5 + v);
961955
}
962956
}
963957

964958
return {indices_weights, aligned_interp_size, weights_precision};
965959
}
966-
967-
968-
969960
};
970961

971962
struct HelperInterpNearest : public HelperInterpBase {
@@ -1175,8 +1166,9 @@ struct HelperInterpLinear : public HelperInterpBase {
11751166

11761167
auto interp_size = HelperInterpLinear::interp_size;
11771168
int unused;
1169+
scalar_t unused_2;
11781170

1179-
std::tie(indices_weights, unused) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t>(
1171+
std::tie(indices_weights, unused, unused_2) = HelperInterpLinear::_compute_indices_weights_aa<scalar_t>(
11801172
input_size,
11811173
output_size,
11821174
stride,
@@ -1307,8 +1299,9 @@ struct HelperInterpCubic : public HelperInterpBase {
13071299

13081300
auto interp_size = HelperInterpCubic::interp_size;
13091301
int unused;
1302+
scalar_t unused_2;
13101303

1311-
std::tie(indices_weights, unused) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t>(
1304+
std::tie(indices_weights, unused, unused_2) = HelperInterpCubic::_compute_indices_weights_aa<scalar_t>(
13121305
input_size,
13131306
output_size,
13141307
stride,

0 commit comments

Comments
 (0)