Skip to content

Commit 13b8cf5

Browse files
authored
Merge pull request #4678 from folkertdev/ternary-logic
add shim for avx512 ternarylogic functions
2 parents 15299b7 + 050412a commit 13b8cf5

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
use rustc_abi::CanonAbi;
2+
use rustc_middle::ty::Ty;
3+
use rustc_span::Symbol;
4+
use rustc_target::callconv::FnAbi;
5+
6+
use crate::*;
7+
8+
impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
9+
pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
10+
fn emulate_x86_avx512_intrinsic(
11+
&mut self,
12+
link_name: Symbol,
13+
abi: &FnAbi<'tcx, Ty<'tcx>>,
14+
args: &[OpTy<'tcx>],
15+
dest: &MPlaceTy<'tcx>,
16+
) -> InterpResult<'tcx, EmulateItemResult> {
17+
let this = self.eval_context_mut();
18+
// Prefix should have already been checked.
19+
let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx512.").unwrap();
20+
21+
match unprefixed_name {
22+
// Used by the ternarylogic functions.
23+
"pternlog.d.128" | "pternlog.d.256" | "pternlog.d.512" => {
24+
this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
25+
if matches!(unprefixed_name, "pternlog.d.128" | "pternlog.d.256") {
26+
this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
27+
}
28+
29+
let [a, b, c, imm8] =
30+
this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
31+
32+
assert_eq!(dest.layout, a.layout);
33+
assert_eq!(dest.layout, b.layout);
34+
assert_eq!(dest.layout, c.layout);
35+
36+
// The signatures of these operations are:
37+
//
38+
// ```
39+
// fn vpternlogd(a: i32x16, b: i32x16, c: i32x16, imm8: i32) -> i32x16;
40+
// fn vpternlogd256(a: i32x8, b: i32x8, c: i32x8, imm8: i32) -> i32x8;
41+
// fn vpternlogd128(a: i32x4, b: i32x4, c: i32x4, imm8: i32) -> i32x4;
42+
// ```
43+
//
44+
// The element type is always a 32-bit integer, the width varies.
45+
46+
let (a, _a_len) = this.project_to_simd(a)?;
47+
let (b, _b_len) = this.project_to_simd(b)?;
48+
let (c, _c_len) = this.project_to_simd(c)?;
49+
let (dest, dest_len) = this.project_to_simd(dest)?;
50+
51+
// Compute one lane with ternary table.
52+
let tern = |xa: u32, xb: u32, xc: u32, imm: u32| -> u32 {
53+
let mut out = 0u32;
54+
// At each bit position, select bit from imm8 at index = (a << 2) | (b << 1) | c
55+
for bit in 0..32 {
56+
let ia = (xa >> bit) & 1;
57+
let ib = (xb >> bit) & 1;
58+
let ic = (xc >> bit) & 1;
59+
let idx = (ia << 2) | (ib << 1) | ic;
60+
let v = (imm >> idx) & 1;
61+
out |= v << bit;
62+
}
63+
out
64+
};
65+
66+
let imm8 = this.read_scalar(imm8)?.to_u32()? & 0xFF;
67+
for i in 0..dest_len {
68+
let a_lane = this.project_index(&a, i)?;
69+
let b_lane = this.project_index(&b, i)?;
70+
let c_lane = this.project_index(&c, i)?;
71+
let d_lane = this.project_index(&dest, i)?;
72+
73+
let va = this.read_scalar(&a_lane)?.to_u32()?;
74+
let vb = this.read_scalar(&b_lane)?.to_u32()?;
75+
let vc = this.read_scalar(&c_lane)?.to_u32()?;
76+
77+
let r = tern(va, vb, vc, imm8);
78+
this.write_scalar(Scalar::from_u32(r), &d_lane)?;
79+
}
80+
}
81+
_ => return interp_ok(EmulateItemResult::NotSupported),
82+
}
83+
interp_ok(EmulateItemResult::NeedsReturn)
84+
}
85+
}

src/tools/miri/src/shims/x86/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use crate::*;
1313
mod aesni;
1414
mod avx;
1515
mod avx2;
16+
mod avx512;
1617
mod bmi;
1718
mod gfni;
1819
mod sha;
@@ -152,6 +153,11 @@ pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
152153
this, link_name, abi, args, dest,
153154
);
154155
}
156+
name if name.starts_with("avx512.") => {
157+
return avx512::EvalContextExt::emulate_x86_avx512_intrinsic(
158+
this, link_name, abi, args, dest,
159+
);
160+
}
155161

156162
_ => return interp_ok(EmulateItemResult::NotSupported),
157163
}

src/tools/miri/tests/pass/shims/x86/intrinsics-x86-avx512.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ fn main() {
1717
unsafe {
1818
test_avx512bitalg();
1919
test_avx512vpopcntdq();
20+
test_avx512ternarylogic();
2021
}
2122
}
2223

@@ -191,6 +192,77 @@ unsafe fn test_avx512vpopcntdq() {
191192
test_mm_popcnt_epi64();
192193
}
193194

195+
#[target_feature(enable = "avx512f,avx512vl")]
196+
unsafe fn test_avx512ternarylogic() {
197+
#[target_feature(enable = "avx512f")]
198+
unsafe fn test_mm512_ternarylogic_epi32() {
199+
let a = _mm512_set4_epi32(0b100, 0b110, 0b001, 0b101);
200+
let b = _mm512_set4_epi32(0b010, 0b011, 0b001, 0b110);
201+
let c = _mm512_set4_epi32(0b001, 0b000, 0b001, 0b111);
202+
203+
// Identity of A.
204+
let r = _mm512_ternarylogic_epi32::<0b1111_0000>(a, b, c);
205+
assert_eq_m512i(r, a);
206+
207+
// Bitwise xor.
208+
let r = _mm512_ternarylogic_epi32::<0b10010110>(a, b, c);
209+
let e = _mm512_set4_epi32(0b111, 0b101, 0b001, 0b100);
210+
assert_eq_m512i(r, e);
211+
212+
// Majority (2 or more bits set).
213+
let r = _mm512_ternarylogic_epi32::<0b1110_1000>(a, b, c);
214+
let e = _mm512_set4_epi32(0b000, 0b010, 0b001, 0b111);
215+
assert_eq_m512i(r, e);
216+
}
217+
test_mm512_ternarylogic_epi32();
218+
219+
#[target_feature(enable = "avx512f,avx512vl")]
220+
unsafe fn test_mm256_ternarylogic_epi32() {
221+
let _mm256_set4_epi32 = |a, b, c, d| _mm256_setr_epi32(a, b, c, d, a, b, c, d);
222+
223+
let a = _mm256_set4_epi32(0b100, 0b110, 0b001, 0b101);
224+
let b = _mm256_set4_epi32(0b010, 0b011, 0b001, 0b110);
225+
let c = _mm256_set4_epi32(0b001, 0b000, 0b001, 0b111);
226+
227+
// Identity of A.
228+
let r = _mm256_ternarylogic_epi32::<0b1111_0000>(a, b, c);
229+
assert_eq_m256i(r, a);
230+
231+
// Bitwise xor.
232+
let r = _mm256_ternarylogic_epi32::<0b10010110>(a, b, c);
233+
let e = _mm256_set4_epi32(0b111, 0b101, 0b001, 0b100);
234+
assert_eq_m256i(r, e);
235+
236+
// Majority (2 or more bits set).
237+
let r = _mm256_ternarylogic_epi32::<0b1110_1000>(a, b, c);
238+
let e = _mm256_set4_epi32(0b000, 0b010, 0b001, 0b111);
239+
assert_eq_m256i(r, e);
240+
}
241+
test_mm256_ternarylogic_epi32();
242+
243+
#[target_feature(enable = "avx512f,avx512vl")]
244+
unsafe fn test_mm_ternarylogic_epi32() {
245+
let a = _mm_setr_epi32(0b100, 0b110, 0b001, 0b101);
246+
let b = _mm_setr_epi32(0b010, 0b011, 0b001, 0b110);
247+
let c = _mm_setr_epi32(0b001, 0b000, 0b001, 0b111);
248+
249+
// Identity of A.
250+
let r = _mm_ternarylogic_epi32::<0b1111_0000>(a, b, c);
251+
assert_eq_m128i(r, a);
252+
253+
// Bitwise xor.
254+
let r = _mm_ternarylogic_epi32::<0b10010110>(a, b, c);
255+
let e = _mm_setr_epi32(0b111, 0b101, 0b001, 0b100);
256+
assert_eq_m128i(r, e);
257+
258+
// Majority (2 or more bits set).
259+
let r = _mm_ternarylogic_epi32::<0b1110_1000>(a, b, c);
260+
let e = _mm_setr_epi32(0b000, 0b010, 0b001, 0b111);
261+
assert_eq_m128i(r, e);
262+
}
263+
test_mm_ternarylogic_epi32();
264+
}
265+
194266
#[track_caller]
195267
unsafe fn assert_eq_m512i(a: __m512i, b: __m512i) {
196268
assert_eq!(transmute::<_, [i32; 16]>(a), transmute::<_, [i32; 16]>(b))

0 commit comments

Comments
 (0)