@@ -76,7 +76,8 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
7676 int numSrcElemsPerDest,
7777 int numFrontPadElems = 0 ) {
7878
79- assert (numFrontPadElems < numSrcElemsPerDest && " intraDataOffset must be less than scale" );
79+ assert (numFrontPadElems < numSrcElemsPerDest &&
80+ " intraDataOffset must be less than scale" );
8081
8182 auto numElements = (numFrontPadElems + numSrcElems + numSrcElemsPerDest - 1 ) /
8283 numSrcElemsPerDest;
@@ -256,23 +257,11 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
256257 newLoad);
257258}
258259
259- static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
260- Value memref, Value index, Value value) {
261- auto originType = dyn_cast<VectorType>(value.getType ());
262- auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
263- auto scale = memrefElemType.getIntOrFloatBitWidth () /
264- originType.getElementType ().getIntOrFloatBitWidth ();
265- auto storeType =
266- VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
267- auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
268- rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
269- }
270-
271- // / atomically store a subbyte-sized value to memory, with a mask.
260+ // / Atomically store a subbyte-sized value to memory, with a mask.
272261static Value atomicStore (OpBuilder &rewriter, Location loc,
273- Value emulatedMemref, Value emulatedIndex ,
274- TypedValue<VectorType> value, Value mask ,
275- int64_t scale) {
262+ TypedValue<MemRefType> emulatedMemref,
263+ Value emulatedIndex, TypedValue<VectorType> value,
264+ Value mask, int64_t scale) {
276265 auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
277266 loc, emulatedMemref, ValueRange{emulatedIndex});
278267 OpBuilder builder =
@@ -294,6 +283,27 @@ static Value atomicStore(OpBuilder &rewriter, Location loc,
294283 return atomicOp;
295284}
296285
286+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
287+ static Value rmwStore (OpBuilder &rewriter, Location loc,
288+ TypedValue<MemRefType> emulatedMemref,
289+ Value emulatedIndex, TypedValue<VectorType> value,
290+ Value mask, int64_t numSrcElemsPerDest) {
291+ auto emulatedIOType =
292+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
293+ auto elemLoad = rewriter.create <vector::LoadOp>(
294+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
295+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
296+ loc,
297+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
298+ elemLoad);
299+ auto select = rewriter.create <arith::SelectOp>(loc, mask, value, fromBitcast);
300+ auto toBitcast =
301+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
302+ return rewriter
303+ .create <vector::StoreOp>(loc, toBitcast, emulatedMemref, emulatedIndex)
304+ ->getResult (0 );
305+ }
306+
297307// Extract a slice of a vector, and insert it into a byte vector.
298308static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
299309 Location loc, TypedValue<VectorType> vector,
@@ -322,6 +332,10 @@ namespace {
322332struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
323333 using OpConversionPattern::OpConversionPattern;
324334
335+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
336+ : OpConversionPattern<vector::StoreOp>(context),
337+ useAtomicWrites_ (useAtomicWrites) {}
338+
325339 LogicalResult
326340 matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
327341 ConversionPatternRewriter &rewriter) const override {
@@ -343,7 +357,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
343357 return rewriter.notifyMatchFailure (
344358 op, " only dstBits % srcBits == 0 supported" );
345359 }
346- int scale = dstBits / srcBits;
360+ int numSrcElemsPerDest = dstBits / srcBits;
347361
348362 // Adjust the number of elements to store when emulating narrow types.
349363 // Here only the 1-D vector store is considered, and the N-D memref types
@@ -359,7 +373,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
359373 // vector<4xi8>
360374
361375 auto origElements = valueToStore.getType ().getNumElements ();
362- bool isUnalignedEmulation = origElements % scale != 0 ;
376+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
363377
364378 auto stridedMetadata =
365379 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -374,21 +388,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
374388 stridedMetadata.getConstifiedMixedStrides (),
375389 getAsOpFoldResult (adaptor.getIndices ()));
376390
377- auto foldedIntraVectorOffset =
391+ auto foldedNumFrontPadElems =
378392 isUnalignedEmulation
379393 ? getConstantIntValue (linearizedInfo.intraDataOffset )
380394 : 0 ;
381395
382- if (!foldedIntraVectorOffset ) {
383- // unimplemented case for dynamic front padding size
396+ if (!foldedNumFrontPadElems ) {
397+ // Unimplemented case for dynamic front padding size != 0
384398 return failure ();
385399 }
386400
387- // conditions when atomic stores and all that are not needed:
401+ // Conditions when atomic stores and all that are not needed:
388402 // 1. The source vector size is multiple of byte size
389403 // 2. The address of the store is byte aligned
390- if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0 ) {
391- auto numElements = origElements / scale ;
404+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
405+ auto numElements = origElements / numSrcElemsPerDest ;
392406 auto bitCast = rewriter.create <vector::BitCastOp>(
393407 loc, VectorType::get (numElements, newElementType),
394408 op.getValueToStore ());
@@ -398,38 +412,41 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
398412 return llvm::success ();
399413 }
400414
401- Value emulatedMemref = adaptor.getBase ();
402- // the index into the target memref we are storing to
415+ TypedValue<MemRefType> emulatedMemref =
416+ cast<TypedValue<MemRefType>>(adaptor.getBase ());
417+ // The index into the target memref we are storing to
403418 Value currentDestIndex =
404419 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
405420 auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
406- auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
407- // the index into the source vector we are currently processing
421+ auto atomicMaskType =
422+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
423+ // The index into the source vector we are currently processing
408424 auto currentSourceIndex = 0 ;
409425
410- // 1. atomic store for the first byte
411- auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
426+ // 1. Atomic store for the first byte
427+ auto frontAtomicStoreElem =
428+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
412429 if (frontAtomicStoreElem != 0 ) {
413- auto frontMaskValues = llvm::SmallVector<bool >(scale , false );
414- if (*foldedIntraVectorOffset + origElements < scale ) {
415- std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset ,
430+ auto frontMaskValues = llvm::SmallVector<bool >(numSrcElemsPerDest , false );
431+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest ) {
432+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems ,
416433 origElements, true );
417434 frontAtomicStoreElem = origElements;
418435 } else {
419436 std::fill_n (frontMaskValues.end () - frontAtomicStoreElem,
420- *foldedIntraVectorOffset , true );
437+ *foldedNumFrontPadElems , true );
421438 }
422439 auto frontMask = rewriter.create <arith::ConstantOp>(
423440 loc, DenseElementsAttr::get (atomicMaskType, frontMaskValues));
424441
425- currentSourceIndex = scale - (*foldedIntraVectorOffset );
442+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems );
426443 auto value = extractSliceIntoByte (
427444 rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
428- frontAtomicStoreElem, *foldedIntraVectorOffset );
445+ frontAtomicStoreElem, *foldedNumFrontPadElems );
429446
430- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
431- cast<TypedValue<VectorType>>(value), frontMask.getResult (),
432- scale );
447+ subByteStore (rewriter, loc, emulatedMemref, currentDestIndex,
448+ cast<TypedValue<VectorType>>(value), frontMask.getResult (),
449+ numSrcElemsPerDest );
433450
434451 currentDestIndex = rewriter.create <arith::AddIOp>(
435452 loc, rewriter.getIndexType (), currentDestIndex, constantOne);
@@ -440,44 +457,62 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
440457 return success ();
441458 }
442459
443- // 2. non-atomic store
444- int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
445- int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
460+ // 2. Non-atomic store
461+ int64_t nonAtomicStoreSize =
462+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
463+ int64_t numNonAtomicElements = nonAtomicStoreSize * numSrcElemsPerDest;
446464 if (nonAtomicStoreSize != 0 ) {
447465 auto nonAtomicStorePart = staticallyExtractSubvector (
448466 rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
449467 currentSourceIndex, numNonAtomicElements);
450468
451- nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
452- nonAtomicStorePart);
469+ auto originType = dyn_cast<VectorType>(nonAtomicStorePart.getType ());
470+ auto memrefElemType =
471+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
472+ auto storeType = VectorType::get (
473+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
474+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
475+ nonAtomicStorePart);
476+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
477+ currentDestIndex);
453478
454479 currentSourceIndex += numNonAtomicElements;
455480 currentDestIndex = rewriter.create <arith::AddIOp>(
456481 loc, rewriter.getIndexType (), currentDestIndex,
457482 rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
458483 }
459484
460- // 3. atomic store for the last byte
485+ // 3. Atomic store for the last byte
461486 auto remainingElements = origElements - currentSourceIndex;
462487 if (remainingElements != 0 ) {
463488 auto atomicStorePart = extractSliceIntoByte (
464489 rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
465490 currentSourceIndex, remainingElements, 0 );
466491
467- // back mask
468- auto maskValues = llvm::SmallVector<bool >(scale , 0 );
492+ // Generate back mask
493+ auto maskValues = llvm::SmallVector<bool >(numSrcElemsPerDest , 0 );
469494 std::fill_n (maskValues.begin (), remainingElements, 1 );
470495 auto backMask = rewriter.create <arith::ConstantOp>(
471496 loc, DenseElementsAttr::get (atomicMaskType, maskValues));
472497
473- atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
474- cast<TypedValue<VectorType>>(atomicStorePart),
475- backMask.getResult (), scale );
498+ subByteStore (rewriter, loc, emulatedMemref, currentDestIndex,
499+ cast<TypedValue<VectorType>>(atomicStorePart),
500+ backMask.getResult (), numSrcElemsPerDest );
476501 }
477502
478503 rewriter.eraseOp (op);
479504 return success ();
480505 }
506+
507+ template <typename ... Args>
508+ Value subByteStore (Args &&...args) const {
509+ std::function<decltype (atomicStore)> storeFunc =
510+ useAtomicWrites_ ? atomicStore : rmwStore;
511+ return storeFunc (std::forward<Args>(args)...);
512+ }
513+
514+ private:
515+ const bool useAtomicWrites_;
481516};
482517
483518// ===----------------------------------------------------------------------===//
@@ -1673,12 +1708,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
16731708
16741709void vector::populateVectorNarrowTypeEmulationPatterns (
16751710 const arith::NarrowTypeEmulationConverter &typeConverter,
1676- RewritePatternSet &patterns) {
1711+ RewritePatternSet &patterns, bool useAtomicWrites ) {
16771712
1678- // Populate `vector.*` conversion patterns.
1679- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1713+ // Populate `vector.*` load conversion patterns.
1714+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
16801715 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
16811716 typeConverter, patterns.getContext ());
1717+
1718+ // Populate `vector.*` store conversion patterns. The caller can choose
1719+ // to avoid emitting atomic operations and reduce it to load-modify-write
1720+ // sequence for stores if it is known there are no thread contentions.
1721+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
16821722}
16831723
16841724void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments