diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h new file mode 100644 index 0000000000000..fe851d17867df --- /dev/null +++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h @@ -0,0 +1,38 @@ +//===- ArmSMEToLLVM.h - Convert ArmSME to LLVM dialect ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_ +#define MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_ + +#include + +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" + +namespace mlir { +class Pass; +class RewritePatternSet; + +#define GEN_PASS_DECL_CONVERTARMSMETOLLVM +#include "mlir/Conversion/Passes.h.inc" + +using arm_sme::ArmSMETypeConverter; + +/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics. +std::unique_ptr createConvertArmSMEToLLVMPass(); + +/// Configure target to convert from the ArmSME dialect to LLVM intrinsics. +void configureArmSMEToLLVMConversionLegality(ConversionTarget &target); + +/// Populate the given list with patterns that convert from the ArmSME dialect +/// to LLVM intrinsics. +void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_ARMSMETOLLVM_ARMSMETOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 3078d909a8946..a25fd17ea923f 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -15,6 +15,7 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h" +#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" #include "mlir/Conversion/ArmSMEToSCF/ArmSMEToSCF.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" #include "mlir/Conversion/BufferizationToMemRef/BufferizationToMemRef.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 626f5f3d19d30..06756ff3df0bb 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1241,6 +1241,20 @@ def ConvertArmSMEToSCF : Pass<"convert-arm-sme-to-scf"> { ]; } +//===----------------------------------------------------------------------===// +// ArmSMEToLLVM +//===----------------------------------------------------------------------===// + +def ConvertArmSMEToLLVM : Pass<"convert-arm-sme-to-llvm"> { + let summary = "Lower the operations from the ArmSME dialect into the LLVM " + "dialect"; + let constructor = "mlir::createConvertArmSMEToLLVMPass()"; + let dependentDialects = [ + "arm_sme::ArmSMEDialect", + "LLVM::LLVMDialect" + ]; +} + //===----------------------------------------------------------------------===// // VectorToLLVM //===----------------------------------------------------------------------===// @@ -1280,10 +1294,6 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, - Option<"armSME", "enable-arm-sme", - "bool", /*default=*/"false", - "Enables the use of ArmSME dialect while lowering the vector " - "dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h index fae0451385993..8ea3e1e57b7ca 100644 --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -20,15 +20,6 @@ void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); } // namespace arm_sme -/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM -/// intrinsics. -void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); - -/// Configure the target to support lowering ArmSME ops to ops that map to LLVM -/// intrinsics. -void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target); - } // namespace mlir #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp similarity index 85% rename from mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp rename to mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp index 6ccb652ecbbc2..e409dc57fb020 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp @@ -1,24 +1,34 @@ -//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// +//===- ArmSMEToLLVM.cpp - Convert ArmSME to LLVM dialect ------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// +// +// This file implements lowering of ArmSME operations to LLVM intrinsics. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" -#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTARMSMETOLLVM +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir using namespace mlir; -using namespace mlir::arm_sme; namespace { @@ -40,11 +50,11 @@ namespace { /// The 'arm_sme.cast_tile_to_vector' (which models the return) and the /// 'arith.shli' (which generates the mask) will be folded away after tile /// allocation and canonization. -struct ZeroOpConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct ZeroOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ZeroOp zero, OpAdaptor adaptor, + matchAndRewrite(arm_sme::ZeroOp zero, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = zero.getLoc(); @@ -121,7 +131,7 @@ struct ZeroOpConversion : public ConvertOpToLLVMPattern { }; /// Lower `arm_sme.load_tile_slice` to SME intrinsics. -struct LoadTileSliceToArmSMELowering +struct LoadTileSliceConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< arm_sme::LoadTileSliceOp>::ConvertOpToLLVMPattern; @@ -220,7 +230,7 @@ struct LoadTileSliceToArmSMELowering }; /// Lower for `arm_sme.store_tile_slice` to SME intrinsics. -struct StoreTileSliceToArmSMELowering +struct StoreTileSliceConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< arm_sme::StoreTileSliceOp>::ConvertOpToLLVMPattern; @@ -313,7 +323,7 @@ struct StoreTileSliceToArmSMELowering }; /// Lower `arm_sme.move_vector_to_tile_slice` to SME intrinsics. -struct MoveVectorToTileSliceToArmSMELowering +struct MoveVectorToTileSliceConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< arm_sme::MoveVectorToTileSliceOp>::ConvertOpToLLVMPattern; @@ -373,7 +383,7 @@ struct MoveVectorToTileSliceToArmSMELowering }; /// Lower `arm_sme.move_tile_slice_to_vector` to SME intrinsics. -struct MoveTileSliceToVectorArmSMELowering +struct MoveTileSliceToVectorConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< arm_sme::MoveTileSliceToVectorOp>::ConvertOpToLLVMPattern; @@ -456,7 +466,8 @@ struct OuterProductOpConversion // * half-precision - +sme2p1,+b16b16 // // It should be possible to control lowering based on target features. - // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile + // [1] + // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable()) return false; @@ -475,7 +486,7 @@ struct OuterProductOpConversion }; // TODO: Support CombiningKind::Sub for outer products. - if (outerProductOp.getKind() != CombiningKind::Add) + if (outerProductOp.getKind() != arm_sme::CombiningKind::Add) return outerProductOp.emitError("unsupported kind"); auto resultVectorType = outerProductOp.getResultType(); @@ -522,32 +533,56 @@ struct OuterProductOpConversion } // namespace -void mlir::configureArmSMELegalizeForExportTarget( - LLVMConversionTarget &target) { +namespace { + +struct ConvertArmSMEToLLVMPass + : public impl::ConvertArmSMEToLLVMBase { + void runOnOperation() override { + LLVMConversionTarget target(getContext()); + RewritePatternSet patterns(&getContext()); + ArmSMETypeConverter converter(&getContext(), + LowerToLLVMOptions(&getContext())); + + configureArmSMEToLLVMConversionLegality(target); + populateArmSMEToLLVMConversionPatterns(converter, patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; + +} // namespace + +void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) { + target.addIllegalDialect(); target.addLegalOp< - scf::ForOp, scf::YieldOp, arm_sme::CastTileToVector, - arm_sme::CastVectorToTile, arm_sme::aarch64_sme_zero, - arm_sme::aarch64_sme_str, arm_sme::aarch64_sme_ld1b_horiz, - arm_sme::aarch64_sme_ld1h_horiz, arm_sme::aarch64_sme_ld1w_horiz, - arm_sme::aarch64_sme_ld1d_horiz, arm_sme::aarch64_sme_ld1q_horiz, - arm_sme::aarch64_sme_st1b_horiz, arm_sme::aarch64_sme_st1h_horiz, - arm_sme::aarch64_sme_st1w_horiz, arm_sme::aarch64_sme_st1d_horiz, - arm_sme::aarch64_sme_st1q_horiz, arm_sme::aarch64_sme_ld1b_vert, - arm_sme::aarch64_sme_ld1h_vert, arm_sme::aarch64_sme_ld1w_vert, - arm_sme::aarch64_sme_ld1d_vert, arm_sme::aarch64_sme_ld1q_vert, - arm_sme::aarch64_sme_st1b_vert, arm_sme::aarch64_sme_st1h_vert, - arm_sme::aarch64_sme_st1w_vert, arm_sme::aarch64_sme_st1d_vert, - arm_sme::aarch64_sme_st1q_vert, arm_sme::aarch64_sme_read_horiz, - arm_sme::aarch64_sme_read_vert, arm_sme::aarch64_sme_write_horiz, - arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>(); - target.addLegalOp(); - target.addIllegalOp(); + arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile, + arm_sme::aarch64_sme_zero, arm_sme::aarch64_sme_str, + arm_sme::aarch64_sme_ld1b_horiz, arm_sme::aarch64_sme_ld1h_horiz, + arm_sme::aarch64_sme_ld1w_horiz, arm_sme::aarch64_sme_ld1d_horiz, + arm_sme::aarch64_sme_ld1q_horiz, arm_sme::aarch64_sme_st1b_horiz, + arm_sme::aarch64_sme_st1h_horiz, arm_sme::aarch64_sme_st1w_horiz, + arm_sme::aarch64_sme_st1d_horiz, arm_sme::aarch64_sme_st1q_horiz, + arm_sme::aarch64_sme_ld1b_vert, arm_sme::aarch64_sme_ld1h_vert, + arm_sme::aarch64_sme_ld1w_vert, arm_sme::aarch64_sme_ld1d_vert, + arm_sme::aarch64_sme_ld1q_vert, arm_sme::aarch64_sme_st1b_vert, + arm_sme::aarch64_sme_st1h_vert, arm_sme::aarch64_sme_st1w_vert, + arm_sme::aarch64_sme_st1d_vert, arm_sme::aarch64_sme_st1q_vert, + arm_sme::aarch64_sme_read_horiz, arm_sme::aarch64_sme_read_vert, + arm_sme::aarch64_sme_write_horiz, arm_sme::aarch64_sme_write_vert, + arm_sme::aarch64_sme_mopa>(); + target.addLegalDialect(); + target.addLegalOp(); +} + +void mlir::populateArmSMEToLLVMConversionPatterns( + ArmSMETypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(converter); } -void mlir::populateArmSMELegalizeForLLVMExportPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add< - LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering, - MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering, - OuterProductOpConversion, ZeroOpConversion>(converter); +std::unique_ptr mlir::createConvertArmSMEToLLVMPass() { + return std::make_unique(); } diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt new file mode 100644 index 0000000000000..9914f39e17a1a --- /dev/null +++ b/mlir/lib/Conversion/ArmSMEToLLVM/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRArmSMEToLLVM + ArmSMEToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ArmSMEToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArmSMETransforms + MLIRArmSMEDialect + MLIRArmSMEUtils + MLIRTransforms + MLIRLLVMCommonConversion + MLIRLLVMDialect) diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 822ce5aca2555..c3a2481975040 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -6,6 +6,7 @@ add_subdirectory(ArithToLLVM) add_subdirectory(ArithToSPIRV) add_subdirectory(ArmNeon2dToIntr) add_subdirectory(ArmSMEToSCF) +add_subdirectory(ArmSMEToLLVM) add_subdirectory(AsyncToLLVM) add_subdirectory(BufferizationToMemRef) add_subdirectory(ComplexToLibm) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 4c6d0672d4108..ff8e78a668e0f 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,9 +14,6 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" -#include "mlir/Dialect/ArmSME/IR/ArmSME.h" -#include "mlir/Dialect/ArmSME/Transforms/Passes.h" -#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -52,8 +49,6 @@ struct LowerVectorToLLVMPass registry.insert(); if (armSVE) registry.insert(); - if (armSME) - registry.insert(); if (amx) registry.insert(); if (x86Vector) @@ -96,7 +91,6 @@ void LowerVectorToLLVMPass::runOnOperation() { target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); - arm_sme::ArmSMETypeConverter armSMEConverter(&getContext(), options); if (armNeon) { // TODO: we may or may not want to include in-dialect lowering to @@ -108,10 +102,6 @@ void LowerVectorToLLVMPass::runOnOperation() { configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } - if (armSME) { - configureArmSMELegalizeForExportTarget(target); - populateArmSMELegalizeForLLVMExportPatterns(armSMEConverter, patterns); - } if (amx) { configureAMXLegalizeForExportTarget(target); populateAMXLegalizeForLLVMExportPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt index 8f485db4e8438..e2407d9f48f70 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,7 +1,6 @@ add_mlir_dialect_library(MLIRArmSMETransforms ArmSMETypeConverter.cpp EnableArmStreaming.cpp - LegalizeForLLVMExport.cpp TileAllocation.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir index 2c26c62ad4248..65996e81c42d9 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm-casts.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -split-input-file | FileCheck %s // This test verifies the temporary casts that are emitted when lowering to // intrinsics to preserve data flow are correct. Canonicalization will remove diff --git a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir index 8fdcf69958244..fa62332bc3f5b 100644 --- a/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/arm-sme-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -verify-diagnostics | FileCheck %s // Test conversion of ArmSME ops to LLVM intrinsics. diff --git a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir index 0f31278eefd15..ba650b031e611 100644 --- a/mlir/test/Dialect/ArmSME/enable-arm-za.mlir +++ b/mlir/test/Dialect/ArmSME/enable-arm-za.mlir @@ -1,6 +1,6 @@ -// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=ENABLE-ZA -// RUN: mlir-opt %s -enable-arm-streaming -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=DISABLE-ZA -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" | FileCheck %s -check-prefix=NO-ARM-STREAMING +// RUN: mlir-opt %s -enable-arm-streaming=za-mode=new-za -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=ENABLE-ZA +// RUN: mlir-opt %s -enable-arm-streaming -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=DISABLE-ZA +// RUN: mlir-opt %s -convert-arm-sme-to-llvm | FileCheck %s -check-prefix=NO-ARM-STREAMING // CHECK-LABEL: @declaration func.func private @declaration() diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir index 26cd91bd3e895..2378f4234aef1 100644 --- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir +++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \ +// RUN: mlir-opt %s -convert-arm-sme-to-llvm \ // RUN: -allocate-arm-sme-tiles -canonicalize \ // RUN: -allow-unregistered-dialect \ // RUN: | FileCheck %s diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir index 721ff8f2c3589..77ac071ef67de 100644 --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-arm-sme-to-llvm -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s //===----------------------------------------------------------------------===// // vector.transfer_write @@ -17,9 +17,8 @@ // CHECK-DAG: %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 // CHECK-DAG: %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32 // CHECK-DAG: "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> () -// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { // CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> @@ -58,9 +57,8 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref) { // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> // CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 -// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { // CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0:.*]] = arith.addi %[[TILE_SLICE]], %[[C123]] : index // CHECK-NEXT: %[[TILE_SLICE_PLUS_OFF0_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_PLUS_OFF0]] : index to i64 @@ -92,9 +90,8 @@ func.func @vector_load_i8_with_offset(%arg0 : memref) -> vector<[16]x[16 // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> -// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { // CHECK-NEXT: %[[TILE_SLICE_IDX:.*]] = arith.muli %[[TILE_SLICE]], %[[SVL_B]] : index // CHECK-NEXT: %[[TILE_SLICE_IDX_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE_IDX]] : index to i64 @@ -255,9 +252,8 @@ func.func @vector_load_i128(%arg0 : memref) -> vector<[1]x[1]xi128> { // CHECK-DAG: %[[MIN_SVL_B:.*]] = arith.constant 16 : index // CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> -// CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index +// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale +// CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE]], %[[MIN_SVL_B]] : index // CHECK-NEXT: scf.for %[[TILE_SLICE:.*]] = %[[C0]] to %[[SVL_B]] step %[[C1]] { // CHECK: %[[TILE_SLICE_I64:.*]] = builtin.unrealized_conversion_cast %[[TILE_SLICE]] : index to i64 // CHECK-NEXT: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[16]x[16]xi8> to i8 @@ -466,14 +462,8 @@ func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vec // CHECK-LABEL: @vector_outerproduct_masked_f32 // CHECK-SAME: (%[[LHS:.*]]: vector<[4]xf32>, %[[RHS:.*]]: vector<[4]xf32>, %[[ACC:.*]]: vector<[4]x[4]xf32>, %[[DIM0:.*]]: index, %[[DIM1:.*]]: index func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %dim0 : index, %dim1 : index) { - // CHECK: %[[DIM0_I32:.*]] = arith.index_cast %[[DIM0]] : index to i32 - // CHECK: %[[INSERT_DIM0:.*]] = llvm.insertelement %[[DIM0_I32]], {{.*}} : vector<[4]xi32> - // CHECK: %[[SPLAT_DIM0:.*]] = llvm.shufflevector %[[INSERT_DIM0]], {{.*}} : vector<[4]xi32> - // CHECK: %[[LHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM0]] : vector<[4]xi32> - // CHECK: %[[DIM1_I32:.*]] = arith.index_cast %[[DIM1]] : index to i32 - // CHECK: %[[INSERT_DIM1:.*]] = llvm.insertelement %[[DIM1_I32]], {{.*}} : vector<[4]xi32> - // CHECK: %[[SPLAT_DIM1:.*]] = llvm.shufflevector %[[INSERT_DIM1]], {{.*}} : vector<[4]xi32> - // CHECK: %[[RHS_MASK:.*]] = arith.cmpi slt, %{{.*}}, %[[SPLAT_DIM1]] : vector<[4]xi32> + // CHECK: %[[LHS_MASK:.*]] = vector.create_mask %[[DIM0]] : vector<[4]xi1> + // CHECK: %[[RHS_MASK:.*]] = vector.create_mask %[[DIM1]] : vector<[4]xi1> // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[4]x[4]xf32> to i32 // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE]], %[[LHS_MASK]], %[[RHS_MASK]], %[[LHS]], %[[RHS]]) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) %mask = vector.create_mask %dim0, %dim1 : vector<[4]x[4]xi1> @@ -486,8 +476,8 @@ func.func @vector_outerproduct_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector< // CHECK-LABEL: @vector_outerproduct_masked_f16 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>, func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>, %dim0 : index, %dim1 : index) { - // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32> - // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32> + // CHECK: vector.create_mask {{.*}} : vector<[8]xi1> + // CHECK: vector.create_mask {{.*}} : vector<[8]xi1> // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1> %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[8]xf16>, vector<[8]xf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xf16> @@ -499,8 +489,8 @@ func.func @vector_outerproduct_masked_f16(%lhs : vector<[8]xf16>, %rhs : vector< // CHECK-LABEL: @vector_outerproduct_masked_bf16 // CHECK-SAME: (%[[LHS:.*]]: vector<[8]xbf16>, %[[RHS:.*]]: vector<[8]xbf16>, %[[ACC:.*]]: vector<[8]x[8]xbf16>, func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>, %dim0 : index, %dim1 : index) { - // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32> - // CHECK: arith.cmpi slt, {{.*}} : vector<[8]xi32> + // CHECK: vector.create_mask {{.*}} : vector<[8]xi1> + // CHECK: vector.create_mask {{.*}} : vector<[8]xi1> // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) %mask = vector.create_mask %dim0, %dim1 : vector<[8]x[8]xi1> %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[8]xbf16>, vector<[8]xbf16> } : vector<[8]x[8]xi1> -> vector<[8]x[8]xbf16> @@ -512,8 +502,8 @@ func.func @vector_outerproduct_masked_bf16(%lhs : vector<[8]xbf16>, %rhs : vecto // CHECK-LABEL: @vector_outerproduct_masked_f64 // CHECK-SAME: (%[[LHS:.*]]: vector<[2]xf64>, %[[RHS:.*]]: vector<[2]xf64>, %[[ACC:.*]]: vector<[2]x[2]xf64>, func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>, %dim0 : index, %dim1 : index) { - // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32> - // CHECK: arith.cmpi slt, {{.*}} : vector<[2]xi32> + // CHECK: vector.create_mask {{.*}} : vector<[2]xi1> + // CHECK: vector.create_mask {{.*}} : vector<[2]xi1> // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) %mask = vector.create_mask %dim0, %dim1 : vector<[2]x[2]xi1> %result = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> } : vector<[2]x[2]xi1> -> vector<[2]x[2]xf64> @@ -522,7 +512,6 @@ func.func @vector_outerproduct_masked_f64(%lhs : vector<[2]xf64>, %rhs : vector< // ----- -// CHECK-LABEL: @vector_outerproduct_unsupported_axpy func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> { // expected-error@+1 {{AXPY operations not supported}} %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, f64 @@ -532,6 +521,7 @@ func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f // ----- func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) { + // expected-error@+2 {{failed to legalize operation 'arm_sme.outerproduct'}} // expected-error@+1 {{unsupported type}} %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[16]xi8>, vector<[16]xi8> "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () @@ -540,7 +530,6 @@ func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : v // ----- func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) { - // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} // expected-error@+1 {{unsupported kind}} %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[2]xf64>, vector<[2]xf64> "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () @@ -549,7 +538,7 @@ func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : v // ----- func.func @vector_outerproduct_unknown_mask(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) { - // expected-error@+1 {{failed to legalize operation 'vector.outerproduct'}} + // CHECK: vector.outerproduct %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () } @@ -655,11 +644,10 @@ func.func @vector_insert_slice_f64(%tile: vector<[2]x[2]xf64>, %slice: vector<[2 func.func @vector_insert_element_i32(%tile: vector<[4]x[4]xi32>, %el: i32, %row: index, %col: index) -> vector<[4]x[4]xi32> { // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32> // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense : vector<[4]xi1> - // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64 // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32 // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32 // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32> - // CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32> + // CHECK-NEXT: %[[NEW_SLICE:.*]] = vector.insert %[[EL]], %[[SLICE]] [%[[COL]]] : i32 into vector<[4]xi32> // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[ROW]] : index to i32 // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> () %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32> @@ -846,11 +834,10 @@ func.func @vector_extract_slice_f64(%tile: vector<[2]x[2]xf64>, %row: index) -> func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: index) -> i32 { // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32> // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense : vector<[4]xi1> - // CHECK-NEXT: %[[COL_I32:.*]] = builtin.unrealized_conversion_cast %[[COL]] : index to i64 // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32 // CHECK-NEXT: %[[ROW_I32:.*]] = arith.index_cast %[[ROW]] : index to i32 // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[ROW_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32> - // CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[COL_I32]] : i64] : vector<[4]xi32> + // CHECK-NEXT: %[[EL:.*]] = vector.extract %[[SLICE]]{{\[}}%[[COL]]] : i32 from vector<[4]xi32> %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32> return %el : i32 } @@ -860,7 +847,7 @@ func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %row: index, %col: // CHECK-LABEL: @vector_extract_element_i8 func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, %col: index) -> i8 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[16]xi8>, vector<[16]xi1>, i32, i32) -> vector<[16]xi8> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[16]xi8> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i8 from vector<[16]xi8> %el = vector.extract %tile[%row, %col] : i8 from vector<[16]x[16]xi8> return %el : i8 } @@ -870,7 +857,7 @@ func.func @vector_extract_element_i8(%tile: vector<[16]x[16]xi8>, %row: index, % // CHECK-LABEL: @vector_extract_element_i16 func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, %col: index) -> i16 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xi16>, vector<[8]xi1>, i32, i32) -> vector<[8]xi16> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xi16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i16 from vector<[8]xi16> %el = vector.extract %tile[%row, %col] : i16 from vector<[8]x[8]xi16> return %el : i16 } @@ -880,7 +867,7 @@ func.func @vector_extract_element_i16(%tile: vector<[8]x[8]xi16>, %row: index, % // CHECK-LABEL: @vector_extract_element_i64 func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, %col: index) -> i64 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xi64>, vector<[2]xi1>, i32, i32) -> vector<[2]xi64> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xi64> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i64 from vector<[2]xi64> %el = vector.extract %tile[%row, %col] : i64 from vector<[2]x[2]xi64> return %el : i64 } @@ -890,7 +877,7 @@ func.func @vector_extract_element_i64(%tile: vector<[2]x[2]xi64>, %row: index, % // CHECK-LABEL: @vector_extract_element_i128 func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, %col: index) -> i128 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[1]xi128>, vector<[1]xi1>, i32, i32) -> vector<[1]xi128> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[1]xi128> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : i128 from vector<[1]xi128> %el = vector.extract %tile[%row, %col] : i128 from vector<[1]x[1]xi128> return %el : i128 } @@ -900,7 +887,7 @@ func.func @vector_extract_element_i128(%tile: vector<[1]x[1]xi128>, %row: index, // CHECK-LABEL: @vector_extract_element_f16 func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, %col: index) -> f16 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xf16> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xf16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f16 from vector<[8]xf16> %el = vector.extract %tile[%row, %col] : f16 from vector<[8]x[8]xf16> return %el : f16 } @@ -910,7 +897,7 @@ func.func @vector_extract_element_f16(%tile: vector<[8]x[8]xf16>, %row: index, % // CHECK-LABEL: @vector_extract_element_bf16 func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, %col: index) -> bf16 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[8]xbf16>, vector<[8]xi1>, i32, i32) -> vector<[8]xbf16> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[8]xbf16> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : bf16 from vector<[8]xbf16> %el = vector.extract %tile[%row, %col] : bf16 from vector<[8]x[8]xbf16> return %el : bf16 } @@ -920,7 +907,7 @@ func.func @vector_extract_element_bf16(%tile: vector<[8]x[8]xbf16>, %row: index, // CHECK-LABEL: @vector_extract_element_f32 func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, %col: index) -> f32 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[4]xf32> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f32 from vector<[4]xf32> %el = vector.extract %tile[%row, %col] : f32 from vector<[4]x[4]xf32> return %el : f32 } @@ -930,7 +917,7 @@ func.func @vector_extract_element_f32(%tile: vector<[4]x[4]xf32>, %row: index, % // CHECK-LABEL: @vector_extract_element_f64 func.func @vector_extract_element_f64(%tile: vector<[2]x[2]xf64>, %row: index, %col: index) -> f64 { // CHECK: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (vector<[2]xf64>, vector<[2]xi1>, i32, i32) -> vector<[2]xf64> - // CHECK-NEXT: %{{.*}} = llvm.extractelement %[[SLICE]]{{\[}}%{{.*}} : i64] : vector<[2]xf64> + // CHECK-NEXT: %{{.*}} = vector.extract %[[SLICE]]{{\[}}%{{.*}}] : f64 from vector<[2]xf64> %el = vector.extract %tile[%row, %col] : f64 from vector<[2]x[2]xf64> return %el : f64 } diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir index efe4da7d3c50c..18b95cf2fdf84 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/fill-2d.mlir @@ -5,7 +5,7 @@ // RUN: -one-shot-bufferize="bufferize-function-boundaries" \ // RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \ // RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -e=entry -entry-point-result=void \ diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir index ab74f01004742..f189fd97d66cd 100644 --- a/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul-transpose-a.mlir @@ -4,7 +4,7 @@ // RUN: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ // RUN: -convert-vector-to-scf -cse -arm-sve-legalize-vector-storage \ -// RUN: -convert-vector-to-llvm=enable-arm-sme \ +// RUN: -convert-arm-sme-to-llvm \ // RUN: -convert-vector-to-llvm=enable-arm-sve \ // RUN: -cse -canonicalize -allocate-arm-sme-tiles -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir index 32e7e6b79ce09..59b4a7e6a52f9 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/load-store-128-bit-tile.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir index 44cf23f41b632..0c186cc373a3b 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-load-vertical.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir index f1ecf768ebe83..442a70cacd665 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm -o %t // DEFINE: %{run} = %mcr_aarch64_cmd %t \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir index 5c907bb1675e4..74b51dcc9b4df 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm -o %t // DEFINE: %{run} = %mcr_aarch64_cmd %t \ // DEFINE: -march=aarch64 -mattr=+sve,+sme-f64f64 \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir index ccc08289570af..82f38b4dbfa9d 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-read-2d.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir index f35f83dcec0da..3b218aefcd415 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transfer-write-2d.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir index 39b5ef2ade4b0..e2cbe735fa4ff 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-transpose.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir index baf2046722b9e..6e33a421bf799 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir @@ -1,6 +1,6 @@ // RUN: mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // RUN: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// RUN: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// RUN: -convert-arm-sme-to-llvm -cse -canonicalize \ // RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ // RUN: %mcr_aarch64_cmd \ // RUN: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir index 8878dca8bdcb6..961bb274d1e33 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-load-store.mlir @@ -2,7 +2,7 @@ // DEFINE: %{compile} = mlir-opt %s \ // DEFINE: -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize \ +// DEFINE: -convert-arm-sme-to-llvm -cse -canonicalize \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \ diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir index a890aaa6f309d..25ef1799e63ad 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -1,7 +1,7 @@ // DEFINE: %{entry_point} = entry // DEFINE: %{compile} = mlir-opt %s -enable-arm-streaming="streaming-mode=streaming-locally za-mode=new-za" \ // DEFINE: -convert-vector-to-arm-sme -convert-arm-sme-to-scf \ -// DEFINE: -convert-vector-to-llvm="enable-arm-sme" \ +// DEFINE: -convert-arm-sme-to-llvm \ // DEFINE: -allocate-arm-sme-tiles -test-lower-to-llvm // DEFINE: %{run} = %mcr_aarch64_cmd \ // DEFINE: -march=aarch64 -mattr=+sve,+sme \