Skip to content

Conversation

@silee2
Copy link
Contributor

@silee2 silee2 commented Oct 20, 2025

Existing pattern for lowering gpu.printf op to LLVM call uses fixed function name and calling convention.
Those two should be exposed as pass option to allow supporting Intel Compute Runtime for GPU.

Also adds gpu.printf op pattern to GPU to LLVMSPV pass.
It may appear out of place, but integration test is added to XeVM integration test as that is the current best folder for testing with Intel Compute Runtime.
Test should be moved in the future if a better test folder is added.

@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

Changes

Existing pattern for lowering gpu.printf op to LLVM call uses fixed function name and calling convention.
Those two should be exposed as pass option to allow supporting Intel Compute Runtime for GPU.


Full diff: https://github.com/llvm/llvm-project/pull/164297.diff

5 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+4-2)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h (+11-4)
  • (modified) mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp (+4-1)
  • (added) mlir/test/Conversion/GPUToLLVMSPV/printf.mlir (+16)
  • (added) mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir (+30)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 2285d2695db4e..eb662a1b056de 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
       LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
                                   /*isVarArg=*/true);
   LLVM::LLVMFuncOp printfDecl =
-      getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
+      getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
+  printfDecl.setCConv(callingConvention);
 
   // Create the global op or find an existing one.
   LLVM::GlobalOp global = getOrCreateStringConstant(
@@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
   printfArgs.push_back(stringStart);
   printfArgs.append(argsRange.begin(), argsRange.end());
 
-  LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+  auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+  call.setCConv(callingConvention);
   rewriter.eraseOp(gpuPrintfOp);
   return success();
 }
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 66d3bb40a8f5a..adf5ba2feb591 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -10,6 +10,7 @@
 
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 
 namespace mlir {
@@ -142,13 +143,17 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
 /// This pass will add a declaration of printf() to the GPUModule if needed
 /// and separate out the format strings into global constants. For some
 /// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
-/// will lower printf calls to appropriate device-side code
+/// will lower printf calls to appropriate device-side code.
+/// callingConvention and funcName can be adjusted as needed.
 struct GPUPrintfOpToLLVMCallLowering
     : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
-  GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter,
-                                int addressSpace = 0)
+  GPUPrintfOpToLLVMCallLowering(
+      const LLVMTypeConverter &converter, int addressSpace = 0,
+      LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C,
+      StringRef funcName = "printf")
       : ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
-        addressSpace(addressSpace) {}
+        addressSpace(addressSpace), callingConvention(callingConvention),
+        funcName(funcName) {}
 
   LogicalResult
   matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
@@ -156,6 +161,8 @@ struct GPUPrintfOpToLLVMCallLowering
 
 private:
   int addressSpace;
+  LLVM::cconv::CConv callingConvention;
+  StringRef funcName;
 };
 
 /// Lowering of gpu.printf to a vprintf standard library.
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index c2363a1a40294..25f1e1b184d61 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final
                         gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
                         gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
                         gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
-                        gpu::ThreadIdOp>();
+                        gpu::ThreadIdOp, gpu::PrintfOp>();
 
     populateGpuToLLVMSPVConversionPatterns(converter, patterns);
     populateGpuMemorySpaceAttributeConversions(converter);
+    patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2,
+                                                LLVM::cconv::CConv::SPIR_FUNC,
+                                                "_Z6printfPU3AS2Kcz");
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir b/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir
new file mode 100644
index 0000000000000..74017e8354cf1
--- /dev/null
+++ b/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-gpu-to-llvm-spv | FileCheck %s
+
+gpu.module @test_module {
+  // CHECK: llvm.mlir.global internal constant @[[$PRINT_GLOBAL:[A-Za-z0-9_]+]]("Hello: %d\0A\00")  {addr_space = 2 : i32}
+  // CHECK: llvm.func spir_funccc @_Z6printfPU3AS2Kcz(!llvm.ptr<2>, ...) -> i32
+  // CHECK-LABEL: llvm.func spir_funccc @test_printf
+  // CHECK: (%[[ARG0:.*]]: i32)
+  gpu.func @test_printf(%arg0: i32) {
+    // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<2>
+    // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<2>) -> !llvm.ptr<2>, !llvm.array<11 x i8>
+    // CHECK-NEXT: %{{.*}} = llvm.call spir_funccc @_Z6printfPU3AS2Kcz(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32) -> i32
+    gpu.printf "Hello: %d\n", %arg0 : i32
+    gpu.return
+  }
+}
+
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
new file mode 100644
index 0000000000000..edf8775c72418
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_sycl_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @test attributes {gpu.container_module} {
+  gpu.module @test_module {
+    gpu.func @test_printf(%arg0: i32, %arg1: f32) kernel {
+      gpu.printf "Hello: %d\n", %arg0 : i32
+      gpu.printf "Hello: %f\n", %arg1 : f32
+      gpu.return
+    }
+  }
+
+  func.func @main() attributes {llvm.emit_c_interface} {
+    %c1 = arith.constant 1 : index
+    %c11 = arith.constant 11 : i32
+    %c4 = arith.constant 4.0 : f32
+    // CHECK: Hello: 11
+    // CHECK: Hello: 4.000000
+    gpu.launch_func @test_module::@test_printf blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%c11 : i32, %c4 : f32)
+    return
+  }
+}

@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2025

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

Existing pattern for lowering gpu.printf op to LLVM call uses fixed function name and calling convention.
Those two should be exposed as pass option to allow supporting Intel Compute Runtime for GPU.


Full diff: https://github.com/llvm/llvm-project/pull/164297.diff

5 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+4-2)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h (+11-4)
  • (modified) mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp (+4-1)
  • (added) mlir/test/Conversion/GPUToLLVMSPV/printf.mlir (+16)
  • (added) mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir (+30)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 2285d2695db4e..eb662a1b056de 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -507,7 +507,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
       LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {ptrType},
                                   /*isVarArg=*/true);
   LLVM::LLVMFuncOp printfDecl =
-      getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
+      getOrDefineFunction(moduleOp, loc, rewriter, funcName, printfType);
+  printfDecl.setCConv(callingConvention);
 
   // Create the global op or find an existing one.
   LLVM::GlobalOp global = getOrCreateStringConstant(
@@ -530,7 +531,8 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
   printfArgs.push_back(stringStart);
   printfArgs.append(argsRange.begin(), argsRange.end());
 
-  LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+  auto call = LLVM::CallOp::create(rewriter, loc, printfDecl, printfArgs);
+  call.setCConv(callingConvention);
   rewriter.eraseOp(gpuPrintfOp);
   return success();
 }
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
index 66d3bb40a8f5a..adf5ba2feb591 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h
@@ -10,6 +10,7 @@
 
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 
 namespace mlir {
@@ -142,13 +143,17 @@ struct GPUPrintfOpToHIPLowering : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
 /// This pass will add a declaration of printf() to the GPUModule if needed
 /// and separate out the format strings into global constants. For some
 /// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
-/// will lower printf calls to appropriate device-side code
+/// will lower printf calls to appropriate device-side code.
+/// callingConvention and funcName can be adjusted as needed.
 struct GPUPrintfOpToLLVMCallLowering
     : public ConvertOpToLLVMPattern<gpu::PrintfOp> {
-  GPUPrintfOpToLLVMCallLowering(const LLVMTypeConverter &converter,
-                                int addressSpace = 0)
+  GPUPrintfOpToLLVMCallLowering(
+      const LLVMTypeConverter &converter, int addressSpace = 0,
+      LLVM::cconv::CConv callingConvention = LLVM::cconv::CConv::C,
+      StringRef funcName = "printf")
       : ConvertOpToLLVMPattern<gpu::PrintfOp>(converter),
-        addressSpace(addressSpace) {}
+        addressSpace(addressSpace), callingConvention(callingConvention),
+        funcName(funcName) {}
 
   LogicalResult
   matchAndRewrite(gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
@@ -156,6 +161,8 @@ struct GPUPrintfOpToLLVMCallLowering
 
 private:
   int addressSpace;
+  LLVM::cconv::CConv callingConvention;
+  StringRef funcName;
 };
 
 /// Lowering of gpu.printf to a vprintf standard library.
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index c2363a1a40294..25f1e1b184d61 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -470,10 +470,13 @@ struct GPUToLLVMSPVConversionPass final
                         gpu::GPUFuncOp, gpu::GlobalIdOp, gpu::GridDimOp,
                         gpu::LaneIdOp, gpu::NumSubgroupsOp, gpu::ReturnOp,
                         gpu::ShuffleOp, gpu::SubgroupIdOp, gpu::SubgroupSizeOp,
-                        gpu::ThreadIdOp>();
+                        gpu::ThreadIdOp, gpu::PrintfOp>();
 
     populateGpuToLLVMSPVConversionPatterns(converter, patterns);
     populateGpuMemorySpaceAttributeConversions(converter);
+    patterns.add<GPUPrintfOpToLLVMCallLowering>(converter, /*addressSpace=*/2,
+                                                LLVM::cconv::CConv::SPIR_FUNC,
+                                                "_Z6printfPU3AS2Kcz");
 
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir b/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir
new file mode 100644
index 0000000000000..74017e8354cf1
--- /dev/null
+++ b/mlir/test/Conversion/GPUToLLVMSPV/printf.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt %s -convert-gpu-to-llvm-spv | FileCheck %s
+
+gpu.module @test_module {
+  // CHECK: llvm.mlir.global internal constant @[[$PRINT_GLOBAL:[A-Za-z0-9_]+]]("Hello: %d\0A\00")  {addr_space = 2 : i32}
+  // CHECK: llvm.func spir_funccc @_Z6printfPU3AS2Kcz(!llvm.ptr<2>, ...) -> i32
+  // CHECK-LABEL: llvm.func spir_funccc @test_printf
+  // CHECK: (%[[ARG0:.*]]: i32)
+  gpu.func @test_printf(%arg0: i32) {
+    // CHECK: %[[IMM0:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL]] : !llvm.ptr<2>
+    // CHECK-NEXT: %[[IMM2:.*]] = llvm.getelementptr %[[IMM0]][0, 0] : (!llvm.ptr<2>) -> !llvm.ptr<2>, !llvm.array<11 x i8>
+    // CHECK-NEXT: %{{.*}} = llvm.call spir_funccc @_Z6printfPU3AS2Kcz(%[[IMM2]], %[[ARG0]]) vararg(!llvm.func<i32 (ptr<2>, ...)>) : (!llvm.ptr<2>, i32) -> i32
+    gpu.printf "Hello: %d\n", %arg0 : i32
+    gpu.return
+  }
+}
+
diff --git a/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
new file mode 100644
index 0000000000000..edf8775c72418
--- /dev/null
+++ b/mlir/test/Integration/Dialect/XeVM/GPU/gpu_printf.mlir
@@ -0,0 +1,30 @@
+// RUN: mlir-opt %s \
+// RUN: | mlir-opt -pass-pipeline='builtin.module(cse,func.func(gpu-async-region),xevm-attach-target,gpu.module(convert-gpu-to-llvm-spv{use-64bit-index=true},convert-xevm-to-llvm,cse))' \
+// RUN: | mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \
+// RUN: | mlir-opt -gpu-to-llvm -reconcile-unrealized-casts -cse -gpu-module-to-binary \
+// RUN: | mlir-runner \
+// RUN:   --shared-libs=%mlir_sycl_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --shared-libs=%mlir_c_runner_utils \
+// RUN:   --entry-point-result=void \
+// RUN: | FileCheck %s
+
+module @test attributes {gpu.container_module} {
+  gpu.module @test_module {
+    gpu.func @test_printf(%arg0: i32, %arg1: f32) kernel {
+      gpu.printf "Hello: %d\n", %arg0 : i32
+      gpu.printf "Hello: %f\n", %arg1 : f32
+      gpu.return
+    }
+  }
+
+  func.func @main() attributes {llvm.emit_c_interface} {
+    %c1 = arith.constant 1 : index
+    %c11 = arith.constant 11 : i32
+    %c4 = arith.constant 4.0 : f32
+    // CHECK: Hello: 11
+    // CHECK: Hello: 4.000000
+    gpu.launch_func @test_module::@test_printf blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1) args(%c11 : i32, %c4 : f32)
+    return
+  }
+}

@silee2 silee2 requested a review from charithaintc October 20, 2025 18:30
Copy link
Contributor

@fabianmcg fabianmcg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than that, LGTM.

/// runtimes, such as OpenCL on AMD, this is sufficient setup, as the compiler
/// will lower printf calls to appropriate device-side code
/// will lower printf calls to appropriate device-side code.
/// callingConvention and funcName can be adjusted as needed.
Copy link
Contributor

@fabianmcg fabianmcg Oct 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please expand the comment here explaining why these needs to be customizable. This means, saying callingConvention, and funcName are customizable as not all backends use the same, for example, the SPV LLVM backend uses....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated comment as suggested.

@fabianmcg fabianmcg self-requested a review October 22, 2025 21:35
@silee2 silee2 merged commit 1501454 into llvm:main Oct 23, 2025
10 checks passed
dvbuka pushed a commit to dvbuka/llvm-project that referenced this pull request Oct 27, 2025
…lvm#164297)

Existing pattern for lowering gpu.printf op to LLVM call uses fixed
function name and calling convention.
Those two should be exposed as pass option to allow supporting Intel
Compute Runtime for GPU.

Also adds gpu.printf op pattern to GPU to LLVMSPV pass.
It may appear out of place, but integration test is added to XeVM
integration test as that is the current best folder for testing with
Intel Compute Runtime.
Test should be moved in the future if a better test folder is added.
Lukacma pushed a commit to Lukacma/llvm-project that referenced this pull request Oct 29, 2025
…lvm#164297)

Existing pattern for lowering gpu.printf op to LLVM call uses fixed
function name and calling convention.
Those two should be exposed as pass option to allow supporting Intel
Compute Runtime for GPU.

Also adds gpu.printf op pattern to GPU to LLVMSPV pass.
It may appear out of place, but integration test is added to XeVM
integration test as that is the current best folder for testing with
Intel Compute Runtime.
Test should be moved in the future if a better test folder is added.
aokblast pushed a commit to aokblast/llvm-project that referenced this pull request Oct 30, 2025
…lvm#164297)

Existing pattern for lowering gpu.printf op to LLVM call uses fixed
function name and calling convention.
Those two should be exposed as pass option to allow supporting Intel
Compute Runtime for GPU.

Also adds gpu.printf op pattern to GPU to LLVMSPV pass.
It may appear out of place, but integration test is added to XeVM
integration test as that is the current best folder for testing with
Intel Compute Runtime.
Test should be moved in the future if a better test folder is added.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants