diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td index d85ef963ae5dc..f051e03efbcda 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td @@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">; def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">; def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">; def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">; +def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">; +def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">; +def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">; +def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">; class ArmSME_IntrLoadStoreOp : ArmSME_IntrOp allowedInputVectorTypes, + list allowedResultVectorTypes, + int numOuterProducts> : + ArmSME_Op, + HasMatchingMaskTypeConstraint<"lhs", "lhsMask">, + HasMatchingMaskTypeConstraint<"rhs", "rhsMask">, + PredOpTrait< + "both `lhsMask` and `rhsMask` should be provided or neither", + CPred<"bool(getLhsMask()) == bool(getRhsMask())"> + >, + OptionalTypesMatchWith<"`result` and `acc` have the same type", + "result", "acc", "::llvm::cast($_self)">, + // This trait ensures the input types match the correct output type for ops + // that takes multiple inputs and outputs (i.e., 4-way). + PredOpTrait< + "tile element size equals input element size * " # numOuterProducts, + CPred<"getTileType().getElementTypeBitWidth() == " + "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")"> + >, + ]> { + + let arguments = (ins + AnyTypeOf:$lhs, AnyVector:$rhs, + Optional:$lhsMask, Optional:$rhsMask, + Optional:$acc); + let results = (outs AnyTypeOf:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs + oilist( + `acc` `` `(` $acc `)` + | `masks` `` `(` $lhsMask `,` $rhsMask `)` + ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result) + }]; + + let extraClassDeclaration = [{ + VectorType getLhsType() { return llvm::cast(getLhs().getType()); } + VectorType getRhsType() { return llvm::cast(getRhs().getType()); } + VectorType getResultType() { return llvm::cast(getResult().getType()); } + std::optional getAllocatedTileType() { + // The outerproduct op allocates a new tile if no accumulator is passed. + if (!getAcc()) + return arm_sme::getSMETileType(getResultType()); + return std::nullopt; + } + VectorType getTileType() { + return getResultType(); + } + }]; +} + +class OuterProduct2Way allowedInputVectorTypes, + list allowedResultVectorTypes> + : OuterProductWideningBase; + +def FMopa2WayOp + : OuterProduct2Way<"fmopa_2way", + [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>], + [nxnxv4f32]> { + let summary = "Floating-point sum of 2 outer products and accumulate"; + + let description = [{ + This operation represents a sum of 2 widened outer products. It takes 2 1-D + scalable vectors as input and a 2-D scalable vector (ZA tile) as output. + + For example (fp16 to fp32): + + ```mlir + %result = arm_sme.fmopa_2way %lhs, %rhs : + vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + ``` + + The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of + 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit + elements in a vector of SVL bits. To illustrate, below is a breakdown of + this operation for fp16 to fp32, SVL=128 (i.e., vscale=1): + + ``` + LHS RHS + [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7] + + ---------------------------------------------------------------------------- + + implicit layout + + [A0 A1] | + [A2 A3] | [B0 B2 B4 B6] + [A4 A5] | [B1 B3 B5 B7] + [A6 A7] | + + ---------------------------------------------------------------------------- + + 2 outer products + + Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1 + ------------- | ------------- + | + [B0 B2 B4 B6] | [B1 B3 B5 B7] + | + [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7] + A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7] + A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7] + A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7] + | + + ---------------------------------------------------------------------------- + + sum of 2 outer products + + Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + + [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7] + [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7] + [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7] + [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7] + + ---------------------------------------------------------------------------- + ``` + + This operation enables the folding of 2 outer products chained via the + accumulator into a single outer product. + + For example: + + ```mlir + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> + ``` + + The 2 outer products in the example above can be fused into a single outer + product as follows: + + ```mlir + %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> + %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> + %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + ``` + + This is implemented in the `-arm-sme-outer-product-fusion` pass. + + Example: FP16 to FP32 + ```mlir + %result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + ``` + + Example: BF16 to FP32 + ```mlir + %result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + ``` + + | Spec | Features | + | ---- | -------- | + | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme | + | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme | + + [1] https://developer.arm.com/documentation/ddi0616 + }]; +} + +// TODO: support: +// - FMOPA 2-way FP8 to FP16 +// - FMOPA 4-way FP16 to FP32 +// once intrinsic support lands in the backend. + +def FMops2WayOp + : OuterProduct2Way<"fmops_2way", + [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>], + [nxnxv4f32]> { + let summary = "Floating-point sum of 2 outer products and subtract"; + let description = [{ + Equivalent to `fmopa_2way` but outer products are subtracted from + destination `result`. + + Example: FP16 to FP32 + ```mlir + %result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + ``` + + Example: BF16 to FP32 + ```mlir + %result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + + Refer to + [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed + description of 2-way outer products. + + | Spec | Features | + | ---- | -------- | + | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme | + | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme | + ``` + }]; +} + +def SMopa2WayOp + : OuterProduct2Way<"smopa_2way", + [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], + [nxnxv4i32]> { + let summary = "Signed integer sum of 2 outer products and accumulate"; + let description = [{ + Example: + ```mlir + %result = arm_sme.smopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + + Refer to + [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed + description of 2-way outer products. + + | Spec | Features | + | ---- | -------- | + | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 | + ``` + }]; +} + +def SMops2WayOp + : OuterProduct2Way<"smops_2way", + [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], + [nxnxv4i32]> { + let summary = "Signed integer sum of 2 outer products and subtract"; + let description = [{ + Example: + ```mlir + %result = arm_sme.smops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + + Refer to + [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed + description of 2-way outer products. + + | Spec | Features | + | ---- | -------- | + | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 | + ``` + }]; +} + +def UMopa2WayOp + : OuterProduct2Way<"umopa_2way", + [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], + [nxnxv4i32]> { + let summary = "Unsiged integer sum of 2 outer products and accumulate"; + let description = [{ + Example: + ```mlir + %result = arm_sme.umopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + + Refer to + [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed + description of 2-way outer products. + + | Spec | Features | + | ---- | -------- | + | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 | + ``` + }]; +} + +def UMops2WayOp + : OuterProduct2Way<"umops_2way", + [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], + [nxnxv4i32]> { + let summary = "Unsiged integer sum of 2 outer products and subtract"; + let description = [{ + Example: + ```mlir + %result = arm_sme.umops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + + Refer to + [fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed + description of 2-way outer products. + + | Spec | Features | + | ---- | -------- | + | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 | + ``` + }]; +} + def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]> { let summary = "Query the streaming vector length"; diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h index aef2959265a7c..bb49ce4c62723 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -32,6 +32,10 @@ std::unique_ptr createEnableArmStreamingPass( /// Pass that allocates tile IDs to ArmSME operations. std::unique_ptr createTileAllocationPass(); +/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening +/// variants. +std::unique_ptr createOuterProductFusionPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td index 8d1ba6ed34e80..844e1957efc0a 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td @@ -122,4 +122,38 @@ def TileAllocation let dependentDialects = ["func::FuncDialect"]; } +def OuterProductFusion + : Pass<"arm-sme-outer-product-fusion", "mlir::func::FuncOp"> { + let summary = "Fuse 'arm_sme.outerproduct' operations into 2-way or 4-way widening variants"; + let description = [{ + This pass fuses 'arm_sme.outerproduct' operations that are chained via the + accumulator into 2-way or 4-way ArmSME outer product operations. + + For example: + ```mlir + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> + ``` + + Becomes: + + ```mlir + %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> + %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> + %0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + ``` + + For further information on the 2-way or 4-way widening ops see: + https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_2way-arm_smefmopa_2wayop + https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop + }]; + let constructor = "mlir::arm_sme::createOuterProductFusionPass()"; + let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"]; +} + #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h index f622bc0562e9e..e00c7503e6999 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -15,6 +15,10 @@ class LLVMConversionTarget; class LLVMTypeConverter; class RewritePatternSet; +namespace arm_sme { +void populateOuterProductFusionPatterns(RewritePatternSet &patterns); +} // namespace arm_sme + } // namespace mlir #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index bbef3b996e40b..e73388b0906e8 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -776,6 +776,49 @@ struct OuterProductOpConversion } }; +/// Lower 2-way and 4-way widening outer products to intrinsics. +template +struct OuterProductWideningOpConversion + : public ConvertArmSMEOpToLLVMPattern { + using ConvertArmSMEOpToLLVMPattern< + OuterProductWideningOp>::ConvertArmSMEOpToLLVMPattern; + + LogicalResult + matchAndRewrite(OuterProductWideningOp op, + typename OuterProductWideningOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto tileId = getTileIdOrError(op); + if (!tileId) + return failure(); + + Value acc = op.getAcc(); + if (!acc) + // Initalize accumulator with zero. + acc = op.template createOpAndForwardTileId( + rewriter, op.getLoc(), op.getResultType()); + + Value lhsMask = op.getLhsMask(); + Value rhsMask = op.getRhsMask(); + if (!lhsMask || !rhsMask) { + auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); + Value allActiveMask = rewriter.create( + op.getLoc(), DenseElementsAttr::get(predTy, true)); + lhsMask = allActiveMask; + rhsMask = allActiveMask; + } + + rewriter.create(op.getLoc(), tileId, lhsMask, + rhsMask, adaptor.getLhs(), + adaptor.getRhs()); + + // The outerproduct intrinsics have no result, replace + // 'arm_sme.outerproduct' with the input tile to preserve dataflow. + rewriter.replaceOp(op, acc); + + return success(); + } +}; + /// Lower `arm_sme.streaming_vl` to SME CNTS intrinsics. /// /// Example: @@ -854,6 +897,13 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, + arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide, + arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide, + arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide, + arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32, + arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32, + arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide, + arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>(); target.addLegalDialect(patterns, converter); + StreamingVLOpConversion, OuterProductOpConversion, + OuterProductWideningOpConversion, + OuterProductWideningOpConversion, + OuterProductWideningOpConversion, + OuterProductWideningOpConversion, + OuterProductWideningOpConversion, + OuterProductWideningOpConversion, + ZeroOpConversion, GetTileConversion>(patterns, converter); } std::unique_ptr mlir::createConvertArmSMEToLLVMPass() { diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt index 96eb584420438..c06f9d3cc7a9f 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp + OuterProductFusion.cpp TileAllocation.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp new file mode 100644 index 0000000000000..fa55f9b7b31e1 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -0,0 +1,286 @@ +//===- OuterProductFusion.cpp - Fuse 'arm_sme.outerproduct' ops -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrites that fuse 'arm_sme.outerproduct' operations +// into the 2-way or 4-way widening outerproduct operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +#define DEBUG_TYPE "arm-sme-outerproduct-fusion" + +namespace mlir::arm_sme { +#define GEN_PASS_DEF_OUTERPRODUCTFUSION +#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" +} // namespace mlir::arm_sme + +using namespace mlir; +using namespace mlir::arm_sme; + +namespace { +// Fuse two 'arm_sme.outerproduct' operations that are chained via the +// accumulator into 2-way outer product operation. +// +// For example: +// +// %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> +// %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> +// %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, +// vector<[4]xf32> +// +// %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> +// %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> +// %1 = arm_sme.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, +// vector<[4]xf32> +// +// Becomes: +// +// %a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) +// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> +// %b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) +// : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> +// %0 = arm_sme.fmopa_2way %a_packed, %b_packed +// : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> +class OuterProductFusion2Way + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arm_sme::OuterProductOp op, + PatternRewriter &rewriter) const override { + Value acc = op.getAcc(); + if (!acc) + return rewriter.notifyMatchFailure(op, "no accumulator operand"); + + arm_sme::OuterProductOp op1 = acc.getDefiningOp(); + arm_sme::OuterProductOp op2 = op; + if (!op1) + return rewriter.notifyMatchFailure(op, + "defining op of accumulator operand " + "must be an 'arm_sme.outerproduct'"); + + if (op1.getKind() != op2.getKind()) + return rewriter.notifyMatchFailure( + op, "combining kind (add or sub) of outer products must match"); + + if (!op1->hasOneUse()) { + // If the first outer product has uses other than as the input to another + // outer product, it can't be erased after fusion. This is a problem when + // it also has an accumulator as this will be used as the root for tile + // allocation and since the widening outer product uses the same + // accumulator it will get assigned the same tile ID, resulting in 3 + // outer products accumulating to the same tile and incorrect results. + // + // Example: + // + // %acc = arith.constant dense<0.0> ; root for tile allocation + // %0 = arm_sme.outerproduct %a0, %b0 acc(%acc) + // vector.print %0 ; intermediary use, can't erase %0 + // %1 = arm_sme.outerproduct %a1, %b1 acc(%0) + // + // After fusion and tile allocation + // + // %0 = arm_sme.zero {tile_id = 0 : i32} + // %1 = arm_sme.outerproduct %a0, %b0 acc(%0) {tile_id = 0 : i32} + // vector.print %1 + // %2 = arm_sme.fmopa_2way %a, %b acc(%0) {tile_id = 0 : i32} + // + // No accumulator would be ok, but it's simpler to prevent this + // altogether, since it has no benefit. + return rewriter.notifyMatchFailure( + op, "first outer product is not single use and cannot be removed, " + "no benefit to fusing"); + } + + if (bool(op1.getLhsMask()) != bool(op2.getLhsMask())) + return rewriter.notifyMatchFailure( + op, "unsupported masking, either both outerproducts are masked " + "or neither"); + + if (failed(canFuseOuterProducts(rewriter, op1, op2))) + return failure(); + + auto loc = op.getLoc(); + + auto packInputs = [&](Value lhs, Value rhs) { + auto inputType = cast(lhs.getType()); + VectorType inputTypeX2 = + VectorType::Builder(inputType).setDim(0, inputType.getShape()[0] * 2); + return rewriter.create( + loc, inputTypeX2, lhs, rhs); + }; + + auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), + op2.getLhs().getDefiningOp()->getOperand(0)); + auto rhs = packInputs(op1.getRhs().getDefiningOp()->getOperand(0), + op2.getRhs().getDefiningOp()->getOperand(0)); + + Value lhsMask, rhsMask; + if (op1.getLhsMask() || op2.getLhsMask()) { + lhsMask = packInputs(op1.getLhsMask(), op2.getLhsMask()); + rhsMask = packInputs(op1.getRhsMask(), op2.getRhsMask()); + } + + auto extOp = op.getLhs().getDefiningOp(); + + arm_sme::CombiningKind kind = op.getKind(); + if (kind == arm_sme::CombiningKind::Add) { + TypeSwitch(extOp) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, + op1.getAcc()); + }) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, + op1.getAcc()); + }) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, + op1.getAcc()); + }) + .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); + } else if (kind == arm_sme::CombiningKind::Sub) { + TypeSwitch(extOp) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, + op1.getAcc()); + }) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, + op1.getAcc()); + }) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op2, op.getResultType(), lhs, rhs, lhsMask, rhsMask, + op1.getAcc()); + }) + .Default([&](auto) { llvm_unreachable("unexpected extend op!"); }); + } else { + llvm_unreachable("unexpected arm_sme::CombiningKind!"); + } + + rewriter.eraseOp(op1); + + return success(); + } + +private: + // A pair of outer product can be fused if all of the following are true: + // - input and result types match. + // - the defining operations of the inputs are identical extensions, + // specifically either: + // - a signed or unsigned extension for integer types. + // - a floating-point extension for floating-point types. + // - the types and extension are supported, i.e. there's a 2-way operation + // they can be fused into. + LogicalResult canFuseOuterProducts(PatternRewriter &rewriter, + arm_sme::OuterProductOp op1, + arm_sme::OuterProductOp op2) const { + // Supported result types. + auto nxnxv4i32 = + VectorType::get({4, 4}, rewriter.getI32Type(), {true, true}); + auto nxnxv4f32 = + VectorType::get({4, 4}, rewriter.getF32Type(), {true, true}); + // Supported input types. + // Note: this is before packing so these have half the number of elements + // of the input vector types of the 2-way operations. + auto nxv4i16 = VectorType::get({4}, rewriter.getI16Type(), true); + auto nxv4f16 = VectorType::get({4}, rewriter.getF16Type(), true); + auto nxv4bf16 = VectorType::get({4}, rewriter.getBF16Type(), true); + if ((failed( + isCompatible(rewriter, op1, nxnxv4f32, nxv4f16)) || + failed( + isCompatible(rewriter, op2, nxnxv4f32, nxv4f16))) && + (failed( + isCompatible(rewriter, op1, nxnxv4f32, nxv4bf16)) || + failed(isCompatible(rewriter, op2, nxnxv4f32, + nxv4bf16))) && + (failed( + isCompatible(rewriter, op1, nxnxv4i32, nxv4i16)) || + failed(isCompatible(rewriter, op2, nxnxv4i32, + nxv4i16))) && + (failed( + isCompatible(rewriter, op1, nxnxv4i32, nxv4i16)) || + failed( + isCompatible(rewriter, op2, nxnxv4i32, nxv4i16)))) + return failure(); + + return success(); + } + + // An outer product is compatible if all of the following are true: + // - the result type matches `resultType`. + // - the defining operations of the inputs are identical and of the type + // `ExtOp`. + // - the input types of the defining operations are identical and match + // `inputType`. + template + LogicalResult isCompatible(PatternRewriter &rewriter, + arm_sme::OuterProductOp op, VectorType resultType, + VectorType inputType) const { + if (op.getResultType() != resultType) + return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { + diag << "unsupported result type, expected " << resultType; + }); + + auto lhsDefOp = op.getLhs().getDefiningOp(); + auto rhsDefOp = op.getRhs().getDefiningOp(); + + if (!lhsDefOp || !rhsDefOp) + return rewriter.notifyMatchFailure( + op, "defining op of outerproduct operands must be one of: " + "'arith.extf' or 'arith.extsi' or 'arith.extui'"); + + auto lhsInType = cast(lhsDefOp.getIn().getType()); + auto rhsInType = cast(rhsDefOp.getIn().getType()); + + if (lhsInType != inputType || rhsInType != inputType) + return rewriter.notifyMatchFailure(op.getLoc(), [&](Diagnostic &diag) { + diag << "unsupported input type, expected " << inputType; + }); + + return success(); + } +}; + +struct OuterProductFusionPass + : public arm_sme::impl::OuterProductFusionBase { + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateOuterProductFusionPatterns(patterns); + + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +void mlir::arm_sme::populateOuterProductFusionPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +std::unique_ptr mlir::arm_sme::createOuterProductFusionPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir index f9cf77ca15ffb..c41504d0e4724 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir @@ -601,3 +601,99 @@ func.func @arm_sme_streaming_vl_double_words() -> index { %svl_d = arm_sme.streaming_vl return %svl_d : index } + +//===----------------------------------------------------------------------===// +// arm_sme.fmopa_2way +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: arm_sme_fmopa_2way_f16f16_to_f32 +// CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> () +func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> { + %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: arm_sme_fmopa_2way_bf16bf16_to_f32 +// CHECK: "arm_sme.intr.mopa.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> () +func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> { + %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.fmops_2way +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: arm_sme_fmops_2way_f16f16_to_f32 +// CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) -> () +func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> { + %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: arm_sme_fmops_2way_bf16bf16_to_f32 +// CHECK: "arm_sme.intr.mops.wide"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) -> () +func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> { + %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.smopa_2way +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: arm_sme_smopa_2way_i16i16_to_i32 +// CHECK: "arm_sme.intr.smopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () +func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + %result = arm_sme.smopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.smops_2way +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: arm_sme_smops_2way_i16i16_to_i32 +// CHECK: "arm_sme.intr.smops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () +func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + %result = arm_sme.smops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.umopa_2way +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: arm_sme_umopa_2way_i16i16_to_i32 +// CHECK: "arm_sme.intr.umopa.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () +func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + %result = arm_sme.umopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.umops_2way +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: arm_sme_umops_2way_i16i16_to_i32 +// CHECK: "arm_sme.intr.umops.za32"({{.*}}) <{tile_id = 0 : i32}> : (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () +func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir index 85b95a8b6cf12..dcc231332f208 100644 --- a/mlir/test/Dialect/ArmSME/invalid.mlir +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -173,3 +173,56 @@ func.func @arm_sme_outerproduct__bad_vector_type(%vecA: vector<[4]xf32>, %vecB: %0 = arm_sme.outerproduct %vecA, %vecB : vector<[4]xf32>, vector<[8]xf32> return %0 : vector<[4]x[4]xf32> } + +//===----------------------------------------------------------------------===// +// arm_sme.fmopa_2way +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_fmopa_2way__bad_rhs_vector_type(%vecA: vector<[8]xf16>, %vecB: vector<[4]xf32>) -> vector<[4]x[4]xf32> +{ + // expected-error@+1 {{op failed to verify that all of {lhs, rhs} have same type}} + %0 = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[4]xf32> into vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way__bad_lhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[4]xi1>, %maskB : vector<[8]xi1>) -> vector<[4]x[4]xf32> +{ + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%maskA' expects different type than prior uses: 'vector<[8]xi1>' vs 'vector<[4]xi1>}} + %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way__bad_rhs_mask_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>, %maskB : vector<[4]xi1>) -> vector<[4]x[4]xf32> +{ + // expected-note@-2 {{prior use here}} + // expected-error@+1 {{use of value '%maskB' expects different type than prior uses: 'vector<[8]xi1>' vs 'vector<[4]xi1>}} + %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way__no_rhs_mask(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA : vector<[8]xi1>) -> vector<[4]x[4]xf32> +{ + // expected-error@+1 {{op failed to verify that both `lhsMask` and `rhsMask` should be provided or neither}} + %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA,) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way__bad_acc_type(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> +{ + %acc = arm_sme.zero : vector<[2]x[2]xi64> + // expected-note@-1 {{prior use here}} + // expected-error@+1 {{use of value '%acc' expects different type than prior uses: 'vector<[4]x[4]xf32>' vs 'vector<[2]x[2]xi64>'}} + %0 = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} diff --git a/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir new file mode 100644 index 0000000000000..6a200848c2d8e --- /dev/null +++ b/mlir/test/Dialect/ArmSME/outer-product-fusion.mlir @@ -0,0 +1,364 @@ +// RUN: mlir-opt %s -arm-sme-outer-product-fusion -cse -split-input-file -allow-unregistered-dialect | FileCheck %s + +// CHECK-LABEL: @outerproduct_add_widening_2way_f16f16f32 +// CHECK-SAME: %[[A0:.*]]: vector<[4]xf16>, %[[B0:.*]]: vector<[4]xf16>, %[[A1:.*]]: vector<[4]xf16>, %[[B1:.*]]: vector<[4]xf16>, +// CHECK-SAME: %[[A0_MASK:.*]]: vector<[4]xi1>, %[[B0_MASK:.*]]: vector<[4]xi1>, %[[A1_MASK:.*]]: vector<[4]xi1>, %[[B1_MASK:.*]]: vector<[4]xi1> +// CHECK-DAG: %[[ACC:.*]] = arith.constant dense<0.000000e+00> : vector<[4]x[4]xf32> +// CHECK-DAG: %[[LHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0]], %[[A1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> +// CHECK-DAG: %[[RHS:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0]], %[[B1]]) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> +// CHECK-DAG: %[[LHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[A0_MASK]], %[[A1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> +// CHECK-DAG: %[[RHS_MASK:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[B0_MASK]], %[[B1_MASK]]) : (vector<[4]xi1>, vector<[4]xi1>) -> vector<[8]xi1> +// CHECK-DAG: arm_sme.fmopa_2way %[[LHS]], %[[RHS]] acc(%[[ACC]]) masks(%[[LHS_MASK]], %[[RHS_MASK]]) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> +func.func @outerproduct_add_widening_2way_f16f16f32( + %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, + %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} + +// ----- + +/// Verify chain of 4 outer products are fused into 2 2-way widening outer +/// products. + +// CHECK-LABEL: @outerproduct_x2_add_widening_2way_f16f16f32 +// CHECK-COUNT-2: arm_sme.fmopa_2way +func.func @outerproduct_x2_add_widening_2way_f16f16f32( + %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, + %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, + %a2 : vector<[4]xf16>, %b2 : vector<[4]xf16>, + %a3 : vector<[4]xf16>, %b3 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %a2_ext = arith.extf %a2 : vector<[4]xf16> to vector<[4]xf32> + %b2_ext = arith.extf %b2 : vector<[4]xf16> to vector<[4]xf32> + + %a3_ext = arith.extf %a3 : vector<[4]xf16> to vector<[4]xf32> + %b3_ext = arith.extf %b3 : vector<[4]xf16> to vector<[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> + %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xf32>, vector<[4]xf32> + %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xf32>, vector<[4]xf32> + + return %3 : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_sub_widening_2way_f16f16f32 +// CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> +func.func @outerproduct_sub_widening_2way_f16f16f32( + %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, + %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_add_widening_2way_bf16bf16f32 +// CHECK: arm_sme.fmopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> +func.func @outerproduct_add_widening_2way_bf16bf16f32( + %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>, + %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32> + + %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_sub_widening_2way_bf16bf16f32 +// CHECK: arm_sme.fmops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> +func.func @outerproduct_sub_widening_2way_bf16bf16f32( + %a0 : vector<[4]xbf16>, %b0 : vector<[4]xbf16>, + %a1 : vector<[4]xbf16>, %b1 : vector<[4]xbf16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xbf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xbf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xbf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xbf16> to vector<[4]xf32> + + %acc = arith.constant dense<0.0> : vector<[4]x[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_add_widening_2way_signed_i16i16i32 +// CHECK: arm_sme.smopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> +func.func @outerproduct_add_widening_2way_signed_i16i16i32( + %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, + %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { + %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32> + %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32> + %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32> + %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32> + + %acc = arith.constant dense<0> : vector<[4]x[4]xi32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> + + return %1 : vector<[4]x[4]xi32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_sub_widening_2way_signed_i16i16i32 +// CHECK: arm_sme.smops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> +func.func @outerproduct_sub_widening_2way_signed_i16i16i32( + %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, + %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { + %a0_ext = arith.extsi %a0 : vector<[4]xi16> to vector<[4]xi32> + %b0_ext = arith.extsi %b0 : vector<[4]xi16> to vector<[4]xi32> + %a1_ext = arith.extsi %a1 : vector<[4]xi16> to vector<[4]xi32> + %b1_ext = arith.extsi %b1 : vector<[4]xi16> to vector<[4]xi32> + + %acc = arith.constant dense<0> : vector<[4]x[4]xi32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> + + return %1 : vector<[4]x[4]xi32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_add_widening_2way_unsigned_i16i16i32 +// CHECK: arm_sme.umopa_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> +func.func @outerproduct_add_widening_2way_unsigned_i16i16i32( + %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, + %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { + %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32> + %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32> + %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32> + %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32> + + %acc = arith.constant dense<0> : vector<[4]x[4]xi32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> + + return %1 : vector<[4]x[4]xi32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_sub_widening_2way_unsigned_i16i16i32 +// CHECK: arm_sme.umops_2way %{{.*}}, %{{.*}} acc(%{{.*}}) masks(%{{.*}}, %{{.*}}) : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> +func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32( + %a0 : vector<[4]xi16>, %b0 : vector<[4]xi16>, + %a1 : vector<[4]xi16>, %b1 : vector<[4]xi16>, + %a0_mask : vector<[4]xi1>, %b0_mask : vector<[4]xi1>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xi32> { + %a0_ext = arith.extui %a0 : vector<[4]xi16> to vector<[4]xi32> + %b0_ext = arith.extui %b0 : vector<[4]xi16> to vector<[4]xi32> + %a1_ext = arith.extui %a1 : vector<[4]xi16> to vector<[4]xi32> + %b1_ext = arith.extui %b1 : vector<[4]xi16> to vector<[4]xi32> + + %acc = arith.constant dense<0> : vector<[4]x[4]xi32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind acc(%acc) masks(%a0_mask, %b0_mask) : vector<[4]xi32>, vector<[4]xi32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xi32>, vector<[4]xi32> + + return %1 : vector<[4]x[4]xi32> +} + +/// Negative tests + +// ----- + +// CHECK-LABEL: @outerproduct_widening_2way__no_acc +// CHECK-NOT: arm_sme.fmopa_2way +// CHECK: arm_sme.outerproduct +// CHECK-NOT: arm_sme.fmopa_2way +func.func @outerproduct_widening_2way__no_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> + + return %0 : vector<[4]x[4]xf32> +} + +// ----- + +/// Defining op of accumulator operand must be an 'arm_sme.outerproduct'. + +// CHECK-LABEL: @outerproduct_widening_2way__bad_acc +// CHECK-NOT: arm_sme.fmopa_2way +// CHECK: arm_sme.outerproduct +// CHECK-NOT: arm_sme.fmopa_2way +func.func @outerproduct_widening_2way__bad_acc(%a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32> + + return %0 : vector<[4]x[4]xf32> +} + +// ----- + +/// Combining kinds of outer products must match to be fused. + +// CHECK-LABEL: @outerproduct_widening_2way__bad_combining_kind +// CHECK-NOT: arm_sme.fmopa_2way +// CHECK: arm_sme.outerproduct +// CHECK: arm_sme.outerproduct +// CHECK-NOT: arm_sme.fmopa_2way +func.func @outerproduct_widening_2way__bad_combining_kind( + %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, + %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext kind : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext kind acc(%0) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} + +// ----- + +/// If the first outer product has uses other than as the input to another +/// outer product, it can't be erased after fusion. This is a problem when +/// it also has an accumulator as this will be used as the root for tile +/// allocation and since the widening outer product uses the same +/// accumulator it will get assigned the same tile ID, resulting in 3 +/// outer products and incorrect results. Check this is prevented. + +// CHECK-LABEL: @outerproduct_widening_2way__cant_erase +// CHECK-NOT: arm_sme.fmopa_2way +// CHECK: arm_sme.outerproduct +// CHECK: arm_sme.outerproduct +// CHECK-NOT: arm_sme.fmopa_2way +func.func @outerproduct_widening_2way__cant_erase( + %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, + %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %acc = arith.constant dense<1.0> : vector<[4]x[4]xf32> + %0 = arm_sme.outerproduct %a0_ext, %b0_ext acc(%acc) : vector<[4]xf32>, vector<[4]xf32> + "fake.use"(%0) : (vector<[4]x[4]xf32>) -> () + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} + +// ----- + +// CHECK-LABEL: @outerproduct_widening_2way__unsupported_type_f32f32f64 +// CHECK-NOT: arm_sme.fmopa_2way +// CHECK: arm_sme.outerproduct +// CHECK: arm_sme.outerproduct +// CHECK-NOT: arm_sme.fmopa_2way +func.func @outerproduct_widening_2way__unsupported_type_f32f32f64( + %a0 : vector<[2]xf32>, %b0 : vector<[2]xf32>, + %a1 : vector<[2]xf32>, %b1 : vector<[2]xf32>) -> vector<[2]x[2]xf64> { + %a0_ext = arith.extf %a0 : vector<[2]xf32> to vector<[2]xf64> + %b0_ext = arith.extf %b0 : vector<[2]xf32> to vector<[2]xf64> + %a1_ext = arith.extf %a1 : vector<[2]xf32> to vector<[2]xf64> + %b1_ext = arith.extf %b1 : vector<[2]xf32> to vector<[2]xf64> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[2]xf64>, vector<[2]xf64> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[2]xf64>, vector<[2]xf64> + + return %1 : vector<[2]x[2]xf64> +} + +// ----- + +/// Fusion only occurs if either both outer products are masked, or neither. + +// CHECK-LABEL: @outerproduct_widening_2way__bad_masking +// CHECK-NOT: arm_sme.fmopa_2way +// CHECK: arm_sme.outerproduct +// CHECK: arm_sme.outerproduct +// CHECK-NOT: arm_sme.fmopa_2way +func.func @outerproduct_widening_2way__bad_masking( + %a0 : vector<[4]xf16>, %b0 : vector<[4]xf16>, + %a1 : vector<[4]xf16>, %b1 : vector<[4]xf16>, + %a1_mask : vector<[4]xi1>, %b1_mask : vector<[4]xi1>) -> vector<[4]x[4]xf32> { + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) masks(%a1_mask, %b1_mask) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} + +// ----- + +/// Defining op of outer product must be a supported extension op. + +// CHECK-LABEL: @outerproduct_widening_2way__bad_defining_op +// CHECK-NOT: arm_sme.fmopa_2way +// CHECK: arm_sme.outerproduct +// CHECK: arm_sme.outerproduct +// CHECK-NOT: arm_sme.fmopa_2way +func.func @outerproduct_widening_2way__bad_defining_op( + %a0 : vector<[4]xf32>, %b0 : vector<[4]xf32>, + %a1 : vector<[4]xf32>, %b1 : vector<[4]xf32>) -> vector<[4]x[4]xf32> { + %0 = arm_sme.outerproduct %a0, %b0 : vector<[4]xf32>, vector<[4]xf32> + %1 = arm_sme.outerproduct %a1, %b1 acc(%0) : vector<[4]xf32>, vector<[4]xf32> + + return %1 : vector<[4]x[4]xf32> +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index 2ad742493408b..ca096363e7283 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -1131,3 +1131,115 @@ func.func @arm_sme_streaming_vl_double_words() -> index { %svl_d = arm_sme.streaming_vl return %svl_d : index } + +//===----------------------------------------------------------------------===// +// arm_sme.fmopa_2way +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_fmopa_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + %result = arm_sme.fmopa_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way_with_masking(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + %result = arm_sme.fmopa_2way %vecA, %vecB masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way_with_acc(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} acc({{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + %result = arm_sme.fmopa_2way %vecA, %vecB acc(%acc) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmopa_2way_with_everything(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>, %acc : vector<[4]x[4]xf32>, %maskA: vector<[8]xi1>, %maskB: vector<[8]xi1>) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.fmopa_2way {{.*}}, {{.*}} acc({{.*}}) masks({{.*}}, {{.*}}) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + %result = arm_sme.fmopa_2way %vecA, %vecB acc(%acc) masks(%maskA, %maskB) : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.fmops_2way +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_fmops_2way_f16f16_to_f32(%vecA: vector<[8]xf16>, %vecB: vector<[8]xf16>) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.fmops_2way {{.*}}, {{.*}} : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_fmops_2way_bf16bf16_to_f32(%vecA: vector<[8]xbf16>, %vecB: vector<[8]xbf16>) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.fmops_2way {{.*}}, {{.*}} : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + %result = arm_sme.fmops_2way %vecA, %vecB : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.smopa_2way +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_smopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + // CHECK: arm_sme.smopa_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + %result = arm_sme.smopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.smops_2way +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_smops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + // CHECK: arm_sme.smops_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + %result = arm_sme.smops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.umopa_2way +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_umopa_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + // CHECK: arm_sme.umopa_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + %result = arm_sme.umopa_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} + +//===----------------------------------------------------------------------===// +// arm_sme.umops_2way +//===----------------------------------------------------------------------===// + +// ----- + +func.func @arm_sme_umops_2way_i16i16_to_i32(%vecA: vector<[8]xi16>, %vecB: vector<[8]xi16>) -> vector<[4]x[4]xi32> { + // CHECK: arm_sme.umops_2way {{.*}}, {{.*}} : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + %result = arm_sme.umops_2way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> + return %result : vector<[4]x[4]xi32> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir new file mode 100644 index 0000000000000..f081838300a9a --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir @@ -0,0 +1,114 @@ +// DEFINE: %{entry} = main +// DEFINE: %{fusion_opts} = -arm-sme-outer-product-fusion +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: -convert-vector-to-arm-sme -convert-arith-to-arm-sme %{fusion_opts} \ +// DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za only-if-required-by-ops" \ +// DEFINE: -convert-arm-sme-to-scf -allocate-arm-sme-tiles \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ +// DEFINE: -test-lower-to-llvm -o %t +// DEFINE: %{run} = %mcr_aarch64_cmd %t \ +// DEFINE: -march=aarch64 -mattr=+sve,+sme \ +// DEFINE: -e %{entry} -entry-point-result=void \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%mlir_arm_runner_utils,%arm_sme_abi_shlib + +// RUN: %{compile} + +// RUN: %{run} | FileCheck %s + +// Check result is the same when outerproducts are not combined into widening +// variant. + +// REDEFINE: %{fusion_opts} = +// RUN: %{run} | FileCheck %s + +func.func @main() { + %c128 = arith.constant 128 : i32 + func.call @setArmSVLBits(%c128) : (i32) -> () + + func.call @test_outerproduct_f16f16f32() : () -> () + + // TODO: A bug in QEMU causes masked FMOPAs to hang [1]. Should be fixed in + // 8.2.0, this test currently isn't run, once this version is available in CI + // it can be run. The output without check lines in the function are correct + // and have been verified on a version with the fix. + // [1] https://gitlab.com/qemu-project/qemu/-/issues/1985 + //func.call @test_masked_outerproduct_f16f16f32() : () -> () + + return +} + +func.func @test_outerproduct_f16f16f32() { + %undef = llvm.mlir.undef : vector<[4]xf16> + + %a0_data = arith.constant dense<[0., 2., 4., 6.]> : vector<4xf16> + %b0_data = arith.constant dense<[1., 3., 5., 7.]> : vector<4xf16> + %a1_data = arith.constant dense<[8., 10., 12., 14.]> : vector<4xf16> + %b1_data = arith.constant dense<[9., 11., 13., 15.]> : vector<4xf16> + + %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %acc = arith.constant dense<7.0> : vector<[4]x[4]xf32> + %0 = vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32> + %1 = vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32> + + // CHECK: ( 79, 95, 111, 127 ) + // CHECK-NEXT: ( 99, 123, 147, 171 ) + // CHECK-NEXT: ( 119, 151, 183, 215 ) + // CHECK-NEXT: ( 139, 179, 219, 259 ) + vector.print %1 : vector<[4]x[4]xf32> + + return +} + +func.func @test_masked_outerproduct_f16f16f32() { + %undef = llvm.mlir.undef : vector<[4]xf16> + + %a0_data = arith.constant dense<[0., 2., 4., 6.]> : vector<4xf16> + %b0_data = arith.constant dense<[1., 3., 5., 7.]> : vector<4xf16> + %a1_data = arith.constant dense<[8., 10., 12., 14.]> : vector<4xf16> + %b1_data = arith.constant dense<[9., 11., 13., 15.]> : vector<4xf16> + + %a0 = vector.scalable.insert %a0_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + %b0 = vector.scalable.insert %b0_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + %a1 = vector.scalable.insert %a1_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + %b1 = vector.scalable.insert %b1_data, %undef[0] : vector<4xf16> into vector<[4]xf16> + + %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> + %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> + %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> + %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> + + %acc = arith.constant dense<7.0> : vector<[4]x[4]xf32> + + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %mask0 = vector.create_mask %c2, %c3 : vector<[4]x[4]xi1> + %mask1 = vector.create_mask %c3, %c2 : vector<[4]x[4]xi1> + + %0 = vector.mask %mask0 { + vector.outerproduct %a0_ext, %b0_ext, %acc : vector<[4]xf32>, vector<[4]xf32> + } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> + + %1 = vector.mask %mask1 { + vector.outerproduct %a1_ext, %b1_ext, %0 : vector<[4]xf32>, vector<[4]xf32> + } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> + + // TODO: CHECK these lines once QEMU is fixed. + // ( 79, 95, 7, 7 ) + // ( 99, 123, 17, 7 ) + // ( 115, 139, 7, 7 ) + // ( 7, 7, 7, 7 ) + vector.print %1 : vector<[4]x[4]xf32> + + return +} + +func.func private @setArmSVLBits(%bits : i32) diff --git a/mlir/test/Target/LLVMIR/arm-sme.mlir b/mlir/test/Target/LLVMIR/arm-sme.mlir index 7a42033dc04bc..aedb6730b06bb 100644 --- a/mlir/test/Target/LLVMIR/arm-sme.mlir +++ b/mlir/test/Target/LLVMIR/arm-sme.mlir @@ -63,6 +63,12 @@ llvm.func @arm_sme_imopa(%nxv8i16 : vector<[8]xi16>, // CHECK: call void @llvm.aarch64.sme.usmopa.wide.nxv16i8 "arm_sme.intr.usmopa.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.smopa.za32.nxv8i16 + "arm_sme.intr.smopa.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> : + (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.umopa.za32.nxv8i16 + "arm_sme.intr.umopa.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> : + (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () llvm.return } @@ -122,6 +128,12 @@ llvm.func @arm_sme_imops(%nxv8i16 : vector<[8]xi16>, // CHECK: call void @llvm.aarch64.sme.usmops.wide.nxv16i8 "arm_sme.intr.usmops.wide"(%nxv16i1, %nxv16i1, %nxv16i8, %nxv16i8) <{tile_id = 0 : i32}> : (vector<[16]xi1>, vector<[16]xi1>, vector<[16]xi8>, vector<[16]xi8>) -> () + // CHECK: call void @llvm.aarch64.sme.smops.za32.nxv8i16 + "arm_sme.intr.smops.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> : + (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () + // CHECK: call void @llvm.aarch64.sme.umops.za32.nxv8i16 + "arm_sme.intr.umops.za32"(%nxv8i1, %nxv8i1, %nxv8i16, %nxv8i16) <{tile_id = 0 : i32}> : + (vector<[8]xi1>, vector<[8]xi1>, vector<[8]xi16>, vector<[8]xi16>) -> () llvm.return }