diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index 3d2cb1dd7a032..7148027dae78d 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -558,6 +558,13 @@ void spirv::ConstantOp::print(OpAsmPrinter &printer) { static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType) { + if (isa(opType)) { + auto denseAttr = dyn_cast(value); + if (!denseAttr || !denseAttr.isSplat()) + return op.emitOpError("expected a splat dense attribute for cooperative " + "matrix constant, but found ") + << denseAttr; + } if (llvm::isa(value)) { auto valueType = llvm::cast(value).getType(); if (valueType != opType) diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 3957dbc0db984..c43d584d7b913 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1468,11 +1468,11 @@ spirv::Deserializer::processConstantComposite(ArrayRef operands) { } auto resultID = operands[1]; - if (auto vectorType = dyn_cast(resultType)) { - auto attr = DenseElementsAttr::get(vectorType, elements); + if (auto shapedType = dyn_cast(resultType)) { + auto attr = DenseElementsAttr::get(shapedType, elements); // For normal constants, we just record the attribute (and its type) for // later materialization at use sites. - constantMap.try_emplace(resultID, attr, resultType); + constantMap.try_emplace(resultID, attr, shapedType); } else if (auto arrayType = dyn_cast(resultType)) { auto attr = opBuilder.getArrayAttr(elements); constantMap.try_emplace(resultID, attr, resultType); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 15e06616f4492..647535809554c 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -845,18 +845,44 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, return 0; } + int64_t numberOfConstituents = shapedType.getDimSize(dim); uint32_t resultID = getNextID(); SmallVector operands = {typeID, resultID}; - operands.reserve(shapedType.getDimSize(dim) + 2); auto elementType = cast(constType).getElementType(0); - for (int i = 0; i < shapedType.getDimSize(dim); ++i) { - index[dim] = i; + + // "If the Result Type is a cooperative matrix type, then there must be only + // one Constituent, with scalar type matching the cooperative matrix Component + // Type, and all components of the matrix are initialized to that value." + // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html) + if (isa(constType)) { + if (!valueAttr.isSplat()) { + emitError( + loc, + "cannot serialize a non-splat value for a cooperative matrix type"); + return 0; + } + // numberOfConstituents is 1, so we only need one more elements in the + // SmallVector, so the total is 3 (1 + 2). + operands.reserve(3); + // We set dim directly to `shapedType.getRank()` so the recursive call + // directly returns the scalar type. if (auto elementID = prepareDenseElementsConstant( - loc, elementType, valueAttr, dim + 1, index)) { + loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) { operands.push_back(elementID); } else { return 0; } + } else { + operands.reserve(numberOfConstituents + 2); + for (int i = 0; i < numberOfConstituents; ++i) { + index[dim] = i; + if (auto elementID = prepareDenseElementsConstant( + loc, elementType, valueAttr, dim + 1, index)) { + operands.push_back(elementID); + } else { + return 0; + } + } } spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; encodeInstructionInto(typesGlobalValues, opcode, operands); diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir index 5e98b9fdb3c54..207549afdda94 100644 --- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir @@ -62,6 +62,10 @@ func.func @const() -> () { // CHECK: spirv.Constant dense<1.000000e+00> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> // CHECK: spirv.Constant dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> // CHECK: spirv.Constant dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> + // CHECK: spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + // CHECK: spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + // CHECK: spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + // CHECK: spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> %0 = spirv.Constant true %1 = spirv.Constant 42 : i32 @@ -73,6 +77,10 @@ func.func @const() -> () { %7 = spirv.Constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> : !spirv.array<2 x !spirv.array<3 x i32>> %8 = spirv.Constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32> : !spirv.array<2 x !spirv.array<3 x f32>> %9 = spirv.Constant [[dense<3.0> : vector<2xf32>]] : !spirv.array<1 x !spirv.array<1xvector<2xf32>>> + %10 = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + %11 = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + %12 = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + %13 = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> return } @@ -132,6 +140,31 @@ func.func @value_result_num_elements_mismatch() -> () { // ----- +func.func @coop_matrix_const_non_splat() -> () { + // expected-error @+1 {{expected a splat dense attribute for cooperative matrix constant, but found}} + %0 = spirv.Constant dense<[[1.0, 2.0], [3.0, 4.0]]> : !spirv.coopmatrix<2x2xf32, Subgroup, MatrixAcc> + return +} + +// ----- + +func.func @coop_matrix_const_non_dense() -> () { + // expected-error @+2 {{floating point value not valid for specified type}} + %0 = spirv.Constant 0.000000e+00 : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + return +} + +// ----- + +func.func @coop_matrix_const_wrong_type() -> () { + // expected-error @below {{unexpected decimal integer literal for a floating point value}} + // expected-note @+1 {{add a trailing dot to make the literal a float}} + %0 = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.EntryPoint //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/constant.mlir b/mlir/test/Target/SPIRV/constant.mlir index f3950214a7f05..8d4e53418b70f 100644 --- a/mlir/test/Target/SPIRV/constant.mlir +++ b/mlir/test/Target/SPIRV/constant.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip %s | FileCheck %s +// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip %s | FileCheck %s spirv.module Logical GLSL450 requires #spirv.vce { // CHECK-LABEL: @bool_const @@ -277,4 +277,32 @@ spirv.module Logical GLSL450 requires #spirv.vce { %signed_minus_one = spirv.Constant -1 : si16 spirv.ReturnValue %signed_minus_one : si16 } + + // CHECK-LABEL: @coop_matrix_const_zero_f32 + spirv.func @coop_matrix_const_zero_f32() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + %coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + } + + // CHECK-LABEL: @coop_matrix_const_non_zero_f32 + spirv.func @coop_matrix_const_non_zero_f32() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + %coop = spirv.Constant dense<4.200000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc> + } + + // CHECK-LABEL: @coop_matrix_const_zero_i8 + spirv.func @coop_matrix_const_zero_i8() -> (!spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + %coop = spirv.Constant dense<0> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + } + + // CHECK-LABEL: @coop_matrix_const_non_zero_i8 + spirv.func @coop_matrix_const_non_zero_i8() -> (!spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc>) "None" { + // CHECK: {{%.*}} = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + %coop = spirv.Constant dense<4> : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xi8, Subgroup, MatrixAcc> + } }