@@ -240,9 +240,10 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
240240// / function emits multiple `vector.extract` and `vector.insert` ops, so only
241241// / use it when `offset` cannot be folded into a constant value.
242242static Value dynamicallyExtractSubVector (OpBuilder &rewriter, Location loc,
243- VectorValue source, Value dest,
243+ Value source, Value dest,
244244 OpFoldResult offset,
245245 int64_t numElementsToExtract) {
246+ assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
246247 for (int i = 0 ; i < numElementsToExtract; ++i) {
247248 Value extractLoc =
248249 (i == 0 ) ? offset.dyn_cast <Value>()
@@ -258,9 +259,10 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
258259
259260// / Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
260261static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
261- VectorValue source, Value dest,
262+ Value source, Value dest,
262263 OpFoldResult destOffsetVar,
263264 size_t length) {
265+ assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
264266 assert (length > 0 && " length must be greater than 0" );
265267 Value destOffsetVal =
266268 getValueOrCreateConstantIndexOp (rewriter, loc, destOffsetVar);
@@ -468,7 +470,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
468470
469471 auto memrefBase = cast<MemRefValue>(adaptor.getBase ());
470472
471- // Conditions when subbyte emulated store is not needed:
473+ // Conditions when atomic RMWs are not needed:
472474 // 1. The source vector size (in bits) is a multiple of byte size.
473475 // 2. The address of the store is aligned to the emulated width boundary.
474476 //
@@ -499,7 +501,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
499501 // Destination: memref<12xi2>
500502 // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
501503 //
502- // MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
504+ // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
503505 //
504506 // Destination memref before:
505507 //
@@ -817,9 +819,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
817819 if (!foldedIntraVectorOffset) {
818820 auto resultVector = rewriter.create <arith::ConstantOp>(
819821 loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
820- result = dynamicallyExtractSubVector (
821- rewriter, loc, cast<VectorValue>(result), resultVector ,
822- linearizedInfo. intraDataOffset , origElements);
822+ result = dynamicallyExtractSubVector (rewriter, loc, result, resultVector,
823+ linearizedInfo. intraDataOffset ,
824+ origElements);
823825 } else if (!isAlignedEmulation) {
824826 result = staticallyExtractSubvector (
825827 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
@@ -938,8 +940,8 @@ struct ConvertVectorMaskedLoad final
938940 loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
939941 if (!foldedIntraVectorOffset) {
940942 passthru = dynamicallyInsertSubVector (
941- rewriter, loc, cast<VectorValue>( passthru) , emptyVector,
942- linearizedInfo. intraDataOffset , origElements);
943+ rewriter, loc, passthru, emptyVector, linearizedInfo. intraDataOffset ,
944+ origElements);
943945 } else if (!isAlignedEmulation) {
944946 passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
945947 *foldedIntraVectorOffset);
@@ -965,9 +967,9 @@ struct ConvertVectorMaskedLoad final
965967 auto emptyMask = rewriter.create <arith::ConstantOp>(
966968 loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
967969 if (!foldedIntraVectorOffset) {
968- mask = dynamicallyInsertSubVector (
969- rewriter, loc, cast<VectorValue>(mask), emptyMask ,
970- linearizedInfo. intraDataOffset , origElements);
970+ mask = dynamicallyInsertSubVector (rewriter, loc, mask, emptyMask,
971+ linearizedInfo. intraDataOffset ,
972+ origElements);
971973 } else if (!isAlignedEmulation) {
972974 mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
973975 *foldedIntraVectorOffset);
@@ -977,7 +979,7 @@ struct ConvertVectorMaskedLoad final
977979 rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
978980 if (!foldedIntraVectorOffset) {
979981 result = dynamicallyExtractSubVector (
980- rewriter, loc, cast<VectorValue>( result) , op.getPassThru (),
982+ rewriter, loc, result, op.getPassThru (),
981983 linearizedInfo.intraDataOffset , origElements);
982984 } else if (!isAlignedEmulation) {
983985 result = staticallyExtractSubvector (
0 commit comments