From 5563f46b0d4b559e0d29d508a3acbbf270a173c2 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Fri, 18 Jul 2025 14:11:05 -0700 Subject: [PATCH 1/8] [WebAssembly] Precommit test for constant folding dot --- .../InstSimplify/ConstProp/WebAssembly/dot.ll | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll new file mode 100644 index 0000000000000..75a500c6278ad --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll @@ -0,0 +1,37 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 + +; RUN: opt -passes=instsimplify -S < %s | FileCheck %s + +; Test that intrinsics wasm dot call are constant folded + +target triple = "wasm32-unknown-unknown" + + +define <4 x i32> @dot_zero() { +; CHECK-LABEL: define <4 x i32> @dot_zero() { +; CHECK-NEXT: [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer) +; CHECK-NEXT: ret <4 x i32> [[RES]] +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer) + ret <4 x i32> %res +} + +define <4 x i32> @dot_nonzero() { +; CHECK-LABEL: define <4 x i32> @dot_nonzero() { +; CHECK-NEXT: [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) +; CHECK-NEXT: ret <4 x i32> [[RES]] +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} + +define <4 x i32> @dot_doubly_negative() { +; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() { +; CHECK-NEXT: [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> splat (i16 -1), <8 x i16> splat (i16 -1)) +; CHECK-NEXT: ret <4 x i32> [[RES]] +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} + + From 50ca839dbee8ea3d505dee87c3e21c458ee28b7f Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Fri, 18 Jul 2025 16:34:41 -0700 Subject: [PATCH 2/8] [WebAssembly] Constant fold dot operation --- llvm/lib/Analysis/ConstantFolding.cpp | 31 +++++++++++++++++++ .../InstSimplify/ConstProp/WebAssembly/dot.ll | 14 +++++---- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 9c1c2c6e60f02..2304c58b3f95f 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) { case Intrinsic::aarch64_sve_convert_from_svbool: case Intrinsic::wasm_alltrue: case Intrinsic::wasm_anytrue: + case Intrinsic::wasm_dot: // WebAssembly float semantics are always known case Intrinsic::wasm_trunc_signed: case Intrinsic::wasm_trunc_unsigned: @@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall( } return ConstantVector::get(Result); } + case Intrinsic::wasm_dot: { + unsigned NumElements = + cast(Operands[0]->getType())->getNumElements(); + + assert(NumElements == 8 && NumElements / 2 == Result.size() && + "wasm dot takes i16x8 and produce i32x4"); + assert(Ty->isIntegerTy()); + SmallVector MulVector; + + for (unsigned I = 0; I < NumElements; ++I) { + ConstantInt *Elt0 = + cast(Operands[0]->getAggregateElement(I)); + ConstantInt *Elt1 = + cast(Operands[1]->getAggregateElement(I)); + + // sext 32 first, according to specs + APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32); + + // TODO: imul in specs includes a modulo operation + // Is this performed automatically via trunc = true in APInt creation of * + MulVector.push_back(IMul); + } + for (unsigned I = 0; I < Result.size(); ++I) { + // Same case as with imul + APInt IAdd = MulVector[I] + MulVector[I + Result.size()]; + Result[I] = ConstantInt::get(Ty, IAdd); + } + + return ConstantVector::get(Result); + } default: break; } diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll index 75a500c6278ad..02c6649becbce 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll @@ -9,17 +9,20 @@ target triple = "wasm32-unknown-unknown" define <4 x i32> @dot_zero() { ; CHECK-LABEL: define <4 x i32> @dot_zero() { -; CHECK-NEXT: [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer) -; CHECK-NEXT: ret <4 x i32> [[RES]] +; CHECK-NEXT: ret <4 x i32> zeroinitializer ; %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer) ret <4 x i32> %res } +; a = 1 2 3 4 5 6 7 8 +; b = 1 2 3 4 5 6 7 8 +; k1|k2 = a * b = 1 4 9 16 25 36 49 64 +; k1 + k2 = (1+25) | (4+36) | (9+49) | (16+64) +; result = 26 | 40 | 58 | 80 define <4 x i32> @dot_nonzero() { ; CHECK-LABEL: define <4 x i32> @dot_nonzero() { -; CHECK-NEXT: [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) -; CHECK-NEXT: ret <4 x i32> [[RES]] +; CHECK-NEXT: ret <4 x i32> ; %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) ret <4 x i32> %res @@ -27,8 +30,7 @@ define <4 x i32> @dot_nonzero() { define <4 x i32> @dot_doubly_negative() { ; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() { -; CHECK-NEXT: [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> splat (i16 -1), <8 x i16> splat (i16 -1)) -; CHECK-NEXT: ret <4 x i32> [[RES]] +; CHECK-NEXT: ret <4 x i32> splat (i32 2) ; %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) ret <4 x i32> %res From fa8c096313e2eab8661a5d63b04477f01c86ed78 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 22 Jul 2025 13:40:55 -0700 Subject: [PATCH 3/8] Addresses specs questions and added test to reflect --- llvm/lib/Analysis/ConstantFolding.cpp | 9 ++++----- .../InstSimplify/ConstProp/WebAssembly/dot.ll | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 2304c58b3f95f..a63be47e21eaa 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3845,13 +3845,12 @@ static Constant *ConstantFoldFixedVectorCall( // sext 32 first, according to specs APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32); - // TODO: imul in specs includes a modulo operation - // Is this performed automatically via trunc = true in APInt creation of * + // i16 -> i32 bypasses specs modulo on imul MulVector.push_back(IMul); } - for (unsigned I = 0; I < Result.size(); ++I) { - // Same case as with imul - APInt IAdd = MulVector[I] + MulVector[I + Result.size()]; + for (unsigned I = 0; I < Result.size(); I++) { + // i16 -> i32 bypasses specs modulo on iadd + APInt IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; Result[I] = ConstantInt::get(Ty, IAdd); } diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll index 02c6649becbce..b2f23d0f153ef 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll @@ -18,11 +18,11 @@ define <4 x i32> @dot_zero() { ; a = 1 2 3 4 5 6 7 8 ; b = 1 2 3 4 5 6 7 8 ; k1|k2 = a * b = 1 4 9 16 25 36 49 64 -; k1 + k2 = (1+25) | (4+36) | (9+49) | (16+64) -; result = 26 | 40 | 58 | 80 +; k1 + k2 = (1+4) | (9 + 16) | (25 + 36) | (49 + 64) +; result = 5 | 25 | 61 | 113 define <4 x i32> @dot_nonzero() { ; CHECK-LABEL: define <4 x i32> @dot_nonzero() { -; CHECK-NEXT: ret <4 x i32> +; CHECK-NEXT: ret <4 x i32> ; %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) ret <4 x i32> %res @@ -36,4 +36,16 @@ define <4 x i32> @dot_doubly_negative() { ret <4 x i32> %res } +; This test checks for llvm's compliance on spec's wasm.dot's imul and iadd +; Since the original number can only be i16::max == 2^15 - 1, +; subsequent modulo of 2^32 of imul and iadd +; should return the same result +; 2*(2^15 - 1)^2 % 2^32 == 2*(2^15 - 1)^2 +define <4 x i32> @dot_follow_modulo_spec() { +; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec() { +; CHECK-NEXT: ret <4 x i32> +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} From 10d7dc9cd06333783b4877dc30d2d17a7d1753b2 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Thu, 24 Jul 2025 20:05:02 -0700 Subject: [PATCH 4/8] Addresses nit, added negative test case --- llvm/lib/Analysis/ConstantFolding.cpp | 10 ++++++---- .../InstSimplify/ConstProp/WebAssembly/dot.ll | 15 +++++++++++++-- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index a63be47e21eaa..765bd3ec4b9fd 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3831,8 +3831,8 @@ static Constant *ConstantFoldFixedVectorCall( unsigned NumElements = cast(Operands[0]->getType())->getNumElements(); - assert(NumElements == 8 && NumElements / 2 == Result.size() && - "wasm dot takes i16x8 and produce i32x4"); + assert(NumElements == 8 && Result.size() == 4 && + "wasm dot takes i16x8 and produces i32x4"); assert(Ty->isIntegerTy()); SmallVector MulVector; @@ -3845,11 +3845,13 @@ static Constant *ConstantFoldFixedVectorCall( // sext 32 first, according to specs APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32); - // i16 -> i32 bypasses specs modulo on imul + // Multiplication can never be more than 32 bit. + // We can opt to not perform modulo of imul here. MulVector.push_back(IMul); } for (unsigned I = 0; I < Result.size(); I++) { - // i16 -> i32 bypasses specs modulo on iadd + // Addition can never be more than 32 bit. + // We can opt to not perform modulo of iadd here. APInt IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; Result[I] = ConstantInt::get(Ty, IAdd); } diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll index b2f23d0f153ef..9c5dc74033f5b 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll @@ -41,11 +41,22 @@ define <4 x i32> @dot_doubly_negative() { ; subsequent modulo of 2^32 of imul and iadd ; should return the same result ; 2*(2^15 - 1)^2 % 2^32 == 2*(2^15 - 1)^2 -define <4 x i32> @dot_follow_modulo_spec() { -; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec() { +define <4 x i32> @dot_follow_modulo_spec_1() { +; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_1() { ; CHECK-NEXT: ret <4 x i32> ; %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) ret <4 x i32> %res } +; This test checks for llvm's compliance on spec's wasm.dot's imul and iadd +; 2*(- 2^15)^2 == 2^31, doesn't exceed 2^32 so we don't have to mod +; wrapping around is -(2^31), still doesn't exceed 2^32 +define <4 x i32> @dot_follow_modulo_spec_2() { +; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_2() { +; CHECK-NEXT: ret <4 x i32> +; + %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> , <8 x i16> ) + ret <4 x i32> %res +} + From 0271d8d3929bd8339ba6d11f43181025337b6282 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 29 Jul 2025 08:43:23 -0700 Subject: [PATCH 5/8] Address nits and performance issues --- llvm/lib/Analysis/ConstantFolding.cpp | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 765bd3ec4b9fd..09886d75b80fa 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3834,7 +3834,7 @@ static Constant *ConstantFoldFixedVectorCall( assert(NumElements == 8 && Result.size() == 4 && "wasm dot takes i16x8 and produces i32x4"); assert(Ty->isIntegerTy()); - SmallVector MulVector; + int32_t MulVector[8]; for (unsigned I = 0; I < NumElements; ++I) { ConstantInt *Elt0 = @@ -3842,17 +3842,12 @@ static Constant *ConstantFoldFixedVectorCall( ConstantInt *Elt1 = cast(Operands[1]->getAggregateElement(I)); - // sext 32 first, according to specs APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32); - // Multiplication can never be more than 32 bit. - // We can opt to not perform modulo of imul here. - MulVector.push_back(IMul); + MulVector[I] = IMul.getSExtValue(); } for (unsigned I = 0; I < Result.size(); I++) { - // Addition can never be more than 32 bit. - // We can opt to not perform modulo of iadd here. - APInt IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; + int32_t IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; Result[I] = ConstantInt::get(Ty, IAdd); } From 67492fa23db705619e9e784b1687d9314cea6982 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Tue, 5 Aug 2025 11:22:32 -0700 Subject: [PATCH 6/8] Addresses PR reviews --- llvm/lib/Analysis/ConstantFolding.cpp | 4 +--- .../InstSimplify/ConstProp/WebAssembly/dot.ll | 10 ++-------- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 09886d75b80fa..6f5573fc939cb 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3842,9 +3842,7 @@ static Constant *ConstantFoldFixedVectorCall( ConstantInt *Elt1 = cast(Operands[1]->getAggregateElement(I)); - APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32); - - MulVector[I] = IMul.getSExtValue(); + MulVector[I] = Elt0->getSExtValue() * Elt1->getSExtValue(); } for (unsigned I = 0; I < Result.size(); I++) { int32_t IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll index 9c5dc74033f5b..b537b7bccf861 100644 --- a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll +++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll @@ -36,11 +36,7 @@ define <4 x i32> @dot_doubly_negative() { ret <4 x i32> %res } -; This test checks for llvm's compliance on spec's wasm.dot's imul and iadd -; Since the original number can only be i16::max == 2^15 - 1, -; subsequent modulo of 2^32 of imul and iadd -; should return the same result -; 2*(2^15 - 1)^2 % 2^32 == 2*(2^15 - 1)^2 +; Tests that i16 max signed values fit in i32 define <4 x i32> @dot_follow_modulo_spec_1() { ; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_1() { ; CHECK-NEXT: ret <4 x i32> @@ -49,9 +45,7 @@ define <4 x i32> @dot_follow_modulo_spec_1() { ret <4 x i32> %res } -; This test checks for llvm's compliance on spec's wasm.dot's imul and iadd -; 2*(- 2^15)^2 == 2^31, doesn't exceed 2^32 so we don't have to mod -; wrapping around is -(2^31), still doesn't exceed 2^32 +; Tests that i16 min signed values fit in i32 define <4 x i32> @dot_follow_modulo_spec_2() { ; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec_2() { ; CHECK-NEXT: ret <4 x i32> From a151fdc2c03a17f95dd3228d32bf6e4cdd487ec0 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Mon, 11 Aug 2025 11:53:47 -0700 Subject: [PATCH 7/8] UB-proof wrap around of adding constant --- llvm/lib/Analysis/ConstantFolding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index 6f5573fc939cb..bf291809b07a2 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3845,7 +3845,7 @@ static Constant *ConstantFoldFixedVectorCall( MulVector[I] = Elt0->getSExtValue() * Elt1->getSExtValue(); } for (unsigned I = 0; I < Result.size(); I++) { - int32_t IAdd = MulVector[I * 2] + MulVector[I * 2 + 1]; + int64_t IAdd = (int64_t)MulVector[I * 2] + MulVector[I * 2 + 1]; Result[I] = ConstantInt::get(Ty, IAdd); } From 751de4c761a31930ec2813517e7ebb6594a8da67 Mon Sep 17 00:00:00 2001 From: Jasmine Tang Date: Thu, 14 Aug 2025 18:18:01 -0700 Subject: [PATCH 8/8] Clarify casting of MulVector[...] --- llvm/lib/Analysis/ConstantFolding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp index bf291809b07a2..77a762c0f3c19 100644 --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -3845,7 +3845,7 @@ static Constant *ConstantFoldFixedVectorCall( MulVector[I] = Elt0->getSExtValue() * Elt1->getSExtValue(); } for (unsigned I = 0; I < Result.size(); I++) { - int64_t IAdd = (int64_t)MulVector[I * 2] + MulVector[I * 2 + 1]; + int64_t IAdd = (int64_t)MulVector[I * 2] + (int64_t)MulVector[I * 2 + 1]; Result[I] = ConstantInt::get(Ty, IAdd); }