diff --git a/llvm/include/llvm/SandboxIR/Context.h b/llvm/include/llvm/SandboxIR/Context.h index 77924fbcd5ace..1285598a1c028 100644 --- a/llvm/include/llvm/SandboxIR/Context.h +++ b/llvm/include/llvm/SandboxIR/Context.h @@ -68,6 +68,7 @@ class Context { } /// Get or create a sandboxir::Constant from an existing LLVM IR \p LLVMC. Constant *getOrCreateConstant(llvm::Constant *LLVMC); + friend class Utils; // For getMemoryBase // Friends for getOrCreateConstant(). #define DEF_CONST(ID, CLASS) friend class CLASS; diff --git a/llvm/include/llvm/SandboxIR/Utils.h b/llvm/include/llvm/SandboxIR/Utils.h index e4156c6af9a22..b0037723a9e88 100644 --- a/llvm/include/llvm/SandboxIR/Utils.h +++ b/llvm/include/llvm/SandboxIR/Utils.h @@ -49,6 +49,16 @@ class Utils { return const_cast(I); } + /// \Returns the base Value for load or store instruction \p LSI. + template + static Value *getMemInstructionBase(const LoadOrStoreT *LSI) { + static_assert(std::is_same_v || + std::is_same_v, + "Expected sandboxir::Load or sandboxir::Store!"); + return LSI->Ctx.getOrCreateValue( + getUnderlyingObject(LSI->getPointerOperand()->Val)); + } + /// \Returns the number of bits required to represent the operands or return /// value of \p V in \p DL. static unsigned getNumBits(Value *V, const DataLayout &DL) { diff --git a/llvm/unittests/SandboxIR/UtilsTest.cpp b/llvm/unittests/SandboxIR/UtilsTest.cpp index 90396eaa53ab3..a30fc253a1a74 100644 --- a/llvm/unittests/SandboxIR/UtilsTest.cpp +++ b/llvm/unittests/SandboxIR/UtilsTest.cpp @@ -215,3 +215,35 @@ define void @foo(float %arg0, double %arg1, i8 %arg2, i64 %arg3, ptr %arg4) { EXPECT_EQ(sandboxir::Utils::getNumBits(L2), 8u); EXPECT_EQ(sandboxir::Utils::getNumBits(L3), 64u); } + +TEST_F(UtilsTest, GetMemBase) { + parseIR(C, R"IR( +define void @foo(ptr %ptrA, float %val, ptr %ptrB) { +bb: + %gepA0 = getelementptr float, ptr %ptrA, i32 0 + %gepA1 = getelementptr float, ptr %ptrA, i32 1 + %gepB0 = getelementptr float, ptr %ptrB, i32 0 + %gepB1 = getelementptr float, ptr %ptrB, i32 1 + store float %val, ptr %gepA0 + store float %val, ptr %gepA1 + store float %val, ptr %gepB0 + store float %val, ptr %gepB1 + ret void +} +)IR"); + llvm::Function &Foo = *M->getFunction("foo"); + sandboxir::Context Ctx(C); + sandboxir::Function *F = Ctx.createFunction(&Foo); + + auto It = std::next(F->begin()->begin(), 4); + auto *St0 = cast(&*It++); + auto *St1 = cast(&*It++); + auto *St2 = cast(&*It++); + auto *St3 = cast(&*It++); + EXPECT_EQ(sandboxir::Utils::getMemInstructionBase(St0), + sandboxir::Utils::getMemInstructionBase(St1)); + EXPECT_EQ(sandboxir::Utils::getMemInstructionBase(St2), + sandboxir::Utils::getMemInstructionBase(St3)); + EXPECT_NE(sandboxir::Utils::getMemInstructionBase(St0), + sandboxir::Utils::getMemInstructionBase(St3)); +}