Skip to content

Conversation

AlexVlx
Copy link
Contributor

@AlexVlx AlexVlx commented May 15, 2025

When compiling in --hipstdpar mode, the builtins corresponding to the standard library might end up in code that is expected to execute on the accelerator (e.g. by using the std:: prefixed functions from <cmath>). We do not have uniform handling for this in AMDGPU, and the errors that obtain are quite arcane. Furthermore, the user-space changes required to work around this tend to be rather intrusive.

This patch adds an additional --hipstdpar specific pass which forwards to the run time component of HIPSTDPAR the intrinsics / libcalls which result from the use of the math builtins, and which are not properly handled. In the long run we will want to stop relying on this and handle things in the compiler, but it is going to be a rather lengthy journey, which makes this medium term escape hatch necessary.

The paired change in the run time component is here ROCm/rocThrust#551.

@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Alex Voicu (AlexVlx)

Changes

When compiling in --hipstdpar mode, the builtins corresponding to the standard library might end up in code that is expected to execute on the accelerator (e.g. by using the std:: prefixed functions from &lt;cmath&gt;). We do not have uniform handling for this in AMDGPU, and the errors that obtain are quite arcane. Furthermore, the user-space changes required to work around this tend to be rather intrusive.

This patch adds an additional --hipstdpar specific pass which forwards to the run time component of HIPSTDPAR the intrinsics / libcalls which result from the use of the math builtins, and which are not properly handled. In the long run we will want to stop relying on this and handle things in the compiler, but it is going to be a rather lengthy journey, which makes this medium term escape hatch necessary.

The paired change in the run time component is here <ROCm/rocThrust#551>.


Patch is 21.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140158.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h (+7)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp (+6-2)
  • (modified) llvm/lib/Transforms/HipStdPar/HipStdPar.cpp (+117)
  • (added) llvm/test/Transforms/HipStdPar/math-fixup.ll (+240)
diff --git a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
index 5ff38bdf04812..27195051ed7eb 100644
--- a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
+++ b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
@@ -40,6 +40,13 @@ class HipStdParAllocationInterpositionPass
   static bool isRequired() { return true; }
 };
 
+class HipStdParMathFixupPass : public PassInfoMixin<HipStdParMathFixupPass> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+
+  static bool isRequired() { return true; }
+};
+
 } // namespace llvm
 
 #endif // LLVM_TRANSFORMS_HIPSTDPAR_HIPSTDPAR_H
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 94dabe290213d..3acdbf4d49fde 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -80,6 +80,7 @@ MODULE_PASS("global-merge-func", GlobalMergeFuncPass())
 MODULE_PASS("globalopt", GlobalOptPass())
 MODULE_PASS("globalsplit", GlobalSplitPass())
 MODULE_PASS("hipstdpar-interpose-alloc", HipStdParAllocationInterpositionPass())
+MODULE_PASS("hipstdpar-math-fixup", HipStdParMathFixupPass())
 MODULE_PASS("hipstdpar-select-accelerator-code",
             HipStdParAcceleratorCodeSelectionPass())
 MODULE_PASS("hotcoldsplit", HotColdSplittingPass())
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index ccb251b730f16..c3f8cee1e1783 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -819,8 +819,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
           // When we are not using -fgpu-rdc, we can run accelerator code
           // selection relatively early, but still after linking to prevent
           // eager removal of potentially reachable symbols.
-          if (EnableHipStdPar)
+          if (EnableHipStdPar) {
+            PM.addPass(HipStdParMathFixupPass());
             PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+          }
           PM.addPass(AMDGPUPrintfRuntimeBindingPass());
         }
 
@@ -899,8 +901,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
         // selection after linking to prevent, otherwise we end up removing
         // potentially reachable symbols that were exported as external in other
         // modules.
-        if (EnableHipStdPar)
+        if (EnableHipStdPar) {
+          PM.addPass(HipStdParMathFixupPass());
           PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+        }
         // We want to support the -lto-partitions=N option as "best effort".
         // For that, we need to lower LDS earlier in the pipeline before the
         // module is partitioned for codegen.
diff --git a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
index 5a87cf8c83d79..815878089c69e 100644
--- a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
+++ b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
@@ -37,6 +37,16 @@
 //    memory that ends up in one of the runtime equivalents, since this can
 //    happen if e.g. a library that was compiled without interposition returns
 //    an allocation that can be validly passed to `free`.
+//
+// 3. MathFixup (required): Some accelerators might have an incomplete
+//    implementation for the intrinsics used to implement some of the math
+//    functions in <cmath> / their corresponding libcall lowerings. Since this
+//    can vary quite significantly between accelerators, we replace calls to a
+//    set of intrinsics / lib functions known to be problematic with calls to a
+//    HIPSTDPAR specific forwarding layer, which gives an uniform interface for
+//    accelerators to implement in their own runtime components. This pass
+//    should run before AcceleratorCodeSelection so as to prevent the spurious
+//    removal of the HIPSTDPAR specific forwarding functions.
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/HipStdPar/HipStdPar.h"
@@ -48,6 +58,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
@@ -321,3 +332,109 @@ HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
 
   return PreservedAnalyses::none();
 }
+
+static constexpr std::pair<StringLiteral, StringLiteral> MathLibToHipStdPar[]{
+    {"acosh", "__hipstdpar_acosh_f64"},
+    {"acoshf", "__hipstdpar_acosh_f32"},
+    {"asinh", "__hipstdpar_asinh_f64"},
+    {"asinhf", "__hipstdpar_asinh_f32"},
+    {"atanh", "__hipstdpar_atanh_f64"},
+    {"atanhf", "__hipstdpar_atanh_f32"},
+    {"cbrt", "__hipstdpar_cbrt_f64"},
+    {"cbrtf", "__hipstdpar_cbrt_f32"},
+    {"erf", "__hipstdpar_erf_f64"},
+    {"erff", "__hipstdpar_erf_f32"},
+    {"erfc", "__hipstdpar_erfc_f64"},
+    {"erfcf", "__hipstdpar_erfc_f32"},
+    {"fdim", "__hipstdpar_fdim_f64"},
+    {"fdimf", "__hipstdpar_fdim_f32"},
+    {"expm1", "__hipstdpar_expm1_f64"},
+    {"expm1f", "__hipstdpar_expm1_f32"},
+    {"hypot", "__hipstdpar_hypot_f64"},
+    {"hypotf", "__hipstdpar_hypot_f32"},
+    {"ilogb", "__hipstdpar_ilogb_f64"},
+    {"ilogbf", "__hipstdpar_ilogb_f32"},
+    {"lgamma", "__hipstdpar_lgamma_f64"},
+    {"lgammaf", "__hipstdpar_lgamma_f32"},
+    {"log1p", "__hipstdpar_log1p_f64"},
+    {"log1pf", "__hipstdpar_log1p_f32"},
+    {"logb", "__hipstdpar_logb_f64"},
+    {"logbf", "__hipstdpar_logb_f32"},
+    {"nextafter", "__hipstdpar_nextafter_f64"},
+    {"nextafterf", "__hipstdpar_nextafter_f32"},
+    {"nexttoward", "__hipstdpar_nexttoward_f64"},
+    {"nexttowardf", "__hipstdpar_nexttoward_f32"},
+    {"remainder", "__hipstdpar_remainder_f64"},
+    {"remainderf", "__hipstdpar_remainder_f32"},
+    {"remquo", "__hipstdpar_remquo_f64"},
+    {"remquof", "__hipstdpar_remquo_f32"},
+    {"scalbln", "__hipstdpar_scalbln_f64"},
+    {"scalblnf", "__hipstdpar_scalbln_f32"},
+    {"scalbn", "__hipstdpar_scalbn_f64"},
+    {"scalbnf", "__hipstdpar_scalbn_f32"},
+    {"tgamma", "__hipstdpar_tgamma_f64"},
+    {"tgammaf", "__hipstdpar_tgamma_f32"}};
+
+PreservedAnalyses HipStdParMathFixupPass::run(Module &M,
+                                              ModuleAnalysisManager &) {
+  if (M.empty())
+    return PreservedAnalyses::all();
+
+  SmallVector<std::pair<Function *, std::string>> ToReplace;
+  for (auto &&F : M) {
+    if (!F.hasName())
+      continue;
+
+    auto N = F.getName().str();
+    auto ID = F.getIntrinsicID();
+
+    switch (ID) {
+    case Intrinsic::not_intrinsic: {
+      auto It = find_if(MathLibToHipStdPar,
+                        [&](auto &&M) { return M.first == N; });
+      if (It == std::cend(MathLibToHipStdPar))
+        continue;
+      ToReplace.emplace_back(&F, It->second);
+      break;
+    }
+    case Intrinsic::acos:
+    case Intrinsic::asin:
+    case Intrinsic::atan:
+    case Intrinsic::atan2:
+    case Intrinsic::cosh:
+    case Intrinsic::modf:
+    case Intrinsic::sinh:
+    case Intrinsic::tan:
+    case Intrinsic::tanh:
+      break;
+    default: {
+      if (F.getReturnType()->isDoubleTy()) {
+        switch (ID) {
+        case Intrinsic::cos:
+        case Intrinsic::exp:
+        case Intrinsic::exp2:
+        case Intrinsic::log:
+        case Intrinsic::log10:
+        case Intrinsic::log2:
+        case Intrinsic::pow:
+        case Intrinsic::sin:
+          break;
+        default:
+          continue;
+        }
+        break;
+      }
+      continue;
+    }
+    }
+
+    llvm::replace(N, '.', '_');
+    N.replace(0, sizeof("llvm"), "__hipstdpar_");
+    ToReplace.emplace_back(&F, std::move(N));
+  }
+  for (auto &&F : ToReplace)
+    F.first->replaceAllUsesWith(M.getOrInsertFunction(
+        F.second, F.first->getFunctionType()).getCallee());
+
+  return PreservedAnalyses::none();
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/HipStdPar/math-fixup.ll b/llvm/test/Transforms/HipStdPar/math-fixup.ll
new file mode 100644
index 0000000000000..e0e2ca79c0843
--- /dev/null
+++ b/llvm/test/Transforms/HipStdPar/math-fixup.ll
@@ -0,0 +1,240 @@
+; RUN: opt -S -passes=hipstdpar-math-fixup %s | FileCheck %s
+
+define void @foo(double noundef %dbl, float noundef %flt, i32 noundef %quo) #0 {
+; CHECK-LABEL: define void @foo(
+; CHECK-SAME: double noundef [[DBL:%.*]], float noundef [[FLT:%.*]], i32 noundef [[QUO:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[QUO_ADDR:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 [[QUO]], ptr [[QUO_ADDR]], align 4
+entry:
+  %quo.addr = alloca i32, align 4
+  store i32 %quo, ptr %quo.addr, align 4
+  ; CHECK-NEXT:    [[TMP0:%.*]] = tail call contract double @llvm.fabs.f64(double [[DBL]])
+  %0 = tail call contract double @llvm.fabs.f64(double %dbl)
+  ; CHECK-NEXT:    [[TMP1:%.*]] = tail call contract float @llvm.fabs.f32(float [[FLT]])
+  %1 = tail call contract float @llvm.fabs.f32(float %flt)
+  ; CHECK-NEXT:    [[CALL:%.*]] = tail call contract double @__hipstdpar_remainder_f64(double noundef [[TMP0]], double noundef [[TMP0]]) #[[ATTR4:[0-9]+]]
+  %call = tail call contract double @remainder(double noundef %0, double noundef %0) #4
+  ; CHECK-NEXT:    [[CALL1:%.*]] = tail call contract float @__hipstdpar_remainder_f32(float noundef [[TMP1]], float noundef [[TMP1]]) #[[ATTR4]]
+  %call1 = tail call contract float @remainderf(float noundef %1, float noundef %1) #4
+  ; CHECK-NEXT:    [[CALL2:%.*]] = call contract double @__hipstdpar_remquo_f64(double noundef [[CALL]], double noundef [[CALL]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3:[0-9]+]]
+  %call2 = call contract double @remquo(double noundef %call, double noundef %call, ptr noundef nonnull %quo.addr) #5
+  ; CHECK-NEXT:    [[CALL3:%.*]] = call contract float @__hipstdpar_remquo_f32(float noundef [[CALL1]], float noundef [[CALL1]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3]]
+  %call3 = call contract float @remquof(float noundef %call1, float noundef %call1, ptr noundef nonnull %quo.addr) #5
+  ; CHECK-NEXT:    [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[CALL2]], double [[CALL2]], double [[CALL2]])
+  %2 = call contract double @llvm.fma.f64(double %call2, double %call2, double %call2)
+  ; CHECK-NEXT:    [[TMP3:%.*]] = call contract float @llvm.fma.f32(float [[CALL3]], float [[CALL3]], float [[CALL3]])
+  %3 = call contract float @llvm.fma.f32(float %call3, float %call3, float %call3)
+  ; CHECK-NEXT:    [[CALL4:%.*]] = call contract double @__hipstdpar_fdim_f64(double noundef [[TMP2]], double noundef [[TMP2]]) #[[ATTR4]]
+  %call4 = call contract double @fdim(double noundef %2, double noundef %2) #4
+  ; CHECK-NEXT:    [[CALL5:%.*]] = call contract float @__hipstdpar_fdim_f32(float noundef [[TMP3]], float noundef [[TMP3]]) #[[ATTR4]]
+  %call5 = call contract float @fdimf(float noundef %3, float noundef %3) #4
+  ; CHECK-NEXT:    [[TMP4:%.*]] = call contract double @__hipstdpar_exp_f64(double [[CALL4]])
+  %4 = call contract double @llvm.exp.f64(double %call4)
+  ; CHECK-NEXT:    [[TMP5:%.*]] = call contract float @llvm.exp.f32(float [[CALL5]])
+  %5 = call contract float @llvm.exp.f32(float %call5)
+  ; CHECK-NEXT:    [[TMP6:%.*]] = call contract double @__hipstdpar_exp2_f64(double [[TMP4]])
+  %6 = call contract double @llvm.exp2.f64(double %4)
+  ; CHECK-NEXT:    [[TMP7:%.*]] = call contract float @llvm.exp2.f32(float [[TMP5]])
+  %7 = call contract float @llvm.exp2.f32(float %5)
+  ; CHECK-NEXT:    [[CALL6:%.*]] = call contract double @__hipstdpar_expm1_f64(double noundef [[TMP6]]) #[[ATTR4]]
+  %call6 = call contract double @expm1(double noundef %6) #4
+  ; CHECK-NEXT:    [[CALL7:%.*]] = call contract float @__hipstdpar_expm1_f32(float noundef [[TMP7]]) #[[ATTR4]]
+  %call7 = call contract float @expm1f(float noundef %7) #4
+  ; CHECK-NEXT:    [[TMP8:%.*]] = call contract double @__hipstdpar_log_f64(double [[CALL6]])
+  %8 = call contract double @llvm.log.f64(double %call6)
+  ; CHECK-NEXT:    [[TMP9:%.*]] = call contract float @llvm.log.f32(float [[CALL7]])
+  %9 = call contract float @llvm.log.f32(float %call7)
+  ; CHECK-NEXT:    [[TMP10:%.*]] = call contract double @__hipstdpar_log10_f64(double [[TMP8]])
+  %10 = call contract double @llvm.log10.f64(double %8)
+  ; CHECK-NEXT:    [[TMP11:%.*]] = call contract float @llvm.log10.f32(float [[TMP9]])
+  %11 = call contract float @llvm.log10.f32(float %9)
+  ; CHECK-NEXT:    [[TMP12:%.*]] = call contract double @__hipstdpar_log2_f64(double [[TMP10]])
+  %12 = call contract double @llvm.log2.f64(double %10)
+  ; CHECK-NEXT:    [[TMP13:%.*]] = call contract float @llvm.log2.f32(float [[TMP11]])
+  %13 = call contract float @llvm.log2.f32(float %11)
+  ; CHECK-NEXT:    [[CALL8:%.*]] = call contract double @__hipstdpar_log1p_f64(double noundef [[TMP12]]) #[[ATTR4]]
+  %call8 = call contract double @log1p(double noundef %12) #4
+  ; CHECK-NEXT:    [[CALL9:%.*]] = call contract float @__hipstdpar_log1p_f32(float noundef [[TMP13]]) #[[ATTR4]]
+  %call9 = call contract float @log1pf(float noundef %13) #4
+  ; CHECK-NEXT:    [[TMP14:%.*]] = call contract float @llvm.pow.f32(float [[CALL9]], float [[CALL9]])
+  %14 = call contract float @llvm.pow.f32(float %call9, float %call9)
+  ; CHECK-NEXT:    [[TMP15:%.*]] = call contract double @llvm.sqrt.f64(double [[CALL8]])
+  %15 = call contract double @llvm.sqrt.f64(double %call8)
+  ; CHECK-NEXT:    [[TMP16:%.*]] = call contract float @llvm.sqrt.f32(float [[TMP14]])
+  %16 = call contract float @llvm.sqrt.f32(float %14)
+  ; CHECK-NEXT:    [[CALL10:%.*]] = call contract double @__hipstdpar_cbrt_f64(double noundef [[TMP15]]) #[[ATTR4]]
+  %call10 = call contract double @cbrt(double noundef %15) #4
+  ; CHECK-NEXT:    [[CALL11:%.*]] = call contract float @__hipstdpar_cbrt_f32(float noundef [[TMP16]]) #[[ATTR4]]
+  %call11 = call contract float @cbrtf(float noundef %16) #4
+  ; CHECK-NEXT:    [[CALL12:%.*]] = call contract double @__hipstdpar_hypot_f64(double noundef [[CALL10]], double noundef [[CALL10]]) #[[ATTR4]]
+  %call12 = call contract double @hypot(double noundef %call10, double noundef %call10) #4
+  ; CHECK-NEXT:    [[CALL13:%.*]] = call contract float @__hipstdpar_hypot_f32(float noundef [[CALL11]], float noundef [[CALL11]]) #[[ATTR4]]
+  %call13 = call contract float @hypotf(float noundef %call11, float noundef %call11) #4
+  ; CHECK-NEXT:    [[TMP17:%.*]] = call contract float @llvm.sin.f32(float [[CALL13]])
+  %17 = call contract float @llvm.sin.f32(float %call13)
+  ; CHECK-NEXT:    [[TMP18:%.*]] = call contract float @llvm.cos.f32(float [[TMP17]])
+  %18 = call contract float @llvm.cos.f32(float %17)
+  ; CHECK-NEXT:    [[TMP19:%.*]] = call contract double @__hipstdpar_tan_f64(double [[CALL12]])
+  %19 = call contract double @llvm.tan.f64(double %call12)
+  ; CHECK-NEXT:    [[TMP20:%.*]] = call contract double @__hipstdpar_asin_f64(double [[TMP19]])
+  %20 = call contract double @llvm.asin.f64(double %19)
+  ; CHECK-NEXT:    [[TMP21:%.*]] = call contract double @__hipstdpar_acos_f64(double [[TMP20]])
+  %21 = call contract double @llvm.acos.f64(double %20)
+  ; CHECK-NEXT:    [[TMP22:%.*]] = call contract double @__hipstdpar_atan_f64(double [[TMP21]])
+  %22 = call contract double @llvm.atan.f64(double %21)
+  ; CHECK-NEXT:    [[TMP23:%.*]] = call contract double @__hipstdpar_atan2_f64(double [[TMP22]], double [[TMP22]])
+  %23 = call contract double @llvm.atan2.f64(double %22, double %22)
+  ; CHECK-NEXT:    [[TMP24:%.*]] = call contract double @__hipstdpar_sinh_f64(double [[TMP23]])
+  %24 = call contract double @llvm.sinh.f64(double %23)
+  ; CHECK-NEXT:    [[TMP25:%.*]] = call contract double @__hipstdpar_cosh_f64(double [[TMP24]])
+  %25 = call contract double @llvm.cosh.f64(double %24)
+  ; CHECK-NEXT:    [[TMP26:%.*]] = call contract double @__hipstdpar_tanh_f64(double [[TMP25]])
+  %26 = call contract double @llvm.tanh.f64(double %25)
+  ; CHECK-NEXT:    [[CALL14:%.*]] = call contract double @__hipstdpar_asinh_f64(double noundef [[TMP26]]) #[[ATTR4]]
+  %call14 = call contract double @asinh(double noundef %26) #4
+  ; CHECK-NEXT:    [[CALL15:%.*]] = call contract float @__hipstdpar_asinh_f32(float noundef [[TMP18]]) #[[ATTR4]]
+  %call15 = call contract float @asinhf(float noundef %18) #4
+  ; CHECK-NEXT:    [[CALL16:%.*]] = call contract double @__hipstdpar_acosh_f64(double noundef [[CALL14]]) #[[ATTR4]]
+  %call16 = call contract double @acosh(double noundef %call14) #4
+  ; CHECK-NEXT:    [[CALL17:%.*]] = call contract float @__hipstdpar_acosh_f32(float noundef [[CALL15]]) #[[ATTR4]]
+  %call17 = call contract float @acoshf(float noundef %call15) #4
+  ; CHECK-NEXT:    [[CALL18:%.*]] = call contract double @__hipstdpar_atanh_f64(double noundef [[CALL16]]) #[[ATTR4]]
+  %call18 = call contract double @atanh(double noundef %call16) #4
+  ; CHECK-NEXT:    [[CALL19:%.*]] = call contract float @__hipstdpar_atanh_f32(float noundef [[CALL17]]) #[[ATTR4]]
+  %call19 = call contract float @atanhf(float noundef %call17) #4
+  ; CHECK-NEXT:    [[CALL20:%.*]] = call contract double @__hipstdpar_erf_f64(double noundef [[CALL18]]) #[[ATTR4]]
+  %call20 = call contract double @erf(double noundef %call18) #4
+  ; CHECK-NEXT:    [[CALL21:%.*]] = call contract float @__hipstdpar_erf_f32(float noundef [[CALL19]]) #[[ATTR4]]
+  %call21 = call contract float @erff(float noundef %call19) #4
+  ; CHECK-NEXT:    [[CALL22:%.*]] = call contract double @__hipstdpar_erfc_f64(double noundef [[CALL20]]) #[[ATTR4]]
+  %call22 = call contract double @erfc(double noundef %call20) #4
+  ; CHECK-NEXT:    [[CALL23:%.*]] = call contract float @__hipstdpar_erfc_f32(float noundef [[CALL21]]) #[[ATTR4]]
+  %call23 = call contract float @erfcf(float noundef %call21) #4
+  ; CHECK-NEXT:    [[CALL24:%.*]] = call contract double @__hipstdpar_tgamma_f64(double noundef [[CALL22]]) #[[ATTR4]]
+  %call24 = call contract double @tgamma(double noundef %call22) #4
+  ; CHECK-NEXT:    [[CALL25:%.*]] = call contract float @__hipstdpar_tgamma_f32(float noundef [[CALL23]]) #[[ATTR4]]
+  %call25 = call contract float @tgammaf(float noundef %call23) #4
+  ; CHECK-NEXT:    [[CALL26:%.*]] = call contract double @__hipstdpar_lgamma_f64(double noundef [[CALL24]]) #[[ATTR3]]
+  %call26 = call contract double @lgamma(double noundef %call24) #5
+  ; CHECK-NEXT:    [[CALL27:%.*]] = call contract float @__hipstdpar_lgamma_f32(float noundef [[CALL25]]) #[[ATTR3]]
+  %call27 = call contract float @lgammaf(float noundef %call25) #5
+  ret void
+}
+
+declare double @llvm.fabs.f64(double) #1
+
+declare float @llvm.fabs.f32(float) #1
+
+declare hidden double @remainder(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @remainderf(float noundef, float noundef) local_unnamed_addr #2
+
+declare hidden double @remquo(double noundef, double noundef, ptr noundef) local_unnamed_addr #3
+
+declare hidden float @remquof(float noundef, float noundef, ptr noundef) local_unnamed_addr #3
+
+declare double @llvm.fma.f64(double, double, double) #1
+
+declare float @llvm.fma.f32(float, float, float) #1
+
+declare hidden double @fdim(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @fdimf(float noundef, float noundef) local_unnamed_addr #2
+
+declare double @llvm.exp.f64(double) #1
+
+declare float @llvm.exp.f32(float) #1
+
+declare double @llvm.exp2.f64(double) #1
+
+declare float @llvm.exp2.f32(float) #1
+
+declare hidden double @expm1(double noundef) local_unnamed_addr #2
+
+declare hidden float @expm1f(float noundef) local_unnamed_addr #2
+
+declare double @llvm.log.f64(double) #1
+
+declare float @llvm.log.f32(float) #1
+
+declare double @llvm.log10.f64(double) #1
+
+declare float @llvm.log10.f32(float) #1
+
+declare double @llvm.log2.f64(double) #1
+
+declare float @llvm.log2.f32(float) #1
+
+declare hidden double @log1p(double noundef) local_unnamed_addr #2
+
+declare hidden float @log1pf(float noundef) local_unnamed_addr #2
+
+declare float @llvm.pow.f32(float, float) #1
+
+declare double @llvm.sqrt.f64(double) #1
+
+declare float @llvm.sqrt.f32(float) #1
+
+declare hidden double @cbrt(double noundef) local_unnamed_addr #2
+
+declare hidden float @cbrtf(float noundef) local_unnamed_addr #2
+
+declare hidden double @hypot(double noundef...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 15, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Alex Voicu (AlexVlx)

Changes

When compiling in --hipstdpar mode, the builtins corresponding to the standard library might end up in code that is expected to execute on the accelerator (e.g. by using the std:: prefixed functions from &lt;cmath&gt;). We do not have uniform handling for this in AMDGPU, and the errors that obtain are quite arcane. Furthermore, the user-space changes required to work around this tend to be rather intrusive.

This patch adds an additional --hipstdpar specific pass which forwards to the run time component of HIPSTDPAR the intrinsics / libcalls which result from the use of the math builtins, and which are not properly handled. In the long run we will want to stop relying on this and handle things in the compiler, but it is going to be a rather lengthy journey, which makes this medium term escape hatch necessary.

The paired change in the run time component is here <ROCm/rocThrust#551>.


Patch is 21.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140158.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h (+7)
  • (modified) llvm/lib/Passes/PassRegistry.def (+1)
  • (modified) llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp (+6-2)
  • (modified) llvm/lib/Transforms/HipStdPar/HipStdPar.cpp (+117)
  • (added) llvm/test/Transforms/HipStdPar/math-fixup.ll (+240)
diff --git a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
index 5ff38bdf04812..27195051ed7eb 100644
--- a/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
+++ b/llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h
@@ -40,6 +40,13 @@ class HipStdParAllocationInterpositionPass
   static bool isRequired() { return true; }
 };
 
+class HipStdParMathFixupPass : public PassInfoMixin<HipStdParMathFixupPass> {
+public:
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+
+  static bool isRequired() { return true; }
+};
+
 } // namespace llvm
 
 #endif // LLVM_TRANSFORMS_HIPSTDPAR_HIPSTDPAR_H
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 94dabe290213d..3acdbf4d49fde 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -80,6 +80,7 @@ MODULE_PASS("global-merge-func", GlobalMergeFuncPass())
 MODULE_PASS("globalopt", GlobalOptPass())
 MODULE_PASS("globalsplit", GlobalSplitPass())
 MODULE_PASS("hipstdpar-interpose-alloc", HipStdParAllocationInterpositionPass())
+MODULE_PASS("hipstdpar-math-fixup", HipStdParMathFixupPass())
 MODULE_PASS("hipstdpar-select-accelerator-code",
             HipStdParAcceleratorCodeSelectionPass())
 MODULE_PASS("hotcoldsplit", HotColdSplittingPass())
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
index ccb251b730f16..c3f8cee1e1783 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp
@@ -819,8 +819,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
           // When we are not using -fgpu-rdc, we can run accelerator code
           // selection relatively early, but still after linking to prevent
           // eager removal of potentially reachable symbols.
-          if (EnableHipStdPar)
+          if (EnableHipStdPar) {
+            PM.addPass(HipStdParMathFixupPass());
             PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+          }
           PM.addPass(AMDGPUPrintfRuntimeBindingPass());
         }
 
@@ -899,8 +901,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
         // selection after linking to prevent, otherwise we end up removing
         // potentially reachable symbols that were exported as external in other
         // modules.
-        if (EnableHipStdPar)
+        if (EnableHipStdPar) {
+          PM.addPass(HipStdParMathFixupPass());
           PM.addPass(HipStdParAcceleratorCodeSelectionPass());
+        }
         // We want to support the -lto-partitions=N option as "best effort".
         // For that, we need to lower LDS earlier in the pipeline before the
         // module is partitioned for codegen.
diff --git a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
index 5a87cf8c83d79..815878089c69e 100644
--- a/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
+++ b/llvm/lib/Transforms/HipStdPar/HipStdPar.cpp
@@ -37,6 +37,16 @@
 //    memory that ends up in one of the runtime equivalents, since this can
 //    happen if e.g. a library that was compiled without interposition returns
 //    an allocation that can be validly passed to `free`.
+//
+// 3. MathFixup (required): Some accelerators might have an incomplete
+//    implementation for the intrinsics used to implement some of the math
+//    functions in <cmath> / their corresponding libcall lowerings. Since this
+//    can vary quite significantly between accelerators, we replace calls to a
+//    set of intrinsics / lib functions known to be problematic with calls to a
+//    HIPSTDPAR specific forwarding layer, which gives an uniform interface for
+//    accelerators to implement in their own runtime components. This pass
+//    should run before AcceleratorCodeSelection so as to prevent the spurious
+//    removal of the HIPSTDPAR specific forwarding functions.
 //===----------------------------------------------------------------------===//
 
 #include "llvm/Transforms/HipStdPar/HipStdPar.h"
@@ -48,6 +58,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/Function.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Transforms/Utils/ModuleUtils.h"
 
@@ -321,3 +332,109 @@ HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
 
   return PreservedAnalyses::none();
 }
+
+static constexpr std::pair<StringLiteral, StringLiteral> MathLibToHipStdPar[]{
+    {"acosh", "__hipstdpar_acosh_f64"},
+    {"acoshf", "__hipstdpar_acosh_f32"},
+    {"asinh", "__hipstdpar_asinh_f64"},
+    {"asinhf", "__hipstdpar_asinh_f32"},
+    {"atanh", "__hipstdpar_atanh_f64"},
+    {"atanhf", "__hipstdpar_atanh_f32"},
+    {"cbrt", "__hipstdpar_cbrt_f64"},
+    {"cbrtf", "__hipstdpar_cbrt_f32"},
+    {"erf", "__hipstdpar_erf_f64"},
+    {"erff", "__hipstdpar_erf_f32"},
+    {"erfc", "__hipstdpar_erfc_f64"},
+    {"erfcf", "__hipstdpar_erfc_f32"},
+    {"fdim", "__hipstdpar_fdim_f64"},
+    {"fdimf", "__hipstdpar_fdim_f32"},
+    {"expm1", "__hipstdpar_expm1_f64"},
+    {"expm1f", "__hipstdpar_expm1_f32"},
+    {"hypot", "__hipstdpar_hypot_f64"},
+    {"hypotf", "__hipstdpar_hypot_f32"},
+    {"ilogb", "__hipstdpar_ilogb_f64"},
+    {"ilogbf", "__hipstdpar_ilogb_f32"},
+    {"lgamma", "__hipstdpar_lgamma_f64"},
+    {"lgammaf", "__hipstdpar_lgamma_f32"},
+    {"log1p", "__hipstdpar_log1p_f64"},
+    {"log1pf", "__hipstdpar_log1p_f32"},
+    {"logb", "__hipstdpar_logb_f64"},
+    {"logbf", "__hipstdpar_logb_f32"},
+    {"nextafter", "__hipstdpar_nextafter_f64"},
+    {"nextafterf", "__hipstdpar_nextafter_f32"},
+    {"nexttoward", "__hipstdpar_nexttoward_f64"},
+    {"nexttowardf", "__hipstdpar_nexttoward_f32"},
+    {"remainder", "__hipstdpar_remainder_f64"},
+    {"remainderf", "__hipstdpar_remainder_f32"},
+    {"remquo", "__hipstdpar_remquo_f64"},
+    {"remquof", "__hipstdpar_remquo_f32"},
+    {"scalbln", "__hipstdpar_scalbln_f64"},
+    {"scalblnf", "__hipstdpar_scalbln_f32"},
+    {"scalbn", "__hipstdpar_scalbn_f64"},
+    {"scalbnf", "__hipstdpar_scalbn_f32"},
+    {"tgamma", "__hipstdpar_tgamma_f64"},
+    {"tgammaf", "__hipstdpar_tgamma_f32"}};
+
+PreservedAnalyses HipStdParMathFixupPass::run(Module &M,
+                                              ModuleAnalysisManager &) {
+  if (M.empty())
+    return PreservedAnalyses::all();
+
+  SmallVector<std::pair<Function *, std::string>> ToReplace;
+  for (auto &&F : M) {
+    if (!F.hasName())
+      continue;
+
+    auto N = F.getName().str();
+    auto ID = F.getIntrinsicID();
+
+    switch (ID) {
+    case Intrinsic::not_intrinsic: {
+      auto It = find_if(MathLibToHipStdPar,
+                        [&](auto &&M) { return M.first == N; });
+      if (It == std::cend(MathLibToHipStdPar))
+        continue;
+      ToReplace.emplace_back(&F, It->second);
+      break;
+    }
+    case Intrinsic::acos:
+    case Intrinsic::asin:
+    case Intrinsic::atan:
+    case Intrinsic::atan2:
+    case Intrinsic::cosh:
+    case Intrinsic::modf:
+    case Intrinsic::sinh:
+    case Intrinsic::tan:
+    case Intrinsic::tanh:
+      break;
+    default: {
+      if (F.getReturnType()->isDoubleTy()) {
+        switch (ID) {
+        case Intrinsic::cos:
+        case Intrinsic::exp:
+        case Intrinsic::exp2:
+        case Intrinsic::log:
+        case Intrinsic::log10:
+        case Intrinsic::log2:
+        case Intrinsic::pow:
+        case Intrinsic::sin:
+          break;
+        default:
+          continue;
+        }
+        break;
+      }
+      continue;
+    }
+    }
+
+    llvm::replace(N, '.', '_');
+    N.replace(0, sizeof("llvm"), "__hipstdpar_");
+    ToReplace.emplace_back(&F, std::move(N));
+  }
+  for (auto &&F : ToReplace)
+    F.first->replaceAllUsesWith(M.getOrInsertFunction(
+        F.second, F.first->getFunctionType()).getCallee());
+
+  return PreservedAnalyses::none();
+}
\ No newline at end of file
diff --git a/llvm/test/Transforms/HipStdPar/math-fixup.ll b/llvm/test/Transforms/HipStdPar/math-fixup.ll
new file mode 100644
index 0000000000000..e0e2ca79c0843
--- /dev/null
+++ b/llvm/test/Transforms/HipStdPar/math-fixup.ll
@@ -0,0 +1,240 @@
+; RUN: opt -S -passes=hipstdpar-math-fixup %s | FileCheck %s
+
+define void @foo(double noundef %dbl, float noundef %flt, i32 noundef %quo) #0 {
+; CHECK-LABEL: define void @foo(
+; CHECK-SAME: double noundef [[DBL:%.*]], float noundef [[FLT:%.*]], i32 noundef [[QUO:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    [[QUO_ADDR:%.*]] = alloca i32, align 4
+; CHECK-NEXT:    store i32 [[QUO]], ptr [[QUO_ADDR]], align 4
+entry:
+  %quo.addr = alloca i32, align 4
+  store i32 %quo, ptr %quo.addr, align 4
+  ; CHECK-NEXT:    [[TMP0:%.*]] = tail call contract double @llvm.fabs.f64(double [[DBL]])
+  %0 = tail call contract double @llvm.fabs.f64(double %dbl)
+  ; CHECK-NEXT:    [[TMP1:%.*]] = tail call contract float @llvm.fabs.f32(float [[FLT]])
+  %1 = tail call contract float @llvm.fabs.f32(float %flt)
+  ; CHECK-NEXT:    [[CALL:%.*]] = tail call contract double @__hipstdpar_remainder_f64(double noundef [[TMP0]], double noundef [[TMP0]]) #[[ATTR4:[0-9]+]]
+  %call = tail call contract double @remainder(double noundef %0, double noundef %0) #4
+  ; CHECK-NEXT:    [[CALL1:%.*]] = tail call contract float @__hipstdpar_remainder_f32(float noundef [[TMP1]], float noundef [[TMP1]]) #[[ATTR4]]
+  %call1 = tail call contract float @remainderf(float noundef %1, float noundef %1) #4
+  ; CHECK-NEXT:    [[CALL2:%.*]] = call contract double @__hipstdpar_remquo_f64(double noundef [[CALL]], double noundef [[CALL]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3:[0-9]+]]
+  %call2 = call contract double @remquo(double noundef %call, double noundef %call, ptr noundef nonnull %quo.addr) #5
+  ; CHECK-NEXT:    [[CALL3:%.*]] = call contract float @__hipstdpar_remquo_f32(float noundef [[CALL1]], float noundef [[CALL1]], ptr noundef nonnull [[QUO_ADDR]]) #[[ATTR3]]
+  %call3 = call contract float @remquof(float noundef %call1, float noundef %call1, ptr noundef nonnull %quo.addr) #5
+  ; CHECK-NEXT:    [[TMP2:%.*]] = call contract double @llvm.fma.f64(double [[CALL2]], double [[CALL2]], double [[CALL2]])
+  %2 = call contract double @llvm.fma.f64(double %call2, double %call2, double %call2)
+  ; CHECK-NEXT:    [[TMP3:%.*]] = call contract float @llvm.fma.f32(float [[CALL3]], float [[CALL3]], float [[CALL3]])
+  %3 = call contract float @llvm.fma.f32(float %call3, float %call3, float %call3)
+  ; CHECK-NEXT:    [[CALL4:%.*]] = call contract double @__hipstdpar_fdim_f64(double noundef [[TMP2]], double noundef [[TMP2]]) #[[ATTR4]]
+  %call4 = call contract double @fdim(double noundef %2, double noundef %2) #4
+  ; CHECK-NEXT:    [[CALL5:%.*]] = call contract float @__hipstdpar_fdim_f32(float noundef [[TMP3]], float noundef [[TMP3]]) #[[ATTR4]]
+  %call5 = call contract float @fdimf(float noundef %3, float noundef %3) #4
+  ; CHECK-NEXT:    [[TMP4:%.*]] = call contract double @__hipstdpar_exp_f64(double [[CALL4]])
+  %4 = call contract double @llvm.exp.f64(double %call4)
+  ; CHECK-NEXT:    [[TMP5:%.*]] = call contract float @llvm.exp.f32(float [[CALL5]])
+  %5 = call contract float @llvm.exp.f32(float %call5)
+  ; CHECK-NEXT:    [[TMP6:%.*]] = call contract double @__hipstdpar_exp2_f64(double [[TMP4]])
+  %6 = call contract double @llvm.exp2.f64(double %4)
+  ; CHECK-NEXT:    [[TMP7:%.*]] = call contract float @llvm.exp2.f32(float [[TMP5]])
+  %7 = call contract float @llvm.exp2.f32(float %5)
+  ; CHECK-NEXT:    [[CALL6:%.*]] = call contract double @__hipstdpar_expm1_f64(double noundef [[TMP6]]) #[[ATTR4]]
+  %call6 = call contract double @expm1(double noundef %6) #4
+  ; CHECK-NEXT:    [[CALL7:%.*]] = call contract float @__hipstdpar_expm1_f32(float noundef [[TMP7]]) #[[ATTR4]]
+  %call7 = call contract float @expm1f(float noundef %7) #4
+  ; CHECK-NEXT:    [[TMP8:%.*]] = call contract double @__hipstdpar_log_f64(double [[CALL6]])
+  %8 = call contract double @llvm.log.f64(double %call6)
+  ; CHECK-NEXT:    [[TMP9:%.*]] = call contract float @llvm.log.f32(float [[CALL7]])
+  %9 = call contract float @llvm.log.f32(float %call7)
+  ; CHECK-NEXT:    [[TMP10:%.*]] = call contract double @__hipstdpar_log10_f64(double [[TMP8]])
+  %10 = call contract double @llvm.log10.f64(double %8)
+  ; CHECK-NEXT:    [[TMP11:%.*]] = call contract float @llvm.log10.f32(float [[TMP9]])
+  %11 = call contract float @llvm.log10.f32(float %9)
+  ; CHECK-NEXT:    [[TMP12:%.*]] = call contract double @__hipstdpar_log2_f64(double [[TMP10]])
+  %12 = call contract double @llvm.log2.f64(double %10)
+  ; CHECK-NEXT:    [[TMP13:%.*]] = call contract float @llvm.log2.f32(float [[TMP11]])
+  %13 = call contract float @llvm.log2.f32(float %11)
+  ; CHECK-NEXT:    [[CALL8:%.*]] = call contract double @__hipstdpar_log1p_f64(double noundef [[TMP12]]) #[[ATTR4]]
+  %call8 = call contract double @log1p(double noundef %12) #4
+  ; CHECK-NEXT:    [[CALL9:%.*]] = call contract float @__hipstdpar_log1p_f32(float noundef [[TMP13]]) #[[ATTR4]]
+  %call9 = call contract float @log1pf(float noundef %13) #4
+  ; CHECK-NEXT:    [[TMP14:%.*]] = call contract float @llvm.pow.f32(float [[CALL9]], float [[CALL9]])
+  %14 = call contract float @llvm.pow.f32(float %call9, float %call9)
+  ; CHECK-NEXT:    [[TMP15:%.*]] = call contract double @llvm.sqrt.f64(double [[CALL8]])
+  %15 = call contract double @llvm.sqrt.f64(double %call8)
+  ; CHECK-NEXT:    [[TMP16:%.*]] = call contract float @llvm.sqrt.f32(float [[TMP14]])
+  %16 = call contract float @llvm.sqrt.f32(float %14)
+  ; CHECK-NEXT:    [[CALL10:%.*]] = call contract double @__hipstdpar_cbrt_f64(double noundef [[TMP15]]) #[[ATTR4]]
+  %call10 = call contract double @cbrt(double noundef %15) #4
+  ; CHECK-NEXT:    [[CALL11:%.*]] = call contract float @__hipstdpar_cbrt_f32(float noundef [[TMP16]]) #[[ATTR4]]
+  %call11 = call contract float @cbrtf(float noundef %16) #4
+  ; CHECK-NEXT:    [[CALL12:%.*]] = call contract double @__hipstdpar_hypot_f64(double noundef [[CALL10]], double noundef [[CALL10]]) #[[ATTR4]]
+  %call12 = call contract double @hypot(double noundef %call10, double noundef %call10) #4
+  ; CHECK-NEXT:    [[CALL13:%.*]] = call contract float @__hipstdpar_hypot_f32(float noundef [[CALL11]], float noundef [[CALL11]]) #[[ATTR4]]
+  %call13 = call contract float @hypotf(float noundef %call11, float noundef %call11) #4
+  ; CHECK-NEXT:    [[TMP17:%.*]] = call contract float @llvm.sin.f32(float [[CALL13]])
+  %17 = call contract float @llvm.sin.f32(float %call13)
+  ; CHECK-NEXT:    [[TMP18:%.*]] = call contract float @llvm.cos.f32(float [[TMP17]])
+  %18 = call contract float @llvm.cos.f32(float %17)
+  ; CHECK-NEXT:    [[TMP19:%.*]] = call contract double @__hipstdpar_tan_f64(double [[CALL12]])
+  %19 = call contract double @llvm.tan.f64(double %call12)
+  ; CHECK-NEXT:    [[TMP20:%.*]] = call contract double @__hipstdpar_asin_f64(double [[TMP19]])
+  %20 = call contract double @llvm.asin.f64(double %19)
+  ; CHECK-NEXT:    [[TMP21:%.*]] = call contract double @__hipstdpar_acos_f64(double [[TMP20]])
+  %21 = call contract double @llvm.acos.f64(double %20)
+  ; CHECK-NEXT:    [[TMP22:%.*]] = call contract double @__hipstdpar_atan_f64(double [[TMP21]])
+  %22 = call contract double @llvm.atan.f64(double %21)
+  ; CHECK-NEXT:    [[TMP23:%.*]] = call contract double @__hipstdpar_atan2_f64(double [[TMP22]], double [[TMP22]])
+  %23 = call contract double @llvm.atan2.f64(double %22, double %22)
+  ; CHECK-NEXT:    [[TMP24:%.*]] = call contract double @__hipstdpar_sinh_f64(double [[TMP23]])
+  %24 = call contract double @llvm.sinh.f64(double %23)
+  ; CHECK-NEXT:    [[TMP25:%.*]] = call contract double @__hipstdpar_cosh_f64(double [[TMP24]])
+  %25 = call contract double @llvm.cosh.f64(double %24)
+  ; CHECK-NEXT:    [[TMP26:%.*]] = call contract double @__hipstdpar_tanh_f64(double [[TMP25]])
+  %26 = call contract double @llvm.tanh.f64(double %25)
+  ; CHECK-NEXT:    [[CALL14:%.*]] = call contract double @__hipstdpar_asinh_f64(double noundef [[TMP26]]) #[[ATTR4]]
+  %call14 = call contract double @asinh(double noundef %26) #4
+  ; CHECK-NEXT:    [[CALL15:%.*]] = call contract float @__hipstdpar_asinh_f32(float noundef [[TMP18]]) #[[ATTR4]]
+  %call15 = call contract float @asinhf(float noundef %18) #4
+  ; CHECK-NEXT:    [[CALL16:%.*]] = call contract double @__hipstdpar_acosh_f64(double noundef [[CALL14]]) #[[ATTR4]]
+  %call16 = call contract double @acosh(double noundef %call14) #4
+  ; CHECK-NEXT:    [[CALL17:%.*]] = call contract float @__hipstdpar_acosh_f32(float noundef [[CALL15]]) #[[ATTR4]]
+  %call17 = call contract float @acoshf(float noundef %call15) #4
+  ; CHECK-NEXT:    [[CALL18:%.*]] = call contract double @__hipstdpar_atanh_f64(double noundef [[CALL16]]) #[[ATTR4]]
+  %call18 = call contract double @atanh(double noundef %call16) #4
+  ; CHECK-NEXT:    [[CALL19:%.*]] = call contract float @__hipstdpar_atanh_f32(float noundef [[CALL17]]) #[[ATTR4]]
+  %call19 = call contract float @atanhf(float noundef %call17) #4
+  ; CHECK-NEXT:    [[CALL20:%.*]] = call contract double @__hipstdpar_erf_f64(double noundef [[CALL18]]) #[[ATTR4]]
+  %call20 = call contract double @erf(double noundef %call18) #4
+  ; CHECK-NEXT:    [[CALL21:%.*]] = call contract float @__hipstdpar_erf_f32(float noundef [[CALL19]]) #[[ATTR4]]
+  %call21 = call contract float @erff(float noundef %call19) #4
+  ; CHECK-NEXT:    [[CALL22:%.*]] = call contract double @__hipstdpar_erfc_f64(double noundef [[CALL20]]) #[[ATTR4]]
+  %call22 = call contract double @erfc(double noundef %call20) #4
+  ; CHECK-NEXT:    [[CALL23:%.*]] = call contract float @__hipstdpar_erfc_f32(float noundef [[CALL21]]) #[[ATTR4]]
+  %call23 = call contract float @erfcf(float noundef %call21) #4
+  ; CHECK-NEXT:    [[CALL24:%.*]] = call contract double @__hipstdpar_tgamma_f64(double noundef [[CALL22]]) #[[ATTR4]]
+  %call24 = call contract double @tgamma(double noundef %call22) #4
+  ; CHECK-NEXT:    [[CALL25:%.*]] = call contract float @__hipstdpar_tgamma_f32(float noundef [[CALL23]]) #[[ATTR4]]
+  %call25 = call contract float @tgammaf(float noundef %call23) #4
+  ; CHECK-NEXT:    [[CALL26:%.*]] = call contract double @__hipstdpar_lgamma_f64(double noundef [[CALL24]]) #[[ATTR3]]
+  %call26 = call contract double @lgamma(double noundef %call24) #5
+  ; CHECK-NEXT:    [[CALL27:%.*]] = call contract float @__hipstdpar_lgamma_f32(float noundef [[CALL25]]) #[[ATTR3]]
+  %call27 = call contract float @lgammaf(float noundef %call25) #5
+  ret void
+}
+
+declare double @llvm.fabs.f64(double) #1
+
+declare float @llvm.fabs.f32(float) #1
+
+declare hidden double @remainder(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @remainderf(float noundef, float noundef) local_unnamed_addr #2
+
+declare hidden double @remquo(double noundef, double noundef, ptr noundef) local_unnamed_addr #3
+
+declare hidden float @remquof(float noundef, float noundef, ptr noundef) local_unnamed_addr #3
+
+declare double @llvm.fma.f64(double, double, double) #1
+
+declare float @llvm.fma.f32(float, float, float) #1
+
+declare hidden double @fdim(double noundef, double noundef) local_unnamed_addr #2
+
+declare hidden float @fdimf(float noundef, float noundef) local_unnamed_addr #2
+
+declare double @llvm.exp.f64(double) #1
+
+declare float @llvm.exp.f32(float) #1
+
+declare double @llvm.exp2.f64(double) #1
+
+declare float @llvm.exp2.f32(float) #1
+
+declare hidden double @expm1(double noundef) local_unnamed_addr #2
+
+declare hidden float @expm1f(float noundef) local_unnamed_addr #2
+
+declare double @llvm.log.f64(double) #1
+
+declare float @llvm.log.f32(float) #1
+
+declare double @llvm.log10.f64(double) #1
+
+declare float @llvm.log10.f32(float) #1
+
+declare double @llvm.log2.f64(double) #1
+
+declare float @llvm.log2.f32(float) #1
+
+declare hidden double @log1p(double noundef) local_unnamed_addr #2
+
+declare hidden float @log1pf(float noundef) local_unnamed_addr #2
+
+declare float @llvm.pow.f32(float, float) #1
+
+declare double @llvm.sqrt.f64(double) #1
+
+declare float @llvm.sqrt.f32(float) #1
+
+declare hidden double @cbrt(double noundef) local_unnamed_addr #2
+
+declare hidden float @cbrtf(float noundef) local_unnamed_addr #2
+
+declare hidden double @hypot(double noundef...
[truncated]

Copy link

github-actions bot commented May 15, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

return PreservedAnalyses::all();

SmallVector<std::pair<Function *, std::string>> ToReplace;
for (auto &&F : M) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why &&?

No auto as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's a forwarding ref (NOT an rvalue ref), and this is somewhat idiomatic usage in range-for loops. It's also used elsewhere in this file already, so it's self consistent.

@shiltian shiltian requested a review from arsenm May 15, 2025 23:47
@AlexVlx AlexVlx requested a review from shiltian May 20, 2025 08:10
}
}

ToReplace.emplace_back(&F, std::move(N));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: no need to use std::move here.

Comment on lines +410 to +426
default: {
if (F.getReturnType()->isDoubleTy()) {
switch (ID) {
case Intrinsic::cos:
case Intrinsic::exp:
case Intrinsic::exp2:
case Intrinsic::log:
case Intrinsic::log10:
case Intrinsic::log2:
case Intrinsic::pow:
case Intrinsic::sin:
break;
default:
continue;
}
break;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why not just take the inner switch out and merge it with the outer one? I don't think they are semantically different.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only missing for FP64, it'd have to check per case, rather than once per switch. I don't know that one is significantly more readable than the other.

StringRef Prefix = "llvm";
ToReplace.back().second.replace(0, Prefix.size(), "__hipstdpar");
}
for (auto &&F : ToReplace)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use structured binding here ? F is not a function but a pair.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We most definitely can, but I never know just how hipster LLVM is with feature use. I'll switch it over and see if anyone objects.

Comment on lines 97 to 98
%call10 = call contract double @cbrt(double noundef %15) #4
%call11 = call contract float @cbrtf(float noundef %16) #4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the callsite attributes. Separate from this, probably should test that callsite function and parameter / return values are preserved (plus fpmath metadata)

case Intrinsic::tanh:
break;
default: {
if (F.getReturnType()->isDoubleTy()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't handle vectors, which is the main plus of the intrinsics

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but this is (at least for the time being) intended to cover the standard library, and neither the C nor the C++ one have vector support at the moment.

// happen if e.g. a library that was compiled without interposition returns
// an allocation that can be validly passed to `free`.
//
// 3. MathFixup (required): Some accelerators might have an incomplete
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We really need to fix our whole library usage strategy (which I'm working on writing up how)

}
for (auto &&F : ToReplace)
F.first->replaceAllUsesWith(
M.getOrInsertFunction(F.second, F.first->getFunctionType())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the function happened to have a non-default address space, this would produce a function with the wrong pointer type. It's probably not too important and this bug is widespread anyway

Comment on lines +393 to +395
auto It =
find_if(MathLibToHipStdPar, [&](auto &&M) { return M.first == N; });
if (It == std::cend(MathLibToHipStdPar))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we would cross reference this with the TargetLibraryInfo for whether this is a recognized call. That will always fail for AMDGPU since we have to say we have no library calls.

Failing that, should at least make an effort to respect nobuiltin

Comment on lines +385 to +386
if (!F.hasName())
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think anonymous functions need special handling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It saves on doing the lookup for the not_intrinsic case.

ToReplace.back().second.replace(0, Prefix.size(), "__hipstdpar");
}
for (auto &&F : ToReplace)
F.first->replaceAllUsesWith(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test with some abnormal transformed function uses? i.e. address capturing stores, call sites with the wrong signature, global initializer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I most definitely can but could you please clarify the middle scenario - what would a call site with the wrong signature look like / where would it originate, given that these would originate from well formed correct HLL invocations of math lib interfaces, which have passed the FE and are thus linguistically correct? Perhaps I am missing something obvious here.

searlmc1 added a commit to ROCm/llvm-project that referenced this pull request Jun 24, 2025
Port so to intercept 7.0 for HLRS drop
searlmc1 added a commit to ROCm/llvm-project that referenced this pull request Jun 24, 2025
Port so to intercept 7.0 for HLRS drop
@AlexVlx
Copy link
Contributor Author

AlexVlx commented Jul 15, 2025

Gentle ping, I'd like to progress this and add it to the backport list ASAP, since it's rather small, hipstdpar specific and user requested (we already have it in AMD's fork but it's a spurious delta to maintain).

@b-sumner
Copy link

I think we should land this to keep making progress. The grand redesign is certainly not imminent.

Copy link
Collaborator

@yxsamliu yxsamliu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@AlexVlx AlexVlx merged commit 6bcff9e into llvm:main Jul 28, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants