From 58cc333c6d1e85c4cd6beef8ac4d5ca94e4a733f Mon Sep 17 00:00:00 2001 From: Alex MacLean Date: Mon, 13 May 2024 21:14:43 +0000 Subject: [PATCH 1/2] [NVPTX] fixup support for over-aligned parameters --- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp | 14 ++- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 33 +++--- llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 3 + llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 47 +++++---- llvm/lib/Target/NVPTX/NVPTXUtilities.h | 5 +- llvm/test/CodeGen/NVPTX/param-overalign.ll | 109 ++++++++++++++++++++ 6 files changed, 168 insertions(+), 43 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/param-overalign.ll diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp index 9f31b72bbceb1..dc9377df208d2 100644 --- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp @@ -72,6 +72,7 @@ #include "llvm/MC/MCStreamer.h" #include "llvm/MC/MCSymbol.h" #include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/Alignment.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Endian.h" @@ -370,11 +371,10 @@ void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) { << " func_retval0"; } else if (ShouldPassAsArray(Ty)) { unsigned totalsz = DL.getTypeAllocSize(Ty); - unsigned retAlignment = 0; - if (!getAlign(*F, 0, retAlignment)) - retAlignment = TLI->getFunctionParamOptimizedAlign(F, Ty, DL).value(); - O << ".param .align " << retAlignment << " .b8 func_retval0[" << totalsz - << "]"; + Align RetAlignment = TLI->getFunctionArgumentAlignment( + F, Ty, AttributeList::ReturnIndex, DL); + O << ".param .align " << RetAlignment.value() << " .b8 func_retval0[" + << totalsz << "]"; } else llvm_unreachable("Unknown return type"); } else { @@ -1558,6 +1558,10 @@ void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) { auto getOptimalAlignForParam = [TLI, &DL, &PAL, F, paramIndex](Type *Ty) -> Align { + if (MaybeAlign StackAlign = + getAlign(*F, paramIndex + AttributeList::FirstArgIndex)) + return StackAlign.value(); + Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL); MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex); return std::max(TypeAlign, ParamAlign.valueOrOne()); diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index b03803f52b78e..1e7477cf9d60e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -1434,12 +1434,11 @@ std::string NVPTXTargetLowering::getPrototype( if (!Outs[OIdx].Flags.isByVal()) { if (IsTypePassedAsArray(Ty)) { - unsigned ParamAlign = 0; const CallInst *CallI = cast(&CB); - // +1 because index 0 is reserved for return type alignment - if (!getAlign(*CallI, i + 1, ParamAlign)) - ParamAlign = getFunctionParamOptimizedAlign(F, Ty, DL).value(); - O << ".param .align " << ParamAlign << " .b8 "; + Align ParamAlign = + getAlign(*CallI, i + AttributeList::FirstArgIndex) + .value_or(getFunctionParamOptimizedAlign(F, Ty, DL)); + O << ".param .align " << ParamAlign.value() << " .b8 "; O << "_"; O << "[" << DL.getTypeAllocSize(Ty) << "]"; // update the index for Outs @@ -1489,6 +1488,11 @@ std::string NVPTXTargetLowering::getPrototype( return Prototype; } +Align NVPTXTargetLowering::getFunctionArgumentAlignment( + const Function *F, Type *Ty, unsigned Idx, const DataLayout &DL) const { + return getAlign(*F, Idx).value_or(getFunctionParamOptimizedAlign(F, Ty, DL)); +} + Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx, const DataLayout &DL) const { @@ -1497,7 +1501,6 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty, return DL.getABITypeAlign(Ty); } - unsigned Alignment = 0; const Function *DirectCallee = CB->getCalledFunction(); if (!DirectCallee) { @@ -1507,21 +1510,16 @@ Align NVPTXTargetLowering::getArgumentAlignment(const CallBase *CB, Type *Ty, // With bitcast'd call targets, the instruction will be the call if (const auto *CI = dyn_cast(CB)) { // Check if we have call alignment metadata - if (getAlign(*CI, Idx, Alignment)) - return Align(Alignment); + if (MaybeAlign StackAlign = getAlign(*CI, Idx)) + return StackAlign.value(); } DirectCallee = getMaybeBitcastedCallee(CB); } // Check for function alignment information if we found that the // ultimate target is a Function - if (DirectCallee) { - if (getAlign(*DirectCallee, Idx, Alignment)) - return Align(Alignment); - // If alignment information is not available, fall back to the - // default function param optimized type alignment - return getFunctionParamOptimizedAlign(DirectCallee, Ty, DL); - } + if (DirectCallee) + return getFunctionArgumentAlignment(DirectCallee, Ty, Idx, DL); // Call is indirect, fall back to the ABI type alignment return DL.getABITypeAlign(Ty); @@ -3195,8 +3193,9 @@ SDValue NVPTXTargetLowering::LowerFormalArguments( if (VTs.empty()) report_fatal_error("Empty parameter types are not supported"); - auto VectorInfo = - VectorizePTXValueVTs(VTs, Offsets, DL.getABITypeAlign(Ty)); + Align ArgAlign = getFunctionArgumentAlignment( + F, Ty, i + AttributeList::FirstArgIndex, DL); + auto VectorInfo = VectorizePTXValueVTs(VTs, Offsets, ArgAlign); SDValue Arg = getParamSymbol(DAG, i, PtrVT); int VecIdx = -1; // Index of the first element of the current vector. diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index c9db10e555cef..e211286fcc556 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -462,6 +462,9 @@ class NVPTXTargetLowering : public TargetLowering { MachineFunction &MF, unsigned Intrinsic) const override; + Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx, + const DataLayout &DL) const; + /// getFunctionParamOptimizedAlign - since function arguments are passed via /// .param space, we may want to increase their alignment in a way that /// ensures that we can effectively vectorize their loads & stores. We can diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp index 35302889095f8..80896a5bc4fd9 100644 --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp @@ -19,11 +19,13 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/Support/Alignment.h" #include "llvm/Support/Mutex.h" #include #include #include #include +#include #include #include @@ -296,37 +298,44 @@ bool isKernelFunction(const Function &F) { return (x == 1); } -bool getAlign(const Function &F, unsigned index, unsigned &align) { +MaybeAlign getAlign(const Function &F, unsigned Index) { + // First check the alignstack metadata + if (MaybeAlign AlignStack = + F.getAttributes().getAttributes(Index).getStackAlignment()) + return AlignStack; + + // If that is missing, check the legacy nvvm metadata std::vector Vs; bool retval = findAllNVVMAnnotation(&F, "align", Vs); if (!retval) - return false; - for (unsigned v : Vs) { - if ((v >> 16) == index) { - align = v & 0xFFFF; - return true; - } - } - return false; + return std::nullopt; + for (unsigned V : Vs) + if ((V >> 16) == Index) + return Align(V & 0xFFFF); + + return std::nullopt; } -bool getAlign(const CallInst &I, unsigned index, unsigned &align) { +MaybeAlign getAlign(const CallInst &I, unsigned Index) { + // First check the alignstack metadata + if (MaybeAlign AlignStack = + I.getAttributes().getAttributes(Index).getStackAlignment()) + return AlignStack; + + // If that is missing, check the legacy nvvm metadata if (MDNode *alignNode = I.getMetadata("callalign")) { for (int i = 0, n = alignNode->getNumOperands(); i < n; i++) { if (const ConstantInt *CI = mdconst::dyn_extract(alignNode->getOperand(i))) { - unsigned v = CI->getZExtValue(); - if ((v >> 16) == index) { - align = v & 0xFFFF; - return true; - } - if ((v >> 16) > index) { - return false; - } + unsigned V = CI->getZExtValue(); + if ((V >> 16) == Index) + return Align(V & 0xFFFF); + if ((V >> 16) > Index) + return std::nullopt; } } } - return false; + return std::nullopt; } Function *getMaybeBitcastedCallee(const CallBase *CB) { diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.h b/llvm/lib/Target/NVPTX/NVPTXUtilities.h index 449973bb53de7..2872db9fa2131 100644 --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.h +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.h @@ -18,6 +18,7 @@ #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Value.h" +#include "llvm/Support/Alignment.h" #include #include #include @@ -60,8 +61,8 @@ bool getMinCTASm(const Function &, unsigned &); bool getMaxNReg(const Function &, unsigned &); bool isKernelFunction(const Function &); -bool getAlign(const Function &, unsigned index, unsigned &); -bool getAlign(const CallInst &, unsigned index, unsigned &); +MaybeAlign getAlign(const Function &, unsigned); +MaybeAlign getAlign(const CallInst &, unsigned); Function *getMaybeBitcastedCallee(const CallBase *CB); // PTX ABI requires all scalar argument/return values to have diff --git a/llvm/test/CodeGen/NVPTX/param-overalign.ll b/llvm/test/CodeGen/NVPTX/param-overalign.ll new file mode 100644 index 0000000000000..63e706982f394 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/param-overalign.ll @@ -0,0 +1,109 @@ +; RUN: llc < %s -march=nvptx | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx -verify-machineinstrs | %ptxas-verify %} + +target triple = "nvptx64-nvidia-cuda" + +%struct.float2 = type { float, float } + +; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md +; CHECK-NEXT: ( +; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8] +; CHECK-NEXT: ) +; CHECK-NEXT: ; + +; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee +; CHECK-NEXT: ( +; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8] +; CHECK-NEXT: ) +; CHECK-NEXT: ; + +define float @caller_md(float %a, float %b) { +; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller_md( +; CHECK-NEXT: .param .b32 caller_md_param_0, +; CHECK-NEXT: .param .b32 caller_md_param_1 +; CHECK-NEXT: ) +; CHECK-NEXT: { + +; CHECK: ld.param.f32 %f1, [caller_md_param_0]; +; CHECK-NEXT: ld.param.f32 %f2, [caller_md_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .param .align 8 .b8 param0[8]; +; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2}; +; CHECK-NEXT: .param .b32 retval0; +; CHECK-NEXT: call.uni (retval0), +; CHECK-NEXT: callee_md, +; CHECK-NEXT: ( +; CHECK-NEXT: param0 +; CHECK-NEXT: ); +; CHECK-NEXT: ld.param.f32 %f3, [retval0+0]; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3; +; CHECK-NEXT: ret; + %s1 = insertvalue %struct.float2 poison, float %a, 0 + %s2 = insertvalue %struct.float2 %s1, float %b, 1 + %r = call float @callee_md(%struct.float2 %s2) + ret float %r +} + +define float @callee_md(%struct.float2 %a) { +; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee_md( +; CHECK-NEXT: .param .align 8 .b8 callee_md_param_0[8] +; CHECK-NEXT: ) +; CHECK-NEXT: { + +; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_md_param_0]; +; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2; +; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3; +; CHECK-NEXT: ret; + %v0 = extractvalue %struct.float2 %a, 0 + %v1 = extractvalue %struct.float2 %a, 1 + %2 = fadd float %v0, %v1 + ret float %2 +} + +define float @caller(float %a, float %b) { +; CHECK-LABEL: .visible .func (.param .b32 func_retval0) caller( +; CHECK-NEXT: .param .b32 caller_param_0, +; CHECK-NEXT: .param .b32 caller_param_1 +; CHECK-NEXT: ) +; CHECK-NEXT: { + +; CHECK: ld.param.f32 %f1, [caller_param_0]; +; CHECK-NEXT: ld.param.f32 %f2, [caller_param_1]; +; CHECK-NEXT: { +; CHECK-NEXT: .param .align 8 .b8 param0[8]; +; CHECK-NEXT: st.param.v2.f32 [param0+0], {%f1, %f2}; +; CHECK-NEXT: .param .b32 retval0; +; CHECK-NEXT: call.uni (retval0), +; CHECK-NEXT: callee, +; CHECK-NEXT: ( +; CHECK-NEXT: param0 +; CHECK-NEXT: ); +; CHECK-NEXT: ld.param.f32 %f3, [retval0+0]; +; CHECK-NEXT: } +; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3; +; CHECK-NEXT: ret; + %s1 = insertvalue %struct.float2 poison, float %a, 0 + %s2 = insertvalue %struct.float2 %s1, float %b, 1 + %r = call float @callee(%struct.float2 %s2) + ret float %r +} + +define float @callee(%struct.float2 alignstack(8) %a ) { +; CHECK-LABEL: .visible .func (.param .b32 func_retval0) callee( +; CHECK-NEXT: .param .align 8 .b8 callee_param_0[8] +; CHECK-NEXT: ) +; CHECK-NEXT: { + +; CHECK: ld.param.v2.f32 {%f1, %f2}, [callee_param_0]; +; CHECK-NEXT: add.rn.f32 %f3, %f1, %f2; +; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3; +; CHECK-NEXT: ret; + %v0 = extractvalue %struct.float2 %a, 0 + %v1 = extractvalue %struct.float2 %a, 1 + %2 = fadd float %v0, %v1 + ret float %2 +} + +!nvvm.annotations = !{!0} +!0 = !{ptr @callee_md, !"align", i32 u0x00010008} From 0b60083e6ed5545d668999c2c6316711f3d50add Mon Sep 17 00:00:00 2001 From: Alex MacLean Date: Fri, 17 May 2024 19:18:26 +0000 Subject: [PATCH 2/2] address comments --- llvm/lib/Target/NVPTX/NVPTXUtilities.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp index 80896a5bc4fd9..013afe916e86c 100644 --- a/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXUtilities.cpp @@ -300,9 +300,9 @@ bool isKernelFunction(const Function &F) { MaybeAlign getAlign(const Function &F, unsigned Index) { // First check the alignstack metadata - if (MaybeAlign AlignStack = + if (MaybeAlign StackAlign = F.getAttributes().getAttributes(Index).getStackAlignment()) - return AlignStack; + return StackAlign; // If that is missing, check the legacy nvvm metadata std::vector Vs; @@ -318,9 +318,9 @@ MaybeAlign getAlign(const Function &F, unsigned Index) { MaybeAlign getAlign(const CallInst &I, unsigned Index) { // First check the alignstack metadata - if (MaybeAlign AlignStack = + if (MaybeAlign StackAlign = I.getAttributes().getAttributes(Index).getStackAlignment()) - return AlignStack; + return StackAlign; // If that is missing, check the legacy nvvm metadata if (MDNode *alignNode = I.getMetadata("callalign")) {