diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h new file mode 100644 index 0000000000000..fa7a635568c7c --- /dev/null +++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h @@ -0,0 +1,26 @@ +//===- MathToROCDL.h - Utils to convert from the complex dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ +#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/IR/PatternMatch.h" +#include + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTMATHTOROCDL +#include "mlir/Conversion/Passes.h.inc" + +/// Populate the given list with patterns that convert from Math to ROCDL calls. +void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 8c6f85d461aea..208f26489d6c3 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -46,6 +46,7 @@ #include "mlir/Conversion/MathToFuncs/MathToFuncs.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/MathToLibm/MathToLibm.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h" #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 560b088dbe5cd..54b94bbfb93d1 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -733,6 +733,23 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> { ]; } +//===----------------------------------------------------------------------===// +// MathToLibm +//===----------------------------------------------------------------------===// + +def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { + let summary = "Convert Math dialect to ROCDL library calls"; + let description = [{ + This pass converts supported Math ops to ROCDL library calls. + }]; + let dependentDialects = [ + "arith::ArithDialect", + "func::FuncDialect", + "ROCDL::ROCDLDialect", + "vector::VectorDialect", + ]; +} + //===----------------------------------------------------------------------===// // MathToSPIRV //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index e107738a4c50c..80c8b84d9ae89 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon) add_subdirectory(MathToFuncs) add_subdirectory(MathToLibm) add_subdirectory(MathToLLVM) +add_subdirectory(MathToROCDL) add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt index 70707b5c3a049..945e3ccdfa87b 100644 --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms MLIRArithToLLVM MLIRArithTransforms MLIRMathToLLVM + MLIRMathToROCDL MLIRAMDGPUToROCDL MLIRFuncToLLVM MLIRGPUDialect diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index 40eb15a491063..100181cdc69fe 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" @@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns( patterns.add(converter); - populateOpPatterns(converter, patterns, "__ocml_fabs_f32", - "__ocml_fabs_f64"); - populateOpPatterns(converter, patterns, "__ocml_atan_f32", - "__ocml_atan_f64"); - populateOpPatterns(converter, patterns, "__ocml_atan2_f32", - "__ocml_atan2_f64"); - populateOpPatterns(converter, patterns, "__ocml_cbrt_f32", - "__ocml_cbrt_f64"); - populateOpPatterns(converter, patterns, "__ocml_ceil_f32", - "__ocml_ceil_f64"); - populateOpPatterns(converter, patterns, "__ocml_cos_f32", - "__ocml_cos_f64"); - populateOpPatterns(converter, patterns, "__ocml_exp_f32", - "__ocml_exp_f64"); - populateOpPatterns(converter, patterns, "__ocml_exp2_f32", - "__ocml_exp2_f64"); - populateOpPatterns(converter, patterns, "__ocml_expm1_f32", - "__ocml_expm1_f64"); - populateOpPatterns(converter, patterns, "__ocml_floor_f32", - "__ocml_floor_f64"); - populateOpPatterns(converter, patterns, "__ocml_fmod_f32", - "__ocml_fmod_f64"); - populateOpPatterns(converter, patterns, "__ocml_log_f32", - "__ocml_log_f64"); - populateOpPatterns(converter, patterns, "__ocml_log10_f32", - "__ocml_log10_f64"); - populateOpPatterns(converter, patterns, "__ocml_log1p_f32", - "__ocml_log1p_f64"); - populateOpPatterns(converter, patterns, "__ocml_log2_f32", - "__ocml_log2_f64"); - populateOpPatterns(converter, patterns, "__ocml_pow_f32", - "__ocml_pow_f64"); - populateOpPatterns(converter, patterns, "__ocml_rsqrt_f32", - "__ocml_rsqrt_f64"); - populateOpPatterns(converter, patterns, "__ocml_sin_f32", - "__ocml_sin_f64"); - populateOpPatterns(converter, patterns, "__ocml_sqrt_f32", - "__ocml_sqrt_f64"); - populateOpPatterns(converter, patterns, "__ocml_tanh_f32", - "__ocml_tanh_f64"); - populateOpPatterns(converter, patterns, "__ocml_tan_f32", - "__ocml_tan_f64"); - populateOpPatterns(converter, patterns, "__ocml_erf_f32", - "__ocml_erf_f64"); + populateMathToROCDLConversionPatterns(converter, patterns); } std::unique_ptr> diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt new file mode 100644 index 0000000000000..2771955aa9493 --- /dev/null +++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_conversion_library(MLIRMathToROCDL + MathToROCDL.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRFuncDialect + MLIRGPUToGPURuntimeTransforms + MLIRMathDialect + MLIRLLVMCommonConversion + MLIRPass + MLIRTransformUtils + MLIRVectorDialect + MLIRVectorUtils + ) diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp new file mode 100644 index 0000000000000..03c7ce5dac0d1 --- /dev/null +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -0,0 +1,146 @@ +//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "../GPUCommon/GPUOpsLowering.h" +#include "../GPUCommon/IndexIntrinsicsOpLowering.h" +#include "../GPUCommon/OpToFuncCallLowering.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTMATHTOROCDL +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "math-to-rocdl" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") + +template +static void populateOpPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns, StringRef f32Func, + StringRef f64Func) { + patterns.add>(converter); + patterns.add>(converter, f32Func, f64Func); +} + +void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns) { + // Handled by mathToLLVM: math::AbsIOp + // Handled by mathToLLVM: math::CopySignOp + // Handled by mathToLLVM: math::CountLeadingZerosOp + // Handled by mathToLLVM: math::CountTrailingZerosOp + // Handled by mathToLLVM: math::CgPopOp + // Handled by mathToLLVM: math::FmaOp + // FIXME: math::IPowIOp + // FIXME: math::FPowIOp + // Handled by mathToLLVM: math::RoundEvenOp + // Handled by mathToLLVM: math::RoundOp + // Handled by mathToLLVM: math::TruncOp + populateOpPatterns(converter, patterns, "__ocml_fabs_f32", + "__ocml_fabs_f64"); + populateOpPatterns(converter, patterns, "__ocml_acos_f32", + "__ocml_acos_f64"); + populateOpPatterns(converter, patterns, "__ocml_acosh_f32", + "__ocml_acosh_f64"); + populateOpPatterns(converter, patterns, "__ocml_asin_f32", + "__ocml_asin_f64"); + populateOpPatterns(converter, patterns, "__ocml_asinh_f32", + "__ocml_asinh_f64"); + populateOpPatterns(converter, patterns, "__ocml_atan_f32", + "__ocml_atan_f64"); + populateOpPatterns(converter, patterns, "__ocml_atanh_f32", + "__ocml_atanh_f64"); + populateOpPatterns(converter, patterns, "__ocml_atan2_f32", + "__ocml_atan2_f64"); + populateOpPatterns(converter, patterns, "__ocml_cbrt_f32", + "__ocml_cbrt_f64"); + populateOpPatterns(converter, patterns, "__ocml_ceil_f32", + "__ocml_ceil_f64"); + populateOpPatterns(converter, patterns, "__ocml_cos_f32", + "__ocml_cos_f64"); + populateOpPatterns(converter, patterns, "__ocml_cosh_f32", + "__ocml_cosh_f64"); + populateOpPatterns(converter, patterns, "__ocml_sinh_f32", + "__ocml_sinh_f64"); + populateOpPatterns(converter, patterns, "__ocml_exp_f32", + "__ocml_exp_f64"); + populateOpPatterns(converter, patterns, "__ocml_exp2_f32", + "__ocml_exp2_f64"); + populateOpPatterns(converter, patterns, "__ocml_expm1_f32", + "__ocml_expm1_f64"); + populateOpPatterns(converter, patterns, "__ocml_floor_f32", + "__ocml_floor_f64"); + populateOpPatterns(converter, patterns, "__ocml_log_f32", + "__ocml_log_f64"); + populateOpPatterns(converter, patterns, "__ocml_log10_f32", + "__ocml_log10_f64"); + populateOpPatterns(converter, patterns, "__ocml_log1p_f32", + "__ocml_log1p_f64"); + populateOpPatterns(converter, patterns, "__ocml_log2_f32", + "__ocml_log2_f64"); + populateOpPatterns(converter, patterns, "__ocml_pow_f32", + "__ocml_pow_f64"); + populateOpPatterns(converter, patterns, "__ocml_rsqrt_f32", + "__ocml_rsqrt_f64"); + populateOpPatterns(converter, patterns, "__ocml_sin_f32", + "__ocml_sin_f64"); + populateOpPatterns(converter, patterns, "__ocml_sqrt_f32", + "__ocml_sqrt_f64"); + populateOpPatterns(converter, patterns, "__ocml_tanh_f32", + "__ocml_tanh_f64"); + populateOpPatterns(converter, patterns, "__ocml_tan_f32", + "__ocml_tan_f64"); + populateOpPatterns(converter, patterns, "__ocml_erf_f32", + "__ocml_erf_f64"); + // Single arith pattern that needs a ROCDL call, probably not + // worth creating a separate pass for it. + populateOpPatterns(converter, patterns, "__ocml_fmod_f32", + "__ocml_fmod_f64"); +} + +namespace { +struct ConvertMathToROCDLPass + : public impl::ConvertMathToROCDLBase { + ConvertMathToROCDLPass() = default; + void runOnOperation() override; +}; +} // namespace + +void ConvertMathToROCDLPass::runOnOperation() { + auto m = getOperation(); + MLIRContext *ctx = m.getContext(); + + RewritePatternSet patterns(&getContext()); + LowerToLLVMOptions options(ctx, DataLayout(m)); + LLVMTypeConverter converter(ctx, options); + populateMathToROCDLConversionPatterns(converter, patterns); + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(applyPartialConversion(m, target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir new file mode 100644 index 0000000000000..a406ec45a7f10 --- /dev/null +++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir @@ -0,0 +1,435 @@ +// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s + +module @test_module { + // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32 + // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64 + // CHECK-LABEL: func @arith_remf + func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = arith.remf %arg_f32, %arg_f32 : f32 + // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %result64 = arith.remf %arg_f64, %arg_f64 : f64 + // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64 + // CHECK-LABEL: func @math_absf + func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.absf %arg_f32 : f32 + // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.absf %arg_f64 : f64 + // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64 + // CHECK-LABEL: func @math_acos + func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.acos %arg_f32 : f32 + // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.acos %arg_f64 : f64 + // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64 + // CHECK-LABEL: func @math_acosh + func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.acosh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.acosh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64 + // CHECK-LABEL: func @math_asin + func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.asin %arg_f32 : f32 + // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.asin %arg_f64 : f64 + // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64 + // CHECK-LABEL: func @math_asinh + func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.asinh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.asinh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64 + // CHECK-LABEL: func @math_atan + func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.atan %arg_f32 : f32 + // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.atan %arg_f64 : f64 + // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64 + // CHECK-LABEL: func @math_atanh + func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.atanh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_atanh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.atanh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_atanh_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_atan2_f32(f32, f32) -> f32 + // CHECK: llvm.func @__ocml_atan2_f64(f64, f64) -> f64 + // CHECK-LABEL: func @math_atan2 + func.func @math_atan2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.atan2 %arg_f32, %arg_f32 : f32 + // CHECK: llvm.call @__ocml_atan2_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %result64 = math.atan2 %arg_f64, %arg_f64 : f64 + // CHECK: llvm.call @__ocml_atan2_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64 + // CHECK-LABEL: func @math_cbrt + func.func @math_cbrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.cbrt %arg_f32 : f32 + // CHECK: llvm.call @__ocml_cbrt_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.cbrt %arg_f64 : f64 + // CHECK: llvm.call @__ocml_cbrt_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_ceil_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_ceil_f64(f64) -> f64 + // CHECK-LABEL: func @math_ceil + func.func @math_ceil(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.ceil %arg_f32 : f32 + // CHECK: llvm.call @__ocml_ceil_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.ceil %arg_f64 : f64 + // CHECK: llvm.call @__ocml_ceil_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_cos_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_cos_f64(f64) -> f64 + // CHECK-LABEL: func @math_cos + func.func @math_cos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.cos %arg_f32 : f32 + // CHECK: llvm.call @__ocml_cos_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.cos %arg_f64 : f64 + // CHECK: llvm.call @__ocml_cos_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_cosh_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_cosh_f64(f64) -> f64 + // CHECK-LABEL: func @math_cosh + func.func @math_cosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.cosh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_cosh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.cosh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_cosh_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_sinh_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_sinh_f64(f64) -> f64 + // CHECK-LABEL: func @math_sinh + func.func @math_sinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.sinh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sinh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.sinh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sinh_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_exp_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_exp_f64(f64) -> f64 + // CHECK-LABEL: func @math_exp + func.func @math_exp(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.exp %arg_f32 : f32 + // CHECK: llvm.call @__ocml_exp_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.exp %arg_f64 : f64 + // CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_exp2_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_exp2_f64(f64) -> f64 + // CHECK-LABEL: func @math_exp2 + func.func @math_exp2(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.exp2 %arg_f32 : f32 + // CHECK: llvm.call @__ocml_exp2_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.exp2 %arg_f64 : f64 + // CHECK: llvm.call @__ocml_exp2_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_expm1_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_expm1_f64(f64) -> f64 + // CHECK-LABEL: func @math_expm1 + func.func @math_expm1(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.expm1 %arg_f32 : f32 + // CHECK: llvm.call @__ocml_expm1_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.expm1 %arg_f64 : f64 + // CHECK: llvm.call @__ocml_expm1_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_floor_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_floor_f64(f64) -> f64 + // CHECK-LABEL: func @math_floor + func.func @math_floor(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.floor %arg_f32 : f32 + // CHECK: llvm.call @__ocml_floor_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.floor %arg_f64 : f64 + // CHECK: llvm.call @__ocml_floor_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_log_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_log_f64(f64) -> f64 + // CHECK-LABEL: func @math_log + func.func @math_log(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.log %arg_f32 : f32 + // CHECK: llvm.call @__ocml_log_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.log %arg_f64 : f64 + // CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_log10_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_log10_f64(f64) -> f64 + // CHECK-LABEL: func @math_log10 + func.func @math_log10(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.log10 %arg_f32 : f32 + // CHECK: llvm.call @__ocml_log10_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.log10 %arg_f64 : f64 + // CHECK: llvm.call @__ocml_log10_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_log1p_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_log1p_f64(f64) -> f64 + // CHECK-LABEL: func @math_log1p + func.func @math_log1p(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.log1p %arg_f32 : f32 + // CHECK: llvm.call @__ocml_log1p_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.log1p %arg_f64 : f64 + // CHECK: llvm.call @__ocml_log1p_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_pow_f32(f32, f32) -> f32 + // CHECK: llvm.func @__ocml_pow_f64(f64, f64) -> f64 + // CHECK-LABEL: func @math_powf + func.func @math_powf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.powf %arg_f32, %arg_f32 : f32 + // CHECK: llvm.call @__ocml_pow_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %result64 = math.powf %arg_f64, %arg_f64 : f64 + // CHECK: llvm.call @__ocml_pow_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_rsqrt_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_rsqrt_f64(f64) -> f64 + // CHECK-LABEL: func @math_rsqrt + func.func @math_rsqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.rsqrt %arg_f32 : f32 + // CHECK: llvm.call @__ocml_rsqrt_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.rsqrt %arg_f64 : f64 + // CHECK: llvm.call @__ocml_rsqrt_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_sin_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_sin_f64(f64) -> f64 + // CHECK-LABEL: func @math_sin + func.func @math_sin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.sin %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sin_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.sin %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sin_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_sqrt_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_sqrt_f64(f64) -> f64 + // CHECK-LABEL: func @math_sqrt + func.func @math_sqrt(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.sqrt %arg_f32 : f32 + // CHECK: llvm.call @__ocml_sqrt_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.sqrt %arg_f64 : f64 + // CHECK: llvm.call @__ocml_sqrt_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_tanh_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_tanh_f64(f64) -> f64 + // CHECK-LABEL: func @math_tanh + func.func @math_tanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.tanh %arg_f32 : f32 + // CHECK: llvm.call @__ocml_tanh_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.tanh %arg_f64 : f64 + // CHECK: llvm.call @__ocml_tanh_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_tan_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_tan_f64(f64) -> f64 + // CHECK-LABEL: func @math_tan + func.func @math_tan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.tan %arg_f32 : f32 + // CHECK: llvm.call @__ocml_tan_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.tan %arg_f64 : f64 + // CHECK: llvm.call @__ocml_tan_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_erf_f32(f32) -> f32 + // CHECK: llvm.func @__ocml_erf_f64(f64) -> f64 + // CHECK-LABEL: func @math_erf + func.func @math_erf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = math.erf %arg_f32 : f32 + // CHECK: llvm.call @__ocml_erf_f32(%{{.*}}) : (f32) -> f32 + %result64 = math.erf %arg_f64 : f64 + // CHECK: llvm.call @__ocml_erf_f64(%{{.*}}) : (f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} + +// ----- + +module @test_module { + // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32 + // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64 + // CHECK-LABEL: func @arith_remf + func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) { + %result32 = arith.remf %arg_f32, %arg_f32 : f32 + // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32 + %result64 = arith.remf %arg_f64, %arg_f64 : f64 + // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64 + func.return %result32, %result64 : f32, f64 + } +} +