From ac434160702c78f4508a045e375b8c1141d0c22e Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Fri, 1 Sep 2023 09:30:56 +0200 Subject: [PATCH 1/8] [MLIR][NVGPU] Adding `nvgpu.wargroup.mma` Op for Hopper GPUs This work introduces a new operation called `wargroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture. Previously, the `nvvm.wgmma.mma_async` operation was introduced to support wargroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.wargroup.mma` operation abstracts this complexity and provides a higher-level interface for performing wargroup-level matrix operations. The `nvgpu.wargroup.mma` does followings: 1) Corresponds multiple `wgmma` instructions. 2) Iterates input matrix descriptors to achieve the desired computation shape. 3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned` 4) Results fragmented matrices Here's an example usage of the `nvgpu.wargroup.mma` operation: ``` %wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}: !nvgpu.wgmma.descriptor>, !nvgpu.wgmma.descriptor>, vector<128x128xf32> -> !nvgpu.warpgroup.result, !nvgpu.warpgroup.result> ``` Differential Revision: https://reviews.llvm.org/D158434 --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 48 +++++ .../mlir/Dialect/NVGPU/IR/NVGPUDialect.h | 2 + .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 169 +++++++++++++++++- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 108 ++++++++++- .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 66 ++++++- mlir/test/Dialect/NVGPU/invalid.mlir | 61 +++++++ 6 files changed, 446 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index a3245bf9196ee..f891aae136eba 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -192,6 +192,15 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w let assemblyFormat = "`<` struct(params) `>`"; } +def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", []> { + let parameters = (ins "Type":$tensor); + let assemblyFormat = "`<` struct(params) `>`"; + let description = [{ + It is fragmented result matrix from `nvgpu.wargroup.mma`. + [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + }]; +} + //===----------------------------------------------------------------------===// // NVGPU Op Definitions //===----------------------------------------------------------------------===// @@ -664,5 +673,44 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { let hasVerifier = 1; } +def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> { + let description = [{ + The `nvgpu.wargroup.mma` op performs the warpgroup-level (4 warps) + matrix-multiply-and-accumulate (mma) operation that results in + `nvvm.wgmma.mma_async`. + + The operands are `descriptorA` and `descriptorB` that are wgmma matrix + descriptors that shows the properties of the matrix in shared memory. The + results are thread-level ownership to the warpgroup-level mma operation + shape. The shape is deduced from the descriptor types and output vector. + + The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the + given shape. As the the instruction `nvvm.wgmma.async` is an asyncronous, + this Op groups the `nvvm.wgmma.async` and surrounds them between + `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`, + `wgmma.wait.group.sync.aligned` Ops. + + Example: + ```mlir + %res = nvgpu.wargroup.mma %wgmmaDescA, %wgmmaDescB, %acc: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + vector<128x128xf32> -> !nvgpu.warpgroup.result + ``` + }]; + + let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, + NVGPU_WarpgroupMatrixDescriptor:$descriptorB, + AnyVector:$matrixC, + DefaultValuedOptionalAttr:$waitGroup, + OptionalAttr:$transposeA, + OptionalAttr:$transposeB); + let results = (outs Variadic:$matrixD); + let assemblyFormat = [{ + $descriptorA`,` $descriptorB`,` $matrixC (`,` `group` `=` $waitGroup^ )? attr-dict + `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD) + }]; + let hasVerifier = 1; +} #endif // NVGPU diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h index 192afcb2dba79..96af26842dafe 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -21,6 +21,8 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc" +constexpr int kWarpSize = 32; + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index b045089244ff1..90d138bd206e0 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" @@ -34,6 +35,10 @@ namespace mlir { using namespace mlir; +/// Number of bits that needs to excluded when building matrix descriptor for +/// wgmma operations. +constexpr int exclude4LSB = 4; + /// GPU has 32 bit registers, this function truncates values when larger width /// is not needed. static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc, @@ -984,10 +989,9 @@ struct NVGPUGenerateGmmaDescriptorLowering shiftLeft(val, startBit)); }; - int ex4LSB = 4; int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); - uint64_t strideDimVal = (layout << 3) >> ex4LSB; - uint64_t leadDimVal = (sizeN * layout) >> ex4LSB; + uint64_t strideDimVal = (layout << 3) >> exclude4LSB; + uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB; uint64_t offsetVal = 0; Value strideDim = makeConst(strideDimVal); @@ -1141,6 +1145,164 @@ struct NVGPUTmaCreateDescriptorOpLowering } }; +struct NVGPUWarpgroupMmaOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType, + int &wgmmaShapeM, int &wgmmaShapeN, + int &wgmmaShapeK) const { + wgmmaShapeM = 64; + wgmmaShapeN = sizeN; + if (inputElemType.isTF32()) { + wgmmaShapeK = 8; + } else if (inputElemType.isF16() || inputElemType.isBF16()) { + wgmmaShapeK = 16; + } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() || + inputElemType.isInteger(16)) { + wgmmaShapeK = 32; + } else if (inputElemType.isInteger(1)) { + wgmmaShapeK = 256; + } else { + return failure(); + } + LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM + << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK + << "]\n"); + return success(); + } + + Value generateNVVMWgmmaOp(MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Location loc, + int m, int n, int k, Type resultStructType, + Value inout, Value descriptorA, + Value descriptorB) const { + TypeRange resultTypes = {resultStructType}; + auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k); + auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one); + auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one); + auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row); + auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col); + // todo input type + auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16); + auto overflow = + NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped); + Value res = rewriter.create( + loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype, + scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); + return res; + } + + static Type buildOutputStructType(MLIRContext *ctx, Type outElemType, + int sizeN) { + int outputElements = 0; + if (outElemType.isF32() || outElemType.isInteger(32)) + outputElements = sizeN / 2; + if (outElemType.isF16()) + outputElements = sizeN / 4; + SmallVector structBody; + for (int i = 0; i < outputElements; i++) + structBody.push_back(outElemType); + return LLVM::LLVMStructType::getLiteral(ctx, structBody); + } + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector wgmmaResults; + + int64_t sizeM = op.getMatrixC().getType().getDimSize(0); + int64_t sizeN = op.getMatrixC().getType().getDimSize(1); + int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1); + + LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A[" + << sizeM << "][" << sizeK << "] * B[" << sizeK << "][" + << sizeN << "] ---===\n"); + + int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK; + if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM, + wgmmaShapeN, wgmmaShapeK))) { + return failure(); + } + + Value descriptorA = adaptor.getDescriptorA(); + Value descriptorB = adaptor.getDescriptorB(); + + // Generate wgmma group + + auto loc = op->getLoc(); + Type outElemType = op.getMatrixC().getType().getElementType(); + Type stype = buildOutputStructType(op->getContext(), outElemType, sizeN); + MemRefType typeTensorA = op.getDescriptorA().getType().getTensor(); + MemRefType typeTensorB = op.getDescriptorB().getType().getTensor(); + + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + + auto iterateDescA = [&](Value desc, int iterM, int iterN, + int iterK) -> Value { + // todo : Handle column major + int byte = typeTensorA.getElementTypeBitWidth() / 8; + int tileShapeA = typeTensorA.getDimSize(1); + int incrementVal = + ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte; + incrementVal = incrementVal >> exclude4LSB; + LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: " + << iterK << "] [wgmma descriptors] Descriptor A + " + << incrementVal << " | \t "); + return incrementVal + ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal)) + : desc; + }; + + auto iterateDescB = [&](Value desc, int iterM, int iterN, + int iterK) -> Value { + // todo : Handle row major + int byte = typeTensorB.getElementTypeBitWidth() / 8; + int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte; + incrementVal = incrementVal >> exclude4LSB; + LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + return incrementVal + ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal)) + : desc; + }; + + rewriter.create(loc); + for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) { + Value undefOp = rewriter.create(loc, stype); + Value inout = undefOp; + LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":" + << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0 + << ":" << wgmmaShapeN << "] += \n"); + for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) { + Value descA = iterateDescA(descriptorA, iterM, 0, iterK); + Value descB = iterateDescB(descriptorB, iterM, 0, iterK); + LLVM_DEBUG(DBGS() << "\t wgmma." + << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k" + << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM) + << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" + << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * " + << " B[" << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0 + << ":" << wgmmaShapeN << "])\n"); + inout = generateNVVMWgmmaOp(op->getContext(), rewriter, loc, + wgmmaShapeM, wgmmaShapeN, wgmmaShapeK, + stype, inout, descA, descB); + } + wgmmaResults.push_back(inout); + } + + rewriter.create(loc); + rewriter.create(loc, op.getWaitGroup()); + + ValueRange myres(wgmmaResults); + rewriter.replaceOp(op, myres); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, @@ -1156,6 +1318,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor + NVGPUWarpgroupMmaOpLowering, // nvgpu.wargroup.mma MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index d832a983a132d..cd0d65ddd9a65 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -151,7 +151,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op, // - For F32 (TF32), F16, S8, and S4 data // types the fundamental tensor core operation is of shape 8-by-8-by-128b. // - F64 is an exception and is of shape 8-by-8-by-256b. - constexpr int kThreads = 32; // 32 threads per warp int64_t shapeM = 8; int64_t shapeN = 8; int64_t shapeK; // set based on data type (128b for all data types except F64) @@ -206,17 +205,17 @@ static LogicalResult verifyMmaSyncOp(Operation *op, // verify warp-wide size for vector a int64_t sparseFactor = sparse ? 2 : 1; - if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor) + if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor) return op->emitOpError() << "expected " << m * k << " warp-wide matrix A elements"; // verify warp-wide size for vector b - if (bShape[0] * bShape[1] * kThreads != k * n) + if (bShape[0] * bShape[1] * kWarpSize != k * n) return op->emitOpError() << "expected " << k * n << " warp-wide matrix B elements"; // verify warp-wide size for vector c - if (cShape[0] * cShape[1] * kThreads != m * n) + if (cShape[0] * cShape[1] * kWarpSize != m * n) return op->emitOpError() << "expected " << m * n << " warp-wide matrix C elements"; @@ -402,6 +401,107 @@ LogicalResult GenerateGmmaDescriptorOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// WarpgroupMmaOp +//===----------------------------------------------------------------------===// + +LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { + // F32 += F16 + F16 + // F16 += F16 + F16 + if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F32 += TF32 + TF32 + if (typeA.isTF32() && typeD.isF32() && typeB.isTF32()) + return success(); + // s32 += i8 + i8 + if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32)) + return success(); + // s32 += i1 + i1 + if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32)) + return success(); + // F32 += BF16 + BF16 + // F16 += BF16 + BF16 + if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F16 += f8 + f8 + // F32 += f8 + f8 + if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) && + (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) && + (typeD.isF32() || typeD.isF16())) + return success(); + + return failure(); +} + +LogicalResult isAllowedSizeN(int sizeN, Type typeA) { + SmallVector allowedN = {8, 16, 24, 32, 40, 48, 56, 64, + 72, 80, 88, 96, 104, 112, 120, 128, + 136, 144, 152, 160, 168, 176, 184, 192, + 200, 208, 216, 224, 232, 240, 248, 256}; + SmallVector allowedNshort = {8, 16, 24, 32, 48, 64, + 80, 96, 112, 128, 144, 160, + 176, 192, 208, 224, 240, 256}; + if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() || + typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2()) + if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; })) + return success(); + + if (typeA.isInteger(8) || typeA.isInteger(1)) + if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; })) + return success(); + return failure(); +} + +LogicalResult WarpgroupMmaOp::verify() { + if (getTransposeA() && !getTransposeB()) + return emitOpError() << "supports non-transpose A (Row Major) " + "and transpose B (Column Major) for the time being"; + auto matrixA = getDescriptorA().getType().getTensor(); + auto matrixB = getDescriptorB().getType().getTensor(); + auto matrixC = getMatrixC().getType(); + if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || + matrixC.getRank() != 2) + return emitOpError() + << "has input matrices A, B and D, they must be 2 dimensional"; + + if (matrixA.getShape()[1] != matrixB.getShape()[0]) + return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1] + << ")!= 1st dim matrix-B (" << matrixB.getShape()[0] + << " )"; + if (matrixA.getShape()[0] != matrixC.getShape()[0]) + return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0] + << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0] + << " )"; + if (matrixB.getShape()[1] != matrixC.getShape()[1]) + return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1] + << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1] + << " )"; + + if (failed(isAllowedWGMMADataType(matrixC.getElementType(), + matrixA.getElementType(), + matrixB.getElementType()))) + return emitOpError() << matrixC.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported."; + // Check N + if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { + return emitOpError() << "has input type " << matrixB << " n is set to " + << matrixB.getDimSize(1) << ", it is not supported"; + } + + // Currently, f16/bf16 supported + if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() && + !matrixA.getElementType().isBF16()) { + return emitOpError() << "hit a limitation: " << matrixC.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported yet"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 0d7ace52ccb36..cafeb785e31ff 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -672,6 +672,70 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc func.return %descA : !nvgpu.wgmma.descriptor> } +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> + +// CHECK-LABEL: @warpgroup_mma_128_128_64( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>) +func.func @warpgroup_mma_128_128_64( + %descA: !nvgpu.wgmma.descriptor>, + %descB: !nvgpu.wgmma.descriptor>, + %D: memref<128x128xf32,3>) +{ +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %arg0 : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %arg1 : !nvgpu.wgmma.descriptor> to i64 +// CHECK: nvvm.wgmma.fence.aligned +// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], , D[%3, , ], A[, , ], B[, , ] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64 +// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64 +// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64 +// CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]] : i64 +// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], , D[%[[S4]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64 +// CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]] : i64 +// CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64 +// CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]] : i64 +// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], , D[%[[S9]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64 +// CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]] : i64 +// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64 +// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64 +// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], , D[%[[S14]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S20:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64 +// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64 +// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], , D[%[[S20]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64 +// CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64 +// CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64 +// CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]] : i64 +// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], , D[%[[S23]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64 +// CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]] : i64 +// CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64 +// CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]] : i64 +// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], , D[%[[S28]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64 +// CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]] : i64 +// CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64 +// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64 +// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], , D[%[[S33]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: nvvm.wgmma.commit.group.sync.aligned +// CHECK: nvvm.wgmma.wait.group.sync.aligned 1 + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32> + %wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + vector<128x128xf32> -> !nvgpu.warpgroup.result, !nvgpu.warpgroup.result + + return +} + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 @@ -681,5 +745,5 @@ transform.sequence failures(propagate) { } with type_converter { transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter {use_opaque_pointers = true} - } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op + } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "vector", "scf"], partial_conversion} : !transform.any_op } \ No newline at end of file diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir index ef721b1801407..d7af22085c10b 100644 --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -221,3 +221,64 @@ func.func @async_cp_size_invalid_f64( %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 3: memref<128x128xf64> to memref<3x16x128xf64, 3> return } + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> + +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + return +} + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op has input matrices A, B and D, they must be 2 dimensional}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult + return +} + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + return +} + +// ----- + +!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, + f32, f32, f32, f32, f32, f32)> +!tResult = !nvgpu.warpgroup.result +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) { + // expected-error @+1 {{'nvgpu.wargroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}} + %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult + return +} From b1c92c26db2ddded92fad3b44724179c5daddb30 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Thu, 7 Sep 2023 11:13:05 +0200 Subject: [PATCH 2/8] Include WGMMA descriptor type in transform dialect --- mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index a173317bbbdb3..d13f640147c52 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -64,6 +64,11 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( return llvmTypeConverter.convertType( getMBarrierMemrefType(type.getContext(), type)); }); + llvmTypeConverter.addConversion( + [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { + return llvmTypeConverter.convertType( + IntegerType::get(type.getContext(), 64)); + }); llvmTypeConverter.addConversion( [&](nvgpu::TensorMapDescriptorType type) -> Type { return llvmTypeConverter.getPointerType( From f95617331f7a3c9759a4010ffeef5d8d1a063747 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Thu, 7 Sep 2023 11:13:40 +0200 Subject: [PATCH 3/8] wargroup -> warpgroup --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 8 ++++---- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 1 - .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 2 +- mlir/test/Dialect/NVGPU/invalid.mlir | 16 ++++++++-------- 4 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index f891aae136eba..060fe656d32e8 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -196,7 +196,7 @@ def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", [] let parameters = (ins "Type":$tensor); let assemblyFormat = "`<` struct(params) `>`"; let description = [{ - It is fragmented result matrix from `nvgpu.wargroup.mma`. + It is fragmented result matrix from `nvgpu.warpgroup.mma`. [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) }]; } @@ -673,9 +673,9 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { let hasVerifier = 1; } -def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> { +def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> { let description = [{ - The `nvgpu.wargroup.mma` op performs the warpgroup-level (4 warps) + The `nvgpu.warpgroup.mma` op performs the warpgroup-level (4 warps) matrix-multiply-and-accumulate (mma) operation that results in `nvvm.wgmma.mma_async`. @@ -692,7 +692,7 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"wargroup.mma"> { Example: ```mlir - %res = nvgpu.wargroup.mma %wgmmaDescA, %wgmmaDescB, %acc: + %res = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc: !nvgpu.wgmma.descriptor>, !nvgpu.wgmma.descriptor>, vector<128x128xf32> -> !nvgpu.warpgroup.result diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 90d138bd206e0..c8d91e7c5893a 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index cafeb785e31ff..fdd6cbc519b6a 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -728,7 +728,7 @@ func.func @warpgroup_mma_128_128_64( %c0 = arith.constant 0 : index %f0 = arith.constant 0.0 : f32 %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32> - %wgmmaResult, %wgmmaResult2 = nvgpu.wargroup.mma %descA, %descB, %acc, group = 1 {transposeB}: + %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc, group = 1 {transposeB}: !nvgpu.wgmma.descriptor>, !nvgpu.wgmma.descriptor>, vector<128x128xf32> -> !nvgpu.warpgroup.result, !nvgpu.warpgroup.result diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir index d7af22085c10b..a915f7f3b8095 100644 --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -233,8 +233,8 @@ func.func @async_cp_size_invalid_f64( !tDescB = !nvgpu.wgmma.descriptor> func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { - // expected-error @+1 {{'nvgpu.wargroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}} - %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult return } @@ -248,8 +248,8 @@ func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vecto !tDescA = !nvgpu.wgmma.descriptor> !tDescB = !nvgpu.wgmma.descriptor> func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) { - // expected-error @+1 {{'nvgpu.wargroup.mma' op has input matrices A, B and D, they must be 2 dimensional}} - %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult + // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input matrices A, B and D, they must be 2 dimensional}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult return } @@ -263,8 +263,8 @@ func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: !tDescA = !nvgpu.wgmma.descriptor> !tDescB = !nvgpu.wgmma.descriptor> func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { - // expected-error @+1 {{'nvgpu.wargroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}} - %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + // expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult return } @@ -278,7 +278,7 @@ func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: v !tDescA = !nvgpu.wgmma.descriptor> !tDescB = !nvgpu.wgmma.descriptor> func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) { - // expected-error @+1 {{'nvgpu.wargroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}} - %0:2 = nvgpu.wargroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult + // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult return } From 8e1f698183a95d95f74c7d998dcf7bf032db4786 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Wed, 13 Sep 2023 16:05:05 +0200 Subject: [PATCH 4/8] Improve accumulator matrix type --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td | 36 ++++++---- .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 66 +++++++++---------- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 39 +++++++++-- .../NVGPU/TransformOps/NVGPUTransformOps.cpp | 10 +++ .../Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 36 +++++----- mlir/test/Dialect/NVGPU/invalid.mlir | 45 ++++--------- 6 files changed, 127 insertions(+), 105 deletions(-) diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index 060fe656d32e8..90381648dac6a 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -192,12 +192,16 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w let assemblyFormat = "`<` struct(params) `>`"; } -def NVGPU_WarpgroupResult : NVGPU_Type<"WarpgroupResult", "warpgroup.result", []> { - let parameters = (ins "Type":$tensor); +def NVGPU_WarpgroupAccumulator : NVGPU_Type<"WarpgroupAccumulator", "warpgroup.accumulator", []> { + let parameters = (ins "VectorType":$fragmented); let assemblyFormat = "`<` struct(params) `>`"; let description = [{ - It is fragmented result matrix from `nvgpu.warpgroup.mma`. - [See the details of register fragment layout for accumulator matrix D](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + This type represents the result matrix obtained from `nvgpu.warpgroup.mma`. + The `$fragmented` type signifies the distributed or fragmented result + vector that is collectively owned by all the threads in the warp-group + that executed `nvgpu.warpgroup.mma`. + [See the details of register fragment layout for accumulator matrix D] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) }]; } @@ -685,29 +689,33 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> { shape. The shape is deduced from the descriptor types and output vector. The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the - given shape. As the the instruction `nvvm.wgmma.async` is an asyncronous, + given shape. As the instruction `nvvm.wgmma.async` is an asynchronous, this Op groups the `nvvm.wgmma.async` and surrounds them between `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`, `wgmma.wait.group.sync.aligned` Ops. Example: ```mlir - %res = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc: - !nvgpu.wgmma.descriptor>, - !nvgpu.wgmma.descriptor>, - vector<128x128xf32> -> !nvgpu.warpgroup.result + %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + !nvgpu.warpgroup.accumulator>, + !nvgpu.warpgroup.accumulator> + -> + !nvgpu.warpgroup.accumulator>, + !nvgpu.warpgroup.accumulator> ``` }]; let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, - NVGPU_WarpgroupMatrixDescriptor:$descriptorB, - AnyVector:$matrixC, + NVGPU_WarpgroupMatrixDescriptor:$descriptorB, DefaultValuedOptionalAttr:$waitGroup, OptionalAttr:$transposeA, - OptionalAttr:$transposeB); - let results = (outs Variadic:$matrixD); + OptionalAttr:$transposeB, + Variadic:$matrixC); + let results = (outs Variadic:$matrixD); let assemblyFormat = [{ - $descriptorA`,` $descriptorB`,` $matrixC (`,` `group` `=` $waitGroup^ )? attr-dict + $descriptorA`,` $descriptorB`,` $matrixC attr-dict `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD) }]; let hasVerifier = 1; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index c8d91e7c5893a..046727e4ea9ab 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -17,10 +17,12 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvgpu-to-nvvm" @@ -423,6 +425,15 @@ struct ConvertNVGPUToNVVMPass converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 32)); }); + converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type { + VectorType vtype = type.getFragmented(); + SmallVector structBody; + for (unsigned i = 0; i < vtype.getDimSize(0); i++) + structBody.push_back(vtype.getElementType()); + auto convertedType = + LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); + return converter.convertType(convertedType); + }); converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 64)); }); @@ -442,6 +453,8 @@ struct ConvertNVGPUToNVVMPass target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::memref::MemRefDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + converter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -1163,7 +1176,7 @@ struct NVGPUWarpgroupMmaOpLowering } else if (inputElemType.isInteger(1)) { wgmmaShapeK = 256; } else { - return failure(); + llvm_unreachable("msg: not supported K shape"); } LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK @@ -1192,26 +1205,11 @@ struct NVGPUWarpgroupMmaOpLowering return res; } - static Type buildOutputStructType(MLIRContext *ctx, Type outElemType, - int sizeN) { - int outputElements = 0; - if (outElemType.isF32() || outElemType.isInteger(32)) - outputElements = sizeN / 2; - if (outElemType.isF16()) - outputElements = sizeN / 4; - SmallVector structBody; - for (int i = 0; i < outputElements; i++) - structBody.push_back(outElemType); - return LLVM::LLVMStructType::getLiteral(ctx, structBody); - } - LogicalResult matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector wgmmaResults; - - int64_t sizeM = op.getMatrixC().getType().getDimSize(0); - int64_t sizeN = op.getMatrixC().getType().getDimSize(1); + int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0); + int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1); int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1); LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A[" @@ -1230,8 +1228,6 @@ struct NVGPUWarpgroupMmaOpLowering // Generate wgmma group auto loc = op->getLoc(); - Type outElemType = op.getMatrixC().getType().getElementType(); - Type stype = buildOutputStructType(op->getContext(), outElemType, sizeN); MemRefType typeTensorA = op.getDescriptorA().getType().getTensor(); MemRefType typeTensorB = op.getDescriptorB().getType().getTensor(); @@ -1250,9 +1246,9 @@ struct NVGPUWarpgroupMmaOpLowering LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: " << iterK << "] [wgmma descriptors] Descriptor A + " << incrementVal << " | \t "); - return incrementVal - ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal)) - : desc; + if (!incrementVal) + return desc; + return makeAdd(desc, makeI64Const(rewriter, op, incrementVal)); }; auto iterateDescB = [&](Value desc, int iterM, int iterN, @@ -1262,15 +1258,18 @@ struct NVGPUWarpgroupMmaOpLowering int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte; incrementVal = incrementVal >> exclude4LSB; LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); - return incrementVal - ? makeAdd(desc, makeI64Const(rewriter, op, incrementVal)) - : desc; + if (!incrementVal) + return desc; + return makeAdd(desc, makeI64Const(rewriter, op, incrementVal)); }; rewriter.create(loc); + + SmallVector wgmmaResults; for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) { - Value undefOp = rewriter.create(loc, stype); - Value inout = undefOp; + Value matrixC = adaptor.getMatrixC()[iterM]; + Value matrixD = op.getMatrixD()[iterM]; + Type structType = getTypeConverter()->convertType(matrixD.getType()); LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0 << ":" << wgmmaShapeN << "] += \n"); @@ -1286,13 +1285,12 @@ struct NVGPUWarpgroupMmaOpLowering << " B[" << (iterK * wgmmaShapeK) << ":" << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0 << ":" << wgmmaShapeN << "])\n"); - inout = generateNVVMWgmmaOp(op->getContext(), rewriter, loc, - wgmmaShapeM, wgmmaShapeN, wgmmaShapeK, - stype, inout, descA, descB); + matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc, + wgmmaShapeM, wgmmaShapeN, wgmmaShapeK, + structType, matrixC, descA, descB); } - wgmmaResults.push_back(inout); + wgmmaResults.push_back(matrixC); } - rewriter.create(loc); rewriter.create(loc, op.getWaitGroup()); @@ -1317,7 +1315,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor - NVGPUWarpgroupMmaOpLowering, // nvgpu.wargroup.mma + NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index cd0d65ddd9a65..d96ed69982870 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -456,19 +457,45 @@ LogicalResult WarpgroupMmaOp::verify() { if (getTransposeA() && !getTransposeB()) return emitOpError() << "supports non-transpose A (Row Major) " "and transpose B (Column Major) for the time being"; - auto matrixA = getDescriptorA().getType().getTensor(); - auto matrixB = getDescriptorB().getType().getTensor(); - auto matrixC = getMatrixC().getType(); + MemRefType matrixA = getDescriptorA().getType().getTensor(); + MemRefType matrixB = getDescriptorB().getType().getTensor(); + VectorType matrixC = getMatrixC() + .front() + .getType() + .cast() + .getFragmented(); + VectorType matrixD = getMatrixD() + .front() + .getType() + .cast() + .getFragmented(); + unsigned sizeAcc = getMatrixC().size(); + + if (getMatrixC().size() != getMatrixD().size()) + return emitOpError() << "number of matrix C and matrix D must be the same"; + + if (llvm::all_of(getMatrixC(), + [&](Value rhs) { return rhs.getType() == matrixC; })) { + return emitOpError() + << "types of all operands in matrix C must be the same"; + } + if (llvm::all_of(getMatrixD(), + [&](Value rhs) { return rhs.getType() == matrixC; })) { + return emitOpError() + << "types of all operands in matrix D must be the same as matrix C"; + } + if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || - matrixC.getRank() != 2) + matrixC.getRank() != 2 || matrixD.getRank() != 2) { return emitOpError() - << "has input matrices A, B and D, they must be 2 dimensional"; + << "has matrices A, B, C and D, they must be 2 dimensional"; + } if (matrixA.getShape()[1] != matrixB.getShape()[0]) return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1] << ")!= 1st dim matrix-B (" << matrixB.getShape()[0] << " )"; - if (matrixA.getShape()[0] != matrixC.getShape()[0]) + if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc)) return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0] << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0] << " )"; diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index d13f640147c52..680c21ab74fe0 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -60,6 +60,16 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); + llvmTypeConverter.addConversion( + [&](nvgpu::WarpgroupAccumulatorType type) -> Type { + VectorType vtype = type.getFragmented(); + SmallVector structBody; + for (unsigned i = 0; i < vtype.getDimSize(0); i++) + structBody.push_back(vtype.getElementType()); + auto convertedType = + LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); + return llvmTypeConverter.convertType(convertedType); + }); llvmTypeConverter.addConversion([&](nvgpu::MBarrierType type) -> Type { return llvmTypeConverter.convertType( getMBarrierMemrefType(type.getContext(), type)); diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index fdd6cbc519b6a..b7aa0c7382d80 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -672,23 +672,20 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc func.return %descA : !nvgpu.wgmma.descriptor> } -!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32)> - // CHECK-LABEL: @warpgroup_mma_128_128_64( -// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: memref<128x128xf32, 3>) +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>) func.func @warpgroup_mma_128_128_64( %descA: !nvgpu.wgmma.descriptor>, %descB: !nvgpu.wgmma.descriptor>, - %D: memref<128x128xf32,3>) + %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>) { -// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %arg0 : !nvgpu.wgmma.descriptor> to i64 -// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %arg1 : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: nvvm.wgmma.fence.aligned -// CHECK: %[[S3:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], , D[%3, , ], A[, , ], B[, , ] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], , D[%[[S2]], , ], A[, , ], B[, , ] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64 // CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64 // CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64 @@ -704,10 +701,9 @@ func.func @warpgroup_mma_128_128_64( // CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64 // CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64 // CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], , D[%[[S14]], , ], A[, , ], B[, , ] : !llvm.struct -// CHECK: %[[S20:.+]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> // CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64 // CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64 -// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], , D[%[[S20]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], , D[%[[S3]], , ], A[, , ], B[, , ] : !llvm.struct // CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64 // CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64 // CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64 @@ -724,15 +720,15 @@ func.func @warpgroup_mma_128_128_64( // CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64 // CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], , D[%[[S33]], , ], A[, , ], B[, , ] : !llvm.struct // CHECK: nvvm.wgmma.commit.group.sync.aligned -// CHECK: nvvm.wgmma.wait.group.sync.aligned 1 - %c0 = arith.constant 0 : index - %f0 = arith.constant 0.0 : f32 - %acc = vector.transfer_read %D[%c0, %c0], %f0 {in_bounds = [true, true]} : memref<128x128xf32,3>, vector<128x128xf32> - %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc, group = 1 {transposeB}: +// CHECK: nvvm.wgmma.wait.group.sync.aligned 1 + %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: !nvgpu.wgmma.descriptor>, !nvgpu.wgmma.descriptor>, - vector<128x128xf32> -> !nvgpu.warpgroup.result, !nvgpu.warpgroup.result - + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> + -> + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> return } diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir index a915f7f3b8095..ff391e469815d 100644 --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -224,61 +224,44 @@ func.func @async_cp_size_invalid_f64( // ----- -!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32)> -!tResult = !nvgpu.warpgroup.result +!tResult = !nvgpu.warpgroup.accumulator> !tDescA = !nvgpu.wgmma.descriptor> !tDescB = !nvgpu.wgmma.descriptor> -func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}} - %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult return } // ----- -!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32)> -!tResult = !nvgpu.warpgroup.result +!tResult = !nvgpu.warpgroup.accumulator> !tDescA = !nvgpu.wgmma.descriptor> !tDescB = !nvgpu.wgmma.descriptor> -func.func @warpgroup_mma_wrong_accumulator(%descA: !tDescA, %descB: !tDescB, %D: vector<128xf32>) { - // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input matrices A, B and D, they must be 2 dimensional}} - %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128xf32> -> !tResult, !tResult +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { + // expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult return } // ----- - -!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32)> -!tResult = !nvgpu.warpgroup.result +!tResult = !nvgpu.warpgroup.accumulator> !tDescA = !nvgpu.wgmma.descriptor> !tDescB = !nvgpu.wgmma.descriptor> -func.func @warpgroup_mma_wrong_datatypes(%descA: !tDescA, %descB: !tDescB, %D: vector<128x128xf32>) { +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { // expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}} - %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x128xf32> -> !tResult, !tResult + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult return } // ----- -!accMatrixStruct = !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, - f32, f32, f32, f32, f32, f32)> -!tResult = !nvgpu.warpgroup.result +!tResult = !nvgpu.warpgroup.accumulator> !tDescA = !nvgpu.wgmma.descriptor> !tDescB = !nvgpu.wgmma.descriptor> -func.func @warpgroup_mma_wrong_large_shape(%descA: !tDescA, %descB: !tDescB, %D: vector<128x512xf32>) { - // expected-error @+1 {{'nvgpu.warpgroup.mma' op has input type 'memref<64x512xf16, 3>' n is set to 512, it is not supported}} - %0:2 = nvgpu.warpgroup.mma %descA, %descB, %D: !tDescA, !tDescB, vector<128x512xf32> -> !tResult, !tResult +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { + // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult return } From 98246b06183f5a6ef99fc2bd7a18a1174258e97c Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Wed, 13 Sep 2023 16:18:33 +0200 Subject: [PATCH 5/8] clarify the todo --- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 046727e4ea9ab..ea4e77742d1ba 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1195,7 +1195,7 @@ struct NVGPUWarpgroupMmaOpLowering auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one); auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row); auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col); - // todo input type + // todo: handle other input and output types auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16); auto overflow = NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped); From 8b1cdd5fefb9cd1e87a15afba124832945a800da Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Fri, 22 Sep 2023 11:08:06 +0200 Subject: [PATCH 6/8] add newline --- mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index b7aa0c7382d80..f011007e040ce 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -742,4 +742,4 @@ transform.sequence failures(propagate) { transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter {use_opaque_pointers = true} } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "vector", "scf"], partial_conversion} : !transform.any_op -} \ No newline at end of file +} From be740569a826dc14ea4935c4a9229ca32f7ecf65 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Fri, 22 Sep 2023 11:35:27 +0200 Subject: [PATCH 7/8] build with result type, not array --- mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index ea4e77742d1ba..f74aa05c0c4c4 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -1189,7 +1189,6 @@ struct NVGPUWarpgroupMmaOpLowering int m, int n, int k, Type resultStructType, Value inout, Value descriptorA, Value descriptorB) const { - TypeRange resultTypes = {resultStructType}; auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k); auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one); auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one); @@ -1200,8 +1199,8 @@ struct NVGPUWarpgroupMmaOpLowering auto overflow = NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped); Value res = rewriter.create( - loc, resultTypes, inout, descriptorA, descriptorB, shape, itype, itype, - scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); + loc, resultStructType, inout, descriptorA, descriptorB, shape, itype, + itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); return res; } From 4e92df3db1ecd25f7b577548109279212106ac35 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Fri, 22 Sep 2023 11:35:43 +0200 Subject: [PATCH 8/8] we always expect struct, remove aggregate --- mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index a528e015523e1..0d4d734edd2b6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1610,9 +1610,9 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", PredOpTrait<"input struct and result struct must be the same type", TCresIsSameAsOpBase<0, 0>>,]> { - let results = (outs LLVM_AnyAggregate:$results); + let results = (outs LLVM_AnyStruct:$results); let arguments = (ins - LLVM_AnyAggregate:$inouts, + LLVM_AnyStruct:$inouts, I64:$descriptorA, I64:$descriptorB, NVVM_MMAShapeAttr:$shape,