Skip to content

Commit 2675b8d

Browse files
authored
Merge pull request #4705 from folkertdev/avx512-adler32
Add avx512 permute and pmaddbw
2 parents 3a89419 + 51f7643 commit 2675b8d

File tree

5 files changed

+193
-80
lines changed

5 files changed

+193
-80
lines changed

src/shims/x86/avx2.rs

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_target::callconv::FnAbi;
66

77
use super::{
88
ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
9-
packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
9+
packuswb, permute, pmaddbw, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
1010
};
1111
use crate::*;
1212

@@ -102,39 +102,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
102102
}
103103
}
104104
// Used to implement the _mm256_maddubs_epi16 function.
105-
// Multiplies packed 8-bit unsigned integers from `left` and packed
106-
// signed 8-bit integers from `right` into 16-bit signed integers. Then,
107-
// the saturating sum of the products with indices `2*i` and `2*i+1`
108-
// produces the output at index `i`.
109105
"pmadd.ub.sw" => {
110106
let [left, right] =
111107
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
112108

113-
let (left, left_len) = this.project_to_simd(left)?;
114-
let (right, right_len) = this.project_to_simd(right)?;
115-
let (dest, dest_len) = this.project_to_simd(dest)?;
116-
117-
assert_eq!(left_len, right_len);
118-
assert_eq!(dest_len.strict_mul(2), left_len);
119-
120-
for i in 0..dest_len {
121-
let j1 = i.strict_mul(2);
122-
let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?;
123-
let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?;
124-
125-
let j2 = j1.strict_add(1);
126-
let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?;
127-
let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?;
128-
129-
let dest = this.project_index(&dest, i)?;
130-
131-
// Multiplication of a u8 and an i8 into an i16 cannot overflow.
132-
let mul1 = i16::from(left1).strict_mul(right1.into());
133-
let mul2 = i16::from(left2).strict_mul(right2.into());
134-
let res = mul1.saturating_add(mul2);
135-
136-
this.write_scalar(Scalar::from_i16(res), &dest)?;
137-
}
109+
pmaddbw(this, left, right, dest)?;
138110
}
139111
// Used to implement the _mm_maskload_epi32, _mm_maskload_epi64,
140112
// _mm256_maskload_epi32 and _mm256_maskload_epi64 functions.
@@ -217,28 +189,12 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
217189

218190
packusdw(this, left, right, dest)?;
219191
}
220-
// Used to implement the _mm256_permutevar8x32_epi32 and
221-
// _mm256_permutevar8x32_ps function.
222-
// Shuffles `left` using the three low bits of each element of `right`
223-
// as indices.
192+
// Used to implement _mm256_permutevar8x32_epi32 and _mm256_permutevar8x32_ps.
224193
"permd" | "permps" => {
225194
let [left, right] =
226195
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
227196

228-
let (left, left_len) = this.project_to_simd(left)?;
229-
let (right, right_len) = this.project_to_simd(right)?;
230-
let (dest, dest_len) = this.project_to_simd(dest)?;
231-
232-
assert_eq!(dest_len, left_len);
233-
assert_eq!(dest_len, right_len);
234-
235-
for i in 0..dest_len {
236-
let dest = this.project_index(&dest, i)?;
237-
let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?;
238-
let left = this.project_index(&left, (right & 0b111).into())?;
239-
240-
this.copy_op(&left, &dest)?;
241-
}
197+
permute(this, left, right, dest)?;
242198
}
243199
// Used to implement the _mm256_sad_epu8 function.
244200
"psad.bw" => {

src/shims/x86/avx512.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use rustc_middle::ty::Ty;
33
use rustc_span::Symbol;
44
use rustc_target::callconv::FnAbi;
55

6-
use super::psadbw;
6+
use super::{permute, pmaddbw, psadbw};
77
use crate::*;
88

99
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -88,6 +88,20 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
8888

8989
psadbw(this, left, right, dest)?
9090
}
91+
// Used to implement the _mm512_maddubs_epi16 function.
92+
"pmaddubs.w.512" => {
93+
let [left, right] =
94+
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
95+
96+
pmaddbw(this, left, right, dest)?;
97+
}
98+
// Used to implement the _mm512_permutexvar_epi32 function.
99+
"permvar.si.512" => {
100+
let [left, right] =
101+
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
102+
103+
permute(this, left, right, dest)?;
104+
}
91105
_ => return interp_ok(EmulateItemResult::NotSupported),
92106
}
93107
interp_ok(EmulateItemResult::NeedsReturn)

src/shims/x86/mod.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,6 +1086,90 @@ fn psadbw<'tcx>(
10861086
interp_ok(())
10871087
}
10881088

1089+
/// Multiplies packed 8-bit unsigned integers from `left` and packed
1090+
/// signed 8-bit integers from `right` into 16-bit signed integers. Then,
1091+
/// the saturating sum of the products with indices `2*i` and `2*i+1`
1092+
/// produces the output at index `i`.
1093+
///
1094+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16>
1095+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_maddubs_epi16>
1096+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_maddubs_epi16>
1097+
fn pmaddbw<'tcx>(
1098+
ecx: &mut crate::MiriInterpCx<'tcx>,
1099+
left: &OpTy<'tcx>,
1100+
right: &OpTy<'tcx>,
1101+
dest: &MPlaceTy<'tcx>,
1102+
) -> InterpResult<'tcx, ()> {
1103+
let (left, left_len) = ecx.project_to_simd(left)?;
1104+
let (right, right_len) = ecx.project_to_simd(right)?;
1105+
let (dest, dest_len) = ecx.project_to_simd(dest)?;
1106+
1107+
// fn pmaddubsw128(a: u8x16, b: i8x16) -> i16x8;
1108+
// fn pmaddubsw( a: u8x32, b: i8x32) -> i16x16;
1109+
// fn vpmaddubsw( a: u8x64, b: i8x64) -> i16x32;
1110+
assert_eq!(left_len, right_len);
1111+
assert_eq!(dest_len.strict_mul(2), left_len);
1112+
1113+
for i in 0..dest_len {
1114+
let j1 = i.strict_mul(2);
1115+
let left1 = ecx.read_scalar(&ecx.project_index(&left, j1)?)?.to_u8()?;
1116+
let right1 = ecx.read_scalar(&ecx.project_index(&right, j1)?)?.to_i8()?;
1117+
1118+
let j2 = j1.strict_add(1);
1119+
let left2 = ecx.read_scalar(&ecx.project_index(&left, j2)?)?.to_u8()?;
1120+
let right2 = ecx.read_scalar(&ecx.project_index(&right, j2)?)?.to_i8()?;
1121+
1122+
let dest = ecx.project_index(&dest, i)?;
1123+
1124+
// Multiplication of a u8 and an i8 into an i16 cannot overflow.
1125+
let mul1 = i16::from(left1).strict_mul(right1.into());
1126+
let mul2 = i16::from(left2).strict_mul(right2.into());
1127+
let res = mul1.saturating_add(mul2);
1128+
1129+
ecx.write_scalar(Scalar::from_i16(res), &dest)?;
1130+
}
1131+
1132+
interp_ok(())
1133+
}
1134+
1135+
/// Shuffle 32-bit integers in `values` across lanes using the corresponding
1136+
/// index in `indices`, and store the results in dst.
1137+
///
1138+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_permutevar8x32_epi32>
1139+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_permutevar8x32_ps>
1140+
/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_permutexvar_epi32>
1141+
fn permute<'tcx>(
1142+
ecx: &mut crate::MiriInterpCx<'tcx>,
1143+
values: &OpTy<'tcx>,
1144+
indices: &OpTy<'tcx>,
1145+
dest: &MPlaceTy<'tcx>,
1146+
) -> InterpResult<'tcx, ()> {
1147+
let (values, values_len) = ecx.project_to_simd(values)?;
1148+
let (indices, indices_len) = ecx.project_to_simd(indices)?;
1149+
let (dest, dest_len) = ecx.project_to_simd(dest)?;
1150+
1151+
// fn permd(a: u32x8, b: u32x8) -> u32x8;
1152+
// fn permps(a: __m256, b: i32x8) -> __m256;
1153+
// fn vpermd(a: i32x16, idx: i32x16) -> i32x16;
1154+
assert_eq!(dest_len, values_len);
1155+
assert_eq!(dest_len, indices_len);
1156+
1157+
// Only use the lower 3 bits to index into a vector with 8 lanes,
1158+
// or the lower 4 bits when indexing into a 16-lane vector.
1159+
assert!(dest_len.is_power_of_two());
1160+
let mask = u32::try_from(dest_len).unwrap().strict_sub(1);
1161+
1162+
for i in 0..dest_len {
1163+
let dest = ecx.project_index(&dest, i)?;
1164+
let index = ecx.read_scalar(&ecx.project_index(&indices, i)?)?.to_u32()?;
1165+
let element = ecx.project_index(&values, (index & mask).into())?;
1166+
1167+
ecx.copy_op(&element, &dest)?;
1168+
}
1169+
1170+
interp_ok(())
1171+
}
1172+
10891173
/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
10901174
/// product to the 18 most significant bits by right-shifting, and then
10911175
/// divides the 18-bit value by 2 (rounding to nearest) by first adding

src/shims/x86/ssse3.rs

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_middle::ty::Ty;
44
use rustc_span::Symbol;
55
use rustc_target::callconv::FnAbi;
66

7-
use super::{horizontal_bin_op, pmulhrsw, psign};
7+
use super::{horizontal_bin_op, pmaddbw, pmulhrsw, psign};
88
use crate::*;
99

1010
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
@@ -67,40 +67,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
6767
horizontal_bin_op(this, which, /*saturating*/ true, left, right, dest)?;
6868
}
6969
// Used to implement the _mm_maddubs_epi16 function.
70-
// Multiplies packed 8-bit unsigned integers from `left` and packed
71-
// signed 8-bit integers from `right` into 16-bit signed integers. Then,
72-
// the saturating sum of the products with indices `2*i` and `2*i+1`
73-
// produces the output at index `i`.
74-
// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16
7570
"pmadd.ub.sw.128" => {
7671
let [left, right] =
7772
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
7873

79-
let (left, left_len) = this.project_to_simd(left)?;
80-
let (right, right_len) = this.project_to_simd(right)?;
81-
let (dest, dest_len) = this.project_to_simd(dest)?;
82-
83-
assert_eq!(left_len, right_len);
84-
assert_eq!(dest_len.strict_mul(2), left_len);
85-
86-
for i in 0..dest_len {
87-
let j1 = i.strict_mul(2);
88-
let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?;
89-
let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?;
90-
91-
let j2 = j1.strict_add(1);
92-
let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?;
93-
let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?;
94-
95-
let dest = this.project_index(&dest, i)?;
96-
97-
// Multiplication of a u8 and an i8 into an i16 cannot overflow.
98-
let mul1 = i16::from(left1).strict_mul(right1.into());
99-
let mul2 = i16::from(left2).strict_mul(right2.into());
100-
let res = mul1.saturating_add(mul2);
101-
102-
this.write_scalar(Scalar::from_i16(res), &dest)?;
103-
}
74+
pmaddbw(this, left, right, dest)?;
10475
}
10576
// Used to implement the _mm_mulhrs_epi16 function.
10677
// Multiplies packed 16-bit signed integer values, truncates the 32-bit

tests/pass/shims/x86/intrinsics-x86-avx512.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,94 @@ unsafe fn test_avx512() {
5555
assert_eq_m512i(r, e);
5656
}
5757
test_mm512_sad_epu8();
58+
59+
#[target_feature(enable = "avx512bw")]
60+
unsafe fn test_mm512_maddubs_epi16() {
61+
// `a` is interpreted as `u8x16`, but `_mm512_set_epi8` expects `i8`, so we have to cast.
62+
#[rustfmt::skip]
63+
let a = _mm512_set_epi8(
64+
255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
65+
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,
66+
67+
255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
68+
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,
69+
70+
255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
71+
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,
72+
73+
255u8 as i8, 255u8 as i8, 60, 50, 100, 100, 255u8 as i8, 200u8 as i8,
74+
255u8 as i8, 200u8 as i8, 200u8 as i8, 100, 60, 50, 20, 10,
75+
);
76+
77+
let b = _mm512_set_epi8(
78+
64, 64, -2, 1, 100, 100, -128, -128, //
79+
127, 127, -1, 1, 2, 2, 1, 1, //
80+
64, 64, -2, 1, 100, 100, -128, -128, //
81+
127, 127, -1, 1, 2, 2, 1, 1, //
82+
64, 64, -2, 1, 100, 100, -128, -128, //
83+
127, 127, -1, 1, 2, 2, 1, 1, //
84+
64, 64, -2, 1, 100, 100, -128, -128, //
85+
127, 127, -1, 1, 2, 2, 1, 1, //
86+
);
87+
88+
let r = _mm512_maddubs_epi16(a, b);
89+
90+
let e = _mm512_set_epi16(
91+
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
92+
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
93+
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
94+
32640, -70, 20000, -32768, 32767, -100, 220, 30, //
95+
);
96+
97+
assert_eq_m512i(r, e);
98+
}
99+
test_mm512_maddubs_epi16();
100+
101+
#[target_feature(enable = "avx512f")]
102+
unsafe fn test_mm512_permutexvar_epi32() {
103+
let a = _mm512_set_epi32(
104+
15, 14, 13, 12, //
105+
11, 10, 9, 8, //
106+
7, 6, 5, 4, //
107+
3, 2, 1, 0, //
108+
);
109+
110+
let idx_identity = _mm512_set_epi32(
111+
15, 14, 13, 12, //
112+
11, 10, 9, 8, //
113+
7, 6, 5, 4, //
114+
3, 2, 1, 0, //
115+
);
116+
let r_id = _mm512_permutexvar_epi32(idx_identity, a);
117+
assert_eq_m512i(r_id, a);
118+
119+
// Test some out-of-bounds indices.
120+
let edge_cases = _mm512_set_epi32(
121+
0,
122+
-1,
123+
-128,
124+
i32::MIN,
125+
15,
126+
16,
127+
128,
128+
i32::MAX,
129+
0,
130+
-1,
131+
-128,
132+
i32::MIN,
133+
15,
134+
16,
135+
128,
136+
i32::MAX,
137+
);
138+
139+
let r = _mm512_permutexvar_epi32(edge_cases, a);
140+
141+
let e = _mm512_set_epi32(0, 15, 0, 0, 15, 0, 0, 15, 0, 15, 0, 0, 15, 0, 0, 15);
142+
143+
assert_eq_m512i(r, e);
144+
}
145+
test_mm512_permutexvar_epi32();
58146
}
59147

60148
// Some of the constants in the tests below are just bit patterns. They should not

0 commit comments

Comments
 (0)