diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h index eab871ab49998..403f811a2569a 100644 --- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h +++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h @@ -12,6 +12,7 @@ #include #include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Interfaces/FunctionInterfaces.h" namespace mlir { class Pass; @@ -21,7 +22,8 @@ class RewritePatternSet; #include "mlir/Conversion/Passes.h.inc" /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics. -std::unique_ptr createConvertArmSMEToLLVMPass(); +std::unique_ptr +createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges = false); /// Configure target to convert from the ArmSME dialect to LLVM intrinsics. void configureArmSMEToLLVMConversionLegality(ConversionTarget &target); diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index d094ee3b36ab9..e6d678dc1b12b 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1285,7 +1285,7 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> { // ArmSMEToLLVM //===----------------------------------------------------------------------===// -def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> { +def ConvertArmSMEToLLVM : InterfacePass<"convert-arm-sme-to-llvm", "FunctionOpInterface"> { let summary = "Lower the operations from the ArmSME dialect into the LLVM " "dialect"; let constructor = "mlir::createConvertArmSMEToLLVMPass()"; @@ -1293,6 +1293,11 @@ def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> { "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect" ]; + let options = [ + Option<"dumpTileLiveRanges", "dump-tile-live-ranges", + "bool", /*default=*/"false", + "Dump the live ranges of SME tiles (for debugging)"> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h index c507cea5357a7..dac54712c7f47 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -15,6 +15,7 @@ #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h" +#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -24,11 +25,6 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -namespace mlir::arm_sme { -static constexpr unsigned kInMemoryTileIdBase = 16; -#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc" -} // namespace mlir::arm_sme - #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/ArmSME/IR/ArmSMEAttrDefs.h.inc" diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h new file mode 100644 index 0000000000000..9153fbb57ea88 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h @@ -0,0 +1,27 @@ +//===- ArmSMEOpInterfaces.h - Arm SME Dialect OpInterfaces ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_OPINTERFACES_H +#define MLIR_DIALECT_ARMSME_OPINTERFACES_H + +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +namespace mlir::arm_sme { + +namespace detail { +LogicalResult verifyArmSMETileOpInterface(Operation *); +} + +// The first in-memory SME tile ID. This is set to 16 as that is the first tile +// ID larger than any virtual tile ID supported by the SME ISA. +static constexpr unsigned kInMemoryTileIdBase = 16; + +#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h.inc" +} // namespace mlir::arm_sme + +#endif // MLIR_DIALECT_ARMSME_OPINTERFACES_H diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td index 239c4beab10d2..9178655f010c9 100644 --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td @@ -39,10 +39,10 @@ def ArmSMETileType : I32EnumAttr<"ArmSMETileType", "Arm SME tile type", def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { let description = [{ - An interface for operations that use or allocate Arm SME tiles. These - operations need to be assigned a tile ID, an i32 attribute, which specifies - which virtual tile within the ZA storage to use. The number of tiles - available depends on the type of the tile. This is summarized below: + An interface for operations that use Arm SME tiles. These operations need to + be assigned a tile ID, an i32 attribute, which specifies which virtual tile + within the ZA storage to use. The number of tiles available depends on the + type of the tile. This is summarized below: | Tile Vector Types | Possible Tile IDs | |-------------------------------------------------------------------------|---------------------| @@ -51,10 +51,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { | `vector<[4]x[4]xi32>` or `vector<[4]x[4]xf32>` | 0 to 3 (inclusive) | | `vector<[2]x[2]xi64>` or `vector<[2]x[2]xf64>` | 0 to 7 (inclusive) | | `vector<[1]x[1]xi128>` | 0 to 15 (inclusive) | - - Operations that allocate a new tile (such as arm_sme.get_tile), are used as - the roots for tile allocation, with all operations that (transitively) - depend on a root being assigned the same tile ID. }]; let methods = [ InterfaceMethod< @@ -84,20 +80,6 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { return op->getAttrOfType("tile_id"); }] >, - InterfaceMethod< - [{ - The type of tile this operation allocates. Returns none (std::nullopt) - if this operation does not allocate a tile. - }], - /*returnType=*/"std::optional<::mlir::arm_sme::ArmSMETileType>", - /*methodName=*/"getAllocatedTileType", - /*arguments=*/(ins), - /*methodBody=*/[{}], - /*defaultImpl=*/ [{ - // This operation does not allocate a tile. - return std::nullopt; - }] - >, InterfaceMethod< "Returns the VectorType of the tile used by this operation.", /*returnType=*/"VectorType", @@ -106,30 +88,13 @@ def ArmSMETileOpInterface : OpInterface<"ArmSMETileOpInterface"> { ]; let extraSharedClassDeclaration = [{ - // A helper to create a new operation and propagate this operations tile ID. - template - T createOpAndForwardTileId(::mlir::RewriterBase& rewriter, ::mlir::Location loc, Args &&...args) { - auto op = rewriter.create(loc, std::forward(args)...); - if (auto tileOp = ::llvm::dyn_cast(op.getOperation())) - tileOp.setTileId($_op.getTileId()); - return op; - } - - // A helper to replace this operation and forward its tile ID (if present). - template - T replaceWithAndForwardTileId(::mlir::RewriterBase& rewriter, Args &&...args) { - auto newOp = createOpAndForwardTileId(rewriter, $_op.getLoc(), std::forward(args)...); - rewriter.replaceOp($_op, newOp); - return newOp; - } - bool isInMemoryTile() { auto tileId = getTileId(); return tileId && tileId.getInt() >= kInMemoryTileIdBase; } }]; - let verify = [{ return ::mlir::arm_sme::verifyOperationHasValidTileId($_op); }]; + let verify = [{ return detail::verifyArmSMETileOpInterface($_op); }]; } //===----------------------------------------------------------------------===// @@ -255,30 +220,30 @@ def ArmSME_TypeSizeAttr : EnumAttr traits = []> : Op {} -def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> { - let summary = "Returns a SME virtual tile"; +def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface, Pure]> { + let summary = "Creates an undefined value of SME virtual tile type"; let description = [{ - Allocates a new SME "virtual tile" within a function. The contents of the - tile returned from this operation are undefined. + Creates a new SME "virtual tile" value within a function. The contents of + the tile returned from this operation are undefined. Example 1: ```mlir - // Allocate an 8-bit element "virtual tile" + // Create an 8-bit element "virtual tile" value: %za0_b = arm_sme.get_tile: vector<[16]x[16]xi8> ``` Example 2: ```mlir - // Allocate two 16-bit element "virtual tiles" + // Create two 16-bit element "virtual tiles" values: %za0_h = arm_sme.get_tile : vector<[8]x[8]xi16> %za1_h = arm_sme.get_tile : vector<[8]x[8]xi16> ``` Example 3: ```mlir - // Allocate an 128-bit element "virtual tile" + // Create an 128-bit element "virtual tile" value: %za0_q = arm_sme.get_tile : vector<[1]x[1]xi128> ``` }]; @@ -290,37 +255,15 @@ def GetTileOp : ArmSME_Op<"get_tile", [ArmSMETileOpInterface]> { VectorType getTileType() { return ::llvm::cast(getTile().getType()); } - - std::optional getAllocatedTileType() { - return arm_sme::getSMETileType(getTileType()); - } - }]; -} - -def MaterializeSSATileOp : ArmSME_Op<"materialize_ssa_tile", [Pure]> { - let summary = "SME tile placeholder"; - let description = [{ - A placeholder to preserve dataflow while lowering to SME intrinsics (which - do not take or return SME virtual tile values). This operation is intended - to be DCE'd once all ArmSME operations have been lowered. - - This operation is not intended to be used outside of the ArmSME -> LLVM - conversion. }]; - let results = (outs SMETile:$tile); - let assemblyFormat = "attr-dict `:` type($tile)"; } -// -// Tile reset. -// - -def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> { - let summary = "Initialize the two-dimensional ZA array with 0s"; +def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface, Pure]> { + let summary = "Creates a zero-initialized value of SME virtual tile type"; let results = (outs SMETile:$res); let description = [{ - Initialise ZA with 0. This operation is convenient wrapper for the SME - `zero` intrinsic and instruction. + Creates a new SME "virtual tile" value within a function. The contents of + the tile returned from this operation are zero-initialized. Example 1: Zero an 8-bit element ZA tile. @@ -338,9 +281,6 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> { VectorType getVectorType() { return ::llvm::cast(getRes().getType()); } - std::optional getAllocatedTileType() { - return arm_sme::getSMETileType(getVectorType()); - } VectorType getTileType() { return getVectorType(); } @@ -348,6 +288,32 @@ def ZeroOp : ArmSME_Op<"zero", [ArmSMETileOpInterface]> { let assemblyFormat = "attr-dict `:` type($res)"; } +def CopyTileOp : ArmSME_Op<"copy_tile", [ + Pure, + ArmSMETileOpInterface, + AllTypesMatch<["tile", "result"]> +]> { + let summary = "Copies an SME tile value"; + let arguments = (ins SMETile:$tile); + let results = (outs SMETile:$result); + let description = [{ + Copies an SME "virtual tile" value to a new SSA value. This operation is + primarily intended to be used to normalize the IR prior to tile allocation. + + Example: + + ```mlir + %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> + ``` + }]; + let extraClassDeclaration = [{ + VectorType getTileType() { + return ::llvm::cast(getResult().getType()); + } + }]; + let assemblyFormat = "$tile attr-dict `:` type($result)"; +} + def TileLoadOp : ArmSME_Op<"tile_load", [ ArmSMETileOpInterface, AttrSizedOperandSegments, @@ -417,9 +383,6 @@ def TileLoadOp : ArmSME_Op<"tile_load", [ VectorType getVectorType() { return ::llvm::cast(getResult().getType()); } - std::optional getAllocatedTileType() { - return arm_sme::getSMETileType(getVectorType()); - } VectorType getTileType() { return getVectorType(); } @@ -545,7 +508,7 @@ def LoadTileSliceOp : ArmSME_Op<"load_tile_slice", [ ``` }]; let arguments = (ins - Arg:$base, SVEPredicate:$mask, + Arg:$base, SVEPredicate:$mask, SMETile:$tile, Variadic:$indices, Index:$tile_slice_index, ArmSME_TileSliceLayoutAttr:$layout ); @@ -630,7 +593,7 @@ def StoreTileSliceOp : ArmSME_Op<"store_tile_slice", [ } def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [ - ArmSMETileOpInterface, + ArmSMETileOpInterface, Pure, AllTypesMatch<["tile", "result"]>, TypesMatchWith< "type of 'vector' matches type of 'tile' slice", @@ -679,7 +642,7 @@ def MoveVectorToTileSliceOp : ArmSME_Op<"move_vector_to_tile_slice", [ } def MoveTileSliceToVectorOp : ArmSME_Op<"move_tile_slice_to_vector", [ - ArmSMETileOpInterface, + ArmSMETileOpInterface, Pure, TypesMatchWith< "type of 'result' matches type of 'tile' slice", "tile", "result", @@ -736,6 +699,7 @@ class OuterProductResultTileTypeConstraint : def OuterProductOp : ArmSME_Op<"outerproduct", [ + Pure, ArmSMETileOpInterface, AttrSizedOperandSegments, AllTypesMatch<["lhs", "rhs"]>, @@ -802,12 +766,6 @@ let arguments = (ins VectorType getLhsType() { return llvm::cast(getLhs().getType()); } VectorType getRhsType() { return llvm::cast(getRhs().getType()); } VectorType getResultType() { return llvm::cast(getResult().getType()); } - std::optional getAllocatedTileType() { - // The outerproduct op allocates a new tile if no accumulator is passed. - if (!getAcc()) - return arm_sme::getSMETileType(getResultType()); - return std::nullopt; - } VectorType getTileType() { return getResultType(); } @@ -819,6 +777,7 @@ class OuterProductWideningBase allowedResultVectorTypes, int numOuterProducts> : ArmSME_Op, @@ -857,12 +816,6 @@ class OuterProductWideningBase(getLhs().getType()); } VectorType getRhsType() { return llvm::cast(getRhs().getType()); } VectorType getResultType() { return llvm::cast(getResult().getType()); } - std::optional getAllocatedTileType() { - // The outerproduct op allocates a new tile if no accumulator is passed. - if (!getAcc()) - return arm_sme::getSMETileType(getResultType()); - return std::nullopt; - } VectorType getTileType() { return getResultType(); } diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h index c2f1b1f1b874e..156744ba57e7b 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -29,9 +29,6 @@ std::unique_ptr createEnableArmStreamingPass( const ArmStreamingMode = ArmStreamingMode::Streaming, const ArmZaMode = ArmZaMode::Disabled, bool onlyIfRequiredByOps = false); -/// Pass that allocates tile IDs to ArmSME operations. -std::unique_ptr createTileAllocationPass(); - /// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening /// variants. std::unique_ptr createOuterProductFusionPass(); diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td index 7959d291e8926..869a031d6cae8 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td @@ -124,17 +124,25 @@ def EnableArmStreaming let dependentDialects = ["func::FuncDialect"]; } -def TileAllocation - : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> { - let summary = "Allocate SME tiles"; +def TestTileAllocation + : Pass<"test-arm-sme-tile-allocation", "mlir::func::FuncOp"> { + let summary = "Tests SME 'virtual tile' allocation"; let description = [{ This pass does tile allocation for SME "virtual tiles". It is run at the 'func.func' op level, and assigns tile IDs (via an attribute) to all ops - that implement the `ArmSMETileOpInterface`. An error will be emitted when - there's no tiles left. + that implement the `ArmSMETileOpInterface`. Note: This pass is only intended + to be used for testing, tile allocation is done as part of the ArmSME to + LLVM conversion (`convert-arm-sme-to-llvm`). }]; - let constructor = "mlir::arm_sme::createTileAllocationPass()"; - let dependentDialects = ["func::FuncDialect"]; + let options = [ + Option<"dumpTileLiveRanges", "dump-tile-live-ranges", + "bool", /*default=*/"false", + "Dump the live ranges of SME tiles (for debugging)">, + Option<"preprocessOnly", "preprocess-only", "bool", /*default=*/"false", + "Only preprocess IR so it is ready for tile allocation " + "(but do not allocate any tiles)"> + ]; + let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect"]; } def OuterProductFusion diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h index e00c7503e6999..a25b844f01eaa 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -9,6 +9,8 @@ #ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H #define MLIR_DIALECT_ARMSME_TRANSFORMS_H +#include "mlir/Interfaces/FunctionInterfaces.h" + namespace mlir { class LLVMConversionTarget; @@ -16,7 +18,14 @@ class LLVMTypeConverter; class RewritePatternSet; namespace arm_sme { + void populateOuterProductFusionPatterns(RewritePatternSet &patterns); + +/// Allocate tile IDs to all ArmSME operations in a function. Requires the +/// function to be lowered to control flow (cf dialect). +LogicalResult allocateSMETiles(FunctionOpInterface function, + bool dumpRanges = false); + } // namespace arm_sme } // namespace mlir diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h index 027ad8954f92f..1f40eb6fc693c 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h @@ -16,8 +16,10 @@ #define MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ #include "mlir/Dialect/ArmSME/IR/ArmSMEEnums.h" +#include "mlir/Dialect/ArmSME/IR/ArmSMEOpInterfaces.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include namespace mlir { @@ -42,6 +44,11 @@ bool isValidSMETileElementType(Type type); /// otherwise. bool isValidSMETileVectorType(VectorType vType); +inline bool isValidSMETileVectorType(Type type) { + auto vType = dyn_cast(type); + return vType && isValidSMETileVectorType(vType); +} + /// Returns the type of SME tile this vector type corresponds to, or none if the /// vector type does not fit within an SME tile. std::optional getSMETileType(VectorType); @@ -63,6 +70,31 @@ bool isMultipleOfSMETileVectorType(VectorType vType); /// Creates a vector type for the SME tile of `elementType`. VectorType getSMETileTypeForElement(Type elementType); +/// Erase trivially dead tile ops from a function. +void eraseTriviallyDeadTileOps(IRRewriter &rewriter, + FunctionOpInterface function); + +/// Returns true if `tileOp` is trivially cloneable. A tile operation is +/// trivially cloneable if: +/// +/// 1. It has no operands (and only a single tile result) +/// 2. It is 'Pure' +/// +/// This ensures that the cloned operation will not share any dependencies with +/// the original operation (which could also need to be considered), and that +/// inserting the cloned operation at a different point in the program won't +/// change the semantics of the program (as it has no side effects). +bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp); + +/// Returns true if `tileOp` produces a tile result. +bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp); + +/// Returns the tile `OpOperand` for this `tileOp` (or null). +OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp); + +/// Returns true `typeA` is >= (in terms of bytes) than `typeB`. +bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB); + } // namespace mlir::arm_sme #endif // MLIR_DIALECT_ARMSME_UTILS_UTILS_H_ diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 1ba1b88fc1234..3dbc8e9916df6 100644 --- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" @@ -245,6 +246,10 @@ struct ConvertArmSMESpillsAndFillsToLLVM : public ConvertToLLVMPattern { if (!tileOp.isInMemoryTile()) return failure(); + tileOp->emitWarning( + "failed to allocate SME virtual tile to operation, all tile " + "operations will go through memory, expect degraded performance"); + // Step 1. Create an alloca for the tile at the top of the function (if one // does not already exist). auto loc = tileOp.getLoc(); @@ -391,20 +396,6 @@ addArmSMEConversionPatterns(RewritePatternSet &patterns, (addArmSMEConversionPattern(patterns, typeConverter), ...); } -struct GetTileConversion - : public ConvertArmSMEOpToLLVMPattern { - using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; - - LogicalResult - matchAndRewrite(arm_sme::GetTileOp getTile, OpAdaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - getTile, getTile.getTileType()); - return success(); - } -}; - /// Lower 'arm_sme.zero' to SME intrinsics. /// /// BEFORE: @@ -415,11 +406,11 @@ struct GetTileConversion /// AFTER: /// ```mlir /// "arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> () -/// %v = arm_sme.materialize_ssa_tile : vector<[4]x[4]xi32> +/// %v = arm_sme.get_tile : vector<[4]x[4]xi32> /// ``` /// -/// The 'arm_sme.materialize_ssa_tile' (which models the return) will fold away -/// once all ArmSME ops have been converted to LLVM intrinsics. +/// The 'arm_sme.get_tile' (which models the return) will fold away once all +/// ArmSME ops have been converted to LLVM intrinsics. struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { using ConvertArmSMEOpToLLVMPattern::ConvertArmSMEOpToLLVMPattern; @@ -436,7 +427,8 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { // The base mask is just the mask to zero the first tile (of a size). // These masks are derived from: // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles- - arm_sme::ArmSMETileType tileType = *zero.getAllocatedTileType(); + arm_sme::ArmSMETileType tileType = + *arm_sme::getSMETileType(zero.getTileType()); auto baseMaskForSize = [&] { switch (tileType) { case arm_sme::ArmSMETileType::ZAB: @@ -488,8 +480,7 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern { loc, rewriter.getI32IntegerAttr(zeroMask)); // Create a placeholder op to preserve dataflow. - rewriter.replaceOpWithNewOp( - zero, zero.getVectorType()); + rewriter.replaceOpWithNewOp(zero, zero.getVectorType()); return success(); } @@ -746,10 +737,12 @@ struct OuterProductOpConversion auto loc = outerProductOp.getLoc(); Value acc = outerProductOp.getAcc(); - if (!acc) + if (!acc) { // Initalize accumulator with zero. - acc = outerProductOp.createOpAndForwardTileId( - rewriter, loc, resultVectorType); + auto zero = rewriter.create(loc, resultVectorType); + zero.setTileId(tileId); + acc = zero; + } Value lhsMask = outerProductOp.getLhsMask(); Value rhsMask = outerProductOp.getRhsMask(); @@ -791,25 +784,27 @@ struct OuterProductWideningOpConversion if (!tileId) return failure(); + auto loc = op.getLoc(); Value acc = op.getAcc(); - if (!acc) + if (!acc) { // Initalize accumulator with zero. - acc = op.template createOpAndForwardTileId( - rewriter, op.getLoc(), op.getResultType()); + auto zero = rewriter.create(loc, op.getResultType()); + zero.setTileId(tileId); + acc = zero; + } Value lhsMask = op.getLhsMask(); Value rhsMask = op.getRhsMask(); if (!lhsMask || !rhsMask) { auto predTy = op.getLhsType().cloneWith({}, rewriter.getI1Type()); Value allActiveMask = rewriter.create( - op.getLoc(), DenseElementsAttr::get(predTy, true)); + loc, DenseElementsAttr::get(predTy, true)); lhsMask = allActiveMask; rhsMask = allActiveMask; } - rewriter.create(op.getLoc(), tileId, lhsMask, - rhsMask, adaptor.getLhs(), - adaptor.getRhs()); + rewriter.create( + loc, tileId, lhsMask, rhsMask, adaptor.getLhs(), adaptor.getRhs()); // The outerproduct intrinsics have no result, replace // 'arm_sme.outerproduct' with the input tile to preserve dataflow. @@ -865,15 +860,22 @@ namespace { struct ConvertArmSMEToLLVMPass : public impl::ConvertArmSMEToLLVMBase { + ConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) { + this->dumpTileLiveRanges = dumpTileLiveRanges; + } void runOnOperation() override { + auto function = getOperation(); + + if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) + return signalPassFailure(); + LLVMConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); configureArmSMEToLLVMConversionLegality(target); populateArmSMEToLLVMConversionPatterns(converter, patterns); - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(applyPartialConversion(function, target, std::move(patterns)))) signalPassFailure(); } }; @@ -883,34 +885,38 @@ struct ConvertArmSMEToLLVMPass void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { target.addIllegalDialect(); target.addLegalOp< - arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero, - arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz, - arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz, - arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz, - arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, - arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, - arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert, - arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert, - arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert, - arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert, - arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert, - arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz, - arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz, - arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa, - arm_sme::aarch64_sme_mopa_wide, arm_sme::aarch64_sme_mops_wide, - arm_sme::aarch64_sme_smopa_wide, arm_sme::aarch64_sme_smops_wide, - arm_sme::aarch64_sme_umopa_wide, arm_sme::aarch64_sme_umops_wide, - arm_sme::aarch64_sme_smopa_za32, arm_sme::aarch64_sme_smops_za32, - arm_sme::aarch64_sme_umopa_za32, arm_sme::aarch64_sme_umops_za32, - arm_sme::aarch64_sme_sumopa_wide, arm_sme::aarch64_sme_sumops_wide, - arm_sme::aarch64_sme_usmopa_wide, arm_sme::aarch64_sme_usmops_wide, - arm_sme::aarch64_sme_cntsb, arm_sme::aarch64_sme_cntsh, - arm_sme::aarch64_sme_cntsw, arm_sme::aarch64_sme_cntsd>(); + arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str, + arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz, + arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz, + arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz, + arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, + arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz, + arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert, + arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert, + arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert, + arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert, + arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert, + arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert, + arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert, + arm_sme::aarch64_sme_mopa, arm_sme::aarch64_sme_mopa_wide, + arm_sme::aarch64_sme_mops_wide, arm_sme::aarch64_sme_smopa_wide, + arm_sme::aarch64_sme_smops_wide, arm_sme::aarch64_sme_umopa_wide, + arm_sme::aarch64_sme_umops_wide, arm_sme::aarch64_sme_smopa_za32, + arm_sme::aarch64_sme_smops_za32, arm_sme::aarch64_sme_umopa_za32, + arm_sme::aarch64_sme_umops_za32, arm_sme::aarch64_sme_sumopa_wide, + arm_sme::aarch64_sme_sumops_wide, arm_sme::aarch64_sme_usmopa_wide, + arm_sme::aarch64_sme_usmops_wide, arm_sme::aarch64_sme_cntsb, + arm_sme::aarch64_sme_cntsh, arm_sme::aarch64_sme_cntsw, + arm_sme::aarch64_sme_cntsd>(); target.addLegalDialect(); - target.addLegalOp(); + // Pseudo operations. These cannot be code-generated but may exist in the + // input IR, or be generated during the conversion. They need to be eliminated + // before the final conversion to LLVM IR (and likely will be due to DCE). + target.addLegalOp(); } void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, @@ -955,9 +961,10 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter, arm_sme::aarch64_sme_usmopa_wide>, OuterProductWideningOpConversion, - ZeroOpConversion, GetTileConversion>(patterns, converter); + ZeroOpConversion>(patterns, converter); } -std::unique_ptr mlir::createConvertArmSMEToLLVMPass() { - return std::make_unique(); +std::unique_ptr +mlir::createConvertArmSMEToLLVMPass(bool dumpTileLiveRanges) { + return std::make_unique(dumpTileLiveRanges); } diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index 16b61c282749c..9f55932c33af6 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -196,12 +196,9 @@ struct TileLoadOpConversion : public OpRewritePattern { // Initialize tile with zero to satisfy padding. Inactive cols will be // zeroed anyway since the loads use zeroing predication. For inactive // rows however, no load will occur so these need to be zeroed. - initTile = tileLoadOp.createOpAndForwardTileId( - rewriter, loc, tileType); + initTile = rewriter.create(loc, tileType); } else { - // Allocate a new SME tile. - initTile = tileLoadOp.createOpAndForwardTileId( - rewriter, loc, tileType); + initTile = rewriter.create(loc, tileType); } // Create a loop to load the active tile slices from memory. @@ -212,10 +209,9 @@ struct TileLoadOpConversion : public OpRewritePattern { Value currentTile) -> Value { // Create 'arm_sme.load_tile_slice' to load tile slice from memory // into tile. - return tileLoadOp.createOpAndForwardTileId( - rewriter, loc, tileType, tileLoadOp.getBase(), predicate, - currentTile, memrefIndices, tileSliceIndex, - tileLoadOp.getLayout()); + return rewriter.create( + loc, tileType, tileLoadOp.getBase(), predicate, currentTile, + memrefIndices, tileSliceIndex, tileLoadOp.getLayout()); }); if (failed(forOp)) @@ -292,9 +288,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto numColsI32 = rewriter.create( loc, rewriter.getI32Type(), numCols); - // Allocate a new SME tile. - auto initTile = tileLoadOp.createOpAndForwardTileId( - rewriter, loc, tileType); + auto initTile = rewriter.create(loc, tileType); // Create a loop that loads each ZA tile slice from memory. auto step = rewriter.create(loc, 1); @@ -339,10 +333,9 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion /*passthru=*/pad1DOp); // Create 'arm_sme.move_vector_to_tile_slice' to move slice into tile. - auto moveSlice = - tileLoadOp.createOpAndForwardTileId( - rewriter, loc, tileType, loadSlice->getResult(0), currentTile, - tileSliceIndex, tileLoadOp.getLayout()); + auto moveSlice = rewriter.create( + loc, tileType, loadSlice->getResult(0), currentTile, tileSliceIndex, + tileLoadOp.getLayout()); rewriter.create(loc, moveSlice.getResult()); rewriter.setInsertionPointAfter(forOp); @@ -386,8 +379,8 @@ struct TileStoreOpConversion : public OpRewritePattern { tileStoreOp.getIndices(), tileStoreOp.getMemRefType().getRank(), tileStoreOp.getMask(), [&](Value tileSliceIndex, ValueRange memrefIndices, Value predicate) { - tileStoreOp.replaceWithAndForwardTileId( - rewriter, tileStoreOp.getValueToStore(), tileSliceIndex, + rewriter.replaceOpWithNewOp( + tileStoreOp, tileStoreOp.getValueToStore(), tileSliceIndex, predicate, tileStoreOp.getBase(), memrefIndices, tileStoreOp.getLayout()); }); diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp index 29fa9085a0a96..cb3a665844872 100644 --- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp @@ -20,6 +20,12 @@ using namespace mlir; using namespace mlir::arm_sme; +namespace mlir::arm_sme::detail { +LogicalResult verifyArmSMETileOpInterface(Operation *op) { + return verifyOperationHasValidTileId(op); +} +} // namespace mlir::arm_sme::detail + //===----------------------------------------------------------------------===// // Tablegen Definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp index 6a9e022182226..1f7305a5f8141 100644 --- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp @@ -116,4 +116,57 @@ VectorType getSMETileTypeForElement(Type elementType) { return VectorType::get({minNumElts, minNumElts}, elementType, {true, true}); } +void eraseTriviallyDeadTileOps(IRRewriter &rewriter, + FunctionOpInterface function) { + SmallVector worklist; + function->walk([&](Operation *op) { + auto armSMEOp = dyn_cast(op); + if (armSMEOp && isOpTriviallyDead(armSMEOp)) + worklist.push_back(armSMEOp); + }); + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + if (!isOpTriviallyDead(op)) + continue; + for (Value value : op->getOperands()) { + if (auto armSMEOp = value.getDefiningOp()) + worklist.push_back(armSMEOp); + } + rewriter.eraseOp(op); + } +} + +bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) { + return tileOp && tileOp->getNumResults() == 1 && + tileOp->getNumOperands() == 0 && isPure(tileOp); +} + +bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) { + for (Value result : tileOp->getResults()) { + if (arm_sme::isValidSMETileVectorType(result.getType())) + return true; + } + return false; +} + +OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) { + if (!tileOp) + return nullptr; + auto isTileOperandType = [](OpOperand &operand) { + return arm_sme::isValidSMETileVectorType(operand.get().getType()); + }; + assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 && + "expected at most one tile operand"); + OpOperand *tileOperand = + llvm::find_if(tileOp->getOpOperands(), isTileOperandType); + if (tileOperand == tileOp->getOpOperands().end()) + return nullptr; + return tileOperand; +} + +bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) { + // Note: This is <= due to how tile types are numbered in ArmSMEOps.td. + return static_cast(typeA) <= static_cast(typeB); +} + } // namespace mlir::arm_sme diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index 4acb2a8fb7b53..1e1e0e569124d 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -6,12 +6,18 @@ // //===----------------------------------------------------------------------===// // -// This pass allocates SME tiles at the 'func.func' op level for ArmSME -// operations. It does this using a 16-bit tile mask that has a bit for each -// 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile granule. +// This transform allocates SME tiles at the 'func.func' op level for ArmSME +// operations. It roughly implements a linear scan register allocator, similar +// to the one outlined in [1], but with simplifications and assumptions made for +// our use case. Note that this is a greedy allocator (so it may not always find +// the most optimal allocation of tiles). +// +// The allocator operates at the CF dialect level. It is the responsibility of +// users to ensure the IR has been lowered to CF before invoking the tile +// allocator. // // The 128-bit tiles overlap with other element tiles as follows (see section -// B2.3.2 of SME spec [1]): +// B2.3.2 of SME spec [2]): // // Tile Overlaps // --------------------------------------------------------------------------- @@ -32,39 +38,34 @@ // ZA6.D ZA6.Q, ZA14.Q // ZA7.D ZA7.Q, ZA15.Q // -// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use' -// that is initalized during the first tile allocation within a function and -// updated on each subsequent allocation. -// -// [1] https://developer.arm.com/documentation/ddi0616/aa +// [1] "Linear Scan Register Allocation in the Context of SSA Form and Register +// Constraints" (Hanspeter Mössenböck and Michael Pfeiffer) +// https://link.springer.com/content/pdf/10.1007/3-540-45937-5_17.pdf +// [2] https://developer.arm.com/documentation/ddi0616/aa // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/Liveness.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/TypeSwitch.h" +#include -#define DEBUG_TYPE "allocate-arm-sme-tiles" - -namespace mlir { -namespace arm_sme { -#define GEN_PASS_DEF_TILEALLOCATION +namespace mlir::arm_sme { +#define GEN_PASS_DEF_TESTTILEALLOCATION #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" -} // namespace arm_sme -} // namespace mlir +} // namespace mlir::arm_sme using namespace mlir; using namespace mlir::arm_sme; namespace { -static constexpr StringLiteral kTilesInUseAttr("arm_sme.tiles_in_use"); -static constexpr StringLiteral - kNextInMemoryTileIdAttr("arm_sme.next_in_memory_tile_id"); - enum class TileMask : unsigned { // clang-format off kZA0B = 0xffff, // 1111 1111 1111 1111 @@ -137,172 +138,640 @@ static ArrayRef getMasks(ArmSMETileType type) { } } -/// Allocates and returns a tile ID. Returns an error if there are no tiles -/// left. -static FailureOr allocateTileId(ArmSMETileType tileType, - TileMask &tilesInUse) { - auto masks = getMasks(tileType); - for (auto [tileId, tileMask] : llvm::enumerate(masks)) { - if ((tilesInUse & tileMask) == TileMask::kNone) { - tilesInUse |= tileMask; - return tileId; +class TileAllocator { +public: + /// Allocates and returns a tile ID. Fails if there are no tiles left. + FailureOr allocateTileId(ArmSMETileType tileType) { + auto masks = getMasks(tileType); + for (auto [tileId, tileMask] : llvm::enumerate(masks)) { + if ((tilesInUse & tileMask) == TileMask::kNone) { + tilesInUse |= tileMask; + return tileId; + } } + return failure(); + } + + /// Releases a previously allocated tile ID. + void releaseTileId(ArmSMETileType tileType, unsigned tileId) { + TileMask tileMask = getMasks(tileType)[tileId]; + assert((tilesInUse & tileMask) != TileMask::kNone && + "cannot release unallocated tile!"); + tilesInUse ^= tileMask; } - return failure(); -} -/// Collects transitive uses of a root value through control flow. This can -/// handle basic SCF constructs, along with control flow (br and cond_br). -/// Simple loops work at the SCF level, while more complex control flow can be -/// dealt with after lowering to CF. This is used to implement basic tile -/// allocation. -static void findDependantOps(Value rootValue, - SetVector &dependantOps) { - auto traverseCorrespondingValues = [&](auto inputValues, auto exitValues) { - for (auto [idx, value] : llvm::enumerate(inputValues)) { - if (value == rootValue) - findDependantOps(exitValues[idx], dependantOps); + /// Allocates an in-memory tile ID. + unsigned allocateInMemoryTileId() { + // Note: We never release in-memory tile IDs. We could, which may allow + // reusing an allocation, but as we _never_ want to spill an SME tile this + // is not optimized. + return nextInMemoryTileId++; + } + +private: + TileMask tilesInUse = TileMask::kNone; + unsigned nextInMemoryTileId = kInMemoryTileIdBase; +}; + +/// Add new intermediate blocks for the true and false destinations of +/// `cf.cond_br`s that contain tile operands. This prevents spurious liveness +/// overlaps due to copies at branches. +/// +/// BEFORE: +/// ```mlir +/// cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 +/// ``` +/// +/// AFTER: +/// ```mlir +/// cf.cond_br %cond, ^bb1_copy, ^bb2_copy +/// ^bb1_copy: +/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) +/// ^bb2_copy: +/// cf.br ^bb2 +/// ``` +void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { + SmallVector worklist; + function.walk([&](cf::CondBranchOp condBranch) { + if (llvm::any_of(condBranch->getOperands(), [&](Value value) { + return isValidSMETileVectorType(value.getType()); + })) { + worklist.push_back(condBranch); } + }); + + auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { + rewriter.setInsertionPointToEnd(source); + rewriter.create(loc, dest, args); }; - for (Operation *user : rootValue.getUsers()) { - if (dependantOps.contains(user)) + + for (auto condBranch : worklist) { + auto loc = condBranch.getLoc(); + Block *block = condBranch->getBlock(); + auto newTrueBranch = rewriter.splitBlock(block, block->end()); + auto newFalseBranch = rewriter.splitBlock(block, block->end()); + insertJump(loc, newTrueBranch, condBranch.getTrueDest(), + condBranch.getTrueDestOperands()); + insertJump(loc, newFalseBranch, condBranch.getFalseDest(), + condBranch.getFalseDestOperands()); + rewriter.modifyOpInPlace(condBranch, [&] { + condBranch.getFalseDestOperandsMutable().clear(); + condBranch.getTrueDestOperandsMutable().clear(); + condBranch.setSuccessor(newTrueBranch, 0); + condBranch.setSuccessor(newFalseBranch, 1); + }); + } +} + +/// Inserts tile copies at `cf.br` operations. +/// +/// BEFORE: +/// ```mlir +/// cf.br ^bb1(%tile: vector<[4]x[4]xf32>) +/// ``` +/// +/// AFTER: +/// ```mlir +/// %copy = arm_sme.copy_tile %tile : vector<[4]x[4]xf32> +/// cf.br ^bb1(%copy: vector<[4]x[4]xf32>) +/// ``` +void insertCopiesAtBranches(IRRewriter &rewriter, + FunctionOpInterface function) { + for (Block &block : function.getBlocks()) { + Operation *terminator = block.getTerminator(); + if (!isa(terminator)) continue; - dependantOps.insert(user); - TypeSwitch(user) - .Case([&](auto branchOp) { - // (CF) Follow branch. - traverseCorrespondingValues(branchOp.getDestOperands(), - branchOp.getDest()->getArguments()); - }) - .Case([&](auto condBranchOp) { - // (CF) Follow true branch. - traverseCorrespondingValues( - condBranchOp.getTrueOperands(), - condBranchOp.getTrueDest()->getArguments()); - // (CF) Follow false branch. - traverseCorrespondingValues( - condBranchOp.getFalseOperands(), - condBranchOp.getFalseDest()->getArguments()); - }) - .Case([&](auto loopOp) { - // (SCF) Follow iter_args of (basic) loops (e.g. for loops). - traverseCorrespondingValues(loopOp.getInits(), - loopOp.getRegionIterArgs()); - }) - .Case([&](auto yieldOp) { - // (SCF) Follow yields of (basic) control flow (e.g. for loops). - auto parent = user->getParentOp(); - traverseCorrespondingValues(user->getOperands(), - parent->getResults()); + rewriter.setInsertionPoint(terminator); + for (OpOperand &operand : terminator->getOpOperands()) { + if (isValidSMETileVectorType(operand.get().getType())) { + auto copy = + rewriter.create(terminator->getLoc(), operand.get()); + rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); + } + } + } +} + +/// Prepares the IR for tile allocation. It does this by first 'splitting' +/// conditional branches (see `splitCondBranches`), then inserting tile copies +/// at branch operations. The conditional branches are split to prevent the +/// copies needed for them overlapping between the true and false paths of the +/// branch (see `tile-allocation-copies.mlir` and +/// `tile-allocation-liveness.mlir` for examples). The copies break up live +/// ranges and ensure when moving out of SSA the semantics of the program are +/// preserved. +void preprocessForTileAllocation(IRRewriter &rewriter, + FunctionOpInterface function) { + splitCondBranches(rewriter, function); + insertCopiesAtBranches(rewriter, function); +} + +/// A live range for a (collection of) tile values. A live range is built up of +/// non-overlapping intervals [start, end) which represent parts of the program +/// where a value in the range needs to be live (i.e. in an SME virtual tile). +/// Note that as the intervals are non-overlapping all values within a live +/// range can be allocated to the same SME virtual tile. +struct LiveRange { + using RangeSet = llvm::IntervalMap>; + using Allocator = RangeSet::Allocator; + // Dummy value for the IntervalMap. Only the keys matter (the intervals). + static constexpr uint8_t kValidLiveRange = 0xff; + + LiveRange(Allocator &allocator) + : ranges(std::make_unique(allocator)) {} + + /// Returns true if this range overlaps with `otherRange`. + bool overlaps(LiveRange const &otherRange) const { + return llvm::IntervalMapOverlaps(*ranges, + *otherRange.ranges) + .valid(); + } + + /// Unions this live range with `otherRange`, aborts if the ranges overlap. + void unionWith(LiveRange const &otherRange) { + for (auto it = otherRange.ranges->begin(); it != otherRange.ranges->end(); + ++it) + ranges->insert(it.start(), it.stop(), kValidLiveRange); + values.set_union(otherRange.values); + } + + /// Inserts an interval [start, end) for `value` into this range. + void insert(Value value, unsigned start, unsigned end) { + values.insert(value); + if (start != end) + ranges->insert(start, end, kValidLiveRange); + } + + bool empty() const { return ranges->empty(); } + unsigned start() const { return ranges->start(); } + unsigned end() const { return ranges->stop(); } + bool operator<(LiveRange const &other) const { + return start() < other.start(); + } + + ArmSMETileType getTileType() const { + return *getSMETileType(cast(values[0].getType())); + } + + /// The values contained in this live range. + SetVector values; + + /// A set of (non-overlapping) intervals that mark where any value in `values` + /// is live. + std::unique_ptr ranges; + + /// The tile ID (or none) assigned to this live range. + std::optional tileId; +}; + +/// Number operations within a function to allow computing live ranges. +/// Operations are numbered consecutively wihin blocks, and the blocks are +/// topologically sorted (using forward edges). This function is only correct if +/// all ArmSME have been converted to CF (which is asserted). +DenseMap +generateOperationNumbering(FunctionOpInterface function) { + unsigned index = 0; + SetVector blocks = + getTopologicallySortedBlocks(function.getFunctionBody()); + DenseMap operationToIndexMap; + for (Block *block : blocks) { + index++; // We want block args to have their own number. + for (Operation &op : block->getOperations()) { +#ifndef NDEBUG + op.walk([&](ArmSMETileOpInterface nestedOp) { + assert(&op == nestedOp.getOperation() && + "ArmSME tile allocation does not support nested regions"); + }); +#endif + operationToIndexMap.try_emplace(&op, index++); + } + } + return operationToIndexMap; +} + +/// Gather live ranges for SME tiles from the MLIR liveness analysis. +DenseMap +gatherTileLiveRanges(DenseMap const &operationToIndexMap, + LiveRange::Allocator &liveRangeAllocator, + Liveness &liveness, FunctionOpInterface function) { + assert(!operationToIndexMap.empty() && "expected operation numbering"); + DenseMap liveRanges; + /// Defines or updates a live range for an SME tile value. Live-ins may update + /// an existing live range (rather than define a new one). Note: If + /// `liveAtBlockEntry` is true then `firstUseOrDef` is the first operation in + /// the block. + auto defineOrUpdateValueLiveRange = [&](Value value, Operation *firstUseOrDef, + LivenessBlockInfo const &livenessInfo, + bool liveAtBlockEntry = false) { + if (!isValidSMETileVectorType(value.getType())) + return; + // Find or create a live range for `value`. + auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator); + LiveRange &valueLiveRange = it->second; + auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef); + // Add the interval [firstUseOrDef, lastUseInBlock) to the live range. + unsigned startOpIdx = + operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0); + unsigned endOpIdx = operationToIndexMap.at(lastUseInBlock); + valueLiveRange.insert(value, startOpIdx, endOpIdx); + }; + + for (Block &block : function.getBlocks()) { + LivenessBlockInfo const *livenessInfo = liveness.getLiveness(&block); + // Handle block arguments: + for (Value argument : block.getArguments()) + defineOrUpdateValueLiveRange(argument, &block.front(), *livenessInfo, + /*liveAtBlockEntry=*/true); + // Handle live-ins: + for (Value liveIn : livenessInfo->in()) + defineOrUpdateValueLiveRange(liveIn, &block.front(), *livenessInfo, + /*liveAtBlockEntry=*/true); + // Handle new definitions: + for (Operation &op : block) { + for (Value result : op.getResults()) + defineOrUpdateValueLiveRange(result, &op, *livenessInfo); + } + } + + return liveRanges; +} + +/// Iterate over all predecessor tile values to a (tile) block argument. +static void forEachPredecessorTileValue(BlockArgument blockArg, + function_ref callback) { + Block *block = blockArg.getOwner(); + unsigned argNumber = blockArg.getArgNumber(); + for (Block *pred : block->getPredecessors()) { + TypeSwitch(pred->getTerminator()) + .Case([&](auto branch) { + Value predecessorOperand = branch.getDestOperands()[argNumber]; + callback(predecessorOperand); }) - .Default([&](auto) { - // Otherwise, assume users of _any_ result are dependant. - for (Value result : user->getResults()) - findDependantOps(result, dependantOps); + .Case([&](auto condBranch) { + if (condBranch.getFalseDest() == block) { + Value predecessorOperand = + condBranch.getFalseDestOperands()[argNumber]; + callback(predecessorOperand); + } + if (condBranch.getTrueDest() == block) { + Value predecessorOperand = + condBranch.getTrueDestOperands()[argNumber]; + callback(predecessorOperand); + } }); } } -struct AssignTileIDsPattern - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - LogicalResult matchAndRewrite(ArmSMETileOpInterface tileOp, - PatternRewriter &rewriter) const override { - if (tileOp.getTileId()) - return failure(); - - auto func = tileOp->getParentOfType(); - auto getDiscardableIntAttr = [&](StringRef name, unsigned defaultVal = 0) { - if (auto attr = llvm::dyn_cast_or_null( - func->getDiscardableAttr(name))) - return unsigned(attr.getInt()); - return defaultVal; - }; - auto setDiscardableIntAttr = [&](StringRef name, auto value) { - rewriter.modifyOpInPlace(tileOp, [&] { - func->setDiscardableAttr(name, - rewriter.getI32IntegerAttr((unsigned)value)); - }); - }; - std::optional tileType = tileOp.getAllocatedTileType(); - if (!tileType) - return rewriter.notifyMatchFailure(tileOp, "op does not allocate a tile"); - - TileMask tilesInUse = - static_cast(getDiscardableIntAttr(kTilesInUseAttr)); - auto tileId = allocateTileId(*tileType, tilesInUse); - bool tileIsInMemory = failed(tileId); - if (tileIsInMemory) { - // If we could not find a real tile ID, use an in-memory tile ID (ID >= - // 16). A later pass will insert the necessary spills and reloads. - tileId = - getDiscardableIntAttr(kNextInMemoryTileIdAttr, kInMemoryTileIdBase); - tileOp->emitWarning( - "failed to allocate SME virtual tile to operation, all tile " - "operations will go through memory, expect degraded performance"); +/// Coalesce live ranges where it would prevent unnecessary tile moves. +SmallVector +coalesceTileLiveRanges(DenseMap &initialLiveRanges) { + DenseMap liveRanges; + for (auto &[value, liveRange] : initialLiveRanges) { + liveRanges.insert({value, &liveRange}); + } + + // Merge the live ranges of values `a` and `b` into one (if they do not + // overlap). After this, the values `a` and `b` will both point to the same + // live range (which will contain multiple values). + auto mergeValuesIfNonOverlapping = [&](Value a, Value b) { + LiveRange *aLiveRange = liveRanges.at(a); + LiveRange *bLiveRange = liveRanges.at(b); + if (aLiveRange != bLiveRange && !aLiveRange->overlaps(*bLiveRange)) { + aLiveRange->unionWith(*bLiveRange); + for (Value value : bLiveRange->values) + liveRanges[value] = aLiveRange; + } + }; + + // Merge the live ranges of new definitions with their tile operands. + auto unifyDefinitionsWithOperands = [&](Value value) { + auto armSMEOp = value.getDefiningOp(); + if (!armSMEOp) + return; + for (auto operand : armSMEOp->getOperands()) { + if (isValidSMETileVectorType(operand.getType())) + mergeValuesIfNonOverlapping(value, operand); } + }; + + // Merge the live ranges of block arguments with their predecessors. + auto unifyBlockArgumentsWithPredecessors = [&](Value value) { + auto blockArg = dyn_cast(value); + if (!blockArg) + return; + forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { + mergeValuesIfNonOverlapping(blockArg, predecessorTile); + }); + }; + + auto applyRule = [&](auto rule) { + llvm::for_each(llvm::make_first_range(initialLiveRanges), rule); + }; + + // Unify as many live ranges as we can. This prevents unnecessary moves. + applyRule(unifyBlockArgumentsWithPredecessors); + applyRule(unifyDefinitionsWithOperands); + + // Remove duplicate live range entries. + SetVector uniqueLiveRanges; + for (auto [_, liveRange] : liveRanges) { + if (!liveRange->empty()) + uniqueLiveRanges.insert(liveRange); + } + + // Sort the new live ranges by starting point (ready for tile allocation). + auto coalescedLiveRanges = uniqueLiveRanges.takeVector(); + std::sort(coalescedLiveRanges.begin(), coalescedLiveRanges.end(), + [](LiveRange *a, LiveRange *b) { return *a < *b; }); + return std::move(coalescedLiveRanges); +} + +/// Choose a live range to spill (via some heuristics). This picks either an +/// active live range from `activeRanges` or the new live range `newRange`. +LiveRange *chooseSpillUsingHeuristics(ArrayRef activeRanges, + LiveRange *newRange) { + // Heuristic: Spill trivially copyable operations (usually free). + auto isTrivialSpill = [&](LiveRange *allocatedRange) { + return isTileTypeGreaterOrEqual(allocatedRange->getTileType(), + newRange->getTileType()) && + allocatedRange->values.size() == 1 && + isTriviallyCloneableTileOp( + allocatedRange->values[0] + .getDefiningOp()); + }; + if (isTrivialSpill(newRange)) + return newRange; + auto trivialSpill = llvm::find_if(activeRanges, isTrivialSpill); + if (trivialSpill != activeRanges.end()) + return *trivialSpill; + + // Heuristic: Spill the range that ends last (with a compatible tile type). + auto isSmallerTileTypeOrEndsEarlier = [](LiveRange *a, LiveRange *b) { + return !isTileTypeGreaterOrEqual(a->getTileType(), b->getTileType()) || + a->end() < b->end(); + }; + LiveRange *lastActiveLiveRange = *std::max_element( + activeRanges.begin(), activeRanges.end(), isSmallerTileTypeOrEndsEarlier); + if (!isSmallerTileTypeOrEndsEarlier(lastActiveLiveRange, newRange)) + return lastActiveLiveRange; + return newRange; +} + +/// Greedily allocate tile IDs to live ranges. Spill using simple heuristics. +/// Note: This does not attempt to fill holes in active live ranges. +void allocateTilesToLiveRanges( + ArrayRef liveRangesSortedByStartPoint) { + TileAllocator tileAllocator; + SetVector activeRanges; + for (LiveRange *nextRange : liveRangesSortedByStartPoint) { + // Release tile IDs from live ranges that have ended. + activeRanges.remove_if([&](LiveRange *activeRange) { + if (activeRange->end() <= nextRange->start()) { + tileAllocator.releaseTileId(activeRange->getTileType(), + *activeRange->tileId); + return true; + } + return false; + }); - // Set all operations dependent on `tileOp` to use the same tile ID. - // This is a naive tile allocation scheme, but works for common cases. For - // example, as this only allocates tile IDs to existing ops, it can't solve - // cases like this (%tileA and %tileB come from different root operations): - // - // %tile = scf.if %some_cond -> vector<[4]x[4]xi32> { - // scf.yield %tileA {tile_id = 0} : vector<[4]x[4]xi32> - // } else { - // scf.yield %tileB {tile_id = 1} : vector<[4]x[4]xi32> - // } - // - // This case would require allocating a new tile for the result of the - // scf.if, and moving the contents of %tileA or %tileB to result tile (based - // on the %some_cond). - // Find all the ops that (transitively) depend on this tile. - SetVector dependantOps; - findDependantOps(tileOp->getResult(0), dependantOps); - auto tileIDAttr = rewriter.getI32IntegerAttr(*tileId); - for (auto *op : dependantOps) { - if (auto dependantTileOp = llvm::dyn_cast(op)) { - auto currentTileId = dependantTileOp.getTileId(); - if (currentTileId && unsigned(currentTileId.getInt()) != tileId) - return dependantTileOp.emitOpError( - "already assigned different SME virtual tile!"); + // Allocate a tile ID to `nextRange`. + auto rangeTileType = nextRange->getTileType(); + auto tileId = tileAllocator.allocateTileId(rangeTileType); + if (succeeded(tileId)) { + nextRange->tileId = *tileId; + } else { + LiveRange *rangeToSpill = + chooseSpillUsingHeuristics(activeRanges.getArrayRef(), nextRange); + if (rangeToSpill != nextRange) { + // Spill an active live range (so release its tile ID first). + tileAllocator.releaseTileId(rangeToSpill->getTileType(), + *rangeToSpill->tileId); + activeRanges.remove(rangeToSpill); + // This will always succeed after a spill (of an active live range). + nextRange->tileId = *tileAllocator.allocateTileId(rangeTileType); } + rangeToSpill->tileId = tileAllocator.allocateInMemoryTileId(); + } + + // Insert the live range into the active ranges. + if (nextRange->tileId < kInMemoryTileIdBase) + activeRanges.insert(nextRange); + } +} + +/// Assigns a tile ID to an MLIR value. +void assignTileIdToValue(IRRewriter &rewriter, Value value, + IntegerAttr tileIdAttr) { + if (auto tileOp = value.getDefiningOp()) + rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); + for (Operation *user : value.getUsers()) { + if (auto tileOp = dyn_cast(user)) { + // Ensure ArmSME ops that don't produce a value still get a tile ID. + if (!hasTileResult(tileOp)) + rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIdAttr); }); } + } +} + +/// Assign tile IDs back to IR and attempt to resolve trivial tile ID conflicts. +LogicalResult assignTileIdsAndResolveTrivialConflicts( + IRRewriter &rewriter, FunctionOpInterface function, + ArrayRef allocatedLiveRanges) { + for (LiveRange const *liveRange : allocatedLiveRanges) { + auto tileIdAttr = rewriter.getI32IntegerAttr(*liveRange->tileId); + auto isAllocatedToSameTile = [&](Value value) { + if (auto tileOp = value.getDefiningOp(); + tileOp && tileOp.getTileId() == tileIdAttr) + return true; + return liveRange->values.contains(value); + }; + + /// Eliminates copies where the operand has the same tile ID. + auto foldRedundantCopies = [&](Value value) -> LogicalResult { + auto copyOp = value.getDefiningOp(); + if (!copyOp || !isAllocatedToSameTile(copyOp.getTile())) + return failure(); + rewriter.replaceAllUsesWith(copyOp, copyOp.getTile()); + return success(); + }; + + /// Validates each predecessor to a tile block argument has been assigned + /// the same tile ID. + auto validateBlockArguments = [&](Value value) { + auto blockArg = dyn_cast(value); + if (!blockArg) { + // Not a block argument (nothing to validate). + return success(); + } + bool tileMismatch = false; + forEachPredecessorTileValue(blockArg, [&](Value predecessorTile) { + if (tileMismatch) + return; + if (!isAllocatedToSameTile(predecessorTile)) { + blockArg.getOwner()->getParentOp()->emitOpError( + "block argument not allocated to the same SME virtial tile as " + "predecessors"); + tileMismatch = true; + } + }); + return success(/*isSuccess=*/!tileMismatch); + }; - // Rewrite IR. - if (!tileIsInMemory) - setDiscardableIntAttr(kTilesInUseAttr, tilesInUse); - else - setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1); - rewriter.modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); }); - for (auto *op : dependantOps) { - if (auto dependantTileOp = llvm::dyn_cast(op)) { + /// Attempts to resolve (trivial) tile ID conflicts. + auto resolveTrivialTileConflicts = [&](Value value) -> LogicalResult { + auto tileOp = value.getDefiningOp(); + OpOperand *tileOperand = getTileOpOperand(tileOp); + if (!tileOperand || isAllocatedToSameTile(tileOperand->get())) { + // Operand already allocated to the correct tile. + // No conflict to resolve. + return success(); + } + auto operandTileOp = + tileOperand->get().getDefiningOp(); + if (!isTriviallyCloneableTileOp(operandTileOp)) { + auto error = + tileOp.emitOpError("tile operand allocated to different SME " + "virtial tile (move required)"); + error.attachNote(tileOperand->get().getLoc()) + << "tile operand is: " << tileOperand->get(); + return error; + } + // Cloning prevents a move/spill (though may require recomputation). + rewriter.setInsertionPoint(tileOp); + auto clonedOp = operandTileOp.clone(); + rewriter.modifyOpInPlace(clonedOp, + [&] { clonedOp.setTileId(tileOp.getTileId()); }); + rewriter.insert(clonedOp); + if (isa(tileOp)) { + rewriter.replaceAllUsesWith(tileOp->getResult(0), + clonedOp->getResult(0)); + } else { rewriter.modifyOpInPlace( - dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); }); + tileOp, [&] { tileOperand->assign(clonedOp->getResult(0)); }); } + return success(); + }; + + for (Value value : liveRange->values) { + // 1. Assign the tile ID to the value. + assignTileIdToValue(rewriter, value, tileIdAttr); + + // 2. Attempt to eliminate redundant tile copies. + if (succeeded(foldRedundantCopies(value))) + continue; + + // 3. Validate tile block arguments. + if (failed(validateBlockArguments(value))) + return failure(); + + // 4. Attempt to resolve (trivial) tile ID conflicts. + if (failed(resolveTrivialTileConflicts(value))) + return failure(); } + } + return success(); +} - return success(); +/// Prints live ranges alongside operation names for debugging. +void dumpLiveRanges(DenseMap const &operationToIndexMap, + ArrayRef liveRanges, + FunctionOpInterface function) { + llvm::errs() << "SME Tile Liveness: @" << function.getName() + << "\nKey:\nS - Start\nE - End\n| - Live\n"; + for (auto [blockIdx, block] : llvm::enumerate(function.getBlocks())) { + llvm::errs() << "^bb" << blockIdx << ":\n"; + for (Operation &op : block.getOperations()) { + unsigned operationIndex = operationToIndexMap.at(&op); + for (LiveRange const *range : liveRanges) { + char liveness = ' '; + for (auto it = range->ranges->begin(); it != range->ranges->end(); + ++it) { + if (it.start() == operationIndex) + liveness = (liveness == 'E' ? '|' : 'S'); + else if (it.stop() == operationIndex) + liveness = (liveness == 'S' ? '|' : 'E'); + else if (operationIndex >= it.start() && operationIndex < it.stop()) + liveness = '|'; + } + llvm::errs() << liveness; + } + llvm::errs() << ' ' << op.getName() << '\n'; + } } -}; + llvm::errs() << "==========\n"; +} -struct TileAllocationPass - : public arm_sme::impl::TileAllocationBase { +struct TestTileAllocationPass + : public arm_sme::impl::TestTileAllocationBase { + using TestTileAllocationBase::TestTileAllocationBase; void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.add(patterns.getContext()); - GreedyRewriteConfig config; - // Setting useTopDownTraversal ensures tiles are allocated in program - // order. - config.useTopDownTraversal = true; - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( - getOperation(), std::move(patterns), config))) { - signalPassFailure(); + FunctionOpInterface function = getOperation(); + if (preprocessOnly) { + IRRewriter rewriter(function); + return preprocessForTileAllocation(rewriter, function); } + if (failed(arm_sme::allocateSMETiles(function, dumpTileLiveRanges))) + signalPassFailure(); } }; } // namespace -std::unique_ptr mlir::arm_sme::createTileAllocationPass() { - return std::make_unique(); +LogicalResult mlir::arm_sme::allocateSMETiles(FunctionOpInterface function, + bool dumpRanges) { + if (function.empty()) { + // TODO: Also return early if the function contains no ArmSME ops? + return success(); + } + + LiveRange::Allocator liveRangeAllocator; + IRRewriter rewriter(function.getContext()); + + // 1. Preprocess the IR for tile allocation. + preprocessForTileAllocation(rewriter, function); + + // 2. Gather live ranges for each ArmSME tile within the function. + Liveness liveness(function); + auto operationToIndexMap = generateOperationNumbering(function); + auto initialLiveRanges = gatherTileLiveRanges( + operationToIndexMap, liveRangeAllocator, liveness, function); + if (initialLiveRanges.empty()) + return success(); + + if (dumpRanges) { + // Wrangle initial live ranges into a form suitable for printing. + auto nonEmpty = llvm::make_filter_range( + llvm::make_second_range(initialLiveRanges), + [&](LiveRange const &liveRange) { return !liveRange.empty(); }); + auto initialRanges = llvm::to_vector(llvm::map_range( + nonEmpty, [](LiveRange const &liveRange) { return &liveRange; })); + std::sort(initialRanges.begin(), initialRanges.end(), + [](LiveRange const *a, LiveRange const *b) { return *a < *b; }); + llvm::errs() << "\n========== Initial Live Ranges:\n"; + dumpLiveRanges(operationToIndexMap, initialRanges, function); + } + + // 3. Coalesce (non-overlapping) live ranges where it would be beneficial + // for tile allocation. E.g. Unify the result of an operation with its + // operands. + auto coalescedLiveRanges = coalesceTileLiveRanges(initialLiveRanges); + + if (dumpRanges) { + llvm::errs() << "\n========== Coalesced Live Ranges:\n"; + dumpLiveRanges(operationToIndexMap, coalescedLiveRanges, function); + } + + // 4. Allocate tile IDs to live ranges. + allocateTilesToLiveRanges(coalescedLiveRanges); + + // 5. Assign the tile IDs back to the ArmSME operations. + if (failed(assignTileIdsAndResolveTrivialConflicts(rewriter, function, + coalescedLiveRanges))) { + return failure(); + } + + // 6. Erase trivially dead tile operations (e.g. a ZeroOp with no + // users). This prevents the LLVM conversion needlessly inserting spills. + eraseTriviallyDeadTileOps(rewriter, function); + return success(); } diff --git a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir index f48046a8d7995..14b1f323da3a2 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s - +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" -split-input-file | FileCheck %s // Test conversion of ArmSME ops to LLVM intrinsics. //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir index a9c1a65a296f4..2c3868d7f25cb 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | \ +// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | \ // RUN: FileCheck %s --check-prefix=AFTER-TILE-ALLOC -// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize -cse \ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,cse,canonicalize))" \ // RUN: -split-input-file -verify-diagnostics | \ // RUN: FileCheck %s --check-prefix=AFTER-LLVM-LOWERING @@ -56,6 +56,9 @@ func.func @use_too_many_tiles() { %1 = arm_sme.zero : vector<[4]x[4]xi32> // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %2 = arm_sme.zero : vector<[8]x[8]xi16> + "test.some_use"(%0) : (vector<[4]x[4]xi32>) -> () + "test.some_use"(%1) : (vector<[4]x[4]xi32>) -> () + "test.some_use"(%2) : (vector<[8]x[8]xi16>) -> () return } // AFTER-TILE-ALLOC-LABEL: @use_too_many_tiles @@ -131,18 +134,16 @@ func.func @use_too_many_tiles() { /// Note: In this example an entire tile swap is inserted before/after the /// `arm_sme.load_tile_slice` operation. Really, this only needs to spill a /// single tile slice (and can omit the initial load, like in the previous example). -func.func @very_excessive_spills(%memref : memref) -> vector<[4]x[4]xf32> { - %useAllTiles = arm_sme.get_tile : vector<[16]x[16]xi8> +func.func @very_excessive_spills(%useAllTiles : vector<[16]x[16]xi8>, %memref: memref) -> vector<[4]x[4]xf32> { %c0 = arith.constant 0 : index - // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile = arm_sme.get_tile : vector<[4]x[4]xf32> %mask = vector.constant_mask [4] : vector<[4]xi1> + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %loadSlice = arm_sme.load_tile_slice %memref[%c0, %c0], %mask, %tile, %c0 : memref, vector<[4]xi1>, vector<[4]x[4]xf32> - "test.some_use"(%loadSlice) : (vector<[4]x[4]xf32>) -> () + "test.some_use"(%useAllTiles) : (vector<[16]x[16]xi8>) -> () + return %loadSlice : vector<[4]x[4]xf32> } // AFTER-TILE-ALLOC-LABEL: @very_excessive_spills -// AFTER-TILE-ALLOC: arm_sme.get_tile -// AFTER-TILE-ALLOC-SAME: tile_id = 0 // AFTER-TILE-ALLOC: arm_sme.load_tile_slice // AFTER-TILE-ALLOC-SAME: tile_id = 16 diff --git a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir index 15767ff1dec3f..a62ca080ab8d9 100644 --- a/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir +++ b/mlir/test/Conversion/ArmSMEToLLVM/unsupported.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -split-input-file -allow-unregistered-dialect -verify-diagnostics +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm))" -verify-diagnostics //===----------------------------------------------------------------------===// // arm_sme.outerproduct diff --git a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir index e144bac970a7d..8b46998d56b04 100644 --- a/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir +++ b/mlir/test/Dialect/ArmSME/basic-tile-allocation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | FileCheck %s // ----- diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir index b7ba3f728c705..643dfd4a7cbd9 100644 --- a/mlir/test/Dialect/ArmSME/canonicalize.mlir +++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir @@ -1,18 +1,14 @@ -// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -canonicalize | mlir-opt | FileCheck %s -// This tests that the `arm_sme.materialize_ssa_tile` placeholder is removed -// once it becomes unused, after lowering to control flow. +// This tests that dead tile values are removed from control flow. -// ----- - -// CHECK-LABEL: @unused_materialize_ssa_tile_is_removed_from_blocks -// CHECK-NOT: arm_sme.materialize_ssa_tile +// CHECK-LABEL: @unused_ssa_tile_is_removed_from_blocks // CHECK-NOT: vector<[4]x[4]xf32> -func.func @unused_materialize_ssa_tile_is_removed_from_blocks(%arg0: memref) { +func.func @unused_ssa_tile_is_removed_from_blocks(%arg0: memref) { %c10 = arith.constant 10 : index %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index - %tile = arm_sme.materialize_ssa_tile : vector<[4]x[4]xf32> + %tile = arm_sme.get_tile : vector<[4]x[4]xf32> cf.br ^bb1(%c0, %tile : index, vector<[4]x[4]xf32>) ^bb1(%1: index, %2: vector<[4]x[4]xf32>): // 2 preds: ^bb0, ^bb2 %3 = arith.cmpi slt, %1, %c10 : index diff --git a/mlir/test/Dialect/ArmSME/cse.mlir b/mlir/test/Dialect/ArmSME/cse.mlir deleted file mode 100644 index 74e7293eaeca5..0000000000000 --- a/mlir/test/Dialect/ArmSME/cse.mlir +++ /dev/null @@ -1,30 +0,0 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s - -// This test is checking that CSE does not remove 'arm_sme.zero/get_tile' ops as -// duplicates. - -// CHECK-LABEL: @zero_tile -// CHECK: %[[TILE_0:.*]] = arm_sme.zero : vector<[4]x[4]xi32> -// CHECK: %[[TILE_1:.*]] = arm_sme.zero : vector<[4]x[4]xi32> -// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> () -// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> () -func.func @zero_tile() { - %tile_1 = arm_sme.zero : vector<[4]x[4]xi32> - %tile_2 = arm_sme.zero : vector<[4]x[4]xi32> - "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> () - "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> () - return -} - -// CHECK-LABEL: @get_tile -// CHECK: %[[TILE_0:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> -// CHECK: %[[TILE_1:.*]] = arm_sme.get_tile : vector<[4]x[4]xi32> -// CHECK: "prevent.dce"(%[[TILE_0]]) : (vector<[4]x[4]xi32>) -> () -// CHECK: "prevent.dce"(%[[TILE_1]]) : (vector<[4]x[4]xi32>) -> () -func.func @get_tile() { - %tile_1 = arm_sme.get_tile : vector<[4]x[4]xi32> - %tile_2 = arm_sme.get_tile : vector<[4]x[4]xi32> - "prevent.dce"(%tile_1) : (vector<[4]x[4]xi32>) -> () - "prevent.dce"(%tile_2) : (vector<[4]x[4]xi32>) -> () - return -} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir index ab46c7adca596..6095fdc11ead8 100644 --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -1403,3 +1403,12 @@ func.func @arm_sme_usmops_4way_i16i16_to_i64(%vecA: vector<[8]xi16>, %vecB: vect %reuslt = arm_sme.usmops_4way %vecA, %vecB : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64> return %reuslt : vector<[2]x[2]xi64> } + +//===----------------------------------------------------------------------===// +// arm_sme.copy_tile +//===----------------------------------------------------------------------===// + +func.func @arm_sme_copy_tile(%vec: vector<[4]x[4]xf32>) -> vector<[4]x[4]xf32> { + %result = arm_sme.copy_tile %vec : vector<[4]x[4]xf32> + return %result : vector<[4]x[4]xf32> +} diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir new file mode 100644 index 0000000000000..6d9cbf36a162f --- /dev/null +++ b/mlir/test/Dialect/ArmSME/tile-allocation-copies.mlir @@ -0,0 +1,159 @@ +// RUN: mlir-opt %s -test-arm-sme-tile-allocation=preprocess-only -split-input-file | FileCheck %s + +// This file tests the inserting copies for the SME tile allocation. Copies are +// inserted at `cf.br` ops (the predecessors to block arguments). Conditional +// branches are split to prevent conflicts (see cond_br_with_backedge). + +// CHECK-LABEL: func.func @simple_branch( +// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) +// %[[COPY:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32> +// cf.br ^bb1(%[[COPY]] : vector<[4]x[4]xf32>) +// ^bb1(%[[BLOCK_ARG:.*]]: vector<[4]x[4]xf32>): + +func.func @simple_branch(%tile : vector<[4]x[4]xf32>) { + cf.br ^bb1(%tile: vector<[4]x[4]xf32>) +^bb1(%blockArg: vector<[4]x[4]xf32>): + return +} + +// ----- + +// Note: The ^POINTLESS_SHIM_FOR_BB2 block is added as the cond_br splitting does +// not check if it needs to insert a copy or not (there is no harm in the empty +// block though -- it will fold away later). + +// CHECK-LABEL: func.func @cond_branch( +// CHECK-SAME: %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32> +// CHECK: cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[POINTLESS_SHIM_FOR_BB2:[[:alnum:]]+]] +// CHECK: ^[[POINTLESS_SHIM_FOR_BB2]]: +// CHECK: cf.br ^[[BB2:.*]] +// CHECK: ^[[BB1_COPIES]]: +// CHECK: arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32> +// CHECK: cf.br ^[[BB1:.*]] +func.func @cond_branch(%cond: i1, %tile: vector<[4]x[4]xf32>) { + cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2 +^bb1(%blockArg: vector<[4]x[4]xf32>): + return +^bb2: + return +} + +// ----- + +// Reduction of a real world example that shows why we must split conditional branches. + +// CHECK-LABEL: @cond_branch_with_backedge( +// CHECK-SAME: %[[TILEA:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILEB:[[:alnum:]]+]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[TILEC:[[:alnum:]]+]]: vector<[4]x[4]xf32>, %[[TILED:[[:alnum:]]+]]: vector<[4]x[4]xf32>, +// CHECK: %[[BB1_COPY_0:.*]] = arm_sme.copy_tile %[[TILEA]] : vector<[4]x[4]xf32> +// CHECK: cf.br ^bb1(%{{[[:alnum:]]+}}, %[[BB1_COPY_0]] +// CHECK: ^bb1(%[[CURRENT_INDEX:.*]]: index, %[[ITER_TILE:.*]]: vector<[4]x[4]xf32>): +// CHECK: %[[CONTINUE_LOOP:.*]] = arith.cmpi +// CHECK: cf.cond_br %[[CONTINUE_LOOP]], ^[[BB2_COPIES:[[:alnum:]]+]], ^[[BB3_COPIES:[[:alnum:]]+]] +// CHECK: ^[[BB3_COPIES]]: +// CHECK-NEXT: %[[BB3_COPY_0:.*]] = arm_sme.copy_tile %[[ITER_TILE]] : vector<[4]x[4]xf32> +// CHECK-NEXT: %[[BB3_COPY_1:.*]] = arm_sme.copy_tile %[[TILEB]] : vector<[4]x[4]xf32> +// CHECK-NEXT: %[[BB3_COPY_2:.*]] = arm_sme.copy_tile %[[TILEC]] : vector<[4]x[4]xf32> +// CHECK-NEXT: %[[BB3_COPY_3:.*]] = arm_sme.copy_tile %[[TILED]] : vector<[4]x[4]xf32> +// CHECK-NEXT: cf.br ^[[BB3:[[:alnum:]]+]](%[[BB3_COPY_0]], %[[BB3_COPY_1]], %[[BB3_COPY_2]], %[[BB3_COPY_3]] +// CHECK: ^[[BB2_COPIES]]: +// CHECK-NEXT: cf.br ^[[BB2:[[:alnum:]]+]] +// CHECK: ^[[BB2]]: +// CHECK-NEXT: %[[NEXT_TILE:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}}, %[[ITER_TILE]] +// CHECK: %[[BB1_COPY_1:.*]] = arm_sme.copy_tile %[[NEXT_TILE]] : vector<[4]x[4]xf32> +// CHECK: cf.br ^bb1(%{{[[:alnum:]]+}}, %[[BB1_COPY_1]] +// CHECK: ^[[BB3]](%{{.*}}: vector<[4]x[4]xf32>): +// CHECK-NEXT: return +func.func @cond_branch_with_backedge(%tileA: vector<[4]x[4]xf32>, %tileB: vector<[4]x[4]xf32>, %tileC: vector<[4]x[4]xf32>, %tileD: vector<[4]x[4]xf32>, %slice: vector<[4]xf32>) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + // Live here: %tileA, %tileB, %tileC, %tileD + cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>) +^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>): + %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index + // Live here: %iterTile, %tileB, %tileC, %tileD + // %iterTile, %tileB, %tileC, %tileD are live out (in the ^bb2 case). If we + // inserted the (four) `arm_sme.copy_tile` operations here we would run out of tiles. + // However, note that the copies are only needed if we take the ^bb3 path. So, if we add + // a new block along that path we can insert the copies without any conflicts. + cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>) +^bb2: + // Live here: %iterTile, %tileB, %tileC, %tileD + %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32> + %nextIndex = arith.addi %currentIndex, %c1 : index + cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>) +^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>): + // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD + return +} + +// ----- + +// CHECK-LABEL: @tile_dominance +// CHECK-NOT: arm_sme.copy_tile +func.func @tile_dominance(%arg0: vector<[4]x[4]xf32>) { + cf.br ^bb1 +^bb1: // 2 preds: ^bb0, ^bb4 + "test.some_use"(%arg0) : (vector<[4]x[4]xf32>) -> () + return +^bb2: // no predecessors + %0 = arm_sme.get_tile : vector<[4]x[4]xf32> + cf.br ^bb3 +^bb3: // pred: ^bb2 + "test.some_use"(%0) : (vector<[4]x[4]xf32>) -> () + return +^bb4: // no predecessors + cf.br ^bb1 +^bb5: // no predecessors + return +} + +// ----- + +// CHECK-LABEL: func.func @cond_branch_true_and_false_tile_args( +// CHECK-SAME: %[[COND:.*]]: i1, %[[TILE:.*]]: vector<[4]x[4]xf32> +// CHECK-NEXT: cf.cond_br %[[COND]], ^[[BB1_COPIES:[[:alnum:]]+]], ^[[BB2_COPIES:[[:alnum:]]+]] +// CHECK: ^[[BB2_COPIES]]: +// CHECK-NEXT: %[[COPY_0:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32> +// CHECK-NEXT: cf.br ^[[BB2:[[:alnum:]]+]](%[[COPY_0]] +// CHECK: ^[[BB1_COPIES]]: +// CHECK-NEXT: %[[COPY_1:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32> +// CHECK-NEXT: cf.br ^[[BB1:[[:alnum:]]+]](%[[COPY_1]] +// CHECK: ^[[BB1]]{{.*}}: +// CHECK-NEXT: return +// CHECK: ^[[BB2]]{{.*}}: +// CHECK-NEXT: return +func.func @cond_branch_true_and_false_tile_args(%cond: i1, %tile: vector<[4]x[4]xf32>) { + cf.cond_br %cond, ^bb1(%tile: vector<[4]x[4]xf32>), ^bb2(%tile: vector<[4]x[4]xf32>) +^bb1(%blockArg0: vector<[4]x[4]xf32>): + return +^bb2(%blockArg1: vector<[4]x[4]xf32>): + return +} + +// ----- + +// CHECK-LABEL: @multiple_predecessors +// CHECK: ^bb1: +// CHECK-NEXT: %[[TILE:.*]] = arm_sme.get_tile : vector<[4]x[4]xf32> +// CHECK-NEXT: %[[COPY_0:.*]] = arm_sme.copy_tile %[[TILE]] : vector<[4]x[4]xf32> +// CHECK-NEXT: cf.br ^bb3(%[[COPY_0]] : vector<[4]x[4]xf32>) +// CHECK: ^bb2: +// CHECK-NEXT: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xf32> +// CHECK-NEXT: %[[COPY_1:.*]] = arm_sme.copy_tile %[[ZERO]] : vector<[4]x[4]xf32> +// CHECK-NEXT: cf.br ^bb3(%[[COPY_1]] : vector<[4]x[4]xf32>) +// CHECK: ^bb3({{.*}}): +// CHECK-NEXT: return +func.func @multiple_predecessors(%cond: i1) +{ + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + %tile = arm_sme.get_tile : vector<[4]x[4]xf32> + cf.br ^bb3(%tile : vector<[4]x[4]xf32>) +^bb2: + %zero = arm_sme.zero : vector<[4]x[4]xf32> + cf.br ^bb3(%zero : vector<[4]x[4]xf32>) +^bb3(%blockArg: vector<[4]x[4]xf32>): // pred: ^bb1, ^bb2 + return +} diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir index 39d9ab6491e3b..6b5e44365bf58 100644 --- a/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir +++ b/mlir/test/Dialect/ArmSME/tile-allocation-invalid.mlir @@ -1,19 +1,17 @@ -// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics +// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -verify-diagnostics -// ----- +// Select between tileA and tileB. This is currently unsupported as it would +// require inserting (runtime) tile moves. -func.func @selecting_between_different_tiles_is_unsupported(%dest : memref, %cond: i1) { +// expected-note@below {{tile operand is: of type 'vector<[4]x[4]xi32>'}} +func.func @selecting_between_different_tiles_is_unsupported(%dest : memref, %tileA : vector<[4]x[4]xi32>, %tileB : vector<[4]x[4]xi32>, %cond: i1) { %c0 = arith.constant 0 : index - %tileA = arm_sme.get_tile : vector<[4]x[4]xi32> - %tileB = arm_sme.get_tile : vector<[4]x[4]xi32> - // Select between tileA and tileB. This is currently unsupported as it would - // require inserting tile move operations during tile allocation. + // expected-error@below {{op tile operand allocated to different SME virtial tile (move required)}} %tile = scf.if %cond -> vector<[4]x[4]xi32> { scf.yield %tileA : vector<[4]x[4]xi32> } else { scf.yield %tileB : vector<[4]x[4]xi32> } - // expected-error@+1 {{op already assigned different SME virtual tile!}} arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xi32> return } diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir index 2dedcb2fbc24e..88fc8a8923d34 100644 --- a/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir +++ b/mlir/test/Dialect/ArmSME/tile-allocation-liveness.mlir @@ -1,18 +1,26 @@ -// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK-BAD +// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -convert-scf-to-cf -test-arm-sme-tile-allocation=dump-tile-live-ranges -mlir-disable-threading -split-input-file -verify-diagnostics 2>&1 >/dev/null | FileCheck %s --check-prefix=CHECK-LIVE-RANGE -// This file tests some aspects of liveness issues in the SME tile allocator. -// These tests were designed with a new liveness-based tile allocator in mind -// (where the names of test cases make more sense), with the current tile -// allocator these tests all give incorrect results (which is documented by -// `CHECK-BAD`). +// This file tests some simple aspects of using liveness in the SME tile allocator. +// Note: We use -convert-scf-to-cf first as the tile allocator expects CF, but +// some of these tests are written in SCF (to make things easier to follow). -// Incorrect result! The second `move_vector_to_tile_slice` overwrites the first (which is still live). -// -// CHECK-BAD-LABEL: @constant_with_multiple_users -// CHECK-BAD: %[[ZERO_TILE:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32> -// CHECK-BAD: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> -// CHECK-BAD: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> +// CHECK-LIVE-RANGE-LABEL: @constant_with_multiple_users +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb0: +// CHECK-LIVE-RANGE: S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: |E test.some_use +// CHECK-LIVE-RANGE-NEXT: E test.some_use + +// CHECK-LABEL: @constant_with_multiple_users( +// CHECK-SAME: %[[VECTOR_A:.*]]: vector<[4]xf32>, %[[VECTOR_B:.*]]: vector<[4]xf32> func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) { + // CHECK-NEXT: %[[ZERO_TILE_0:.*]] = arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32> + // CHECK-NEXT: %[[ZERO_TILE_1:.*]] = arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32> + // CHECK-NEXT: %[[INSERT_TILE_1:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_A]], %[[ZERO_TILE_1]], %{{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> + // CHECK-NEXT: %[[INSERT_TILE_0:.*]] = arm_sme.move_vector_to_tile_slice %[[VECTOR_B]], %[[ZERO_TILE_0]], %{{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> %zero = arm_sme.zero : vector<[4]x[4]xf32> %tile_a = arm_sme.move_vector_to_tile_slice %a, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32> %tile_b = arm_sme.move_vector_to_tile_slice %b, %zero, %index : vector<[4]xf32> into vector<[4]x[4]xf32> @@ -23,12 +31,17 @@ func.func @constant_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32> // ----- -// (No tile IDs -- the current tile allocator ignores this case) +// CHECK-LIVE-RANGE-LABEL: @value_with_multiple_users +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb0: +// CHECK-LIVE-RANGE-NEXT: |S arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: || arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: |E test.some_use +// CHECK-LIVE-RANGE-NEXT: E test.some_use -// CHECK-BAD-LABEL: @value_with_multiple_users -// CHECK-BAD-NOT: tile_id +// expected-note@below {{tile operand is: of type 'vector<[4]x[4]xf32>'}} func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]xf32>, %b: vector<[4]xf32>, %index: index) { - // A future allocator should error here (as `%tile` would need to be copied). + // expected-error@below {{op tile operand allocated to different SME virtial tile (move required)}} %tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32> %tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %index : vector<[4]xf32> into vector<[4]x[4]xf32> "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> () @@ -38,12 +51,38 @@ func.func @value_with_multiple_users(%tile: vector<[4]x[4]xf32>, %a: vector<[4]x // ----- -// CHECK-BAD-LABEL: @reuse_tiles_after_initial_use +// CHECK-LIVE-RANGE-LABEL: @reuse_tiles_after_initial_use +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb0: +// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile +// CHECK-LIVE-RANGE-NEXT: |S arm_sme.get_tile +// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.get_tile +// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.get_tile +// CHECK-LIVE-RANGE-NEXT: |||| test.dummy +// CHECK-LIVE-RANGE-NEXT: |||| test.dummy +// CHECK-LIVE-RANGE-NEXT: |||| test.dummy +// CHECK-LIVE-RANGE-NEXT: E||| test.some_use +// CHECK-LIVE-RANGE-NEXT: E|| test.some_use +// CHECK-LIVE-RANGE-NEXT: E| test.some_use +// CHECK-LIVE-RANGE-NEXT: E test.some_use +// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: |||| test.dummy +// CHECK-LIVE-RANGE-NEXT: |||| test.dummy +// CHECK-LIVE-RANGE-NEXT: |||| test.dummy +// CHECK-LIVE-RANGE-NEXT: E||| test.some_use +// CHECK-LIVE-RANGE-NEXT: E|| test.some_use +// CHECK-LIVE-RANGE-NEXT: E| test.some_use +// CHECK-LIVE-RANGE-NEXT: E test.some_use + +// CHECK-LABEL: @reuse_tiles_after_initial_use func.func @reuse_tiles_after_initial_use() { - // CHECK-BAD: arm_sme.get_tile {tile_id = 0 : i32} - // CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32} - // CHECK-BAD: arm_sme.get_tile {tile_id = 2 : i32} - // CHECK-BAD: arm_sme.get_tile {tile_id = 3 : i32} + // CHECK: arm_sme.get_tile {tile_id = 0 : i32} + // CHECK: arm_sme.get_tile {tile_id = 1 : i32} + // CHECK: arm_sme.get_tile {tile_id = 2 : i32} + // CHECK: arm_sme.get_tile {tile_id = 3 : i32} %tile_a = arm_sme.get_tile : vector<[4]x[4]xf32> %tile_b = arm_sme.get_tile : vector<[4]x[4]xf32> %tile_c = arm_sme.get_tile : vector<[4]x[4]xf32> @@ -55,19 +94,13 @@ func.func @reuse_tiles_after_initial_use() { "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> () "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> () "test.some_use"(%tile_d) : (vector<[4]x[4]xf32>) -> () - // -> Spills after the fourth tile (unnecessary): - // CHECK-BAD: arm_sme.zero {tile_id = 16 : i32} - // CHECK-BAD: arm_sme.zero {tile_id = 17 : i32} - // CHECK-BAD: arm_sme.zero {tile_id = 18 : i32} - // CHECK-BAD: arm_sme.zero {tile_id = 19 : i32} - // Unnecessary spills: - // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} + // CHECK: arm_sme.zero {tile_id = 0 : i32} + // CHECK: arm_sme.zero {tile_id = 1 : i32} + // CHECK: arm_sme.zero {tile_id = 2 : i32} + // CHECK: arm_sme.zero {tile_id = 3 : i32} %tile_1 = arm_sme.zero : vector<[4]x[4]xf32> - // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_2 = arm_sme.zero : vector<[4]x[4]xf32> - // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_3 = arm_sme.zero : vector<[4]x[4]xf32> - // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_4 = arm_sme.zero : vector<[4]x[4]xf32> "test.dummy"(): () -> () "test.dummy"(): () -> () @@ -81,16 +114,123 @@ func.func @reuse_tiles_after_initial_use() { // ----- -// Incorrect result! Both branches should yield the result via the same tile. +// CHECK-LIVE-RANGE-LABEL: @tile_live_ins +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb0: +// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile +// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: EE cf.br +// CHECK-LIVE-RANGE-NEXT: ^bb1: +// CHECK-LIVE-RANGE-NEXT: || test.dummy +// CHECK-LIVE-RANGE-NEXT: || test.dummy +// CHECK-LIVE-RANGE-NEXT: EE cf.br +// CHECK-LIVE-RANGE-NEXT: ^bb2: +// CHECK-LIVE-RANGE-NEXT: || test.dummy +// CHECK-LIVE-RANGE-NEXT: || test.dummy +// CHECK-LIVE-RANGE-NEXT: EE cf.br +// CHECK-LIVE-RANGE-NEXT: ^bb3: +// CHECK-LIVE-RANGE-NEXT: E| test.some_use +// CHECK-LIVE-RANGE-NEXT: E test.some_use + +// CHECK-LABEL: @tile_live_ins +func.func @tile_live_ins() +{ + // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32> + // CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32> + %tile_1 = arm_sme.get_tile : vector<[4]x[4]xf32> + %tile_2 = arm_sme.zero : vector<[4]x[4]xf32> + cf.br ^bb1 +^bb1: + "test.dummy"(): () -> () + "test.dummy"(): () -> () + cf.br ^bb2 +^bb2: + "test.dummy"(): () -> () + "test.dummy"(): () -> () + cf.br ^bb3 +^bb3: + "test.some_use"(%tile_1) : (vector<[4]x[4]xf32>) -> () + "test.some_use"(%tile_2) : (vector<[4]x[4]xf32>) -> () + return +} + +// ----- + +// This is basically the same test as tile_live_ins but shows that the order of +// the blocks within the source does not relate to the liveness, which is based +// on successors and predecessors (not textual order). +// +// So %tile_1 is live on the path bb0 -> bb2 -> bb1 (and dies in bb1). The +// 'hole' when looking at the live range dump comes from the textual order +// (and would disappear if bb1 was moved before bb2 in the source). // -// CHECK-BAD-LABEL: @non_overlapping_branches -// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32> -// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32> +// When looking at the live range dump (outside of straight-line code) it +// normally makes more sense to consider blocks in isolation (and how they +// relate to the CFG). + +// CHECK-LIVE-RANGE-LABEL: @non_sequential_live_ins +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb0: +// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile +// CHECK-LIVE-RANGE-NEXT: | test.dummy +// CHECK-LIVE-RANGE-NEXT: E cf.br +// CHECK-LIVE-RANGE-NEXT: ^bb1: +// CHECK-LIVE-RANGE-NEXT: E| test.some_use +// CHECK-LIVE-RANGE-NEXT: | test.dummy +// CHECK-LIVE-RANGE-NEXT: E cf.br +// CHECK-LIVE-RANGE-NEXT: ^bb2: +// CHECK-LIVE-RANGE-NEXT: |S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: || test.dummy +// CHECK-LIVE-RANGE-NEXT: EE cf.cond_br +// CHECK-LIVE-RANGE-NEXT: ^bb3: +// CHECK-LIVE-RANGE-NEXT: | test.dummy +// CHECK-LIVE-RANGE-NEXT: E test.some_use +// CHECK-LIVE-RANGE-NEXT: func.return + +// CHECK-LABEL: @non_sequential_live_ins +func.func @non_sequential_live_ins(%cond: i1) { + // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32> + // CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32> + %tile_1 = arm_sme.get_tile : vector<[4]x[4]xf32> + "test.dummy"(): () -> () + cf.br ^bb2 +^bb1: + "test.some_use"(%tile_1) : (vector<[4]x[4]xf32>) -> () + "test.dummy"(): () -> () + cf.br ^bb3 +^bb2: + %tile_2 = arm_sme.zero : vector<[4]x[4]xf32> + "test.dummy"(): () -> () + cf.cond_br %cond, ^bb1, ^bb3 +^bb3: + "test.dummy"(): () -> () + "test.some_use"(%tile_2) : (vector<[4]x[4]xf32>) -> () + return +} + +// ----- + +// CHECK-LIVE-RANGE-LABEL: @non_overlapping_branches +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb1: +// CHECK-LIVE-RANGE-NEXT: S arm_sme.zero +// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile +// CHECK-LIVE-RANGE-NEXT: E cf.br +// CHECK-LIVE-RANGE-NEXT: ^bb2: +// CHECK-LIVE-RANGE-NEXT: S arm_sme.get_tile +// CHECK-LIVE-RANGE-NEXT: | arm_sme.copy_tile +// CHECK-LIVE-RANGE-NEXT: E cf.br + +// CHECK-LABEL: @non_overlapping_branches func.func @non_overlapping_branches(%cond: i1) { + // CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32> + // CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32> %tile = scf.if %cond -> vector<[4]x[4]xf32> { + // ^bb1: %zero = arm_sme.zero : vector<[4]x[4]xf32> scf.yield %zero : vector<[4]x[4]xf32> } else { + // ^bb2: %undef = arm_sme.get_tile : vector<[4]x[4]xf32> scf.yield %undef : vector<[4]x[4]xf32> } @@ -100,52 +240,65 @@ func.func @non_overlapping_branches(%cond: i1) { // ----- -// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten). -// -// CHECK-BAD-LABEL: @constant_loop_init_with_multiple_users -// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32> -// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> -// CHECK-BAD: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> -func.func @constant_loop_init_with_multiple_users(%a: vector<[4]xf32>, %b: vector<[4]xf32>) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c10 = arith.constant 10 : index - %init = arm_sme.zero : vector<[4]x[4]xf32> - %tile_a = scf.for %i = %c0 to %c10 step %c1 iter_args(%iter = %init) -> vector<[4]x[4]xf32> { - %new_tile = arm_sme.move_vector_to_tile_slice %a, %iter, %i : vector<[4]xf32> into vector<[4]x[4]xf32> - scf.yield %new_tile : vector<[4]x[4]xf32> - } - %tile_b = scf.for %i = %c0 to %c10 step %c1 iter_args(%iter = %init) -> vector<[4]x[4]xf32> { - %new_tile = arm_sme.move_vector_to_tile_slice %a, %iter, %i : vector<[4]xf32> into vector<[4]x[4]xf32> - scf.yield %new_tile : vector<[4]x[4]xf32> +// Here %vecA and %vecB are not merged into the same live range (as they are unknown values). +// This means that %vecA and %vecB are both allocated to different tiles (which is not legal). + +// expected-note@below {{tile operand is: of type 'vector<[4]x[4]xf32>'}} +func.func @overlapping_branches(%cond: i1, %vecA: vector<[4]x[4]xf32>, %vecB: vector<[4]x[4]xf32>) { + // expected-error@below {{op tile operand allocated to different SME virtial tile (move required)}} + %tile = scf.if %cond -> vector<[4]x[4]xf32> { + scf.yield %vecA : vector<[4]x[4]xf32> + } else { + scf.yield %vecB : vector<[4]x[4]xf32> } - "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> () - "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> () + "test.some_use"(%tile) : (vector<[4]x[4]xf32>) -> () return } // ----- -// Incorrect result! Everything assigned to tile 0 (which means values that are still live are overwritten). -// -// CHECK-BAD-LABEL: @run_out_of_tiles_but_avoid_spill -// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} -// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> +// CHECK-LIVE-RANGE-LABEL: @run_out_of_tiles_but_avoid_spill +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb2: +// CHECK-LIVE-RANGE-NEXT: |S arm_sme.copy_tile +// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.copy_tile +// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.copy_tile +// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.copy_tile +// CHECK-LIVE-RANGE-NEXT: EEEEE cf.br + +// Note in the live ranges (above) there is five tile values, but we only have four tiles. +// There is no 'real' spill as we spill the `arm_sme.zero` but are then able to clone it +// at each of its uses. + +// CHECK-LABEL: @run_out_of_tiles_but_avoid_spill func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) { %init = arm_sme.zero : vector<[4]x[4]xf32> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index + // Live = %init scf.for %i = %c0 to %c10 step %c1 { + // CHECK: arm_sme.zero {tile_id = 1 : i32} + // CHECK: arm_sme.zero {tile_id = 2 : i32} + // CHECK: arm_sme.zero {tile_id = 3 : i32} + // CHECK: arm_sme.zero {tile_id = 0 : i32} %tile_a, %tile_b, %tile_c, %tile_d = scf.for %j = %c0 to %c10 step %c1 iter_args(%iter_a = %init, %iter_b = %init, %iter_c = %init, %iter_d = %init) -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32> , vector<[4]x[4]xf32> , vector<[4]x[4]xf32>) { + // ^bb2: + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 2 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 3 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> + // CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> %new_a = arm_sme.move_vector_to_tile_slice %a, %iter_a, %i : vector<[4]xf32> into vector<[4]x[4]xf32> %new_b = arm_sme.move_vector_to_tile_slice %b, %iter_b, %i : vector<[4]xf32> into vector<[4]x[4]xf32> %new_c = arm_sme.move_vector_to_tile_slice %c, %iter_c, %i : vector<[4]xf32> into vector<[4]x[4]xf32> %new_d = arm_sme.move_vector_to_tile_slice %d, %iter_d, %i : vector<[4]xf32> into vector<[4]x[4]xf32> scf.yield %new_a, %new_b, %new_c, %new_d : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32> } + // Live = %init, %tile_a, %tile_b, %tile_c, %tile_d (out of tiles!) + // This should be resolved by duplicating the arm_sme.zero (from folding + // arm_sme.copy_tile operations inserted by the tile allocator). "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> () "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> () "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> () @@ -156,24 +309,48 @@ func.func @run_out_of_tiles_but_avoid_spill(%a: vector<[4]xf32>, %b: vector<[4]x // ----- -// Incorrect result! Everything other than zero assigned to tile 1 (which means values that are still live are overwritten). -// -// CHECK-BAD-LABEL: @avoidable_spill -// CHECK-BAD: arm_sme.zero {tile_id = 0 : i32} -// CHECK-BAD: arm_sme.get_tile {tile_id = 1 : i32} -// CHECK-BAD-COUNT-4: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 1 : i32} +// We should be able to avoid spills like this, but logic handling this case is +// not implemented yet. Note tile ID >= 16 means a spill/in-memory tile. + +// CHECK-LIVE-RANGE-LABEL: @avoidable_spill +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb2: +// CHECK-LIVE-RANGE-NEXT: || test.some_use +// CHECK-LIVE-RANGE-NEXT: ||S arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: |||S arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: ||||S arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: |||||S arm_sme.move_vector_to_tile_slice +// CHECK-LIVE-RANGE-NEXT: ||E||| test.some_use +// CHECK-LIVE-RANGE-NEXT: || E|| test.some_use +// CHECK-LIVE-RANGE-NEXT: || E| test.some_use +// CHECK-LIVE-RANGE-NEXT: || E test.some_use +// CHECK-LIVE-RANGE-NEXT: || arith.addi +// CHECK-LIVE-RANGE-NEXT: EE cf.br + +// Note in the live ranges (above) there is two constant live-ins (first two ranges), +// which gives six overlapping live ranges (at the point where %tile_d is defined). +// The allocator currently will spill the first constant (which results in a real +// spill at it's use), however, this could be avoided by using the knowledge that +// at the first "test.some_use" there's actually only two live ranges (so we can +// fix this be duplicating the constant). + +// CHECK-LABEL: @avoidable_spill func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector<[4]xf32>, %d: vector<[4]xf32>) { + // CHECK: arm_sme.zero {tile_id = 16 : i32} : vector<[4]x[4]xf32> %zero = arm_sme.zero : vector<[4]x[4]xf32> %tile = arm_sme.get_tile : vector<[4]x[4]xf32> %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index scf.for %i = %c0 to %c10 step %c1 { + // So spilled here (unnecessarily). + // The arm_sme.zero op could be moved into the loop to avoid this. "test.some_use"(%zero) : (vector<[4]x[4]xf32>) -> () %tile_a = arm_sme.move_vector_to_tile_slice %a, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32> %tile_b = arm_sme.move_vector_to_tile_slice %b, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32> %tile_c = arm_sme.move_vector_to_tile_slice %c, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32> %tile_d = arm_sme.move_vector_to_tile_slice %d, %tile, %c0 : vector<[4]xf32> into vector<[4]x[4]xf32> + // %zero is still live here (due the the backedge) "test.some_use"(%tile_a) : (vector<[4]x[4]xf32>) -> () "test.some_use"(%tile_b) : (vector<[4]x[4]xf32>) -> () "test.some_use"(%tile_c) : (vector<[4]x[4]xf32>) -> () @@ -181,3 +358,75 @@ func.func @avoidable_spill(%a: vector<[4]xf32>, %b: vector<[4]xf32>, %c: vector< } return } + +// ----- + +// This test is a follow up to the test of the same name in `tile-allocation-copies.mlir`. +// This shows the live ranges (which are why we need to split the conditional branch). + +// CHECK-LIVE-RANGE-LABEL: @cond_branch_with_backedge +// CHECK-LIVE-RANGE: ^bb1: +// CHECK-LIVE-RANGE--NEXT: ||| | arith.cmpi +// CHECK-LIVE-RANGE--NEXT: EEE E cf.cond_br +// +// CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES:[[:alnum:]]+]]: +// CHECK-LIVE-RANGE--NEXT: ||| ES arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: E|| |S arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: E| ||S arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: E |||S arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: EEEE cf.br +// +// It is important to note that the first three live ranges in ^bb1 do not end +// at the `cf.cond_br` they are live-out via the backedge bb1 -> bb2 -> bb1. +// This means that if we placed the `arm_sme.tile_copies` before the `cf.cond_br` +// then those live ranges would not end at the copies, resulting in unwanted +// overlapping live ranges (and hence tile spills). +// +// With the conditional branch split and the copies placed in the BB3_COPIES +// block the first three live ranges end at the copy operations (as the +// BB3_COPIES block is on the path out of the loop and has no backedge). This +// means there is no overlaps and the live ranges all merge, as shown below. +// +// CHECK-LIVE-RANGE: ========== Coalesced Live Ranges: +// CHECK-LIVE-RANGE: ^bb1: +// CHECK-LIVE-RANGE--NEXT: |||| arith.cmpi +// CHECK-LIVE-RANGE--NEXT: EEEE cf.cond_br +// +// CHECK-LIVE-RANGE--NEXT: ^[[BB3_COPIES]]: +// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: |||| arm_sme.copy_tile +// CHECK-LIVE-RANGE--NEXT: EEEE cf.br + +// CHECK-LABEL: @cond_branch_with_backedge +// CHECK-NOT: tile_id = 16 +// CHECK: arm_sme.get_tile {tile_id = 0 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.get_tile {tile_id = 1 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.get_tile {tile_id = 2 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.get_tile {tile_id = 3 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.move_vector_to_tile_slice {{.*}} {tile_id = 0 : i32} : vector<[4]xf32> into vector<[4]x[4]xf32> +// CHECK-NOT tile_id = 16 +func.func @cond_branch_with_backedge(%slice: vector<[4]xf32>) { + %tileA = arm_sme.get_tile : vector<[4]x[4]xf32> + %tileB = arm_sme.get_tile : vector<[4]x[4]xf32> + %tileC = arm_sme.get_tile : vector<[4]x[4]xf32> + %tileD = arm_sme.get_tile : vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + // Live here: %tileA, %tileB, %tileC, %tileD + cf.br ^bb1(%c0, %tileA : index, vector<[4]x[4]xf32>) +^bb1(%currentIndex: index, %iterTile: vector<[4]x[4]xf32>): + %continueLoop = arith.cmpi slt, %currentIndex, %c10 : index + // Live here: %iterTile, %tileB, %tileC, %tileD + cf.cond_br %continueLoop, ^bb2, ^bb3(%iterTile, %tileB, %tileC, %tileD : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>) +^bb2: + // Live here: %iterTile, %tileB, %tileC, %tileD + %nextTile = arm_sme.move_vector_to_tile_slice %slice, %iterTile, %currentIndex : vector<[4]xf32> into vector<[4]x[4]xf32> + %nextIndex = arith.addi %currentIndex, %c1 : index + cf.br ^bb1(%nextIndex, %nextTile : index, vector<[4]x[4]xf32>) +^bb3(%finalTileA: vector<[4]x[4]xf32>, %finalTileB: vector<[4]x[4]xf32>, %finalTileC: vector<[4]x[4]xf32>, %finalTileD: vector<[4]x[4]xf32>): + // Live here: %finalTileA, %finalTileB, %finalTileC, %finalTileD + return +} diff --git a/mlir/test/Dialect/ArmSME/tile-allocation-spills-with-mixed-tile-types.mlir b/mlir/test/Dialect/ArmSME/tile-allocation-spills-with-mixed-tile-types.mlir new file mode 100644 index 0000000000000..27757e29c1e2f --- /dev/null +++ b/mlir/test/Dialect/ArmSME/tile-allocation-spills-with-mixed-tile-types.mlir @@ -0,0 +1,38 @@ + +// RUN: mlir-opt %s -test-arm-sme-tile-allocation -split-input-file | FileCheck %s + +// CHECK-LABEL: @always_spill_larger_or_equal_tile_type +// CHECK: arm_sme.zero {tile_id = 0 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.zero {tile_id = 1 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.zero {tile_id = 2 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.zero {tile_id = 3 : i32} : vector<[4]x[4]xf32> +// CHECK: arm_sme.tile_load {{.*}} {tile_id = 16 : i32} : memref, vector<[8]x[8]xf16> +func.func @always_spill_larger_or_equal_tile_type(%memref: memref) -> (vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[8]x[8]xf16>) { + %c0 = arith.constant 0 : index + %0 = arm_sme.zero : vector<[4]x[4]xf32> + %1 = arm_sme.zero : vector<[4]x[4]xf32> + %2 = arm_sme.zero : vector<[4]x[4]xf32> + %3 = arm_sme.zero : vector<[4]x[4]xf32> + // The load will be spilled (even though the zero's are 'trivial' spills) as a single `f32` tile would not fit the load. + %load = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[8]x[8]xf16> + return %0, %1, %2, %3, %load : vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[8]x[8]xf16> +} + +// ----- + +// CHECK-LABEL: @spill_larger_tile_type +// CHECK: arm_sme.zero {tile_id = 16 : i32} : vector<[16]x[16]xi8> +// CHECK: arm_sme.tile_load {{.*}} {tile_id = 0 : i32} : memref, vector<[4]x[4]xf32> +// CHECK: arm_sme.tile_load {{.*}} {tile_id = 1 : i32} : memref, vector<[4]x[4]xf32> +// CHECK: arm_sme.tile_load {{.*}} {tile_id = 2 : i32} : memref, vector<[4]x[4]xf32> +// CHECK: arm_sme.tile_load {{.*}} {tile_id = 3 : i32} : memref, vector<[4]x[4]xf32> +func.func @spill_larger_tile_type(%memref: memref) -> (vector<[16]x[16]xi8>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>) { + %c0 = arith.constant 0 : index + // Spilling the `arm_sme.zero` should free up space for all four f32 tiles. + %0 = arm_sme.zero : vector<[16]x[16]xi8> + %1 = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[4]x[4]xf32> + %2 = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[4]x[4]xf32> + %3 = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[4]x[4]xf32> + %4 = arm_sme.tile_load %memref[%c0, %c0] : memref, vector<[4]x[4]xf32> + return %0, %1, %2, %3, %4 : vector<[16]x[16]xi8>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32>, vector<[4]x[4]xf32> +} diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir index cac2dcc24d104..ca339be5fb56f 100644 --- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir +++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -allocate-arm-sme-tiles -convert-arm-sme-to-llvm -canonicalize | FileCheck %s +// RUN: mlir-opt %s --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm,canonicalize))" | FileCheck %s // This test verifies the tile mask operand of the zero intrinsic zeroes // the correct tiles. Both integer and floating-point datatypes are checked. diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir index 588b44a36c29f..14d9712e971a8 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/use-too-many-tiles.mlir @@ -13,26 +13,29 @@ /// performance (hence the warning). func.func @use_too_many_tiles(%a: memref, %b: memref, %c: memref) { %c0 = arith.constant 0 : index + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_a = arith.constant dense<0> : vector<[8]x[8]xi16> + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_b = arith.constant dense<1> : vector<[8]x[8]xi16> // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_c = arm_sme.tile_load %a[%c0, %c0] : memref, vector<[8]x[8]xi16> - // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_d = arm_sme.tile_load %b[%c0, %c0] : memref, vector<[8]x[8]xi16> - // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} %tile_e = arm_sme.tile_load %c[%c0, %c0] : memref, vector<[8]x[8]xi16> // CHECK-LABEL: tile_a: // CHECK-COUNT-8: ( 0, 0, 0, 0, 0, 0, 0, 0 vector.print str "tile_a:\n" + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} vector.print %tile_a : vector<[8]x[8]xi16> // CHECK-LABEL: tile_b: // CHECK-COUNT-8: ( 1, 1, 1, 1, 1, 1, 1, 1 vector.print str "tile_b:\n" + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} vector.print %tile_b : vector<[8]x[8]xi16> // CHECK-LABEL: tile_c: // CHECK-COUNT-8: ( 2, 2, 2, 2, 2, 2, 2, 2 vector.print str "tile_c:\n" + // expected-warning @below {{failed to allocate SME virtual tile to operation, all tile operations will go through memory, expect degraded performance}} vector.print %tile_c : vector<[8]x[8]xi16> // CHECK-LABEL: tile_d: // CHECK-COUNT-8: ( 3, 3, 3, 3, 3, 3, 3, 3 diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir index 1794564a6a724..0648e771b8891 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/Emulated/test-setArmSVLBits.mlir @@ -1,5 +1,6 @@ // DEFINE: %{entry_point} = main -// DEFINE: %{compile} = mlir-opt %s -convert-arm-sme-to-llvm -test-lower-to-llvm +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --pass-pipeline="builtin.module(func.func(convert-arm-sme-to-llvm),test-lower-to-llvm)" // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ // DEFINE: -e %{entry_point} -entry-point-result=void \ diff --git a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt index e942c7b8ac058..cdd8afe141421 100644 --- a/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt +++ b/mlir/test/lib/Dialect/ArmSME/CMakeLists.txt @@ -15,4 +15,5 @@ add_mlir_library(MLIRArmSMETestPasses MLIRTransforms MLIRVectorToArmSME MLIRVectorToSCF + MLIRSCFToControlFlow ) diff --git a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp index 48d4a5859f8a0..d3dabaf200fdc 100644 --- a/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp +++ b/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp @@ -14,10 +14,12 @@ #include "mlir/Conversion/ArithToArmSME/ArithToArmSME.h" #include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/VectorToArmSME/VectorToArmSME.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSVE/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" @@ -34,6 +36,10 @@ struct TestLowerToArmSMEOptions llvm::cl::desc("Fuse outer product operations via " "'-arm-sme-outer-product-fusion' pass"), llvm::cl::init(true)}; + PassOptions::Option dumpTileLiveRanges{ + *this, "dump-tile-live-ranges", + llvm::cl::desc("Dump the live ranges of SME tiles (for debugging)"), + llvm::cl::init(false)}; }; void buildTestLowerToArmSME(OpPassManager &pm, @@ -65,20 +71,17 @@ void buildTestLowerToArmSME(OpPassManager &pm, pm.addPass(createConvertVectorToSCFPass( VectorTransferToSCFOptions().enableFullUnroll())); - // Allocate tiles for ArmSME operations. - // - // Later passes may create further ArmSME ops that implement the - // ArmSMETileOpInterface, but tiles are allocated for root operations, - // all of which should now exist. - pm.addPass(arm_sme::createTileAllocationPass()); - // Enable streaming-mode and ZA. pm.addPass(arm_sme::createEnableArmStreamingPass( arm_sme::ArmStreamingMode::StreamingLocally, arm_sme::ArmZaMode::NewZA, /*onlyIfRequiredByOps=*/true)); + // Convert SCF to CF (required for ArmSME tile allocation). + pm.addPass(createConvertSCFToCFPass()); + // Convert ArmSME to LLVM. - pm.addPass(createConvertArmSMEToLLVMPass()); + pm.addNestedPass( + createConvertArmSMEToLLVMPass(options.dumpTileLiveRanges)); // Sprinkle some cleanups. pm.addPass(createCanonicalizerPass());