@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7676 int numSrcElemsPerDest,
7777 int numFrontPadElems = 0 ) {
7878
79- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
79+ assert (numFrontPadElems < numSrcElemsPerDest &&
80+ " intraDataOffset must be less than scale" );
8081
8182 auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
8283 numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
256257 newLoad);
257258}
258259
259- static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
260- Value memref, Value index, Value value) {
261- auto originType = dyn_cast<VectorType>(value.getType ());
262- auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
263- auto scale = memrefElemType.getIntOrFloatBitWidth () /
264- originType.getElementType ().getIntOrFloatBitWidth ();
265- auto storeType =
266- VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
267- auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
268- rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
269- }
270-
271- // / atomically store a subbyte-sized value to memory, with a mask.
272- static Value atomicStore (OpBuilder &rewriter, Location loc,
273- Value emulatedMemref, Value emulatedIndex,
274- TypedValue<VectorType> value, Value mask,
275- int64_t scale) {
260+ // / Atomically store a subbyte-sized value to memory, with a mask.
261+ static void atomicStore (OpBuilder &rewriter, Location loc,
262+ TypedValue<MemRefType> emulatedMemref,
263+ Value emulatedIndex, TypedValue<VectorType> value,
264+ Value mask, int64_t scale) {
276265 auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
277266 loc, emulatedMemref, ValueRange{emulatedIndex});
278267 OpBuilder builder =
@@ -291,9 +280,31 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
291280 auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
292281 auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
293282 builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
294- return atomicOp;
295283}
296284
285+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
286+ static void rmwStore (OpBuilder &rewriter, Location loc,
287+ TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
288+ TypedValue<VectorType> value, Value mask,
289+ int64_t numSrcElemsPerDest) {
290+ auto emulatedIOType =
291+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
292+ auto elemLoad = rewriter.create <vector::LoadOp>(
293+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
294+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
295+ loc,
296+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
297+ elemLoad);
298+ auto select = rewriter.create <arith::SelectOp>(loc, mask, fromBitcast, value);
299+ auto toBitcast =
300+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
301+ rewriter.create <vector::StoreOp>(loc, toBitcast, emulatedMemref,
302+ emulatedIndex);
303+ }
304+
305+ static_assert (std::is_same_v<decltype (atomicStore), decltype (rmwStore)> &&
306+ " `atomicStore` and `rmwStore` must have same function type." );
307+
297308// Extract a slice of a vector, and insert it into a byte vector.
298309static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
299310 Location loc, TypedValue<VectorType> vector,
@@ -322,6 +333,10 @@ namespace {
322333struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
323334 using OpConversionPattern::OpConversionPattern;
324335
336+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
337+ : OpConversionPattern<vector::StoreOp>(context),
338+ useAtomicWrites_ (useAtomicWrites) {}
339+
325340 LogicalResult
326341 matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
327342 ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +358,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
343358 return rewriter.notifyMatchFailure (
344359 op, " only dstBits % srcBits == 0 supported" );
345360 }
346- int scale = dstBits / srcBits;
361+ int numSrcElemsPerDest = dstBits / srcBits;
347362
348363 // Adjust the number of elements to store when emulating narrow types.
349364 // Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +374,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
359374 // vector<4xi8>
360375
361376 auto origElements = valueToStore.getType ().getNumElements ();
362- bool isUnalignedEmulation = origElements % scale != 0 ;
377+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
363378
364379 auto stridedMetadata =
365380 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -374,62 +389,68 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
374389 stridedMetadata.getConstifiedMixedStrides (),
375390 getAsOpFoldResult (adaptor.getIndices ()));
376391
377- auto foldedIntraVectorOffset =
392+ auto foldedNumFrontPadElems =
378393 isUnalignedEmulation
379394 ? getConstantIntValue (linearizedInfo.intraDataOffset )
380395 : 0 ;
381396
382- if (!foldedIntraVectorOffset ) {
383- // unimplemented case for dynamic front padding size
397+ if (!foldedNumFrontPadElems ) {
398+ // Unimplemented case for dynamic front padding size != 0
384399 return failure ();
385400 }
386401
387- // conditions when atomic stores and all that are not needed:
402+ TypedValue<MemRefType> emulatedMemref =
403+ cast<TypedValue<MemRefType>>(adaptor.getBase ());
404+
405+ // Shortcut: conditions when subbyte store at the front is not needed:
388406 // 1. The source vector size is multiple of byte size
389- // 2. The address of the store is byte aligned
390- if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0 ) {
391- auto numElements = origElements / scale ;
407+ // 2. The address of the store is aligned to the emulated width boundary
408+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
409+ auto numElements = origElements / numSrcElemsPerDest ;
392410 auto bitCast = rewriter.create <vector::BitCastOp>(
393411 loc, VectorType::get (numElements, newElementType),
394412 op.getValueToStore ());
395413 rewriter.replaceOpWithNewOp <vector::StoreOp>(
396- op, bitCast.getResult (), adaptor. getBase () ,
414+ op, bitCast.getResult (), emulatedMemref ,
397415 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
398416 return llvm::success ();
399417 }
400418
401- Value emulatedMemref = adaptor.getBase ();
402- // the index into the target memref we are storing to
419+ // The index into the target memref we are storing to
403420 Value currentDestIndex =
404421 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
405422 auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
406- auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
407- // the index into the source vector we are currently processing
423+ auto subWidthStoreMaskType =
424+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
425+ // The index into the source vector we are currently processing
408426 auto currentSourceIndex = 0 ;
409427
410- // 1. atomic store for the first byte
411- auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
412- if (frontAtomicStoreElem != 0 ) {
413- auto frontMaskValues = llvm::SmallVector<bool >(scale, false );
414- if (*foldedIntraVectorOffset + origElements < scale) {
415- std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset,
428+ // 1. Partial width store for the first byte, when the store address is not
429+ // aligned to emulated width boundary, deal with the unaligned part so that
430+ // the rest elements are aligned to width boundary.
431+ auto frontSubWidthStoreElem =
432+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
433+ if (frontSubWidthStoreElem != 0 ) {
434+ auto frontMaskValues = llvm::SmallVector<bool >(numSrcElemsPerDest, false );
435+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
436+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
416437 origElements, true );
417- frontAtomicStoreElem = origElements;
438+ frontSubWidthStoreElem = origElements;
418439 } else {
419- std::fill_n (frontMaskValues.end () - frontAtomicStoreElem ,
420- *foldedIntraVectorOffset , true );
440+ std::fill_n (frontMaskValues.end () - frontSubWidthStoreElem ,
441+ *foldedNumFrontPadElems , true );
421442 }
422443 auto frontMask = rewriter.create <arith::ConstantOp>(
423- loc, DenseElementsAttr::get (atomicMaskType , frontMaskValues));
444+ loc, DenseElementsAttr::get (subWidthStoreMaskType , frontMaskValues));
424445
425- currentSourceIndex = scale - (*foldedIntraVectorOffset );
446+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems );
426447 auto value = extractSliceIntoByte (
427448 rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
428- frontAtomicStoreElem , *foldedIntraVectorOffset );
449+ frontSubWidthStoreElem , *foldedNumFrontPadElems );
429450
430- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
431- cast<TypedValue<VectorType>>(value), frontMask. getResult ( ),
432- scale );
451+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
452+ cast<TypedValue<VectorType>>(value),
453+ frontMask. getResult (), numSrcElemsPerDest );
433454
434455 currentDestIndex = rewriter.create <arith::AddIOp>(
435456 loc, rewriter.getIndexType (), currentDestIndex, constantOne);
@@ -440,44 +461,66 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
440461 return success ();
441462 }
442463
443- // 2. non-atomic store
444- int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445- int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
446- if (nonAtomicStoreSize != 0 ) {
447- auto nonAtomicStorePart = staticallyExtractSubvector (
464+ // 2. Full width store. After the previous step, the store address is
465+ // aligned to the emulated width boundary.
466+ int64_t fullWidthStoreSize =
467+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
468+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
469+ if (fullWidthStoreSize != 0 ) {
470+ auto fullWidthStorePart = staticallyExtractSubvector (
448471 rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449- currentSourceIndex, numNonAtomicElements);
450-
451- nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
452- nonAtomicStorePart);
453-
454- currentSourceIndex += numNonAtomicElements;
472+ currentSourceIndex, numNonFullWidthElements);
473+
474+ auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType ());
475+ auto memrefElemType =
476+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
477+ auto storeType = VectorType::get (
478+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
479+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
480+ fullWidthStorePart);
481+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
482+ currentDestIndex);
483+
484+ currentSourceIndex += numNonFullWidthElements;
455485 currentDestIndex = rewriter.create <arith::AddIOp>(
456486 loc, rewriter.getIndexType (), currentDestIndex,
457- rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize ));
487+ rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize ));
458488 }
459489
460- // 3. atomic store for the last byte
490+ // 3. Deal with trailing elements that are aligned to the emulated width,
491+ // but their length is smaller than the emulated width.
461492 auto remainingElements = origElements - currentSourceIndex;
462493 if (remainingElements != 0 ) {
463- auto atomicStorePart = extractSliceIntoByte (
494+ auto subWidthStorePart = extractSliceIntoByte (
464495 rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465496 currentSourceIndex, remainingElements, 0 );
466497
467- // back mask
468- auto maskValues = llvm::SmallVector<bool >(scale , 0 );
498+ // Generate back mask
499+ auto maskValues = llvm::SmallVector<bool >(numSrcElemsPerDest , 0 );
469500 std::fill_n (maskValues.begin (), remainingElements, 1 );
470501 auto backMask = rewriter.create <arith::ConstantOp>(
471- loc, DenseElementsAttr::get (atomicMaskType , maskValues));
502+ loc, DenseElementsAttr::get (subWidthStoreMaskType , maskValues));
472503
473- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
474- cast<TypedValue<VectorType>>(atomicStorePart ),
475- backMask.getResult (), scale );
504+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
505+ cast<TypedValue<VectorType>>(subWidthStorePart ),
506+ backMask.getResult (), numSrcElemsPerDest );
476507 }
477508
478509 rewriter.eraseOp (op);
479510 return success ();
480511 }
512+
513+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
514+ // / configuration, it could be an atomic store or an RMW sequence.
515+ template <typename ... Args>
516+ void subEmulatedWidthStore (Args &&...args) const {
517+ std::function<decltype (atomicStore)> storeFunc =
518+ useAtomicWrites_ ? atomicStore : rmwStore;
519+ storeFunc (std::forward<Args>(args)...);
520+ }
521+
522+ private:
523+ const bool useAtomicWrites_;
481524};
482525
483526// ===----------------------------------------------------------------------===//
@@ -1673,12 +1716,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
16731716
16741717void vector::populateVectorNarrowTypeEmulationPatterns (
16751718 const arith::NarrowTypeEmulationConverter &typeConverter,
1676- RewritePatternSet &patterns) {
1719+ RewritePatternSet &patterns, bool useAtomicWrites ) {
16771720
1678- // Populate `vector.*` conversion patterns.
1679- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1721+ // Populate `vector.*` load conversion patterns.
1722+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
16801723 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
16811724 typeConverter, patterns.getContext ());
1725+
1726+ // Populate `vector.*` store conversion patterns. The caller can choose
1727+ // to avoid emitting atomic operations and reduce it to load-modify-write
1728+ // sequence for stores if it is known there are no thread contentions.
1729+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
16821730}
16831731
16841732void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments