Skip to content

Commit 8ae14d8

Browse files
committed
[mlir][spirv] Deserialize OpConstantComposite of type Cooperative Matrix
1 parent a3c7d46 commit 8ae14d8

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,11 +1468,11 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
14681468
}
14691469

14701470
auto resultID = operands[1];
1471-
if (auto vectorType = dyn_cast<VectorType>(resultType)) {
1472-
auto attr = DenseElementsAttr::get(vectorType, elements);
1471+
if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
1472+
auto attr = DenseElementsAttr::get(shapedType, elements);
14731473
// For normal constants, we just record the attribute (and its type) for
14741474
// later materialization at use sites.
1475-
constantMap.try_emplace(resultID, attr, resultType);
1475+
constantMap.try_emplace(resultID, attr, shapedType);
14761476
} else if (auto arrayType = dyn_cast<spirv::ArrayType>(resultType)) {
14771477
auto attr = opBuilder.getArrayAttr(elements);
14781478
constantMap.try_emplace(resultID, attr, resultType);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -845,18 +845,38 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
845845
return 0;
846846
}
847847

848+
int64_t numberOfConstituents = shapedType.getDimSize(dim);
848849
uint32_t resultID = getNextID();
849850
SmallVector<uint32_t, 4> operands = {typeID, resultID};
850-
operands.reserve(shapedType.getDimSize(dim) + 2);
851851
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
852-
for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
853-
index[dim] = i;
852+
853+
// "If the Result Type is a cooperative matrix type, then there must be only
854+
// one Constituent, with scalar type matching the cooperative matrix Component
855+
// Type, and all components of the matrix are initialized to that value."
856+
// (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
857+
if (isa<spirv::CooperativeMatrixType>(constType)) {
858+
// numberOfConstituents is 1, so we only need one more elements in the
859+
// SmallVector, so the total is 3 (1 + 2).
860+
operands.reserve(3);
861+
// We set dim directly to `shapedType.getRank()` so the recursive call
862+
// directly returns the scalar type.
854863
if (auto elementID = prepareDenseElementsConstant(
855-
loc, elementType, valueAttr, dim + 1, index)) {
864+
loc, elementType, valueAttr, /*dim=*/shapedType.getRank(), index)) {
856865
operands.push_back(elementID);
857866
} else {
858867
return 0;
859868
}
869+
} else {
870+
operands.reserve(numberOfConstituents + 2);
871+
for (int i = 0; i < numberOfConstituents; ++i) {
872+
index[dim] = i;
873+
if (auto elementID = prepareDenseElementsConstant(
874+
loc, elementType, valueAttr, dim + 1, index)) {
875+
operands.push_back(elementID);
876+
} else {
877+
return 0;
878+
}
879+
}
860880
}
861881
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
862882
encodeInstructionInto(typesGlobalValues, opcode, operands);

mlir/test/Target/SPIRV/constant.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,4 +277,11 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
277277
%signed_minus_one = spirv.Constant -1 : si16
278278
spirv.ReturnValue %signed_minus_one : si16
279279
}
280+
281+
// CHECK-LABEL: @coop_matrix_const
282+
spirv.func @coop_matrix_const() -> (!spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>) "None" {
283+
// CHECK: {{%.*}} = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
284+
%coop = spirv.Constant dense<0.000000e+00> : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
285+
spirv.ReturnValue %coop : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixAcc>
286+
}
280287
}

0 commit comments

Comments
 (0)