@@ -845,18 +845,38 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
845
845
return 0 ;
846
846
}
847
847
848
+ int64_t numberOfConstituents = shapedType.getDimSize (dim);
848
849
uint32_t resultID = getNextID ();
849
850
SmallVector<uint32_t , 4 > operands = {typeID, resultID};
850
- operands.reserve (shapedType.getDimSize (dim) + 2 );
851
851
auto elementType = cast<spirv::CompositeType>(constType).getElementType (0 );
852
- for (int i = 0 ; i < shapedType.getDimSize (dim); ++i) {
853
- index[dim] = i;
852
+
853
+ // "If the Result Type is a cooperative matrix type, then there must be only
854
+ // one Constituent, with scalar type matching the cooperative matrix Component
855
+ // Type, and all components of the matrix are initialized to that value."
856
+ // (https://github.khronos.org/SPIRV-Registry/extensions/KHR/SPV_KHR_cooperative_matrix.html)
857
+ if (isa<spirv::CooperativeMatrixType>(constType)) {
858
+ // numberOfConstituents is 1, so we only need one more elements in the
859
+ // SmallVector, so the total is 3 (1 + 2).
860
+ operands.reserve (3 );
861
+ // We set dim directly to `shapedType.getRank()` so the recursive call
862
+ // directly returns the scalar type.
854
863
if (auto elementID = prepareDenseElementsConstant (
855
- loc, elementType, valueAttr, dim + 1 , index)) {
864
+ loc, elementType, valueAttr, /* dim= */ shapedType. getRank () , index)) {
856
865
operands.push_back (elementID);
857
866
} else {
858
867
return 0 ;
859
868
}
869
+ } else {
870
+ operands.reserve (numberOfConstituents + 2 );
871
+ for (int i = 0 ; i < numberOfConstituents; ++i) {
872
+ index[dim] = i;
873
+ if (auto elementID = prepareDenseElementsConstant (
874
+ loc, elementType, valueAttr, dim + 1 , index)) {
875
+ operands.push_back (elementID);
876
+ } else {
877
+ return 0 ;
878
+ }
879
+ }
860
880
}
861
881
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
862
882
encodeInstructionInto (typesGlobalValues, opcode, operands);
0 commit comments