3333#include " mlir/Transforms/DialectConversion.h"
3434#include " llvm/ADT/SmallVector.h"
3535#include " llvm/Support/Debug.h"
36+ #include " llvm/Support/LogicalResult.h"
3637#include " llvm/Support/MathExtras.h"
3738#include " llvm/Support/raw_ostream.h"
3839#include < cstdint>
@@ -143,13 +144,10 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
143144// / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
144145// / emitting `vector.extract_strided_slice`.
145146static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
146- VectorType extractType, Value source,
147- int64_t frontOffset,
147+ Value source, int64_t frontOffset,
148148 int64_t subvecSize) {
149149 auto vectorType = cast<VectorType>(source.getType ());
150- assert ((vectorType.getRank () == 1 && extractType.getRank () == 1 ) &&
151- " expected 1-D source and destination types" );
152- (void )vectorType;
150+ assert (vectorType.getRank () == 1 && " expected 1-D source types" );
153151 assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
154152 " subvector out of bounds" );
155153
@@ -160,9 +158,12 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
160158 auto offsets = rewriter.getI64ArrayAttr ({frontOffset});
161159 auto sizes = rewriter.getI64ArrayAttr ({subvecSize});
162160 auto strides = rewriter.getI64ArrayAttr ({1 });
161+
162+ auto resultVectorType =
163+ VectorType::get ({subvecSize}, vectorType.getElementType ());
163164 return rewriter
164- .create <vector::ExtractStridedSliceOp>(loc, extractType , source, offsets ,
165- sizes, strides)
165+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType , source,
166+ offsets, sizes, strides)
166167 ->getResult (0 );
167168}
168169
@@ -171,12 +172,10 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
171172// / `vector.insert_strided_slice`.
172173static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
173174 Value src, Value dest, int64_t offset) {
174- auto srcType = cast<VectorType>(src.getType ());
175- auto destType = cast<VectorType>(dest.getType ());
175+ [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
176+ [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
176177 assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
177178 " expected source and dest to be vector type" );
178- (void )srcType;
179- (void )destType;
180179 auto offsets = rewriter.getI64ArrayAttr ({offset});
181180 auto strides = rewriter.getI64ArrayAttr ({1 });
182181 return rewriter.create <vector::InsertStridedSliceOp>(loc, dest.getType (), src,
@@ -243,6 +242,63 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
243242 newLoad);
244243}
245244
245+ static void nonAtomicStore (ConversionPatternRewriter &rewriter, Location loc,
246+ Value memref, Value index, Value value) {
247+ auto originType = dyn_cast<VectorType>(value.getType ());
248+ auto memrefElemType = dyn_cast<MemRefType>(memref.getType ()).getElementType ();
249+ auto scale = memrefElemType.getIntOrFloatBitWidth () /
250+ originType.getElementType ().getIntOrFloatBitWidth ();
251+ auto storeType =
252+ VectorType::get ({originType.getNumElements () / scale}, memrefElemType);
253+ auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType, value);
254+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memref, index);
255+ }
256+
257+ // / atomically store a subbyte-sized value to memory, with a mask.
258+ static Value atomicStore (OpBuilder &rewriter, Location loc,
259+ Value emulatedMemref, Value emulatedIndex,
260+ TypedValue<VectorType> value, Value mask,
261+ int64_t scale) {
262+ auto atomicOp = rewriter.create <memref::GenericAtomicRMWOp>(
263+ loc, emulatedMemref, ValueRange{emulatedIndex});
264+ OpBuilder builder =
265+ OpBuilder::atBlockEnd (atomicOp.getBody (), rewriter.getListener ());
266+ Value origValue = atomicOp.getCurrentValue ();
267+
268+ // i8 -> vector type <1xi8> then <1xi8> -> <scale x i.>
269+ auto oneVectorType = VectorType::get ({1 }, origValue.getType ());
270+ auto fromElem = builder.create <vector::FromElementsOp>(loc, oneVectorType,
271+ ValueRange{origValue});
272+ auto vectorBitCast =
273+ builder.create <vector::BitCastOp>(loc, value.getType (), fromElem);
274+
275+ auto select =
276+ builder.create <arith::SelectOp>(loc, mask, value, vectorBitCast);
277+ auto bitcast2 = builder.create <vector::BitCastOp>(loc, oneVectorType, select);
278+ auto extract = builder.create <vector::ExtractOp>(loc, bitcast2, 0 );
279+ builder.create <memref::AtomicYieldOp>(loc, extract.getResult ());
280+ return atomicOp;
281+ }
282+
283+ // Extract a slice of a vector, and insert it into a byte vector.
284+ static Value extractSliceIntoByte (ConversionPatternRewriter &rewriter,
285+ Location loc, TypedValue<VectorType> vector,
286+ int64_t sliceOffset, int64_t sliceNumElements,
287+ int64_t byteOffset) {
288+ auto vectorElementType = vector.getType ().getElementType ();
289+ assert (8 % vectorElementType.getIntOrFloatBitWidth () == 0 &&
290+ " vector element must be a valid sub-byte type" );
291+ auto scale = 8 / vectorElementType.getIntOrFloatBitWidth ();
292+ auto emptyByteVector = rewriter.create <arith::ConstantOp>(
293+ loc, VectorType::get ({scale}, vectorElementType),
294+ rewriter.getZeroAttr (VectorType::get ({scale}, vectorElementType)));
295+ auto extracted = staticallyExtractSubvector (rewriter, loc, vector,
296+ sliceOffset, sliceNumElements);
297+ auto inserted = staticallyInsertSubvector (rewriter, loc, extracted,
298+ emptyByteVector, byteOffset);
299+ return inserted;
300+ }
301+
246302namespace {
247303
248304// ===----------------------------------------------------------------------===//
@@ -263,7 +319,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
263319
264320 auto loc = op.getLoc ();
265321 auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
266- Type oldElementType = op.getValueToStore ().getType ().getElementType ();
322+ auto valueToStore = op.getValueToStore ();
323+ Type oldElementType = valueToStore.getType ().getElementType ();
267324 Type newElementType = convertedType.getElementType ();
268325 int srcBits = oldElementType.getIntOrFloatBitWidth ();
269326 int dstBits = newElementType.getIntOrFloatBitWidth ();
@@ -287,30 +344,124 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
287344 // vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
288345 // vector<4xi8>
289346
290- auto origElements = op.getValueToStore ().getType ().getNumElements ();
291- if (origElements % scale != 0 )
292- return failure ();
347+ auto origElements = valueToStore.getType ().getNumElements ();
348+ bool isUnalignedEmulation = origElements % scale != 0 ;
293349
294350 auto stridedMetadata =
295351 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
296352
297353 OpFoldResult linearizedIndices;
298- std::tie (std::ignore, linearizedIndices) =
354+ memref::LinearizedMemRefInfo linearizedInfo;
355+ std::tie (linearizedInfo, linearizedIndices) =
299356 memref::getLinearizedMemRefOffsetAndSize (
300357 rewriter, loc, srcBits, dstBits,
301358 stridedMetadata.getConstifiedMixedOffset (),
302359 stridedMetadata.getConstifiedMixedSizes (),
303360 stridedMetadata.getConstifiedMixedStrides (),
304361 getAsOpFoldResult (adaptor.getIndices ()));
305362
306- auto numElements = origElements / scale;
307- auto bitCast = rewriter.create <vector::BitCastOp>(
308- loc, VectorType::get (numElements, newElementType),
309- op.getValueToStore ());
363+ auto foldedIntraVectorOffset =
364+ isUnalignedEmulation
365+ ? getConstantIntValue (linearizedInfo.intraDataOffset )
366+ : 0 ;
367+
368+ if (!foldedIntraVectorOffset) {
369+ // unimplemented case for dynamic front padding size
370+ return failure ();
371+ }
372+
373+ // conditions when atomic stores and all that are not needed:
374+ // 1. The source vector size is multiple of byte size
375+ // 2. The address of the store is byte aligned
376+ if (!isUnalignedEmulation && *foldedIntraVectorOffset == 0 ) {
377+ auto numElements = origElements / scale;
378+ auto bitCast = rewriter.create <vector::BitCastOp>(
379+ loc, VectorType::get (numElements, newElementType),
380+ op.getValueToStore ());
381+ rewriter.replaceOpWithNewOp <vector::StoreOp>(
382+ op, bitCast.getResult (), adaptor.getBase (),
383+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
384+ return llvm::success ();
385+ }
386+
387+ Value emulatedMemref = adaptor.getBase ();
388+ // the index into the target memref we are storing to
389+ Value currentDestIndex =
390+ getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
391+ auto constantOne = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
392+ auto atomicMaskType = VectorType::get ({scale}, rewriter.getI1Type ());
393+ // the index into the source vector we are currently processing
394+ auto currentSourceIndex = 0 ;
395+
396+ // 1. atomic store for the first byte
397+ auto frontAtomicStoreElem = (scale - *foldedIntraVectorOffset) % scale;
398+ if (frontAtomicStoreElem != 0 ) {
399+ auto frontMaskValues = llvm::SmallVector<bool >(scale, false );
400+ if (*foldedIntraVectorOffset + origElements < scale) {
401+ std::fill_n (frontMaskValues.begin () + *foldedIntraVectorOffset,
402+ origElements, true );
403+ frontAtomicStoreElem = origElements;
404+ } else {
405+ std::fill_n (frontMaskValues.end () - frontAtomicStoreElem,
406+ *foldedIntraVectorOffset, true );
407+ }
408+ auto frontMask = rewriter.create <arith::ConstantOp>(
409+ loc, DenseElementsAttr::get (atomicMaskType, frontMaskValues));
410+
411+ currentSourceIndex = scale - (*foldedIntraVectorOffset);
412+ auto value = extractSliceIntoByte (
413+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore), 0 ,
414+ frontAtomicStoreElem, *foldedIntraVectorOffset);
415+
416+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
417+ cast<TypedValue<VectorType>>(value), frontMask.getResult (),
418+ scale);
419+
420+ currentDestIndex = rewriter.create <arith::AddIOp>(
421+ loc, rewriter.getIndexType (), currentDestIndex, constantOne);
422+ }
423+
424+ if (currentSourceIndex >= origElements) {
425+ rewriter.eraseOp (op);
426+ return success ();
427+ }
428+
429+ // 2. non-atomic store
430+ int64_t nonAtomicStoreSize = (origElements - currentSourceIndex) / scale;
431+ int64_t numNonAtomicElements = nonAtomicStoreSize * scale;
432+ if (nonAtomicStoreSize != 0 ) {
433+ auto nonAtomicStorePart = staticallyExtractSubvector (
434+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
435+ currentSourceIndex, numNonAtomicElements);
436+
437+ nonAtomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
438+ nonAtomicStorePart);
439+
440+ currentSourceIndex += numNonAtomicElements;
441+ currentDestIndex = rewriter.create <arith::AddIOp>(
442+ loc, rewriter.getIndexType (), currentDestIndex,
443+ rewriter.create <arith::ConstantIndexOp>(loc, nonAtomicStoreSize));
444+ }
445+
446+ // 3. atomic store for the last byte
447+ auto remainingElements = origElements - currentSourceIndex;
448+ if (remainingElements != 0 ) {
449+ auto atomicStorePart = extractSliceIntoByte (
450+ rewriter, loc, cast<TypedValue<VectorType>>(valueToStore),
451+ currentSourceIndex, remainingElements, 0 );
452+
453+ // back mask
454+ auto maskValues = llvm::SmallVector<bool >(scale, 0 );
455+ std::fill_n (maskValues.begin (), remainingElements, 1 );
456+ auto backMask = rewriter.create <arith::ConstantOp>(
457+ loc, DenseElementsAttr::get (atomicMaskType, maskValues));
458+
459+ atomicStore (rewriter, loc, emulatedMemref, currentDestIndex,
460+ cast<TypedValue<VectorType>>(atomicStorePart),
461+ backMask.getResult (), scale);
462+ }
310463
311- rewriter.replaceOpWithNewOp <vector::StoreOp>(
312- op, bitCast.getResult (), adaptor.getBase (),
313- getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
464+ rewriter.eraseOp (op);
314465 return success ();
315466 }
316467};
@@ -518,9 +669,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
518669 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
519670 linearizedInfo.intraDataOffset , origElements);
520671 } else if (isUnalignedEmulation) {
521- result =
522- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
523- *foldedIntraVectorOffset, origElements);
672+ result = staticallyExtractSubvector (
673+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
524674 }
525675 rewriter.replaceOp (op, result);
526676 return success ();
@@ -679,9 +829,8 @@ struct ConvertVectorMaskedLoad final
679829 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result),
680830 op.getPassThru (), linearizedInfo.intraDataOffset , origElements);
681831 } else if (isUnalignedEmulation) {
682- result =
683- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
684- *foldedIntraVectorOffset, origElements);
832+ result = staticallyExtractSubvector (
833+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
685834 }
686835 rewriter.replaceOp (op, result);
687836
@@ -764,9 +913,8 @@ struct ConvertVectorTransferRead final
764913 linearizedInfo.intraDataOffset ,
765914 origElements);
766915 } else if (isUnalignedEmulation) {
767- result =
768- staticallyExtractSubvector (rewriter, loc, op.getType (), result,
769- *foldedIntraVectorOffset, origElements);
916+ result = staticallyExtractSubvector (
917+ rewriter, loc, result, *foldedIntraVectorOffset, origElements);
770918 }
771919 rewriter.replaceOp (op, result);
772920
0 commit comments