diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td index 22d5afcd77381..309079e549846 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -23,7 +23,7 @@ class SPIRV_ArithmeticBinaryOp { + [Pure, AllTypesMatch<["operand1", "operand2", "result"]>])> { // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins @@ -42,7 +42,7 @@ class SPIRV_ArithmeticUnaryOp { + [Pure, AllTypesMatch<["operand", "result"]>])> { // In addition to normal types arithmetic instructions can support cooperative // matrix. let arguments = (ins diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h index 2e29e9afaabf4..787535d0a6bd2 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h @@ -394,7 +394,8 @@ hash_value(const StructType::MemberDecorationInfo &memberDecorationInfo); // SPIR-V KHR cooperative matrix type class CooperativeMatrixType : public Type::TypeBase { + detail::CooperativeMatrixTypeStorage, + ShapedType::Trait> { public: using Base::Base; @@ -418,6 +419,22 @@ class CooperativeMatrixType std::optional storage = std::nullopt); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, std::optional storage = std::nullopt); + + operator ShapedType() const { return llvm::cast(*this); } + + ArrayRef getShape() const; + + bool hasRank() const { return true; } + + CooperativeMatrixType cloneWith(std::optional> shape, + Type elementType) const { + if (!shape) + return get(elementType, getRows(), getColumns(), getScope(), getUse()); + + assert(shape.value().size() == 2); + return get(elementType, shape.value()[0], shape.value()[1], getScope(), + getUse()); + } }; // SPIR-V matrix type diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 337df3a5a65f0..1aff43c301334 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -194,8 +194,21 @@ std::optional CompositeType::getSizeInBytes() { //===----------------------------------------------------------------------===// struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { + // In the specification dimensions of the Cooperative Matrix are 32-bit + // integers --- the initial implementation kept those values as such. However, + // the `ShapedType` expects the shape to be `int64_t`. We could keep the shape + // as 32-bits and expose it as int64_t through `getShape`, however, this + // method returns an `ArrayRef`, so returning `ArrayRef` having two + // 32-bits integers would require an extra logic and storage. So, we diverge + // from the spec and internally represent the dimensions as 64-bit integers, + // so we can easily return an `ArrayRef` from `getShape` without any extra + // logic. Alternatively, we could store both rows and columns (both 32-bits) + // and shape (64-bits), assigning rows and columns to shape whenever + // `getShape` is called. This would be at the cost of extra logic and storage. + // Note: Because `ArrayRef` is returned we cannot construct an object in + // `getShape` on the fly. using KeyTy = - std::tuple; + std::tuple; static CooperativeMatrixTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key) { @@ -204,17 +217,17 @@ struct spirv::detail::CooperativeMatrixTypeStorage final : TypeStorage { } bool operator==(const KeyTy &key) const { - return key == KeyTy(elementType, rows, columns, scope, use); + return key == KeyTy(elementType, shape[0], shape[1], scope, use); } CooperativeMatrixTypeStorage(const KeyTy &key) - : elementType(std::get<0>(key)), rows(std::get<1>(key)), - columns(std::get<2>(key)), scope(std::get<3>(key)), + : elementType(std::get<0>(key)), + shape({std::get<1>(key), std::get<2>(key)}), scope(std::get<3>(key)), use(std::get<4>(key)) {} Type elementType; - uint32_t rows; - uint32_t columns; + // [#rows, #columns] + std::array shape; Scope scope; CooperativeMatrixUseKHR use; }; @@ -231,10 +244,18 @@ Type CooperativeMatrixType::getElementType() const { return getImpl()->elementType; } -uint32_t CooperativeMatrixType::getRows() const { return getImpl()->rows; } +uint32_t CooperativeMatrixType::getRows() const { + assert(getImpl()->shape[0] != ShapedType::kDynamic); + return static_cast(getImpl()->shape[0]); +} uint32_t CooperativeMatrixType::getColumns() const { - return getImpl()->columns; + assert(getImpl()->shape[1] != ShapedType::kDynamic); + return static_cast(getImpl()->shape[1]); +} + +ArrayRef CooperativeMatrixType::getShape() const { + return getImpl()->shape; } Scope CooperativeMatrixType::getScope() const { return getImpl()->scope; } diff --git a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir index d3e1dbc229ef9..8733ff93768ab 100644 --- a/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/khr-cooperative-matrix-ops.mlir @@ -524,7 +524,7 @@ spirv.func @matrix_times_scalar(%a: !matA_f32, %b: f32) "None" { spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) "None" { - // expected-error @+1 {{op requires the same type for all operands and results}} + // expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}} %q = "spirv.IAdd"(%a, %b) : (!spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xi32, Subgroup, MatrixB>) -> !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA> @@ -535,7 +535,7 @@ spirv.func @iadd(%a: !spirv.coopmatrix<2x2xi32, Subgroup, MatrixA>, spirv.func @fadd(%a: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, %b: !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) "None" { - // expected-error @+1 {{op requires the same type for all operands and results}} + // expected-error @+1 {{failed to verify that all of {operand1, operand2, result} have same type}} %q = "spirv.FAdd"(%a, %b) : (!spirv.coopmatrix<2x2xf32, Subgroup, MatrixA>, !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>) -> !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc>