3333#include " mlir/Transforms/DialectConversion.h"
3434#include " llvm/ADT/SmallVector.h"
3535#include " llvm/Support/Debug.h"
36+ #include " llvm/Support/LogicalResult.h"
3637#include " llvm/Support/MathExtras.h"
3738#include " llvm/Support/raw_ostream.h"
3839#include < cstdint>
@@ -211,13 +212,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
211212// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
212213// / emitting `vector.extract_strided_slice`.
213214static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
214- VectorType extractType, Value source,
215- int64_t frontOffset,
215+ Value source, int64_t frontOffset,
216216 int64_t subvecSize) {
217217 auto vectorType = cast<VectorType>(source.getType ());
218- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
219- " expected 1-D source and destination types" );
220- (void )vectorType;
218+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
221219 assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
222220 " subvector out of bounds" );
223221
@@ -228,9 +226,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
228226 auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
229227 auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
230228 auto strides = rewriter.getI64ArrayAttr ({1 });
229+
230+ auto resultVectorType =
231+ VectorType::get ({subvecSize}, vectorType.getElementType ());
231232 return rewriter
232- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
233- sizes, strides)
233+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
234+ offsets, sizes, strides)
234235 ->getResult (0 );
235236}
236237
@@ -309,6 +310,76 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
309310 newLoad);
310311}
311312
313+ // / Atomically store a subbyte-sized value to memory, with a mask.
314+ static void atomicStore (OpBuilder &builder, Location loc,
315+ TypedValue<MemRefType> emulatedMemref,
316+ Value emulatedIndex, TypedValue<VectorType> value,
317+ Value mask, int64_t ) {
318+ auto atomicOp = builder.create <memref::GenericAtomicRMWOp>(
319+ loc, emulatedMemref, ValueRange{emulatedIndex});
320+ Value origValue = atomicOp.getCurrentValue ();
321+
322+ OpBuilder::InsertionGuard guard (builder);
323+ builder.setInsertionPointToStart (atomicOp.getBody ());
324+
325+ // i8 -> <1xi8> -> <numSrcElemsPerDest x i.>
326+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
327+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
328+ ValueRange{origValue});
329+ auto vectorBitCast =
330+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
331+
332+ auto select =
333+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
334+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
335+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
336+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
337+ }
338+
339+ // / Generate a non-atomic read-modify-write sequence for subbyte storing.
340+ static void rmwStore (OpBuilder &rewriter, Location loc,
341+ TypedValue<MemRefType> emulatedMemref, Value emulatedIndex,
342+ TypedValue<VectorType> value, Value mask,
343+ int64_t numSrcElemsPerDest) {
344+ auto emulatedIOType =
345+ VectorType::get ({1 }, emulatedMemref.getType ().getElementType ());
346+ auto elemLoad = rewriter.create <vector::LoadOp>(
347+ loc, emulatedIOType, emulatedMemref, ValueRange{emulatedIndex});
348+ auto fromBitcast = rewriter.create <vector::BitCastOp>(
349+ loc,
350+ VectorType::get ({numSrcElemsPerDest}, value.getType ().getElementType ()),
351+ elemLoad);
352+ auto select = rewriter.create <arith::SelectOp>(loc, mask, fromBitcast, value);
353+ auto toBitcast =
354+ rewriter.create <vector::BitCastOp>(loc, emulatedIOType, select);
355+ rewriter.create <vector::StoreOp>(loc, toBitcast, emulatedMemref,
356+ emulatedIndex);
357+ }
358+
359+ static_assert (std::is_same_v<decltype (atomicStore), decltype (rmwStore)> &&
360+ " `atomicStore` and `rmwStore` must have same signature, as per "
361+ " the design to keep the code clean, which one to call is "
362+ " determined by the `useAtomicWrites` flag." );
363+
364+ // Extract a slice of a vector, and insert it into a byte vector.
365+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
366+ Location loc, TypedValue<VectorType> vector,
367+ int64_t sliceOffset, int64_t sliceNumElements,
368+ int64_t byteOffset) {
369+ auto vectorElementType = vector.getType ().getElementType ();
370+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
371+ " vector element must be a valid sub-byte type" );
372+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
373+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
374+ loc, VectorType::get ({scale}, vectorElementType),
375+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
376+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
377+ sliceOffset, sliceNumElements);
378+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
379+ emptyByteVector, byteOffset);
380+ return inserted;
381+ }
382+
312383namespace {
313384
314385// ===----------------------------------------------------------------------===//
@@ -318,6 +389,10 @@ namespace {
318389struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
319390 using OpConversionPattern::OpConversionPattern;
320391
392+ ConvertVectorStore (MLIRContext *context, bool useAtomicWrites)
393+ : OpConversionPattern<vector::StoreOp>(context),
394+ useAtomicWrites_ (useAtomicWrites) {}
395+
321396 LogicalResult
322397 matchAndRewrite (vector::StoreOp op, OpAdaptor adaptor,
323398 ConversionPatternRewriter &rewriter) const override {
@@ -329,16 +404,17 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
329404
330405 auto loc = op.getLoc ();
331406 auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
332- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
333- Type newElementType = convertedType.getElementType ();
407+ auto valueToStore = cast<TypedValue<VectorType>>(op.getValueToStore ());
408+ auto oldElementType = valueToStore.getType ().getElementType ();
409+ auto newElementType = convertedType.getElementType ();
334410 int srcBits = oldElementType.getIntOrFloatBitWidth ();
335411 int dstBits = newElementType.getIntOrFloatBitWidth ();
336412
337413 if (dstBits % srcBits != 0 ) {
338414 return rewriter.notifyMatchFailure (
339415 op, " only dstBits % srcBits == 0 supported" );
340416 }
341- int scale = dstBits / srcBits;
417+ int numSrcElemsPerDest = dstBits / srcBits;
342418
343419 // Adjust the number of elements to store when emulating narrow types.
344420 // Here only the 1-D vector store is considered, and the N-D memref types
@@ -353,32 +429,153 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
353429 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
354430 // vector<4xi8>
355431
356- auto origElements = op.getValueToStore ().getType ().getNumElements ();
357- if (origElements % scale != 0 )
358- return failure ();
432+ auto origElements = valueToStore.getType ().getNumElements ();
433+ bool isUnalignedEmulation = origElements % numSrcElemsPerDest != 0 ;
359434
360435 auto stridedMetadata =
361436 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
362437
363438 OpFoldResult linearizedIndices;
364- std::tie (std::ignore, linearizedIndices) =
439+ memref::LinearizedMemRefInfo linearizedInfo;
440+ std::tie (linearizedInfo, linearizedIndices) =
365441 memref::getLinearizedMemRefOffsetAndSize (
366442 rewriter, loc, srcBits, dstBits,
367443 stridedMetadata.getConstifiedMixedOffset (),
368444 stridedMetadata.getConstifiedMixedSizes (),
369445 stridedMetadata.getConstifiedMixedStrides (),
370446 getAsOpFoldResult (adaptor.getIndices ()));
371447
372- auto numElements = origElements / scale;
373- auto bitCast = rewriter. create <vector::BitCastOp>(
374- loc, VectorType::get (numElements, newElementType),
375- op. getValueToStore ()) ;
448+ auto foldedNumFrontPadElems =
449+ isUnalignedEmulation
450+ ? getConstantIntValue (linearizedInfo. intraDataOffset )
451+ : 0 ;
376452
377- rewriter.replaceOpWithNewOp <vector::StoreOp>(
378- op, bitCast.getResult (), adaptor.getBase (),
379- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
453+ if (!foldedNumFrontPadElems) {
454+ // Unimplemented case for dynamic front padding size != 0
455+ return failure ();
456+ }
457+
458+ auto emulatedMemref = cast<TypedValue<MemRefType>>(adaptor.getBase ());
459+
460+ // Shortcut: conditions when subbyte store at the front is not needed:
461+ // 1. The source vector size is multiple of byte size
462+ // 2. The address of the store is aligned to the emulated width boundary
463+ if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
464+ auto numElements = origElements / numSrcElemsPerDest;
465+ auto bitCast = rewriter.create <vector::BitCastOp>(
466+ loc, VectorType::get (numElements, newElementType),
467+ op.getValueToStore ());
468+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
469+ op, bitCast.getResult (), emulatedMemref,
470+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
471+ return success ();
472+ }
473+
474+ // The index into the target memref we are storing to
475+ Value currentDestIndex =
476+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
477+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
478+ auto subWidthStoreMaskType =
479+ VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
480+ // The index into the source vector we are currently processing
481+ auto currentSourceIndex = 0 ;
482+
483+ // 1. Partial width store for the first byte, when the store address is not
484+ // aligned to emulated width boundary, deal with the unaligned part so that
485+ // the rest elements are aligned to width boundary.
486+ auto frontSubWidthStoreElem =
487+ (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
488+ if (frontSubWidthStoreElem != 0 ) {
489+ SmallVector<bool > frontMaskValues (numSrcElemsPerDest, false );
490+ if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
491+ std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
492+ origElements, true );
493+ frontSubWidthStoreElem = origElements;
494+ } else {
495+ std::fill_n (frontMaskValues.end () - frontSubWidthStoreElem,
496+ *foldedNumFrontPadElems, true );
497+ }
498+ auto frontMask = rewriter.create <arith::ConstantOp>(
499+ loc, DenseElementsAttr::get (subWidthStoreMaskType, frontMaskValues));
500+
501+ currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
502+ auto value =
503+ extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
504+ frontSubWidthStoreElem, *foldedNumFrontPadElems);
505+
506+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
507+ cast<TypedValue<VectorType>>(value),
508+ frontMask.getResult (), numSrcElemsPerDest);
509+
510+ currentDestIndex = rewriter.create <arith::AddIOp>(
511+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
512+ }
513+
514+ if (currentSourceIndex >= origElements) {
515+ rewriter.eraseOp (op);
516+ return success ();
517+ }
518+
519+ // 2. Full width store. After the previous step, the store address is
520+ // aligned to the emulated width boundary.
521+ int64_t fullWidthStoreSize =
522+ (origElements - currentSourceIndex) / numSrcElemsPerDest;
523+ int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
524+ if (fullWidthStoreSize != 0 ) {
525+ auto fullWidthStorePart = staticallyExtractSubvector (
526+ rewriter, loc, valueToStore, currentSourceIndex,
527+ numNonFullWidthElements);
528+
529+ auto originType = dyn_cast<VectorType>(fullWidthStorePart.getType ());
530+ auto memrefElemType =
531+ dyn_cast<MemRefType>(emulatedMemref.getType ()).getElementType ();
532+ auto storeType = VectorType::get (
533+ {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
534+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
535+ fullWidthStorePart);
536+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), emulatedMemref,
537+ currentDestIndex);
538+
539+ currentSourceIndex += numNonFullWidthElements;
540+ currentDestIndex = rewriter.create <arith::AddIOp>(
541+ loc, rewriter.getIndexType (), currentDestIndex,
542+ rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize));
543+ }
544+
545+ // 3. Deal with trailing elements that are aligned to the emulated width,
546+ // but their length is smaller than the emulated width.
547+ auto remainingElements = origElements - currentSourceIndex;
548+ if (remainingElements != 0 ) {
549+ auto subWidthStorePart = extractSliceIntoByte (
550+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
551+ currentSourceIndex, remainingElements, 0 );
552+
553+ // Generate back mask
554+ auto maskValues = SmallVector<bool >(numSrcElemsPerDest, 0 );
555+ std::fill_n (maskValues.begin (), remainingElements, 1 );
556+ auto backMask = rewriter.create <arith::ConstantOp>(
557+ loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
558+
559+ subEmulatedWidthStore (rewriter, loc, emulatedMemref, currentDestIndex,
560+ cast<TypedValue<VectorType>>(subWidthStorePart),
561+ backMask.getResult (), numSrcElemsPerDest);
562+ }
563+
564+ rewriter.eraseOp (op);
380565 return success ();
381566 }
567+
568+ // / Store a subbyte-sized value to memory, with a mask. Depending on the
569+ // / configuration, it could be an atomic store or an RMW sequence.
570+ template <typename ... Args>
571+ void subEmulatedWidthStore (Args &&...args) const {
572+ std::function<decltype (atomicStore)> storeFunc =
573+ useAtomicWrites_ ? atomicStore : rmwStore;
574+ storeFunc (std::forward<Args>(args)...);
575+ }
576+
577+ private:
578+ const bool useAtomicWrites_;
382579};
383580
384581// ===----------------------------------------------------------------------===//
@@ -584,9 +781,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
584781 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
585782 linearizedInfo.intraDataOffset , origElements);
586783 } else if (isUnalignedEmulation) {
587- result =
588- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
589- *foldedIntraVectorOffset, origElements);
784+ result = staticallyExtractSubvector (
785+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
590786 }
591787 rewriter.replaceOp (op, result);
592788 return success ();
@@ -745,9 +941,8 @@ struct ConvertVectorMaskedLoad final
745941 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
746942 op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
747943 } else if (isUnalignedEmulation) {
748- result =
749- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
750- *foldedIntraVectorOffset, origElements);
944+ result = staticallyExtractSubvector (
945+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
751946 }
752947 rewriter.replaceOp (op, result);
753948
@@ -830,9 +1025,8 @@ struct ConvertVectorTransferRead final
8301025 linearizedInfo.intraDataOffset ,
8311026 origElements);
8321027 } else if (isUnalignedEmulation) {
833- result =
834- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
835- *foldedIntraVectorOffset, origElements);
1028+ result = staticallyExtractSubvector (
1029+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
8361030 }
8371031 rewriter.replaceOp (op, result);
8381032
@@ -1577,12 +1771,17 @@ struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
15771771
15781772void vector::populateVectorNarrowTypeEmulationPatterns (
15791773 const arith::NarrowTypeEmulationConverter &typeConverter,
1580- RewritePatternSet &patterns) {
1774+ RewritePatternSet &patterns, bool useAtomicWrites ) {
15811775
1582- // Populate `vector.*` conversion patterns.
1583- patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
1776+ // Populate `vector.*` load conversion patterns.
1777+ patterns.add <ConvertVectorLoad, ConvertVectorMaskedLoad,
15841778 ConvertVectorMaskedStore, ConvertVectorTransferRead>(
15851779 typeConverter, patterns.getContext ());
1780+
1781+ // Populate `vector.*` store conversion patterns. The caller can choose
1782+ // to avoid emitting atomic operations and reduce it to load-modify-write
1783+ // sequence for stores if it is known there are no thread contentions.
1784+ patterns.insert <ConvertVectorStore>(patterns.getContext (), useAtomicWrites);
15861785}
15871786
15881787void vector::populateVectorNarrowTypeRewritePatterns (
0 commit comments