From 1799e6a9448a382dd76bdecf2457e9c78c7c1087 Mon Sep 17 00:00:00 2001 From: Giacomo Castiglioni Date: Fri, 31 Oct 2025 16:05:46 +0100 Subject: [PATCH 1/3] GPU mma fp64 extension --- .../mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h | 2 +- mlir/include/mlir/Dialect/GPU/IR/GPUBase.td | 2 +- mlir/include/mlir/Dialect/GPU/IR/GPUOps.td | 8 +-- .../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 52 +++++++++++--- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 4 +- .../GPUToNVVM/wmma-ops-to-nvvm.mlir | 22 ++++++ .../GPU/CUDA/TensorCore/wmma-matmul-f64.mlir | 72 +++++++++++++++++++ 7 files changed, 144 insertions(+), 18 deletions(-) create mode 100644 mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h index 4c8abea680b66..48982ac6efe7c 100644 --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -27,7 +27,7 @@ class MMAMatrixType; #define GEN_PASS_DECL_CONVERTGPUOPSTONVVMOPS #include "mlir/Conversion/Passes.h.inc" -LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type); +Type convertMMAToLLVMType(gpu::MMAMatrixType type); /// Configure target to convert from the GPU dialect to NVVM. void configureGpuToNVVMConversionLegality(ConversionTarget &target); diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td index 860f893367203..2c29bb8a01a41 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -114,7 +114,7 @@ def GPU_MMAMatrix : DialectType< GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">; // Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops. -def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, VectorOfRankAndType<[1], [I8, I32, F16, F32]>]>; +def GPU_MMAMemRef : MemRefOf<[I8, I32, F16, F32, F64, VectorOfRankAndType<[1], [I8, I32, F16, F32, F64]>]>; class MMAMatrixOf allowedTypes> : ContainerType, IsMMAMatrixTypePred, diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td index a6c6038e1e224..5c7df25c58cde 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1872,7 +1872,7 @@ def GPU_SubgroupMmaStoreMatrixOp : GPU_Op<"subgroup_mma_store_matrix", ``` }]; - let arguments = (ins Arg>:$src, + let arguments = (ins Arg>:$src, Arg]>:$dstMemref, Variadic:$indices, IndexAttr:$leadDimension, @@ -1919,9 +1919,9 @@ def GPU_SubgroupMmaComputeOp ``` }]; - let arguments = (ins Arg>:$opA, - Arg>:$opB, - Arg>:$opC, + let arguments = (ins Arg>:$opA, + Arg>:$opB, + Arg>:$opC, OptionalAttr:$a_transpose, OptionalAttr:$b_transpose); diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index 99c059cb03299..fb1a37a03fe4d 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" using namespace mlir; @@ -57,7 +58,8 @@ static NVVM::MMATypes getElementType(gpu::MMAMatrixType type) { if (type.getElementType().isF32()) return type.getOperand() == "COp" ? NVVM::MMATypes::f32 : NVVM::MMATypes::tf32; - + if (type.getElementType().isF64()) + return NVVM::MMATypes::f64; if (type.getElementType().isSignedInteger(8)) return NVVM::MMATypes::s8; if (type.getElementType().isUnsignedInteger(8)) @@ -212,8 +214,13 @@ struct WmmaMmaOpToNVVMLowering // then passed on to the intrinsic call. Emit llvm ops to extract individual // values form lowered memrefs. SmallVector unpackedOps; - auto unpackOp = [&](Value operand) { + // f64 a and b fragments are not structs but scalars. + if (!isa(operand.getType())) { + unpackedOps.push_back(operand); + return; + } + // every other type is lowered to an LLVM struct, extract the values. auto structType = cast(operand.getType()); for (size_t i = 0, e = structType.getBody().size(); i < e; ++i) { Value toUse = LLVM::ExtractValueOp::create(rewriter, loc, operand, i); @@ -276,10 +283,16 @@ struct WmmaConstantOpToNVVMLowering return failure(); Location loc = subgroupMmaConstantOp.getLoc(); Value cst = adaptor.getOperands()[0]; - LLVM::LLVMStructType type = convertMMAToLLVMType( + Type type = convertMMAToLLVMType( cast(subgroupMmaConstantOp.getType())); + // If the element is not a struct, it means it's a scalar f64. + LLVM::LLVMStructType structType = dyn_cast(type); + if (!structType) { + rewriter.replaceOp(subgroupMmaConstantOp, cst); + return success(); + } // If the element type is a vector create a vector from the operand. - if (auto vecType = dyn_cast(type.getBody()[0])) { + if (auto vecType = dyn_cast(structType.getBody()[0])) { Value vecCst = LLVM::PoisonOp::create(rewriter, loc, vecType); for (int64_t vecEl = 0; vecEl < vecType.getNumElements(); vecEl++) { Value idx = LLVM::ConstantOp::create(rewriter, loc, @@ -289,8 +302,8 @@ struct WmmaConstantOpToNVVMLowering } cst = vecCst; } - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, type); - for (size_t i : llvm::seq(size_t(0), type.getBody().size())) { + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structType); + for (size_t i : llvm::seq(size_t(0), structType.getBody().size())) { matrixStruct = LLVM::InsertValueOp::create(rewriter, loc, matrixStruct, cst, i); } @@ -354,10 +367,24 @@ struct WmmaElementwiseOpToNVVMLowering return failure(); Location loc = subgroupMmaElementwiseOp.getLoc(); size_t numOperands = adaptor.getOperands().size(); - LLVM::LLVMStructType destType = convertMMAToLLVMType( + Type destType = convertMMAToLLVMType( cast(subgroupMmaElementwiseOp.getType())); - Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, destType); - for (size_t i = 0, e = destType.getBody().size(); i < e; ++i) { + + // If the element is not a struct, it means it's a scalar f64. + LLVM::LLVMStructType structDestTy = dyn_cast(destType); + if (!structDestTy) { + SmallVector operands; + for (auto operand : adaptor.getOperands()) { + operands.push_back(operand); + } + Value element = + createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), + operands); + rewriter.replaceOp(subgroupMmaElementwiseOp, element); + return success(); + } + Value matrixStruct = LLVM::PoisonOp::create(rewriter, loc, structDestTy); + for (size_t i = 0, e = structDestTy.getBody().size(); i < e; ++i) { SmallVector extractedOperands; for (size_t opIdx = 0; opIdx < numOperands; opIdx++) { extractedOperands.push_back(LLVM::ExtractValueOp::create( @@ -377,13 +404,18 @@ struct WmmaElementwiseOpToNVVMLowering } // namespace /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. -LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { +Type mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { NVVM::MMAFrag frag = convertOperand(type.getOperand()); NVVM::MMATypes eltType = getElementType(type); auto nRow = type.getShape()[0]; auto nCol = type.getShape()[1]; std::pair typeInfo = NVVM::inferMMAType(eltType, frag, nRow, nCol, type.getContext()); + // Special handling for f64 a and b fragments + Type f64Ty = Float64Type::get(type.getContext()); + if (typeInfo.first == f64Ty && typeInfo.second == 1) { + return f64Ty; + } return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector(typeInfo.second, typeInfo.first)); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 6c6d8d2bad55d..61a630aa88960 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -208,7 +208,7 @@ Type MMAMatrixType::getElementType() const { return getImpl()->elementType; } StringRef MMAMatrixType::getOperand() const { return getImpl()->getOperand(); } bool MMAMatrixType::isValidElementType(Type elementType) { - return elementType.isF16() || elementType.isF32() || + return elementType.isF16() || elementType.isF32() || elementType.isF64() || elementType.isUnsignedInteger(8) || elementType.isSignedInteger(8) || elementType.isInteger(32); } @@ -225,7 +225,7 @@ MMAMatrixType::verifyInvariants(function_ref emitError, if (!MMAMatrixType::isValidElementType(elementType)) return emitError() - << "MMAMatrixType elements must be SI8, UI8, I32, F16, or F32"; + << "MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64"; return success(); } diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir index b479467efc208..83b5fb5e6ea54 100644 --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -79,6 +79,28 @@ gpu.module @test_module { // ----- +gpu.module @test_module { + + // CHECK-LABEL: func @gpu_wmma_f64_load_op() -> + // CHECK-SAME: f64 + // CHECK32-LABEL: func @gpu_wmma_f64_load_op() -> + func.func @gpu_wmma_f64_load_op() -> (!gpu.mma_matrix<8x4xf64, "AOp">) { + %wg = memref.alloca() {alignment = 32} : memref<32x32xf64, 3> + %i = arith.constant 16 : index + %j = arith.constant 16 : index + %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index} : memref<32x32xf64, 3> -> !gpu.mma_matrix<8x4xf64, "AOp"> + return %0 : !gpu.mma_matrix<8x4xf64, "AOp"> + // CHECK: %[[MUL:.*]] = llvm.mul %{{.*}}, %{{.*}} : i64 + // CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %{{.*}} : i64 + // CHECK: %[[GEP:.*]] = llvm.getelementptr %{{.*}}[%[[ADD]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f64 + // CHECK: %[[C32_I32:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[LOAD:.*]] = nvvm.wmma.load %[[GEP]], %[[C32_I32]] {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 4 : i32, layout = #nvvm.mma_layout, m = 8 : i32, n = 8 : i32} : (!llvm.ptr<3>) -> f64 + // CHECK: llvm.return %[[LOAD]] : f64 + } +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: func @gpu_wmma_store_op diff --git a/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir new file mode 100644 index 0000000000000..a016a60022699 --- /dev/null +++ b/mlir/test/Integration/GPU/CUDA/TensorCore/wmma-matmul-f64.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt %s \ +// RUN: | mlir-opt -gpu-lower-to-nvvm-pipeline="cubin-chip=sm_80 cubin-format=%gpu_compilation_format" \ +// RUN: | mlir-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --entry-point-result=void \ +// RUN: | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d1, d0)> + +func.func @main() { + %a = memref.alloc() : memref<8x4xf64> + %b = memref.alloc() : memref<4x8xf64> + %c = memref.alloc() : memref<8x8xf64> + %d = memref.alloc() : memref<8x8xf64> + + %f1 = arith.constant 1.0e+00 : f64 + %fcst = arith.constant 3.14e+00 : f64 + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + // Initialize the Input matrixes with ones. + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c4 step %c1 { + memref.store %f1, %a[%arg0, %arg1] : memref<8x4xf64> + memref.store %f1, %b[%arg1, %arg0] : memref<4x8xf64> + } + } + // Initialize the accumulator matrix with a constant. + scf.for %arg0 = %c0 to %c8 step %c1 { + scf.for %arg1 = %c0 to %c8 step %c1 { + memref.store %fcst, %c[%arg0, %arg1] : memref<8x8xf64> + } + } + + %2 = memref.cast %a : memref<8x4xf64> to memref<*xf64> + %20 = memref.cast %b : memref<4x8xf64> to memref<*xf64> + %33 = memref.cast %c : memref<8x8xf64> to memref<*xf64> + %34 = memref.cast %d : memref<8x8xf64> to memref<*xf64> + + gpu.host_register %2 : memref<*xf64> + gpu.host_register %20 : memref<*xf64> + gpu.host_register %33 : memref<*xf64> + + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c1, %grid_y = %c1, %grid_z = %c1) + threads(%tx, %ty, %tz) in (%block_x = %c32, %block_y = %c1, %block_z = %c1) { + %A = gpu.subgroup_mma_load_matrix %a[%c0, %c0] {leadDimension = 4 : index} : memref<8x4xf64> -> !gpu.mma_matrix<8x4xf64, "AOp"> + %B = gpu.subgroup_mma_load_matrix %b[%c0, %c0] {leadDimension = 8 : index} : memref<4x8xf64> -> !gpu.mma_matrix<4x8xf64, "BOp"> + %C = gpu.subgroup_mma_load_matrix %c[%c0, %c0] {leadDimension = 8 : index} : memref<8x8xf64> -> !gpu.mma_matrix<8x8xf64, "COp"> + + %R = gpu.subgroup_mma_compute %A, %B, %C : !gpu.mma_matrix<8x4xf64, "AOp">, !gpu.mma_matrix<4x8xf64, "BOp"> -> !gpu.mma_matrix<8x8xf64, "COp"> + + gpu.subgroup_mma_store_matrix %R, %d[%c0, %c0] {leadDimension = 8 : index}: !gpu.mma_matrix<8x8xf64, "COp">, memref<8x8xf64> + gpu.terminator + } + // Print the memref after computation. + call @printMemrefF64(%34) : (memref<*xf64>) -> () + // CHECK: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14], + // CHECK-NEXT: [7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14, 7.14] + return +} + +func.func private @printMemrefF64(memref<*xf64>) From 4cee36a2d674a28f741eb4f77160ce878a3cfd93 Mon Sep 17 00:00:00 2001 From: Giacomo Castiglioni Date: Fri, 31 Oct 2025 16:24:03 +0100 Subject: [PATCH 2/3] format --- mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index fb1a37a03fe4d..13bb2231b13ca 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -371,15 +371,15 @@ struct WmmaElementwiseOpToNVVMLowering cast(subgroupMmaElementwiseOp.getType())); // If the element is not a struct, it means it's a scalar f64. - LLVM::LLVMStructType structDestTy = dyn_cast(destType); + LLVM::LLVMStructType structDestTy = + dyn_cast(destType); if (!structDestTy) { SmallVector operands; for (auto operand : adaptor.getOperands()) { operands.push_back(operand); } - Value element = - createScalarOp(rewriter, loc, subgroupMmaElementwiseOp.getOpType(), - operands); + Value element = createScalarOp( + rewriter, loc, subgroupMmaElementwiseOp.getOpType(), operands); rewriter.replaceOp(subgroupMmaElementwiseOp, element); return success(); } From 094264fa3f398cee597c739c0de8d61ba00bc618 Mon Sep 17 00:00:00 2001 From: Giacomo Castiglioni Date: Fri, 31 Oct 2025 17:12:48 +0100 Subject: [PATCH 3/3] fix invalid IR test --- mlir/test/Dialect/GPU/invalid.mlir | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir index 35381dab7b200..26bcf948bc85d 100644 --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -688,7 +688,7 @@ func.func @mmamatrix_operand_type(){ func.func @mmamatrix_invalid_element_type(){ %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> %i = arith.constant 16 : index - // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, or F32}} + // expected-error @+1 {{MMAMatrixType elements must be SI8, UI8, I32, F16, F32, or F64}} %0 = gpu.subgroup_mma_load_matrix %wg[%i, %i] {leadDimension = 32 : index} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xbf16, "AOp"> return } @@ -708,7 +708,7 @@ func.func @mmaLoadOp_identity_layout(){ // ----- func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) { - // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float values of ranks 1 values}} + // expected-error @+1 {{operand #0 must be memref of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float or vector of 8-bit signless integer or 32-bit signless integer or 16-bit float or 32-bit float or 64-bit float values of ranks 1 values}} %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp"> return }