-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[HIPSTDPAR] Add handling for math builtins #140158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… HIPSTDPAR runtime component.
@llvm/pr-subscribers-backend-amdgpu Author: Alex Voicu (AlexVlx) ChangesWhen compiling in This patch adds an additional The paired change in the run time component is here <ROCm/rocThrust#551>. Patch is 21.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140158.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
index 5ff38bdf04812..27195051ed7eb 100644
--- a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
+++ b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
@@ -40,6 +40,13 @@ class HipStdParAllocationInterpositionPass
static bool isRequired() { return true; }
};
+class HipStdParMathFixupPass : public PassInfoMixin<HipStdParMathFixupPass> {
+public:
+ PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+
+ static bool isRequired() { return true; }
+};
+
} // namespace llvm
#endif // LLVM_TRANSFORMS_HIPSTDPAR_HIPSTDPAR_H
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 94dabe290213d..3acdbf4d49fde 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -80,6 +80,7 @@ MODULE_PASS("global-merge-func", GlobalMergeFuncPass())
MODULE_PASS("globalopt", GlobalOptPass())
MODULE_PASS("globalsplit", GlobalSplitPass())
MODULE_PASS("hipstdpar-interpose-alloc", HipStdParAllocationInterpositionPass())
+MODULE_PASS("hipstdpar-math-fixup", HipStdParMathFixupPass())
MODULE_PASS("hipstdpar-select-accelerator-code",
HipStdParAcceleratorCodeSelectionPass())
MODULE_PASS("hotcoldsplit", HotColdSplittingPass())
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index ccb251b730f16..c3f8cee1e1783 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -819,8 +819,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
// When we are not using -fgpu-rdc, we can run accelerator code
// selection relatively early, but still after linking to prevent
// eager removal of potentially reachable symbols.
- if (EnableHipStdPar)
+ if (EnableHipStdPar) {
+ PM.addPass(HipStdParMathFixupPass());
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+ }
PM.addPass(AMDGPUPrintfRuntimeBindingPass());
}
@@ -899,8 +901,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
// selection after linking to prevent, otherwise we end up removing
// potentially reachable symbols that were exported as external in other
// modules.
- if (EnableHipStdPar)
+ if (EnableHipStdPar) {
+ PM.addPass(HipStdParMathFixupPass());
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+ }
// We want to support the -lto-partitions=N option as "best effort".
// For that, we need to lower LDS earlier in the pipeline before the
// module is partitioned for codegen.
diff --git a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
index 5a87cf8c83d79..815878089c69e 100644
--- a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
+++ b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
@@ -37,6 +37,16 @@
// memory that ends up in one of the runtime equivalents, since this can
// happen if e.g. a library that was compiled without interposition returns
// an allocation that can be validly passed to `free`.
+//
+// 3. MathFixup (required): Some accelerators might have an incomplete
+// implementation for the intrinsics used to implement some of the math
+// functions in <cmath> / their corresponding libcall lowerings. Since this
+// can vary quite significantly between accelerators, we replace calls to a
+// set of intrinsics / lib functions known to be problematic with calls to a
+// HIPSTDPAR specific forwarding layer, which gives an uniform interface for
+// accelerators to implement in their own runtime components. This pass
+// should run before AcceleratorCodeSelection so as to prevent the spurious
+// removal of the HIPSTDPAR specific forwarding functions.
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/HipStdPar/HipStdPar.h"
@@ -48,6 +58,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -321,3 +332,109 @@ HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
return PreservedAnalyses::none();
}
+
+static constexpr std::pair<StringLiteral, StringLiteral> MathLibToHipStdPar[]{
+ {"acosh", "__hipstdpar_acosh_f64"},
+ {"acoshf", "__hipstdpar_acosh_f32"},
+ {"asinh", "__hipstdpar_asinh_f64"},
+ {"asinhf", "__hipstdpar_asinh_f32"},
+ {"atanh", "__hipstdpar_atanh_f64"},
+ {"atanhf", "__hipstdpar_atanh_f32"},
+ {"cbrt", "__hipstdpar_cbrt_f64"},
+ {"cbrtf", "__hipstdpar_cbrt_f32"},
+ {"erf", "__hipstdpar_erf_f64"},
+ {"erff", "__hipstdpar_erf_f32"},
+ {"erfc", "__hipstdpar_erfc_f64"},
+ {"erfcf", "__hipstdpar_erfc_f32"},
+ {"fdim", "__hipstdpar_fdim_f64"},
+ {"fdimf", "__hipstdpar_fdim_f32"},
+ {"expm1", "__hipstdpar_expm1_f64"},
+ {"expm1f", "__hipstdpar_expm1_f32"},
+ {"hypot", "__hipstdpar_hypot_f64"},
+ {"hypotf", "__hipstdpar_hypot_f32"},
+ {"ilogb", "__hipstdpar_ilogb_f64"},
+ {"ilogbf", "__hipstdpar_ilogb_f32"},
+ {"lgamma", "__hipstdpar_lgamma_f64"},
+ {"lgammaf", "__hipstdpar_lgamma_f32"},
+ {"log1p", "__hipstdpar_log1p_f64"},
+ {"log1pf", "__hipstdpar_log1p_f32"},
+ {"logb", "__hipstdpar_logb_f64"},
+ {"logbf", "__hipstdpar_logb_f32"},
+ {"nextafter", "__hipstdpar_nextafter_f64"},
+ {"nextafterf", "__hipstdpar_nextafter_f32"},
+ {"nexttoward", "__hipstdpar_nexttoward_f64"},
+ {"nexttowardf", "__hipstdpar_nexttoward_f32"},
+ {"remainder", "__hipstdpar_remainder_f64"},
+ {"remainderf", "__hipstdpar_remainder_f32"},
+ {"remquo", "__hipstdpar_remquo_f64"},
+ {"remquof", "__hipstdpar_remquo_f32"},
+ {"scalbln", "__hipstdpar_scalbln_f64"},
+ {"scalblnf", "__hipstdpar_scalbln_f32"},
+ {"scalbn", "__hipstdpar_scalbn_f64"},
+ {"scalbnf", "__hipstdpar_scalbn_f32"},
+ {"tgamma", "__hipstdpar_tgamma_f64"},
+ {"tgammaf", "__hipstdpar_tgamma_f32"}};
+
+PreservedAnalyses HipStdParMathFixupPass::run(Module &M,
+ ModuleAnalysisManager &) {
+ if (M.empty())
+ return PreservedAnalyses::all();
+
+ SmallVector<std::pair<Function *, std::string>> ToReplace;
+ for (auto &&F : M) {
+ if (!F.hasName())
+ continue;
+
+ auto N = F.getName().str();
+ auto ID = F.getIntrinsicID();
+
+ switch (ID) {
+ case Intrinsic::not_intrinsic: {
+ auto It = find_if(MathLibToHipStdPar,
+ [&](auto &&M) { return M.first == N; });
+ if (It == std::cend(MathLibToHipStdPar))
+ continue;
+ ToReplace.emplace_back(&F, It->second);
+ break;
+ }
+ case Intrinsic::acos:
+ case Intrinsic::asin:
+ case Intrinsic::atan:
+ case Intrinsic::atan2:
+ case Intrinsic::cosh:
+ case Intrinsic::modf:
+ case Intrinsic::sinh:
+ case Intrinsic::tan:
+ case Intrinsic::tanh:
+ break;
+ default: {
+ if (F.getReturnType()->isDoubleTy()) {
+ switch (ID) {
+ case Intrinsic::cos:
+ case Intrinsic::exp:
+ case Intrinsic::exp2:
+ case Intrinsic::log:
+ case Intrinsic::log10:
+ case Intrinsic::log2:
+ case Intrinsic::pow:
+ case Intrinsic::sin:
+ break;
+ default:
+ continue;
+ }
+ break;
+ }
+ continue;
+ }
+ }
+
+ llvm::replace(N, '.', '_');
+ N.replace(0, sizeof("llvm"), "__hipstdpar_");
+ ToReplace.emplace_back(&F, std::move(N));
+ }
+ for (auto &&F : ToReplace)
+ F.first->replaceAllUsesWith(M.getOrInsertFunction(
+ F.second, F.first->getFunctionType()).getCallee());
+
+ return PreservedAnalyses::none();
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/HipStdPar/math-fixup.ll b/llvm/test/Transforms/HipStdPar/math-fixup.ll
new file mode 100644
index 0000000000000..e0e2ca79c0843
--- /dev/null
+++ b/llvm/test/Transforms/HipStdPar/math-fixup.ll
@@ -0,0 +1,240 @@
+; RUN: opt -S -passes=hipstdpar-math-fixup %s | FileCheck %s
+
+define void @foo(double noundef %dbl, float noundef %flt, i32 noundef %quo) #0 {
+; CHECK-LABEL: define void @foo(
+; CHECK-SAME: double noundef [[DBL:%.*]], float noundef [[FLT:%.*]], i32 noundef [[QUO:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[QUO_ADDR:%.*]] = alloca i32, align 4
+; CHECK-NEXT: store i32 [[QUO]], ptr [[QUO_ADDR]], align 4
+entry:
+ %quo.addr = alloca i32, align 4
+ store i32 %quo, ptr %quo.addr, align 4
+ ; CHECK-NEXT: [[TMP0:%.*]] = tail call contract double @llvm.fabs.f64(double [[DBL]])
+ %0 = tail call contract double @llvm.fabs.f64(double %dbl)
+ ; CHECK-NEXT: [[TMP1:%.*]] = tail call contract float @llvm.fabs.f32(float [[FLT]])
+ %1 = tail call contract float @llvm.fabs.f32(float %flt)
+ ; CHECK-NEXT: [[CALL:%.*]] = tail call contract double @__hipstdpar_remainder_f64(double noundef [[TMP0]], double noundef [[TMP0]]) #[[ATTR4:[0-9]+]]
+ %call = tail call contract double @remainder(double noundef %0, double noundef %0) #4
+ ; CHECK-NEXT: [[CALL1:%.*]] = tail call contract float @__hipstdpar_remainder_f32(float noundef [[TMP1]], float noundef [[TMP1]]) #[[ATTR4]]
+ %call1 = tail call contract float @remainderf(float noundef %1, float noundef %1) #4
+ ; CHECK-NEXT: [[CALL2:%.*]] = call contract double @__hipstdpar_remquo_f64(double noundef [[CALL]], double noundef [[CALL]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3:[0-9]+]]
+ %call2 = call contract double @remquo(double noundef %call, double noundef %call, ptr noundef nonnull %quo.addr) #5
+ ; CHECK-NEXT: [[CALL3:%.*]] = call contract float @__hipstdpar_remquo_f32(float noundef [[CALL1]], float noundef [[CALL1]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3]]
+ %call3 = call contract float @remquof(float noundef %call1, float noundef %call1, ptr noundef nonnull %quo.addr) #5
+ ; CHECK-NEXT: [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[CALL2]], double [[CALL2]], double [[CALL2]])
+ %2 = call contract double @llvm.fma.f64(double %call2, double %call2, double %call2)
+ ; CHECK-NEXT: [[TMP3:%.*]] = call contract float @llvm.fma.f32(float [[CALL3]], float [[CALL3]], float [[CALL3]])
+ %3 = call contract float @llvm.fma.f32(float %call3, float %call3, float %call3)
+ ; CHECK-NEXT: [[CALL4:%.*]] = call contract double @__hipstdpar_fdim_f64(double noundef [[TMP2]], double noundef [[TMP2]]) #[[ATTR4]]
+ %call4 = call contract double @fdim(double noundef %2, double noundef %2) #4
+ ; CHECK-NEXT: [[CALL5:%.*]] = call contract float @__hipstdpar_fdim_f32(float noundef [[TMP3]], float noundef [[TMP3]]) #[[ATTR4]]
+ %call5 = call contract float @fdimf(float noundef %3, float noundef %3) #4
+ ; CHECK-NEXT: [[TMP4:%.*]] = call contract double @__hipstdpar_exp_f64(double [[CALL4]])
+ %4 = call contract double @llvm.exp.f64(double %call4)
+ ; CHECK-NEXT: [[TMP5:%.*]] = call contract float @llvm.exp.f32(float [[CALL5]])
+ %5 = call contract float @llvm.exp.f32(float %call5)
+ ; CHECK-NEXT: [[TMP6:%.*]] = call contract double @__hipstdpar_exp2_f64(double [[TMP4]])
+ %6 = call contract double @llvm.exp2.f64(double %4)
+ ; CHECK-NEXT: [[TMP7:%.*]] = call contract float @llvm.exp2.f32(float [[TMP5]])
+ %7 = call contract float @llvm.exp2.f32(float %5)
+ ; CHECK-NEXT: [[CALL6:%.*]] = call contract double @__hipstdpar_expm1_f64(double noundef [[TMP6]]) #[[ATTR4]]
+ %call6 = call contract double @expm1(double noundef %6) #4
+ ; CHECK-NEXT: [[CALL7:%.*]] = call contract float @__hipstdpar_expm1_f32(float noundef [[TMP7]]) #[[ATTR4]]
+ %call7 = call contract float @expm1f(float noundef %7) #4
+ ; CHECK-NEXT: [[TMP8:%.*]] = call contract double @__hipstdpar_log_f64(double [[CALL6]])
+ %8 = call contract double @llvm.log.f64(double %call6)
+ ; CHECK-NEXT: [[TMP9:%.*]] = call contract float @llvm.log.f32(float [[CALL7]])
+ %9 = call contract float @llvm.log.f32(float %call7)
+ ; CHECK-NEXT: [[TMP10:%.*]] = call contract double @__hipstdpar_log10_f64(double [[TMP8]])
+ %10 = call contract double @llvm.log10.f64(double %8)
+ ; CHECK-NEXT: [[TMP11:%.*]] = call contract float @llvm.log10.f32(float [[TMP9]])
+ %11 = call contract float @llvm.log10.f32(float %9)
+ ; CHECK-NEXT: [[TMP12:%.*]] = call contract double @__hipstdpar_log2_f64(double [[TMP10]])
+ %12 = call contract double @llvm.log2.f64(double %10)
+ ; CHECK-NEXT: [[TMP13:%.*]] = call contract float @llvm.log2.f32(float [[TMP11]])
+ %13 = call contract float @llvm.log2.f32(float %11)
+ ; CHECK-NEXT: [[CALL8:%.*]] = call contract double @__hipstdpar_log1p_f64(double noundef [[TMP12]]) #[[ATTR4]]
+ %call8 = call contract double @log1p(double noundef %12) #4
+ ; CHECK-NEXT: [[CALL9:%.*]] = call contract float @__hipstdpar_log1p_f32(float noundef [[TMP13]]) #[[ATTR4]]
+ %call9 = call contract float @log1pf(float noundef %13) #4
+ ; CHECK-NEXT: [[TMP14:%.*]] = call contract float @llvm.pow.f32(float [[CALL9]], float [[CALL9]])
+ %14 = call contract float @llvm.pow.f32(float %call9, float %call9)
+ ; CHECK-NEXT: [[TMP15:%.*]] = call contract double @llvm.sqrt.f64(double [[CALL8]])
+ %15 = call contract double @llvm.sqrt.f64(double %call8)
+ ; CHECK-NEXT: [[TMP16:%.*]] = call contract float @llvm.sqrt.f32(float [[TMP14]])
+ %16 = call contract float @llvm.sqrt.f32(float %14)
+ ; CHECK-NEXT: [[CALL10:%.*]] = call contract double @__hipstdpar_cbrt_f64(double noundef [[TMP15]]) #[[ATTR4]]
+ %call10 = call contract double @cbrt(double noundef %15) #4
+ ; CHECK-NEXT: [[CALL11:%.*]] = call contract float @__hipstdpar_cbrt_f32(float noundef [[TMP16]]) #[[ATTR4]]
+ %call11 = call contract float @cbrtf(float noundef %16) #4
+ ; CHECK-NEXT: [[CALL12:%.*]] = call contract double @__hipstdpar_hypot_f64(double noundef [[CALL10]], double noundef [[CALL10]]) #[[ATTR4]]
+ %call12 = call contract double @hypot(double noundef %call10, double noundef %call10) #4
+ ; CHECK-NEXT: [[CALL13:%.*]] = call contract float @__hipstdpar_hypot_f32(float noundef [[CALL11]], float noundef [[CALL11]]) #[[ATTR4]]
+ %call13 = call contract float @hypotf(float noundef %call11, float noundef %call11) #4
+ ; CHECK-NEXT: [[TMP17:%.*]] = call contract float @llvm.sin.f32(float [[CALL13]])
+ %17 = call contract float @llvm.sin.f32(float %call13)
+ ; CHECK-NEXT: [[TMP18:%.*]] = call contract float @llvm.cos.f32(float [[TMP17]])
+ %18 = call contract float @llvm.cos.f32(float %17)
+ ; CHECK-NEXT: [[TMP19:%.*]] = call contract double @__hipstdpar_tan_f64(double [[CALL12]])
+ %19 = call contract double @llvm.tan.f64(double %call12)
+ ; CHECK-NEXT: [[TMP20:%.*]] = call contract double @__hipstdpar_asin_f64(double [[TMP19]])
+ %20 = call contract double @llvm.asin.f64(double %19)
+ ; CHECK-NEXT: [[TMP21:%.*]] = call contract double @__hipstdpar_acos_f64(double [[TMP20]])
+ %21 = call contract double @llvm.acos.f64(double %20)
+ ; CHECK-NEXT: [[TMP22:%.*]] = call contract double @__hipstdpar_atan_f64(double [[TMP21]])
+ %22 = call contract double @llvm.atan.f64(double %21)
+ ; CHECK-NEXT: [[TMP23:%.*]] = call contract double @__hipstdpar_atan2_f64(double [[TMP22]], double [[TMP22]])
+ %23 = call contract double @llvm.atan2.f64(double %22, double %22)
+ ; CHECK-NEXT: [[TMP24:%.*]] = call contract double @__hipstdpar_sinh_f64(double [[TMP23]])
+ %24 = call contract double @llvm.sinh.f64(double %23)
+ ; CHECK-NEXT: [[TMP25:%.*]] = call contract double @__hipstdpar_cosh_f64(double [[TMP24]])
+ %25 = call contract double @llvm.cosh.f64(double %24)
+ ; CHECK-NEXT: [[TMP26:%.*]] = call contract double @__hipstdpar_tanh_f64(double [[TMP25]])
+ %26 = call contract double @llvm.tanh.f64(double %25)
+ ; CHECK-NEXT: [[CALL14:%.*]] = call contract double @__hipstdpar_asinh_f64(double noundef [[TMP26]]) #[[ATTR4]]
+ %call14 = call contract double @asinh(double noundef %26) #4
+ ; CHECK-NEXT: [[CALL15:%.*]] = call contract float @__hipstdpar_asinh_f32(float noundef [[TMP18]]) #[[ATTR4]]
+ %call15 = call contract float @asinhf(float noundef %18) #4
+ ; CHECK-NEXT: [[CALL16:%.*]] = call contract double @__hipstdpar_acosh_f64(double noundef [[CALL14]]) #[[ATTR4]]
+ %call16 = call contract double @acosh(double noundef %call14) #4
+ ; CHECK-NEXT: [[CALL17:%.*]] = call contract float @__hipstdpar_acosh_f32(float noundef [[CALL15]]) #[[ATTR4]]
+ %call17 = call contract float @acoshf(float noundef %call15) #4
+ ; CHECK-NEXT: [[CALL18:%.*]] = call contract double @__hipstdpar_atanh_f64(double noundef [[CALL16]]) #[[ATTR4]]
+ %call18 = call contract double @atanh(double noundef %call16) #4
+ ; CHECK-NEXT: [[CALL19:%.*]] = call contract float @__hipstdpar_atanh_f32(float noundef [[CALL17]]) #[[ATTR4]]
+ %call19 = call contract float @atanhf(float noundef %call17) #4
+ ; CHECK-NEXT: [[CALL20:%.*]] = call contract double @__hipstdpar_erf_f64(double noundef [[CALL18]]) #[[ATTR4]]
+ %call20 = call contract double @erf(double noundef %call18) #4
+ ; CHECK-NEXT: [[CALL21:%.*]] = call contract float @__hipstdpar_erf_f32(float noundef [[CALL19]]) #[[ATTR4]]
+ %call21 = call contract float @erff(float noundef %call19) #4
+ ; CHECK-NEXT: [[CALL22:%.*]] = call contract double @__hipstdpar_erfc_f64(double noundef [[CALL20]]) #[[ATTR4]]
+ %call22 = call contract double @erfc(double noundef %call20) #4
+ ; CHECK-NEXT: [[CALL23:%.*]] = call contract float @__hipstdpar_erfc_f32(float noundef [[CALL21]]) #[[ATTR4]]
+ %call23 = call contract float @erfcf(float noundef %call21) #4
+ ; CHECK-NEXT: [[CALL24:%.*]] = call contract double @__hipstdpar_tgamma_f64(double noundef [[CALL22]]) #[[ATTR4]]
+ %call24 = call contract double @tgamma(double noundef %call22) #4
+ ; CHECK-NEXT: [[CALL25:%.*]] = call contract float @__hipstdpar_tgamma_f32(float noundef [[CALL23]]) #[[ATTR4]]
+ %call25 = call contract float @tgammaf(float noundef %call23) #4
+ ; CHECK-NEXT: [[CALL26:%.*]] = call contract double @__hipstdpar_lgamma_f64(double noundef [[CALL24]]) #[[ATTR3]]
+ %call26 = call contract double @lgamma(double noundef %call24) #5
+ ; CHECK-NEXT: [[CALL27:%.*]] = call contract float @__hipstdpar_lgamma_f32(float noundef [[CALL25]]) #[[ATTR3]]
+ %call27 = call contract float @lgammaf(float noundef %call25) #5
+ ret void
+}
+
+declare double @llvm.fabs.f64(double) #1
+
+declare float @llvm.fabs.f32(float) #1
+
+declare hidden double @remainder(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @remainderf(float noundef, float noundef) local_unnamed_addr #2
+
+declare hidden double @remquo(double noundef, double noundef, ptr noundef) local_unnamed_addr #3
+
+declare hidden float @remquof(float noundef, float noundef, ptr noundef) local_unnamed_addr #3
+
+declare double @llvm.fma.f64(double, double, double) #1
+
+declare float @llvm.fma.f32(float, float, float) #1
+
+declare hidden double @fdim(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @fdimf(float noundef, float noundef) local_unnamed_addr #2
+
+declare double @llvm.exp.f64(double) #1
+
+declare float @llvm.exp.f32(float) #1
+
+declare double @llvm.exp2.f64(double) #1
+
+declare float @llvm.exp2.f32(float) #1
+
+declare hidden double @expm1(double noundef) local_unnamed_addr #2
+
+declare hidden float @expm1f(float noundef) local_unnamed_addr #2
+
+declare double @llvm.log.f64(double) #1
+
+declare float @llvm.log.f32(float) #1
+
+declare double @llvm.log10.f64(double) #1
+
+declare float @llvm.log10.f32(float) #1
+
+declare double @llvm.log2.f64(double) #1
+
+declare float @llvm.log2.f32(float) #1
+
+declare hidden double @log1p(double noundef) local_unnamed_addr #2
+
+declare hidden float @log1pf(float noundef) local_unnamed_addr #2
+
+declare float @llvm.pow.f32(float, float) #1
+
+declare double @llvm.sqrt.f64(double) #1
+
+declare float @llvm.sqrt.f32(float) #1
+
+declare hidden double @cbrt(double noundef) local_unnamed_addr #2
+
+declare hidden float @cbrtf(float noundef) local_unnamed_addr #2
+
+declare hidden double @hypot(double noundef...
[truncated]
|
@llvm/pr-subscribers-llvm-transforms Author: Alex Voicu (AlexVlx) ChangesWhen compiling in This patch adds an additional The paired change in the run time component is here <ROCm/rocThrust#551>. Patch is 21.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140158.diff 5 Files Affected:
diff --git a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
index 5ff38bdf04812..27195051ed7eb 100644
--- a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
+++ b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
@@ -40,6 +40,13 @@ class HipStdParAllocationInterpositionPass
static bool isRequired() { return true; }
};
+class HipStdParMathFixupPass : public PassInfoMixin<HipStdParMathFixupPass> {
+public:
+ PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+
+ static bool isRequired() { return true; }
+};
+
} // namespace llvm
#endif // LLVM_TRANSFORMS_HIPSTDPAR_HIPSTDPAR_H
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 94dabe290213d..3acdbf4d49fde 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -80,6 +80,7 @@ MODULE_PASS("global-merge-func", GlobalMergeFuncPass())
MODULE_PASS("globalopt", GlobalOptPass())
MODULE_PASS("globalsplit", GlobalSplitPass())
MODULE_PASS("hipstdpar-interpose-alloc", HipStdParAllocationInterpositionPass())
+MODULE_PASS("hipstdpar-math-fixup", HipStdParMathFixupPass())
MODULE_PASS("hipstdpar-select-accelerator-code",
HipStdParAcceleratorCodeSelectionPass())
MODULE_PASS("hotcoldsplit", HotColdSplittingPass())
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index ccb251b730f16..c3f8cee1e1783 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -819,8 +819,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
// When we are not using -fgpu-rdc, we can run accelerator code
// selection relatively early, but still after linking to prevent
// eager removal of potentially reachable symbols.
- if (EnableHipStdPar)
+ if (EnableHipStdPar) {
+ PM.addPass(HipStdParMathFixupPass());
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+ }
PM.addPass(AMDGPUPrintfRuntimeBindingPass());
}
@@ -899,8 +901,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
// selection after linking to prevent, otherwise we end up removing
// potentially reachable symbols that were exported as external in other
// modules.
- if (EnableHipStdPar)
+ if (EnableHipStdPar) {
+ PM.addPass(HipStdParMathFixupPass());
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+ }
// We want to support the -lto-partitions=N option as "best effort".
// For that, we need to lower LDS earlier in the pipeline before the
// module is partitioned for codegen.
diff --git a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
index 5a87cf8c83d79..815878089c69e 100644
--- a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
+++ b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
@@ -37,6 +37,16 @@
// memory that ends up in one of the runtime equivalents, since this can
// happen if e.g. a library that was compiled without interposition returns
// an allocation that can be validly passed to `free`.
+//
+// 3. MathFixup (required): Some accelerators might have an incomplete
+// implementation for the intrinsics used to implement some of the math
+// functions in <cmath> / their corresponding libcall lowerings. Since this
+// can vary quite significantly between accelerators, we replace calls to a
+// set of intrinsics / lib functions known to be problematic with calls to a
+// HIPSTDPAR specific forwarding layer, which gives an uniform interface for
+// accelerators to implement in their own runtime components. This pass
+// should run before AcceleratorCodeSelection so as to prevent the spurious
+// removal of the HIPSTDPAR specific forwarding functions.
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/HipStdPar/HipStdPar.h"
@@ -48,6 +58,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -321,3 +332,109 @@ HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
return PreservedAnalyses::none();
}
+
+static constexpr std::pair<StringLiteral, StringLiteral> MathLibToHipStdPar[]{
+ {"acosh", "__hipstdpar_acosh_f64"},
+ {"acoshf", "__hipstdpar_acosh_f32"},
+ {"asinh", "__hipstdpar_asinh_f64"},
+ {"asinhf", "__hipstdpar_asinh_f32"},
+ {"atanh", "__hipstdpar_atanh_f64"},
+ {"atanhf", "__hipstdpar_atanh_f32"},
+ {"cbrt", "__hipstdpar_cbrt_f64"},
+ {"cbrtf", "__hipstdpar_cbrt_f32"},
+ {"erf", "__hipstdpar_erf_f64"},
+ {"erff", "__hipstdpar_erf_f32"},
+ {"erfc", "__hipstdpar_erfc_f64"},
+ {"erfcf", "__hipstdpar_erfc_f32"},
+ {"fdim", "__hipstdpar_fdim_f64"},
+ {"fdimf", "__hipstdpar_fdim_f32"},
+ {"expm1", "__hipstdpar_expm1_f64"},
+ {"expm1f", "__hipstdpar_expm1_f32"},
+ {"hypot", "__hipstdpar_hypot_f64"},
+ {"hypotf", "__hipstdpar_hypot_f32"},
+ {"ilogb", "__hipstdpar_ilogb_f64"},
+ {"ilogbf", "__hipstdpar_ilogb_f32"},
+ {"lgamma", "__hipstdpar_lgamma_f64"},
+ {"lgammaf", "__hipstdpar_lgamma_f32"},
+ {"log1p", "__hipstdpar_log1p_f64"},
+ {"log1pf", "__hipstdpar_log1p_f32"},
+ {"logb", "__hipstdpar_logb_f64"},
+ {"logbf", "__hipstdpar_logb_f32"},
+ {"nextafter", "__hipstdpar_nextafter_f64"},
+ {"nextafterf", "__hipstdpar_nextafter_f32"},
+ {"nexttoward", "__hipstdpar_nexttoward_f64"},
+ {"nexttowardf", "__hipstdpar_nexttoward_f32"},
+ {"remainder", "__hipstdpar_remainder_f64"},
+ {"remainderf", "__hipstdpar_remainder_f32"},
+ {"remquo", "__hipstdpar_remquo_f64"},
+ {"remquof", "__hipstdpar_remquo_f32"},
+ {"scalbln", "__hipstdpar_scalbln_f64"},
+ {"scalblnf", "__hipstdpar_scalbln_f32"},
+ {"scalbn", "__hipstdpar_scalbn_f64"},
+ {"scalbnf", "__hipstdpar_scalbn_f32"},
+ {"tgamma", "__hipstdpar_tgamma_f64"},
+ {"tgammaf", "__hipstdpar_tgamma_f32"}};
+
+PreservedAnalyses HipStdParMathFixupPass::run(Module &M,
+ ModuleAnalysisManager &) {
+ if (M.empty())
+ return PreservedAnalyses::all();
+
+ SmallVector<std::pair<Function *, std::string>> ToReplace;
+ for (auto &&F : M) {
+ if (!F.hasName())
+ continue;
+
+ auto N = F.getName().str();
+ auto ID = F.getIntrinsicID();
+
+ switch (ID) {
+ case Intrinsic::not_intrinsic: {
+ auto It = find_if(MathLibToHipStdPar,
+ [&](auto &&M) { return M.first == N; });
+ if (It == std::cend(MathLibToHipStdPar))
+ continue;
+ ToReplace.emplace_back(&F, It->second);
+ break;
+ }
+ case Intrinsic::acos:
+ case Intrinsic::asin:
+ case Intrinsic::atan:
+ case Intrinsic::atan2:
+ case Intrinsic::cosh:
+ case Intrinsic::modf:
+ case Intrinsic::sinh:
+ case Intrinsic::tan:
+ case Intrinsic::tanh:
+ break;
+ default: {
+ if (F.getReturnType()->isDoubleTy()) {
+ switch (ID) {
+ case Intrinsic::cos:
+ case Intrinsic::exp:
+ case Intrinsic::exp2:
+ case Intrinsic::log:
+ case Intrinsic::log10:
+ case Intrinsic::log2:
+ case Intrinsic::pow:
+ case Intrinsic::sin:
+ break;
+ default:
+ continue;
+ }
+ break;
+ }
+ continue;
+ }
+ }
+
+ llvm::replace(N, '.', '_');
+ N.replace(0, sizeof("llvm"), "__hipstdpar_");
+ ToReplace.emplace_back(&F, std::move(N));
+ }
+ for (auto &&F : ToReplace)
+ F.first->replaceAllUsesWith(M.getOrInsertFunction(
+ F.second, F.first->getFunctionType()).getCallee());
+
+ return PreservedAnalyses::none();
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/HipStdPar/math-fixup.ll b/llvm/test/Transforms/HipStdPar/math-fixup.ll
new file mode 100644
index 0000000000000..e0e2ca79c0843
--- /dev/null
+++ b/llvm/test/Transforms/HipStdPar/math-fixup.ll
@@ -0,0 +1,240 @@
+; RUN: opt -S -passes=hipstdpar-math-fixup %s | FileCheck %s
+
+define void @foo(double noundef %dbl, float noundef %flt, i32 noundef %quo) #0 {
+; CHECK-LABEL: define void @foo(
+; CHECK-SAME: double noundef [[DBL:%.*]], float noundef [[FLT:%.*]], i32 noundef [[QUO:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT: [[ENTRY:.*:]]
+; CHECK-NEXT: [[QUO_ADDR:%.*]] = alloca i32, align 4
+; CHECK-NEXT: store i32 [[QUO]], ptr [[QUO_ADDR]], align 4
+entry:
+ %quo.addr = alloca i32, align 4
+ store i32 %quo, ptr %quo.addr, align 4
+ ; CHECK-NEXT: [[TMP0:%.*]] = tail call contract double @llvm.fabs.f64(double [[DBL]])
+ %0 = tail call contract double @llvm.fabs.f64(double %dbl)
+ ; CHECK-NEXT: [[TMP1:%.*]] = tail call contract float @llvm.fabs.f32(float [[FLT]])
+ %1 = tail call contract float @llvm.fabs.f32(float %flt)
+ ; CHECK-NEXT: [[CALL:%.*]] = tail call contract double @__hipstdpar_remainder_f64(double noundef [[TMP0]], double noundef [[TMP0]]) #[[ATTR4:[0-9]+]]
+ %call = tail call contract double @remainder(double noundef %0, double noundef %0) #4
+ ; CHECK-NEXT: [[CALL1:%.*]] = tail call contract float @__hipstdpar_remainder_f32(float noundef [[TMP1]], float noundef [[TMP1]]) #[[ATTR4]]
+ %call1 = tail call contract float @remainderf(float noundef %1, float noundef %1) #4
+ ; CHECK-NEXT: [[CALL2:%.*]] = call contract double @__hipstdpar_remquo_f64(double noundef [[CALL]], double noundef [[CALL]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3:[0-9]+]]
+ %call2 = call contract double @remquo(double noundef %call, double noundef %call, ptr noundef nonnull %quo.addr) #5
+ ; CHECK-NEXT: [[CALL3:%.*]] = call contract float @__hipstdpar_remquo_f32(float noundef [[CALL1]], float noundef [[CALL1]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3]]
+ %call3 = call contract float @remquof(float noundef %call1, float noundef %call1, ptr noundef nonnull %quo.addr) #5
+ ; CHECK-NEXT: [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[CALL2]], double [[CALL2]], double [[CALL2]])
+ %2 = call contract double @llvm.fma.f64(double %call2, double %call2, double %call2)
+ ; CHECK-NEXT: [[TMP3:%.*]] = call contract float @llvm.fma.f32(float [[CALL3]], float [[CALL3]], float [[CALL3]])
+ %3 = call contract float @llvm.fma.f32(float %call3, float %call3, float %call3)
+ ; CHECK-NEXT: [[CALL4:%.*]] = call contract double @__hipstdpar_fdim_f64(double noundef [[TMP2]], double noundef [[TMP2]]) #[[ATTR4]]
+ %call4 = call contract double @fdim(double noundef %2, double noundef %2) #4
+ ; CHECK-NEXT: [[CALL5:%.*]] = call contract float @__hipstdpar_fdim_f32(float noundef [[TMP3]], float noundef [[TMP3]]) #[[ATTR4]]
+ %call5 = call contract float @fdimf(float noundef %3, float noundef %3) #4
+ ; CHECK-NEXT: [[TMP4:%.*]] = call contract double @__hipstdpar_exp_f64(double [[CALL4]])
+ %4 = call contract double @llvm.exp.f64(double %call4)
+ ; CHECK-NEXT: [[TMP5:%.*]] = call contract float @llvm.exp.f32(float [[CALL5]])
+ %5 = call contract float @llvm.exp.f32(float %call5)
+ ; CHECK-NEXT: [[TMP6:%.*]] = call contract double @__hipstdpar_exp2_f64(double [[TMP4]])
+ %6 = call contract double @llvm.exp2.f64(double %4)
+ ; CHECK-NEXT: [[TMP7:%.*]] = call contract float @llvm.exp2.f32(float [[TMP5]])
+ %7 = call contract float @llvm.exp2.f32(float %5)
+ ; CHECK-NEXT: [[CALL6:%.*]] = call contract double @__hipstdpar_expm1_f64(double noundef [[TMP6]]) #[[ATTR4]]
+ %call6 = call contract double @expm1(double noundef %6) #4
+ ; CHECK-NEXT: [[CALL7:%.*]] = call contract float @__hipstdpar_expm1_f32(float noundef [[TMP7]]) #[[ATTR4]]
+ %call7 = call contract float @expm1f(float noundef %7) #4
+ ; CHECK-NEXT: [[TMP8:%.*]] = call contract double @__hipstdpar_log_f64(double [[CALL6]])
+ %8 = call contract double @llvm.log.f64(double %call6)
+ ; CHECK-NEXT: [[TMP9:%.*]] = call contract float @llvm.log.f32(float [[CALL7]])
+ %9 = call contract float @llvm.log.f32(float %call7)
+ ; CHECK-NEXT: [[TMP10:%.*]] = call contract double @__hipstdpar_log10_f64(double [[TMP8]])
+ %10 = call contract double @llvm.log10.f64(double %8)
+ ; CHECK-NEXT: [[TMP11:%.*]] = call contract float @llvm.log10.f32(float [[TMP9]])
+ %11 = call contract float @llvm.log10.f32(float %9)
+ ; CHECK-NEXT: [[TMP12:%.*]] = call contract double @__hipstdpar_log2_f64(double [[TMP10]])
+ %12 = call contract double @llvm.log2.f64(double %10)
+ ; CHECK-NEXT: [[TMP13:%.*]] = call contract float @llvm.log2.f32(float [[TMP11]])
+ %13 = call contract float @llvm.log2.f32(float %11)
+ ; CHECK-NEXT: [[CALL8:%.*]] = call contract double @__hipstdpar_log1p_f64(double noundef [[TMP12]]) #[[ATTR4]]
+ %call8 = call contract double @log1p(double noundef %12) #4
+ ; CHECK-NEXT: [[CALL9:%.*]] = call contract float @__hipstdpar_log1p_f32(float noundef [[TMP13]]) #[[ATTR4]]
+ %call9 = call contract float @log1pf(float noundef %13) #4
+ ; CHECK-NEXT: [[TMP14:%.*]] = call contract float @llvm.pow.f32(float [[CALL9]], float [[CALL9]])
+ %14 = call contract float @llvm.pow.f32(float %call9, float %call9)
+ ; CHECK-NEXT: [[TMP15:%.*]] = call contract double @llvm.sqrt.f64(double [[CALL8]])
+ %15 = call contract double @llvm.sqrt.f64(double %call8)
+ ; CHECK-NEXT: [[TMP16:%.*]] = call contract float @llvm.sqrt.f32(float [[TMP14]])
+ %16 = call contract float @llvm.sqrt.f32(float %14)
+ ; CHECK-NEXT: [[CALL10:%.*]] = call contract double @__hipstdpar_cbrt_f64(double noundef [[TMP15]]) #[[ATTR4]]
+ %call10 = call contract double @cbrt(double noundef %15) #4
+ ; CHECK-NEXT: [[CALL11:%.*]] = call contract float @__hipstdpar_cbrt_f32(float noundef [[TMP16]]) #[[ATTR4]]
+ %call11 = call contract float @cbrtf(float noundef %16) #4
+ ; CHECK-NEXT: [[CALL12:%.*]] = call contract double @__hipstdpar_hypot_f64(double noundef [[CALL10]], double noundef [[CALL10]]) #[[ATTR4]]
+ %call12 = call contract double @hypot(double noundef %call10, double noundef %call10) #4
+ ; CHECK-NEXT: [[CALL13:%.*]] = call contract float @__hipstdpar_hypot_f32(float noundef [[CALL11]], float noundef [[CALL11]]) #[[ATTR4]]
+ %call13 = call contract float @hypotf(float noundef %call11, float noundef %call11) #4
+ ; CHECK-NEXT: [[TMP17:%.*]] = call contract float @llvm.sin.f32(float [[CALL13]])
+ %17 = call contract float @llvm.sin.f32(float %call13)
+ ; CHECK-NEXT: [[TMP18:%.*]] = call contract float @llvm.cos.f32(float [[TMP17]])
+ %18 = call contract float @llvm.cos.f32(float %17)
+ ; CHECK-NEXT: [[TMP19:%.*]] = call contract double @__hipstdpar_tan_f64(double [[CALL12]])
+ %19 = call contract double @llvm.tan.f64(double %call12)
+ ; CHECK-NEXT: [[TMP20:%.*]] = call contract double @__hipstdpar_asin_f64(double [[TMP19]])
+ %20 = call contract double @llvm.asin.f64(double %19)
+ ; CHECK-NEXT: [[TMP21:%.*]] = call contract double @__hipstdpar_acos_f64(double [[TMP20]])
+ %21 = call contract double @llvm.acos.f64(double %20)
+ ; CHECK-NEXT: [[TMP22:%.*]] = call contract double @__hipstdpar_atan_f64(double [[TMP21]])
+ %22 = call contract double @llvm.atan.f64(double %21)
+ ; CHECK-NEXT: [[TMP23:%.*]] = call contract double @__hipstdpar_atan2_f64(double [[TMP22]], double [[TMP22]])
+ %23 = call contract double @llvm.atan2.f64(double %22, double %22)
+ ; CHECK-NEXT: [[TMP24:%.*]] = call contract double @__hipstdpar_sinh_f64(double [[TMP23]])
+ %24 = call contract double @llvm.sinh.f64(double %23)
+ ; CHECK-NEXT: [[TMP25:%.*]] = call contract double @__hipstdpar_cosh_f64(double [[TMP24]])
+ %25 = call contract double @llvm.cosh.f64(double %24)
+ ; CHECK-NEXT: [[TMP26:%.*]] = call contract double @__hipstdpar_tanh_f64(double [[TMP25]])
+ %26 = call contract double @llvm.tanh.f64(double %25)
+ ; CHECK-NEXT: [[CALL14:%.*]] = call contract double @__hipstdpar_asinh_f64(double noundef [[TMP26]]) #[[ATTR4]]
+ %call14 = call contract double @asinh(double noundef %26) #4
+ ; CHECK-NEXT: [[CALL15:%.*]] = call contract float @__hipstdpar_asinh_f32(float noundef [[TMP18]]) #[[ATTR4]]
+ %call15 = call contract float @asinhf(float noundef %18) #4
+ ; CHECK-NEXT: [[CALL16:%.*]] = call contract double @__hipstdpar_acosh_f64(double noundef [[CALL14]]) #[[ATTR4]]
+ %call16 = call contract double @acosh(double noundef %call14) #4
+ ; CHECK-NEXT: [[CALL17:%.*]] = call contract float @__hipstdpar_acosh_f32(float noundef [[CALL15]]) #[[ATTR4]]
+ %call17 = call contract float @acoshf(float noundef %call15) #4
+ ; CHECK-NEXT: [[CALL18:%.*]] = call contract double @__hipstdpar_atanh_f64(double noundef [[CALL16]]) #[[ATTR4]]
+ %call18 = call contract double @atanh(double noundef %call16) #4
+ ; CHECK-NEXT: [[CALL19:%.*]] = call contract float @__hipstdpar_atanh_f32(float noundef [[CALL17]]) #[[ATTR4]]
+ %call19 = call contract float @atanhf(float noundef %call17) #4
+ ; CHECK-NEXT: [[CALL20:%.*]] = call contract double @__hipstdpar_erf_f64(double noundef [[CALL18]]) #[[ATTR4]]
+ %call20 = call contract double @erf(double noundef %call18) #4
+ ; CHECK-NEXT: [[CALL21:%.*]] = call contract float @__hipstdpar_erf_f32(float noundef [[CALL19]]) #[[ATTR4]]
+ %call21 = call contract float @erff(float noundef %call19) #4
+ ; CHECK-NEXT: [[CALL22:%.*]] = call contract double @__hipstdpar_erfc_f64(double noundef [[CALL20]]) #[[ATTR4]]
+ %call22 = call contract double @erfc(double noundef %call20) #4
+ ; CHECK-NEXT: [[CALL23:%.*]] = call contract float @__hipstdpar_erfc_f32(float noundef [[CALL21]]) #[[ATTR4]]
+ %call23 = call contract float @erfcf(float noundef %call21) #4
+ ; CHECK-NEXT: [[CALL24:%.*]] = call contract double @__hipstdpar_tgamma_f64(double noundef [[CALL22]]) #[[ATTR4]]
+ %call24 = call contract double @tgamma(double noundef %call22) #4
+ ; CHECK-NEXT: [[CALL25:%.*]] = call contract float @__hipstdpar_tgamma_f32(float noundef [[CALL23]]) #[[ATTR4]]
+ %call25 = call contract float @tgammaf(float noundef %call23) #4
+ ; CHECK-NEXT: [[CALL26:%.*]] = call contract double @__hipstdpar_lgamma_f64(double noundef [[CALL24]]) #[[ATTR3]]
+ %call26 = call contract double @lgamma(double noundef %call24) #5
+ ; CHECK-NEXT: [[CALL27:%.*]] = call contract float @__hipstdpar_lgamma_f32(float noundef [[CALL25]]) #[[ATTR3]]
+ %call27 = call contract float @lgammaf(float noundef %call25) #5
+ ret void
+}
+
+declare double @llvm.fabs.f64(double) #1
+
+declare float @llvm.fabs.f32(float) #1
+
+declare hidden double @remainder(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @remainderf(float noundef, float noundef) local_unnamed_addr #2
+
+declare hidden double @remquo(double noundef, double noundef, ptr noundef) local_unnamed_addr #3
+
+declare hidden float @remquof(float noundef, float noundef, ptr noundef) local_unnamed_addr #3
+
+declare double @llvm.fma.f64(double, double, double) #1
+
+declare float @llvm.fma.f32(float, float, float) #1
+
+declare hidden double @fdim(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @fdimf(float noundef, float noundef) local_unnamed_addr #2
+
+declare double @llvm.exp.f64(double) #1
+
+declare float @llvm.exp.f32(float) #1
+
+declare double @llvm.exp2.f64(double) #1
+
+declare float @llvm.exp2.f32(float) #1
+
+declare hidden double @expm1(double noundef) local_unnamed_addr #2
+
+declare hidden float @expm1f(float noundef) local_unnamed_addr #2
+
+declare double @llvm.log.f64(double) #1
+
+declare float @llvm.log.f32(float) #1
+
+declare double @llvm.log10.f64(double) #1
+
+declare float @llvm.log10.f32(float) #1
+
+declare double @llvm.log2.f64(double) #1
+
+declare float @llvm.log2.f32(float) #1
+
+declare hidden double @log1p(double noundef) local_unnamed_addr #2
+
+declare hidden float @log1pf(float noundef) local_unnamed_addr #2
+
+declare float @llvm.pow.f32(float, float) #1
+
+declare double @llvm.sqrt.f64(double) #1
+
+declare float @llvm.sqrt.f32(float) #1
+
+declare hidden double @cbrt(double noundef) local_unnamed_addr #2
+
+declare hidden float @cbrtf(float noundef) local_unnamed_addr #2
+
+declare hidden double @hypot(double noundef...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
return PreservedAnalyses::all(); | ||
|
||
SmallVector<std::pair<Function *, std::string>> ToReplace; | ||
for (auto &&F : M) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why &&
?
No auto
as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it's a forwarding ref (NOT an rvalue ref), and this is somewhat idiomatic usage in range-for loops. It's also used elsewhere in this file already, so it's self consistent.
} | ||
} | ||
|
||
ToReplace.emplace_back(&F, std::move(N)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: no need to use std::move
here.
default: { | ||
if (F.getReturnType()->isDoubleTy()) { | ||
switch (ID) { | ||
case Intrinsic::cos: | ||
case Intrinsic::exp: | ||
case Intrinsic::exp2: | ||
case Intrinsic::log: | ||
case Intrinsic::log10: | ||
case Intrinsic::log2: | ||
case Intrinsic::pow: | ||
case Intrinsic::sin: | ||
break; | ||
default: | ||
continue; | ||
} | ||
break; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure why not just take the inner switch out and merge it with the outer one? I don't think they are semantically different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are only missing for FP64, it'd have to check per case, rather than once per switch. I don't know that one is significantly more readable than the other.
StringRef Prefix = "llvm"; | ||
ToReplace.back().second.replace(0, Prefix.size(), "__hipstdpar"); | ||
} | ||
for (auto &&F : ToReplace) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use structured binding here ? F
is not a function but a pair.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We most definitely can, but I never know just how hipster LLVM is with feature use. I'll switch it over and see if anyone objects.
%call10 = call contract double @cbrt(double noundef %15) #4 | ||
%call11 = call contract float @cbrtf(float noundef %16) #4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the callsite attributes. Separate from this, probably should test that callsite function and parameter / return values are preserved (plus fpmath metadata)
case Intrinsic::tanh: | ||
break; | ||
default: { | ||
if (F.getReturnType()->isDoubleTy()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't handle vectors, which is the main plus of the intrinsics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but this is (at least for the time being) intended to cover the standard library, and neither the C nor the C++ one have vector support at the moment.
// happen if e.g. a library that was compiled without interposition returns | ||
// an allocation that can be validly passed to `free`. | ||
// | ||
// 3. MathFixup (required): Some accelerators might have an incomplete |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We really need to fix our whole library usage strategy (which I'm working on writing up how)
} | ||
for (auto &&F : ToReplace) | ||
F.first->replaceAllUsesWith( | ||
M.getOrInsertFunction(F.second, F.first->getFunctionType()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the function happened to have a non-default address space, this would produce a function with the wrong pointer type. It's probably not too important and this bug is widespread anyway
auto It = | ||
find_if(MathLibToHipStdPar, [&](auto &&M) { return M.first == N; }); | ||
if (It == std::cend(MathLibToHipStdPar)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally we would cross reference this with the TargetLibraryInfo for whether this is a recognized call. That will always fail for AMDGPU since we have to say we have no library calls.
Failing that, should at least make an effort to respect nobuiltin
if (!F.hasName()) | ||
continue; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think anonymous functions need special handling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It saves on doing the lookup for the not_intrinsic
case.
ToReplace.back().second.replace(0, Prefix.size(), "__hipstdpar"); | ||
} | ||
for (auto &&F : ToReplace) | ||
F.first->replaceAllUsesWith( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add a test with some abnormal transformed function uses? i.e. address capturing stores, call sites with the wrong signature, global initializer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I most definitely can but could you please clarify the middle scenario - what would a call site with the wrong signature look like / where would it originate, given that these would originate from well formed correct HLL invocations of math lib interfaces, which have passed the FE and are thus linguistically correct? Perhaps I am missing something obvious here.
…tdpar_math_forwarding
…tdpar_math_forwarding
…tdpar_math_forwarding
…tdpar_math_forwarding
Port so to intercept 7.0 for HLRS drop
Port so to intercept 7.0 for HLRS drop
Gentle ping, I'd like to progress this and add it to the backport list ASAP, since it's rather small, hipstdpar specific and user requested (we already have it in AMD's fork but it's a spurious delta to maintain). |
Fix ordering of includes.
I think we should land this to keep making progress. The grand redesign is certainly not imminent. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks
When compiling in
--hipstdpar
mode, the builtins corresponding to the standard library might end up in code that is expected to execute on the accelerator (e.g. by using thestd::
prefixed functions from<cmath>
). We do not have uniform handling for this in AMDGPU, and the errors that obtain are quite arcane. Furthermore, the user-space changes required to work around this tend to be rather intrusive.This patch adds an additional
--hipstdpar
specific pass which forwards to the run time component of HIPSTDPAR the intrinsics / libcalls which result from the use of the math builtins, and which are not properly handled. In the long run we will want to stop relying on this and handle things in the compiler, but it is going to be a rather lengthy journey, which makes this medium term escape hatch necessary.The paired change in the run time component is here ROCm/rocThrust#551.