Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21490,6 +21490,15 @@ void RISCVTargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
Known = Known.sext(BitWidth);
break;
}
case RISCVISD::SRLW: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1);
Known = KnownBits::lshr(Known.trunc(32), Known2.trunc(5).zext(32));
// Restore the original width by sign extending.
Known = Known.sext(BitWidth);
break;
}
case RISCVISD::SRAW: {
KnownBits Known2;
Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
Expand Down
1 change: 1 addition & 0 deletions llvm/unittests/Target/RISCV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ set(LLVM_LINK_COMPONENTS
add_llvm_target_unittest(RISCVTests
MCInstrAnalysisTest.cpp
RISCVInstrInfoTest.cpp
RISCVSelectionDAGTest.cpp
)
134 changes: 134 additions & 0 deletions llvm/unittests/Target/RISCV/RISCVSelectionDAGTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
//===----------------------------------------------------------------------===//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "RISCVISelLowering.h"
#include "RISCVSelectionDAGInfo.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/CodeGen/MachineModuleInfo.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Target/TargetMachine.h"
#include "gtest/gtest.h"

namespace llvm {

class RISCVSelectionDAGTest : public testing::Test {

protected:
static void SetUpTestCase() {
LLVMInitializeRISCVTargetInfo();
LLVMInitializeRISCVTarget();
LLVMInitializeRISCVTargetMC();
}

void SetUp() override {
StringRef Assembly = "define void @f() { ret void }";

Triple TargetTriple("riscv64", "unknown", "linux");

std::string Error;
const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error);

TargetOptions Options;
TM = std::unique_ptr<TargetMachine>(T->createTargetMachine(
TargetTriple, "generic", "", Options, std::nullopt, std::nullopt,
CodeGenOptLevel::Default));

SMDiagnostic SMError;
M = parseAssemblyString(Assembly, SMError, Context);
if (!M)
report_fatal_error(SMError.getMessage());
M->setDataLayout(TM->createDataLayout());

F = M->getFunction("f");
if (!F)
report_fatal_error("Function 'f' not found");

MachineModuleInfo MMI(TM.get());

MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F),
MMI.getContext(), /*FunctionNum*/ 0);

DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None);
if (!DAG)
report_fatal_error("SelectionDAG allocation failed");

OptimizationRemarkEmitter ORE(F);
DAG->init(*MF, ORE, /*LibInfo*/ nullptr, /*AA*/ nullptr,
/*AC*/ nullptr, /*MDT*/ nullptr, /*MSDT*/ nullptr, MMI, nullptr);
}

LLVMContext Context;
std::unique_ptr<TargetMachine> TM;
std::unique_ptr<Module> M;
Function *F = nullptr;
std::unique_ptr<MachineFunction> MF;
std::unique_ptr<SelectionDAG> DAG;
};

/// SRLW: Logical Shift Right
TEST_F(RISCVSelectionDAGTest, computeKnownBits_SRLW) {
// Given the following IR snippet:
// define i64 @f(i32 %x, i32 %y) {
// %a = and i32 %x, 2147483647 ; zeros the MSB for %x
// %b = lshr i32 %a, %y
// %c = zext i32 %b to i64 ; makes the most significant 32 bits 0
// ret i64 %c
// }
// The Optimized SelectionDAG as show by llc -mtriple="riscv64"
// -debug-only=isel-dump is:
// t0: ch,glue = EntryToken
// t2: i64,ch = CopyFromReg t0, Register:i64 %0
// t18: i64 = and t2, Constant:i64<2147483647>
// t4: i64,ch = CopyFromReg t0, Register:i64 %1
// t20: i64 = RISCVISD::SRLW t18, t4
// t22: i64 = and t20, Constant:i64<4294967295>
// t13: ch,glue = CopyToReg t0, Register:i64 $x10, t22
// t14: ch = RISCVISD::RET_GLUE t13, Register:i64 $x10, t13:1
//
// The DAG created below is derived from this
SDLoc Loc;
auto Int64VT = EVT::getIntegerVT(Context, 64);
auto Px = DAG->getRegister(0, Int64VT);
auto Py = DAG->getConstant(2147483647, Loc, Int64VT);
auto N1 = DAG->getNode(ISD::AND, Loc, Int64VT, Px, Py);
auto Qx = DAG->getRegister(0, Int64VT);
auto N2 = DAG->getNode(RISCVISD::SRLW, Loc, Int64VT, N1, Qx);
auto Py2 = DAG->getConstant(4294967295, Loc, Int64VT);
auto N3 = DAG->getNode(ISD::AND, Loc, Int64VT, N2, Py2);
// N1 = Px & 0x7FFFFFFF
// The first AND ensures that the input to the shift has bit 31 cleared.
// This means bits [63:31] of N1 are known to be zero.
//
// N2 = SRLW N1, Qx
// SRLW performs a 32-bit logical right shift and then sign-extends the
// 32-bit result to 64 bits. Because we know N1's bit 31 is 0, the
// 32-bit result of the shift will also have its sign bit (bit 31) as 0.
// Therefore, the sign-extension is guaranteed to be a zero-extension.
//
// N3 = N2 & 0xFFFFFFFF
// This second AND is part of the canonical pattern to clear the upper
// 32 bits, explicitly performing the zero-extension. From a KnownBits
// perspective, it's redundant, as N2's upper bits are already known zero.
//
// As a result, for N3, we know the upper 32 bits are zero (from the effective
// zero-extension) and we also know bit 31 is zero (from the initial AND).
// This gives us 33 known most-significant zero bits.
KnownBits Known = DAG->computeKnownBits(N3);
EXPECT_EQ(Known.Zero, APInt(64, -2147483648));
EXPECT_EQ(Known.One, APInt(64, 0));
}

} // end namespace llvm