diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index a528e015523e1..0d4d734edd2b6 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -1610,9 +1610,9 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async", PredOpTrait<"input struct and result struct must be the same type", TCresIsSameAsOpBase<0, 0>>,]> { - let results = (outs LLVM_AnyAggregate:$results); + let results = (outs LLVM_AnyStruct:$results); let arguments = (ins - LLVM_AnyAggregate:$inouts, + LLVM_AnyStruct:$inouts, I64:$descriptorA, I64:$descriptorB, NVVM_MMAShapeAttr:$shape, diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index a3245bf9196ee..90381648dac6a 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -192,6 +192,19 @@ def NVGPU_WarpgroupMatrixDescriptor : NVGPU_Type<"WarpgroupMatrixDescriptor", "w let assemblyFormat = "`<` struct(params) `>`"; } +def NVGPU_WarpgroupAccumulator : NVGPU_Type<"WarpgroupAccumulator", "warpgroup.accumulator", []> { + let parameters = (ins "VectorType":$fragmented); + let assemblyFormat = "`<` struct(params) `>`"; + let description = [{ + This type represents the result matrix obtained from `nvgpu.warpgroup.mma`. + The `$fragmented` type signifies the distributed or fragmented result + vector that is collectively owned by all the threads in the warp-group + that executed `nvgpu.warpgroup.mma`. + [See the details of register fragment layout for accumulator matrix D] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) + }]; +} + //===----------------------------------------------------------------------===// // NVGPU Op Definitions //===----------------------------------------------------------------------===// @@ -664,5 +677,48 @@ def NVGPU_GenerateGmmaDescriptorOp : NVGPU_Op<"wgmma.generate.descriptor", []> { let hasVerifier = 1; } +def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> { + let description = [{ + The `nvgpu.warpgroup.mma` op performs the warpgroup-level (4 warps) + matrix-multiply-and-accumulate (mma) operation that results in + `nvvm.wgmma.mma_async`. + + The operands are `descriptorA` and `descriptorB` that are wgmma matrix + descriptors that shows the properties of the matrix in shared memory. The + results are thread-level ownership to the warpgroup-level mma operation + shape. The shape is deduced from the descriptor types and output vector. + + The Op corresponds multiple `nvvm.wgmma.mma_async` operations to complete the + given shape. As the instruction `nvvm.wgmma.async` is an asynchronous, + this Op groups the `nvvm.wgmma.async` and surrounds them between + `wgmma.fence.aligned` and `wgmma.commit.group.sync.aligned`, + `wgmma.wait.group.sync.aligned` Ops. + + Example: + ```mlir + %r1,%r2 = nvgpu.warpgroup.mma %wgmmaDescA, %wgmmaDescB, %acc1, %acc2: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + !nvgpu.warpgroup.accumulator>, + !nvgpu.warpgroup.accumulator> + -> + !nvgpu.warpgroup.accumulator>, + !nvgpu.warpgroup.accumulator> + ``` + }]; + + let arguments = (ins NVGPU_WarpgroupMatrixDescriptor:$descriptorA, + NVGPU_WarpgroupMatrixDescriptor:$descriptorB, + DefaultValuedOptionalAttr:$waitGroup, + OptionalAttr:$transposeA, + OptionalAttr:$transposeB, + Variadic:$matrixC); + let results = (outs Variadic:$matrixD); + let assemblyFormat = [{ + $descriptorA`,` $descriptorB`,` $matrixC attr-dict + `:` type($descriptorA) `,` type($descriptorB) `,` type($matrixC) `->` type($matrixD) + }]; + let hasVerifier = 1; +} #endif // NVGPU diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h index 192afcb2dba79..96af26842dafe 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPUDialect.h @@ -21,6 +21,8 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUEnums.h.inc" +constexpr int kWarpSize = 32; + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPUAttrDefs.h.inc" diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index b045089244ff1..f74aa05c0c4c4 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -17,10 +17,12 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvgpu-to-nvvm" @@ -34,6 +36,10 @@ namespace mlir { using namespace mlir; +/// Number of bits that needs to excluded when building matrix descriptor for +/// wgmma operations. +constexpr int exclude4LSB = 4; + /// GPU has 32 bit registers, this function truncates values when larger width /// is not needed. static Value truncToI32(ConversionPatternRewriter &rewriter, Location loc, @@ -419,6 +425,15 @@ struct ConvertNVGPUToNVVMPass converter.addConversion([&](nvgpu::DeviceAsyncTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 32)); }); + converter.addConversion([&](nvgpu::WarpgroupAccumulatorType type) -> Type { + VectorType vtype = type.getFragmented(); + SmallVector structBody; + for (unsigned i = 0; i < vtype.getDimSize(0); i++) + structBody.push_back(vtype.getElementType()); + auto convertedType = + LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); + return converter.convertType(convertedType); + }); converter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { return converter.convertType(IntegerType::get(type.getContext(), 64)); }); @@ -438,6 +453,8 @@ struct ConvertNVGPUToNVVMPass target.addLegalDialect<::mlir::LLVM::LLVMDialect>(); target.addLegalDialect<::mlir::memref::MemRefDialect>(); target.addLegalDialect<::mlir::NVVM::NVVMDialect>(); + mlir::scf::populateSCFStructuralTypeConversionsAndLegality( + converter, patterns, target); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -984,10 +1001,9 @@ struct NVGPUGenerateGmmaDescriptorLowering shiftLeft(val, startBit)); }; - int ex4LSB = 4; int64_t sizeN = op.getTensorMap().getType().getTensor().getDimSize(0); - uint64_t strideDimVal = (layout << 3) >> ex4LSB; - uint64_t leadDimVal = (sizeN * layout) >> ex4LSB; + uint64_t strideDimVal = (layout << 3) >> exclude4LSB; + uint64_t leadDimVal = (sizeN * layout) >> exclude4LSB; uint64_t offsetVal = 0; Value strideDim = makeConst(strideDimVal); @@ -1141,6 +1157,148 @@ struct NVGPUTmaCreateDescriptorOpLowering } }; +struct NVGPUWarpgroupMmaOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult getWgmmaShape(int64_t sizeM, int64_t sizeN, Type inputElemType, + int &wgmmaShapeM, int &wgmmaShapeN, + int &wgmmaShapeK) const { + wgmmaShapeM = 64; + wgmmaShapeN = sizeN; + if (inputElemType.isTF32()) { + wgmmaShapeK = 8; + } else if (inputElemType.isF16() || inputElemType.isBF16()) { + wgmmaShapeK = 16; + } else if (inputElemType.isFloat8E4M3FN() || inputElemType.isFloat8E5M2() || + inputElemType.isInteger(16)) { + wgmmaShapeK = 32; + } else if (inputElemType.isInteger(1)) { + wgmmaShapeK = 256; + } else { + llvm_unreachable("msg: not supported K shape"); + } + LLVM_DEBUG(DBGS() << "Generating wgmma.mma.async shape[m = " << wgmmaShapeM + << ", n = " << wgmmaShapeN << ", k = " << wgmmaShapeK + << "]\n"); + return success(); + } + + Value generateNVVMWgmmaOp(MLIRContext *ctx, + ConversionPatternRewriter &rewriter, Location loc, + int m, int n, int k, Type resultStructType, + Value inout, Value descriptorA, + Value descriptorB) const { + auto shape = NVVM::MMAShapeAttr::get(ctx, m, n, k); + auto scaleOut = NVVM::WGMMAScaleOutAttr::get(ctx, NVVM::WGMMAScaleOut::one); + auto scaleIn = NVVM::WGMMAScaleInAttr::get(ctx, NVVM::WGMMAScaleIn::one); + auto layoutA = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::row); + auto layoutB = NVVM::MMALayoutAttr::get(ctx, NVVM::MMALayout::col); + // todo: handle other input and output types + auto itype = NVVM::WGMMATypesAttr::get(ctx, NVVM::WGMMATypes::f16); + auto overflow = + NVVM::MMAIntOverflowAttr::get(ctx, NVVM::MMAIntOverflow::wrapped); + Value res = rewriter.create( + loc, resultStructType, inout, descriptorA, descriptorB, shape, itype, + itype, scaleOut, scaleIn, scaleIn, layoutA, layoutB, overflow); + return res; + } + + LogicalResult + matchAndRewrite(nvgpu::WarpgroupMmaOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int64_t sizeM = op.getDescriptorA().getType().getTensor().getDimSize(0); + int64_t sizeN = op.getDescriptorB().getType().getTensor().getDimSize(1); + int64_t sizeK = op.getDescriptorA().getType().getTensor().getDimSize(1); + + LLVM_DEBUG(DBGS() << "===--- GEMM D[" << sizeM << "][" << sizeN << "] += A[" + << sizeM << "][" << sizeK << "] * B[" << sizeK << "][" + << sizeN << "] ---===\n"); + + int wgmmaShapeM, wgmmaShapeN, wgmmaShapeK; + if (failed(getWgmmaShape(sizeM, sizeN, rewriter.getF16Type(), wgmmaShapeM, + wgmmaShapeN, wgmmaShapeK))) { + return failure(); + } + + Value descriptorA = adaptor.getDescriptorA(); + Value descriptorB = adaptor.getDescriptorB(); + + // Generate wgmma group + + auto loc = op->getLoc(); + MemRefType typeTensorA = op.getDescriptorA().getType().getTensor(); + MemRefType typeTensorB = op.getDescriptorB().getType().getTensor(); + + auto makeAdd = [&](Value lhs, Value rhs) -> Value { + return rewriter.create(loc, lhs.getType(), lhs, rhs); + }; + + auto iterateDescA = [&](Value desc, int iterM, int iterN, + int iterK) -> Value { + // todo : Handle column major + int byte = typeTensorA.getElementTypeBitWidth() / 8; + int tileShapeA = typeTensorA.getDimSize(1); + int incrementVal = + ((wgmmaShapeK * iterK) + (sizeK * tileShapeA * iterM)) * byte; + incrementVal = incrementVal >> exclude4LSB; + LLVM_DEBUG(DBGS() << "\t\t[m: " << iterM << " n: " << iterN << " k: " + << iterK << "] [wgmma descriptors] Descriptor A + " + << incrementVal << " | \t "); + if (!incrementVal) + return desc; + return makeAdd(desc, makeI64Const(rewriter, op, incrementVal)); + }; + + auto iterateDescB = [&](Value desc, int iterM, int iterN, + int iterK) -> Value { + // todo : Handle row major + int byte = typeTensorB.getElementTypeBitWidth() / 8; + int incrementVal = typeTensorB.getDimSize(0) * wgmmaShapeK * iterK * byte; + incrementVal = incrementVal >> exclude4LSB; + LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + if (!incrementVal) + return desc; + return makeAdd(desc, makeI64Const(rewriter, op, incrementVal)); + }; + + rewriter.create(loc); + + SmallVector wgmmaResults; + for (int iterM = 0; iterM < (sizeM / wgmmaShapeM); iterM++) { + Value matrixC = adaptor.getMatrixC()[iterM]; + Value matrixD = op.getMatrixD()[iterM]; + Type structType = getTypeConverter()->convertType(matrixD.getType()); + LLVM_DEBUG(DBGS() << " D[" << (iterM * wgmmaShapeM) << ":" + << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" << 0 + << ":" << wgmmaShapeN << "] += \n"); + for (int iterK = 0; iterK < (sizeK / wgmmaShapeK); iterK++) { + Value descA = iterateDescA(descriptorA, iterM, 0, iterK); + Value descB = iterateDescB(descriptorB, iterM, 0, iterK); + LLVM_DEBUG(DBGS() << "\t wgmma." + << "m" << wgmmaShapeM << "n" << wgmmaShapeN << "k" + << wgmmaShapeK << "(A[" << (iterM * wgmmaShapeM) + << ":" << (iterM * wgmmaShapeM) + wgmmaShapeM << "][" + << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "] * " + << " B[" << (iterK * wgmmaShapeK) << ":" + << (iterK * wgmmaShapeK + wgmmaShapeK) << "][" << 0 + << ":" << wgmmaShapeN << "])\n"); + matrixC = generateNVVMWgmmaOp(op->getContext(), rewriter, loc, + wgmmaShapeM, wgmmaShapeN, wgmmaShapeK, + structType, matrixC, descA, descB); + } + wgmmaResults.push_back(matrixC); + } + rewriter.create(loc); + rewriter.create(loc, op.getWaitGroup()); + + ValueRange myres(wgmmaResults); + rewriter.replaceOp(op, myres); + return success(); + } +}; + } // namespace void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, @@ -1156,6 +1314,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, NVGPUTmaCreateDescriptorOpLowering, // nvgpu.tma.create.descriptor NVGPUMBarrierArriveExpectTxLowering, // nvgpu.mbarrier.arrive.expect_tx NVGPUGenerateGmmaDescriptorLowering, // nvgpu.wgmma.generate.descriptor + NVGPUWarpgroupMmaOpLowering, // nvgpu.warpgroup.mma MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering, NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering, NVGPUMmaSparseSyncLowering>(converter); diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index d832a983a132d..d96ed69982870 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" @@ -151,7 +152,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op, // - For F32 (TF32), F16, S8, and S4 data // types the fundamental tensor core operation is of shape 8-by-8-by-128b. // - F64 is an exception and is of shape 8-by-8-by-256b. - constexpr int kThreads = 32; // 32 threads per warp int64_t shapeM = 8; int64_t shapeN = 8; int64_t shapeK; // set based on data type (128b for all data types except F64) @@ -206,17 +206,17 @@ static LogicalResult verifyMmaSyncOp(Operation *op, // verify warp-wide size for vector a int64_t sparseFactor = sparse ? 2 : 1; - if (aShape[0] * aShape[1] * kThreads != m * k / sparseFactor) + if (aShape[0] * aShape[1] * kWarpSize != m * k / sparseFactor) return op->emitOpError() << "expected " << m * k << " warp-wide matrix A elements"; // verify warp-wide size for vector b - if (bShape[0] * bShape[1] * kThreads != k * n) + if (bShape[0] * bShape[1] * kWarpSize != k * n) return op->emitOpError() << "expected " << k * n << " warp-wide matrix B elements"; // verify warp-wide size for vector c - if (cShape[0] * cShape[1] * kThreads != m * n) + if (cShape[0] * cShape[1] * kWarpSize != m * n) return op->emitOpError() << "expected " << m * n << " warp-wide matrix C elements"; @@ -402,6 +402,133 @@ LogicalResult GenerateGmmaDescriptorOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// WarpgroupMmaOp +//===----------------------------------------------------------------------===// + +LogicalResult isAllowedWGMMADataType(Type typeD, Type typeA, Type typeB) { + // F32 += F16 + F16 + // F16 += F16 + F16 + if (typeA.isF16() && typeB.isF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F32 += TF32 + TF32 + if (typeA.isTF32() && typeD.isF32() && typeB.isTF32()) + return success(); + // s32 += i8 + i8 + if (typeA.isInteger(16) && typeB.isInteger(16) && typeD.isInteger(32)) + return success(); + // s32 += i1 + i1 + if (typeA.isInteger(1) && typeB.isInteger(1) && typeD.isInteger(32)) + return success(); + // F32 += BF16 + BF16 + // F16 += BF16 + BF16 + if (typeA.isBF16() && typeB.isBF16() && (typeD.isF32() || typeD.isF16())) + return success(); + // F16 += f8 + f8 + // F32 += f8 + f8 + if ((typeA.isFloat8E5M2() || typeA.isFloat8E4M3FN()) && + (typeB.isFloat8E5M2() || typeB.isFloat8E4M3FN()) && + (typeD.isF32() || typeD.isF16())) + return success(); + + return failure(); +} + +LogicalResult isAllowedSizeN(int sizeN, Type typeA) { + SmallVector allowedN = {8, 16, 24, 32, 40, 48, 56, 64, + 72, 80, 88, 96, 104, 112, 120, 128, + 136, 144, 152, 160, 168, 176, 184, 192, + 200, 208, 216, 224, 232, 240, 248, 256}; + SmallVector allowedNshort = {8, 16, 24, 32, 48, 64, + 80, 96, 112, 128, 144, 160, + 176, 192, 208, 224, 240, 256}; + if (typeA.isBF16() || typeA.isF16() || typeA.isTF32() || + typeA.isFloat8E4M3FN() || typeA.isFloat8E5M2()) + if (llvm::any_of(allowedN, [&](int n) { return sizeN == n; })) + return success(); + + if (typeA.isInteger(8) || typeA.isInteger(1)) + if (llvm::any_of(allowedNshort, [&](int n) { return sizeN == n; })) + return success(); + return failure(); +} + +LogicalResult WarpgroupMmaOp::verify() { + if (getTransposeA() && !getTransposeB()) + return emitOpError() << "supports non-transpose A (Row Major) " + "and transpose B (Column Major) for the time being"; + MemRefType matrixA = getDescriptorA().getType().getTensor(); + MemRefType matrixB = getDescriptorB().getType().getTensor(); + VectorType matrixC = getMatrixC() + .front() + .getType() + .cast() + .getFragmented(); + VectorType matrixD = getMatrixD() + .front() + .getType() + .cast() + .getFragmented(); + unsigned sizeAcc = getMatrixC().size(); + + if (getMatrixC().size() != getMatrixD().size()) + return emitOpError() << "number of matrix C and matrix D must be the same"; + + if (llvm::all_of(getMatrixC(), + [&](Value rhs) { return rhs.getType() == matrixC; })) { + return emitOpError() + << "types of all operands in matrix C must be the same"; + } + if (llvm::all_of(getMatrixD(), + [&](Value rhs) { return rhs.getType() == matrixC; })) { + return emitOpError() + << "types of all operands in matrix D must be the same as matrix C"; + } + + if (matrixA.getRank() != 2 || matrixB.getRank() != 2 || + matrixC.getRank() != 2 || matrixD.getRank() != 2) { + return emitOpError() + << "has matrices A, B, C and D, they must be 2 dimensional"; + } + + if (matrixA.getShape()[1] != matrixB.getShape()[0]) + return emitOpError() << "2nd dim matrix-A (" << matrixA.getShape()[1] + << ")!= 1st dim matrix-B (" << matrixB.getShape()[0] + << " )"; + if (matrixA.getShape()[0] != (matrixC.getShape()[0] * sizeAcc)) + return emitOpError() << "1st dim matrix-A ( " << matrixA.getShape()[0] + << " )!= 1st dim matrix-C ( " << matrixC.getShape()[0] + << " )"; + if (matrixB.getShape()[1] != matrixC.getShape()[1]) + return emitOpError() << "2nd dim matrix-B ( " << matrixB.getShape()[1] + << " ) != 2nd dim matrix-C ( " << matrixC.getShape()[1] + << " )"; + + if (failed(isAllowedWGMMADataType(matrixC.getElementType(), + matrixA.getElementType(), + matrixB.getElementType()))) + return emitOpError() << matrixC.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported."; + // Check N + if (failed(isAllowedSizeN(matrixB.getDimSize(1), matrixA.getElementType()))) { + return emitOpError() << "has input type " << matrixB << " n is set to " + << matrixB.getDimSize(1) << ", it is not supported"; + } + + // Currently, f16/bf16 supported + if (!matrixC.getElementType().isF32() && !matrixA.getElementType().isF16() && + !matrixA.getElementType().isBF16()) { + return emitOpError() << "hit a limitation: " << matrixC.getElementType() + << " += " << matrixA.getElementType() << " * " + << matrixB.getElementType() + << ", it is not supported yet"; + } + + return success(); +} + //===----------------------------------------------------------------------===// // TableGen'd dialect, type, and op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index a173317bbbdb3..680c21ab74fe0 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -60,10 +60,25 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); + llvmTypeConverter.addConversion( + [&](nvgpu::WarpgroupAccumulatorType type) -> Type { + VectorType vtype = type.getFragmented(); + SmallVector structBody; + for (unsigned i = 0; i < vtype.getDimSize(0); i++) + structBody.push_back(vtype.getElementType()); + auto convertedType = + LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); + return llvmTypeConverter.convertType(convertedType); + }); llvmTypeConverter.addConversion([&](nvgpu::MBarrierType type) -> Type { return llvmTypeConverter.convertType( getMBarrierMemrefType(type.getContext(), type)); }); + llvmTypeConverter.addConversion( + [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { + return llvmTypeConverter.convertType( + IntegerType::get(type.getContext(), 64)); + }); llvmTypeConverter.addConversion( [&](nvgpu::TensorMapDescriptorType type) -> Type { return llvmTypeConverter.getPointerType( diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir index 0d7ace52ccb36..f011007e040ce 100644 --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -672,6 +672,66 @@ func.func @create_wgmma_descriptor(%tensorMap : !tensorMap) -> !nvgpu.wgmma.desc func.return %descA : !nvgpu.wgmma.descriptor> } +// CHECK-LABEL: @warpgroup_mma_128_128_64( +// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg1:[a-zA-Z0-9_]+]]: !nvgpu.wgmma.descriptor>, %[[arg2:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>, %[[arg3:[a-zA-Z0-9_]+]]: !nvgpu.warpgroup.accumulator>) +func.func @warpgroup_mma_128_128_64( + %descA: !nvgpu.wgmma.descriptor>, + %descB: !nvgpu.wgmma.descriptor>, + %acc1: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + %acc2: !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>) +{ +// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[arg1]] : !nvgpu.wgmma.descriptor> to i64 +// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[arg2]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S3:.+]] = builtin.unrealized_conversion_cast %[[arg3]] : !nvgpu.warpgroup.accumulator> to !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: nvvm.wgmma.fence.aligned +// CHECK: %[[S4:.+]] = nvvm.wgmma.mma_async %[[S0]], %[[S1]], , D[%[[S2]], , ], A[, , ], B[, , ] : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> +// CHECK: %[[S5:.+]] = llvm.mlir.constant(2 : i32) : i64 +// CHECK: %[[S6:.+]] = llvm.add %[[S0]], %[[S5]] : i64 +// CHECK: %[[S7:.+]] = llvm.mlir.constant(128 : i32) : i64 +// CHECK: %[[S8:.+]] = llvm.add %[[S1]], %[[S7]] : i64 +// CHECK: %[[S9:.+]] = nvvm.wgmma.mma_async %[[S6]], %[[S8]], , D[%[[S4]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S10:.+]] = llvm.mlir.constant(4 : i32) : i64 +// CHECK: %[[S11:.+]] = llvm.add %[[S0]], %[[S10]] : i64 +// CHECK: %[[S12:.+]] = llvm.mlir.constant(256 : i32) : i64 +// CHECK: %[[S13:.+]] = llvm.add %[[S1]], %[[S12]] : i64 +// CHECK: %[[S14:.+]] = nvvm.wgmma.mma_async %[[S11]], %[[S13]], , D[%[[S9]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S15:.+]] = llvm.mlir.constant(6 : i32) : i64 +// CHECK: %[[S16:.+]] = llvm.add %[[S0]], %[[S15]] : i64 +// CHECK: %[[S17:.+]] = llvm.mlir.constant(384 : i32) : i64 +// CHECK: %[[S18:.+]] = llvm.add %[[S1]], %[[S17]] : i64 +// CHECK: %[[S19:.+]] = nvvm.wgmma.mma_async %[[S16]], %[[S18]], , D[%[[S14]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S21:.+]] = llvm.mlir.constant(512 : i32) : i64 +// CHECK: %[[S22:.+]] = llvm.add %[[S0]], %[[S21]] : i64 +// CHECK: %[[S23:.+]] = nvvm.wgmma.mma_async %[[S22]], %[[S1]], , D[%[[S3]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S24:.+]] = llvm.mlir.constant(514 : i32) : i64 +// CHECK: %[[S25:.+]] = llvm.add %[[S0]], %[[S24]] : i64 +// CHECK: %[[S26:.+]] = llvm.mlir.constant(128 : i32) : i64 +// CHECK: %[[S27:.+]] = llvm.add %[[S1]], %[[S26]] : i64 +// CHECK: %[[S28:.+]] = nvvm.wgmma.mma_async %[[S25]], %[[S27]], , D[%[[S23]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S29:.+]] = llvm.mlir.constant(516 : i32) : i64 +// CHECK: %[[S30:.+]] = llvm.add %[[S0]], %[[S29]] : i64 +// CHECK: %[[S31:.+]] = llvm.mlir.constant(256 : i32) : i64 +// CHECK: %[[S32:.+]] = llvm.add %[[S1]], %[[S31]] : i64 +// CHECK: %[[S33:.+]] = nvvm.wgmma.mma_async %[[S30]], %[[S32]], , D[%[[S28]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: %[[S34:.+]] = llvm.mlir.constant(518 : i32) : i64 +// CHECK: %[[S35:.+]] = llvm.add %[[S0]], %[[S34]] : i64 +// CHECK: %[[S36:.+]] = llvm.mlir.constant(384 : i32) : i64 +// CHECK: %[[S37:.+]] = llvm.add %[[S1]], %[[S36]] : i64 +// CHECK: %[[S38:.+]] = nvvm.wgmma.mma_async %[[S35]], %[[S37]], , D[%[[S33]], , ], A[, , ], B[, , ] : !llvm.struct +// CHECK: nvvm.wgmma.commit.group.sync.aligned +// CHECK: nvvm.wgmma.wait.group.sync.aligned 1 + %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: + !nvgpu.wgmma.descriptor>, + !nvgpu.wgmma.descriptor>, + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> + -> + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, + !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> + return +} + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 @@ -681,5 +741,5 @@ transform.sequence failures(propagate) { } with type_converter { transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter {use_opaque_pointers = true} - } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "scf"], partial_conversion} : !transform.any_op -} \ No newline at end of file + } {legal_dialects = ["arith", "func", "llvm", "memref", "nvvm", "vector", "scf"], partial_conversion} : !transform.any_op +} diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir index ef721b1801407..ff391e469815d 100644 --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -221,3 +221,47 @@ func.func @async_cp_size_invalid_f64( %0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 3: memref<128x128xf64> to memref<3x16x128xf64, 3> return } + +// ----- + +!tResult = !nvgpu.warpgroup.accumulator> +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> + +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { + // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 121 ) != 2nd dim matrix-C ( 128 )}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult + return +} + +// ----- + +!tResult = !nvgpu.warpgroup.accumulator> +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { + // expected-error @+1 {{'nvgpu.warpgroup.mma' op has matrices A, B, C and D, they must be 2 dimensional}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult + return +} + +// ----- +!tResult = !nvgpu.warpgroup.accumulator> +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { + // expected-error @+1 {{'nvgpu.warpgroup.mma' op 'f32' += 'f16' * 'f32', it is not supported.}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult + return +} + +// ----- + +!tResult = !nvgpu.warpgroup.accumulator> +!tDescA = !nvgpu.wgmma.descriptor> +!tDescB = !nvgpu.wgmma.descriptor> +func.func @warpgroup_mma_wrong_input(%descA: !tDescA, %descB: !tDescB, %acc1: !tResult, %acc2: !tResult) { + // expected-error @+1 {{'nvgpu.warpgroup.mma' op 2nd dim matrix-B ( 512 ) != 2nd dim matrix-C ( 128 )}} + %0:2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc1: !tDescA, !tDescB, !tResult, !tResult -> !tResult, !tResult + return +}