From 1c005844a0444549de16b23c47547df861740271 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 17 Jan 2025 13:54:34 +0000 Subject: [PATCH 1/4] [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N) This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In #121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). NEXT STEPS (1): We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified. NEXT STEPS (2): With this PR, I am introducing explicit references to "sub-byte" as that is effectively what this logic is used of (i.e. for emulating "sub-byte" types). We should either generalise (which would include increasing test coverage) or restrict everything to "sub-byte" type emulation. --- .../Transforms/VectorEmulateNarrowType.cpp | 134 +++++++++++++----- 1 file changed, 102 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 51e72753ff162..59ed3b5521470 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final } }; +/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy` +/// +/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g. +/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy` +/// (a multi-byte scalar, e.g. i16), where N is some integer. +/// +/// Put differently, this method checks whether this would be valid: +/// +/// vector.bitcast subByteVecTy into vector +/// +/// EXAMPLES: +/// * vector<4xi4> -> i16 - yes (N = 1) +/// * vector<4xi4> -> i8 - yes (N = 2) +/// * vector<3xi4> -> i8 - no (N would have to be 1.5) +/// * vector<3xi2> -> i16 - no (N would have to be 0.5) +static bool isSubByteVecFittable(VectorType subByteVecTy, + Type multiByteScalarTy) { + assert((isa(multiByteScalarTy)) && "Not scalar!"); + + int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth(); + int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth(); + + assert(subByteBits < 8 && "Not a sub-byte scalar type!"); + assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!"); + assert(multiByteBits % subByteBits == 0 && "Unalagined element types!"); + + int elemsPerMultiByte = multiByteBits / subByteBits; + + // TODO: This is a bit too restrictive for vectors rank > 1. + return subByteVecTy.getShape().back() % elemsPerMultiByte == 0; +} + //===----------------------------------------------------------------------===// // ConvertVectorTransferRead //===----------------------------------------------------------------------===// @@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final auto origElements = op.getVectorType().getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isFullyAligned = + isSubByteVecFittable(op.getVectorType(), containerElemTy); auto newPadding = rewriter.create(loc, containerElemTy, adaptor.getPadding()); @@ -1428,41 +1461,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, return commonConversionPrecondition(rewriter, preconditionType, op); } -/// Verify that `subByteVecType` and `dstType` are aligned. Alignment -/// means that: -/// 1. The `dstType` element type is a multiple of the -/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8 -/// is not supported). Let this multiple be `N`. -/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a -/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is -/// not supported). +/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned. +/// +/// Alignment means that `subByteVecTy` can be packed into a vector of +/// `containerTy` elements. More specifically: +/// 1. The bit-width of `containerTy` is a multiple of the +/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16` +/// this multiple is 4. +/// 2. The multiple from 1. above divides evenly the number of the (trailing) +/// elements in `subByteVecTy`. +/// +/// EXAMPLE 1: +/// `subByteVecTy = vector<2xi4>`, and +/// `containerTy = i16` +/// +/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_. +/// +/// EXAMPLE 2: +/// `subByteVecTy = vector<3xi4>`, and +/// `containerTy = i16` +/// +/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_. +/// +/// EXAMPLE 3: +/// `subByteVecTy = vector<3xi3>`, and +/// `containerTy = i16` +/// +/// 16 _is not_ a multiple of 3, hence the conditions are _not met_. /// /// NOTE: This method assumes that common conversion preconditions are met. In -/// particular, the element type of `dstType` is assumed to be a multi-byte -/// type (e.g. i8, i16, i32). +/// particular, `containerTy` is assumed to be a +/// multi-byte scalar type (e.g., i8, i16, i32). static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, - VectorType subByteVecType, - VectorType dstType, + VectorType subByteVecTy, + Type containerTy, Operation *op) { - if (!subByteVecType || !dstType) - return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); - unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth(); - unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); + // TODO: This is validating the inputs rather than checking the conditions + // documented above. Replace with an assert. + if (!subByteVecTy) + return rewriter.notifyMatchFailure(op, "not a vector!"); - if (dstElemBitwidth < 8) - return rewriter.notifyMatchFailure( - op, "the bitwidth of dstType must be greater than or equal to 8"); - if (dstElemBitwidth % srcElemBitwidth != 0) - return rewriter.notifyMatchFailure(op, "unaligned cases are not supported"); - if (srcElemBitwidth != 2 && srcElemBitwidth != 4) + // TODO: This is validating the inputs rather than checking the conditions + // documented above. Replace with an assert. + if (!containerTy.isIntOrFloat()) + return rewriter.notifyMatchFailure(op, "not a scalar!"); + + unsigned subByteBits = subByteVecTy.getElementTypeBitWidth(); + unsigned multiByteBits = containerTy.getIntOrFloatBitWidth(); + + // Enforced by the common pre-conditions. + assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!"); + + // TODO: Remove this condition - the assert above (and + // commonConversionPrecondtion) takes care of that. + if (multiByteBits < 8) + return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!"); + + // TODO: Add support other widths (when/if needed) + if (subByteBits != 2 && subByteBits != 4) return rewriter.notifyMatchFailure( - op, "only src bitwidth of 2 or 4 is supported at this moment"); + op, "only 2-bit and 4-bit sub-byte type is supported at this moment"); + + // Condition 1. + if (multiByteBits % subByteBits != 0) + return rewriter.notifyMatchFailure(op, "unalagined element types"); - const int numSrcElemsPerByte = 8 / srcElemBitwidth; - if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0) + // Condition 2. + if (!isSubByteVecFittable(subByteVecTy, containerTy)) return rewriter.notifyMatchFailure( - op, "the trailing dimension of the input vector of sub-bytes must be a " - "multiple of 8 / "); + op, "not possible to fit this sub-byte vector type into a vector of " + "the given multi-byte type"); return success(); } @@ -1899,8 +1967,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { return failure(); // Check general alignment preconditions. - if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType, - conversionOp))) + Type containerType = rewriter.getI8Type(); + if (failed(alignedConversionPrecondition(rewriter, srcVecType, + containerType, conversionOp))) return failure(); // Perform the rewrite. @@ -1964,8 +2033,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { // Check general alignment preconditions. We invert the src/dst type order // to reuse the existing precondition logic. - if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType, - truncOp))) + Type containerType = rewriter.getI8Type(); + if (failed(alignedConversionPrecondition(rewriter, dstVecType, + containerType, truncOp))) return failure(); // Create a new iX -> i8 truncation op. From 91b8aac4278860fa9982e1bc36b05a81e961fb0b Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 10 Mar 2025 16:21:12 +0000 Subject: [PATCH 2/4] fixup! [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N) Address comments from Alan --- .../Transforms/VectorEmulateNarrowType.cpp | 37 ++++++++----------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 59ed3b5521470..649d73a0a460f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1106,8 +1106,8 @@ struct ConvertVectorMaskedLoad final /// * vector<4xi4> -> i8 - yes (N = 2) /// * vector<3xi4> -> i8 - no (N would have to be 1.5) /// * vector<3xi2> -> i16 - no (N would have to be 0.5) -static bool isSubByteVecFittable(VectorType subByteVecTy, - Type multiByteScalarTy) { +static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, + Type multiByteScalarTy) { assert((isa(multiByteScalarTy)) && "Not scalar!"); int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth(); @@ -1160,7 +1160,7 @@ struct ConvertVectorTransferRead final // Note, per-element-alignment was already verified above. bool isFullyAligned = - isSubByteVecFittable(op.getVectorType(), containerElemTy); + fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy); auto newPadding = rewriter.create(loc, containerElemTy, adaptor.getPadding()); @@ -1496,38 +1496,31 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, VectorType subByteVecTy, Type containerTy, Operation *op) { + assert(containerTy.isIntOrFloat() && + "container element type is not a scalar"); + // TODO: This is validating the inputs rather than checking the conditions // documented above. Replace with an assert. if (!subByteVecTy) return rewriter.notifyMatchFailure(op, "not a vector!"); - // TODO: This is validating the inputs rather than checking the conditions - // documented above. Replace with an assert. - if (!containerTy.isIntOrFloat()) - return rewriter.notifyMatchFailure(op, "not a scalar!"); - unsigned subByteBits = subByteVecTy.getElementTypeBitWidth(); unsigned multiByteBits = containerTy.getIntOrFloatBitWidth(); // Enforced by the common pre-conditions. assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!"); - // TODO: Remove this condition - the assert above (and - // commonConversionPrecondtion) takes care of that. - if (multiByteBits < 8) - return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!"); - // TODO: Add support other widths (when/if needed) if (subByteBits != 2 && subByteBits != 4) return rewriter.notifyMatchFailure( op, "only 2-bit and 4-bit sub-byte type is supported at this moment"); - // Condition 1. + // Condition 1 ("per-element" alignment) if (multiByteBits % subByteBits != 0) return rewriter.notifyMatchFailure(op, "unalagined element types"); - // Condition 2. - if (!isSubByteVecFittable(subByteVecTy, containerTy)) + // Condition 2 ("full" alignment) + if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy)) return rewriter.notifyMatchFailure( op, "not possible to fit this sub-byte vector type into a vector of " "the given multi-byte type"); @@ -1967,9 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { return failure(); // Check general alignment preconditions. - Type containerType = rewriter.getI8Type(); - if (failed(alignedConversionPrecondition(rewriter, srcVecType, - containerType, conversionOp))) + if (failed(alignedConversionPrecondition( + rewriter, srcVecType, + /*containerTy=*/rewriter.getI8Type(), conversionOp))) return failure(); // Perform the rewrite. @@ -2033,9 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { // Check general alignment preconditions. We invert the src/dst type order // to reuse the existing precondition logic. - Type containerType = rewriter.getI8Type(); - if (failed(alignedConversionPrecondition(rewriter, dstVecType, - containerType, truncOp))) + if (failed(alignedConversionPrecondition( + rewriter, dstVecType, + /*containerTy=*/rewriter.getI8Type(), truncOp))) return failure(); // Create a new iX -> i8 truncation op. From aeea3e0227120239e0d1bf700e3562c6b0a4176c Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 11 Mar 2025 19:57:08 +0000 Subject: [PATCH 3/4] fixup! fixup! [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N) isFullyAligned -> isDivisibleInSize --- .../Transforms/VectorEmulateNarrowType.cpp | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 649d73a0a460f..1b274095dc625 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern { auto origElements = valueToStore.getType().getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedNumFrontPadElems = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); if (!foldedNumFrontPadElems) { return rewriter.notifyMatchFailure( @@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern { // need unaligned emulation because the store address is aligned and the // source is a whole byte. bool emulationRequiresPartialStores = - !isFullyAligned || *foldedNumFrontPadElems != 0; + !isDivisibleInSize || *foldedNumFrontPadElems != 0; if (!emulationRequiresPartialStores) { // Basic case: storing full bytes. auto numElements = origElements / emulatedPerContainerElem; @@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern { auto origElements = op.getVectorType().getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); // Always load enough elements which can cover the original elements. int64_t maxintraDataOffset = @@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern { result = dynamicallyExtractSubVector( rewriter, loc, dyn_cast>(result), resultVector, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } @@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final auto origType = op.getVectorType(); auto origElements = origType.getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); @@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final passthru = dynamicallyInsertSubVector( rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset); } @@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, *foldedIntraVectorOffset); } @@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final result = dynamicallyExtractSubVector( rewriter, loc, result, op.getPassThru(), linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } @@ -1159,7 +1159,7 @@ struct ConvertVectorTransferRead final auto origElements = op.getVectorType().getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = + bool isDivisibleInSize = fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy); auto newPadding = rewriter.create(loc, containerElemTy, @@ -1179,8 +1179,8 @@ struct ConvertVectorTransferRead final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); @@ -1204,7 +1204,7 @@ struct ConvertVectorTransferRead final result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } From b0a3d06fbe24c74ce8b49f3de4187006cf993be4 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sat, 15 Mar 2025 19:22:03 +0000 Subject: [PATCH 4/4] fixup! fixup! fixup! [mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N) Add minor missing re-naming (otherwise there are inconsistent names remaining)) --- .../Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 1b274095dc625..cf6efaa04ae44 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1505,10 +1505,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, return rewriter.notifyMatchFailure(op, "not a vector!"); unsigned subByteBits = subByteVecTy.getElementTypeBitWidth(); - unsigned multiByteBits = containerTy.getIntOrFloatBitWidth(); + unsigned containerBits = containerTy.getIntOrFloatBitWidth(); // Enforced by the common pre-conditions. - assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!"); + assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!"); // TODO: Add support other widths (when/if needed) if (subByteBits != 2 && subByteBits != 4) @@ -1516,7 +1516,7 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, op, "only 2-bit and 4-bit sub-byte type is supported at this moment"); // Condition 1 ("per-element" alignment) - if (multiByteBits % subByteBits != 0) + if (containerBits % subByteBits != 0) return rewriter.notifyMatchFailure(op, "unalagined element types"); // Condition 2 ("full" alignment)