diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index 5d2dfe76b1b98..90b16290d0a19 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -822,7 +822,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, // We have some custom DAG combine patterns for these nodes setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT, - ISD::BUILD_VECTOR}); + ISD::BUILD_VECTOR, ISD::ADDRSPACECAST}); // setcc for f16x2 and bf16x2 needs special handling to prevent // legalizer's attempt to scalarize it due to v2i1 not being legal. @@ -5209,6 +5209,21 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { return DAG.getNode(ISD::BITCAST, DL, VT, PRMT); } +static SDValue combineADDRSPACECAST(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI) { + auto *ASCN1 = cast(N); + + if (auto *ASCN2 = dyn_cast(ASCN1->getOperand(0))) { + assert(ASCN2->getDestAddressSpace() == ASCN1->getSrcAddressSpace()); + + // Fold asc[B -> A](asc[A -> B](x)) -> x + if (ASCN1->getDestAddressSpace() == ASCN2->getSrcAddressSpace()) + return ASCN2->getOperand(0); + } + + return SDValue(); +} + SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel(); @@ -5243,6 +5258,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N, return PerformVSELECTCombine(N, DCI); case ISD::BUILD_VECTOR: return PerformBUILD_VECTORCombine(N, DCI); + case ISD::ADDRSPACECAST: + return combineADDRSPACECAST(N, DCI); } return SDValue(); } diff --git a/llvm/test/CodeGen/NVPTX/addrspacecast-folding.ll b/llvm/test/CodeGen/NVPTX/addrspacecast-folding.ll new file mode 100644 index 0000000000000..87698c1c9644b --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/addrspacecast-folding.ll @@ -0,0 +1,35 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -mcpu=sm_20 -O0 | FileCheck %s +; RUN: %if ptxas %{ llc < %s -mcpu=sm_20 -O0 | %ptxas-verify %} + +target triple = "nvptx64-unknown-unknown" + +define ptr @test1(ptr %p) { +; CHECK-LABEL: test1( +; CHECK: { +; CHECK-NEXT: .reg .b64 %rd<2>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u64 %rd1, [test1_param_0]; +; CHECK-NEXT: st.param.b64 [func_retval0], %rd1; +; CHECK-NEXT: ret; + %a = addrspacecast ptr %p to ptr addrspace(5) + %b = addrspacecast ptr addrspace(5) %a to ptr + ret ptr %b +} + +define ptr addrspace(1) @test2(ptr addrspace(5) %p) { +; CHECK-LABEL: test2( +; CHECK: { +; CHECK-NEXT: .reg .b64 %rd<4>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.u64 %rd1, [test2_param_0]; +; CHECK-NEXT: cvta.local.u64 %rd2, %rd1; +; CHECK-NEXT: cvta.to.global.u64 %rd3, %rd2; +; CHECK-NEXT: st.param.b64 [func_retval0], %rd3; +; CHECK-NEXT: ret; + %a = addrspacecast ptr addrspace(5) %p to ptr + %b = addrspacecast ptr %a to ptr addrspace(1) + ret ptr addrspace(1) %b +}