-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][MemRef] Add support for emulating narrow floats #148036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: Quinn Dawkins (qedawkins) ChangesThis enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations. Full diff: https://github.com/llvm/llvm-project/pull/148036.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index d2a032688fb6d..ec2bc95291455 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// It is not clear if this case actually happens in practice, but we keep
// the operations just in case. Otherwise, if the arith computation bitwidth
// is different from the emulated bitwidth we truncate the result.
- Operation *result;
+ Value result;
auto resultTy = getTypeConverter()->convertType(oldElementType);
- if (resultTy == convertedElementType) {
+ auto conversionTy =
+ resultTy.isInteger()
+ ? resultTy
+ : IntegerType::get(rewriter.getContext(),
+ resultTy.getIntOrFloatBitWidth());
+ if (conversionTy == convertedElementType) {
auto mask = rewriter.create<arith::ConstantOp>(
loc, convertedElementType,
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
} else {
- result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+ result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
}
- rewriter.replaceOp(op, result->getResult(0));
+ if (conversionTy != resultTy) {
+ result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
};
@@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
}
Location loc = op.getLoc();
- Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
- adaptor.getValue());
+
+ // Pad the input value with 0s on the left.
+ Value input = adaptor.getValue();
+ if (!input.getType().isInteger()) {
+ input = rewriter.create<arith::BitcastOp>(
+ loc,
+ IntegerType::get(rewriter.getContext(),
+ input.getType().getIntOrFloatBitWidth()),
+ input);
+ }
+ Value extendedInput =
+ rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
@@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
arith::NarrowTypeEmulationConverter &typeConverter) {
typeConverter.addConversion(
[&typeConverter](MemRefType ty) -> std::optional<Type> {
- auto intTy = dyn_cast<IntegerType>(ty.getElementType());
- if (!intTy)
+ Type elementType = ty.getElementType();
+ if (!elementType.isIntOrFloat())
return ty;
- unsigned width = intTy.getWidth();
+ unsigned width = elementType.getIntOrFloatBitWidth();
unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
if (width >= loadStoreWidth)
return ty;
@@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
if (!strides.empty() && strides.back() != 1)
return nullptr;
- auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
- intTy.getSignedness());
+ auto newElemTy = IntegerType::get(
+ ty.getContext(), loadStoreWidth,
+ elementType.isInteger()
+ ? cast<IntegerType>(elementType).getSignedness()
+ : IntegerType::SignednessSemantics::Signless);
if (!newElemTy)
return nullptr;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 004beadc9ec7d..0fe08417f818f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final
bool isDivisibleInSize =
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
- auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
- adaptor.getPadding());
+ // Pad the padding value with 0s on the left. These bits are discarded and
+ // thus their values don't matter.
+ Value padding = adaptor.getPadding();
+ if (!padding.getType().isInteger()) {
+ padding = rewriter.create<arith::BitcastOp>(
+ loc,
+ IntegerType::get(rewriter.getContext(),
+ padding.getType().getIntOrFloatBitWidth()),
+ padding);
+ }
+ auto newPadding =
+ rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 3378d329e8205..0cce8c18a40bc 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
// -----
+func.func @memref_load_f4(%arg0: index) -> f4E2M1FN {
+ %0 = memref.alloc() : memref<5xf4E2M1FN>
+ %1 = memref.load %0[%arg0] : memref<5xf4E2M1FN>
+ return %1 : f4E2M1FN
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_load_f4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+// CHECK: return %[[BC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_load_f4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+// CHECK32: return %[[BC]]
+
+// -----
+
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
%0 = memref.alloc() : memref<3x125xi4>
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
@@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// -----
+func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
+ %0 = memref.alloc() : memref<f4E2M1FN>
+ memref.store %arg0, %0[] : memref<f4E2M1FN>
+ return
+}
+// CHECK-LABEL: func @rank_zero_memref
+// CHECK-SAME: %[[ARG0:.+]]: f4E2M1FN
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: return
+
+// CHECK32-LABEL: func @rank_zero_memref
+// CHECK32-SAME: %[[ARG0:.+]]: f4E2M1FN
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+// CHECK32: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: return
+
+// -----
+
func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
%arr = memref.alloc() : memref<32x8x128xi4>
%collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6c924492b513e..98b1f07ef5fb0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
// -----
+func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> {
+ %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+ %cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ %2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN>
+ return %2 : vector<3x8xf4E2M1FN>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_load_f4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_load_f4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
@@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// -----
+func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> {
+ %c0 = arith.constant 0.0 : f4E2M1FN
+ %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+ memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ return %1 : vector<8xf4E2M1FN>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_transfer_read_f4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+// CHECK: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_transfer_read_f4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+// CHECK32: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
///----------------------------------------------------------------------------------------
/// vector.maskedload
///----------------------------------------------------------------------------------------
@@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
// -----
+func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) {
+ %0 = memref.alloc() : memref<4x8xf4E2M1FN>
+ vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ return
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_store_f4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8>
+// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_store_f4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32>
+// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
// FIXME: This example assumes that the store happens at a byte boundary, but
// that's not guaranteed. Below is a counter-example with specific dimensions:
// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>
|
@llvm/pr-subscribers-mlir-vector Author: Quinn Dawkins (qedawkins) ChangesThis enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations. Full diff: https://github.com/llvm/llvm-project/pull/148036.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index d2a032688fb6d..ec2bc95291455 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
// It is not clear if this case actually happens in practice, but we keep
// the operations just in case. Otherwise, if the arith computation bitwidth
// is different from the emulated bitwidth we truncate the result.
- Operation *result;
+ Value result;
auto resultTy = getTypeConverter()->convertType(oldElementType);
- if (resultTy == convertedElementType) {
+ auto conversionTy =
+ resultTy.isInteger()
+ ? resultTy
+ : IntegerType::get(rewriter.getContext(),
+ resultTy.getIntOrFloatBitWidth());
+ if (conversionTy == convertedElementType) {
auto mask = rewriter.create<arith::ConstantOp>(
loc, convertedElementType,
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
} else {
- result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+ result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
}
- rewriter.replaceOp(op, result->getResult(0));
+ if (conversionTy != resultTy) {
+ result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+ }
+
+ rewriter.replaceOp(op, result);
return success();
}
};
@@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
}
Location loc = op.getLoc();
- Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
- adaptor.getValue());
+
+ // Pad the input value with 0s on the left.
+ Value input = adaptor.getValue();
+ if (!input.getType().isInteger()) {
+ input = rewriter.create<arith::BitcastOp>(
+ loc,
+ IntegerType::get(rewriter.getContext(),
+ input.getType().getIntOrFloatBitWidth()),
+ input);
+ }
+ Value extendedInput =
+ rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
// Special case 0-rank memref stores. No need for masking.
if (convertedType.getRank() == 0) {
@@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
arith::NarrowTypeEmulationConverter &typeConverter) {
typeConverter.addConversion(
[&typeConverter](MemRefType ty) -> std::optional<Type> {
- auto intTy = dyn_cast<IntegerType>(ty.getElementType());
- if (!intTy)
+ Type elementType = ty.getElementType();
+ if (!elementType.isIntOrFloat())
return ty;
- unsigned width = intTy.getWidth();
+ unsigned width = elementType.getIntOrFloatBitWidth();
unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
if (width >= loadStoreWidth)
return ty;
@@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
if (!strides.empty() && strides.back() != 1)
return nullptr;
- auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
- intTy.getSignedness());
+ auto newElemTy = IntegerType::get(
+ ty.getContext(), loadStoreWidth,
+ elementType.isInteger()
+ ? cast<IntegerType>(elementType).getSignedness()
+ : IntegerType::SignednessSemantics::Signless);
if (!newElemTy)
return nullptr;
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 004beadc9ec7d..0fe08417f818f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final
bool isDivisibleInSize =
fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
- auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
- adaptor.getPadding());
+ // Pad the padding value with 0s on the left. These bits are discarded and
+ // thus their values don't matter.
+ Value padding = adaptor.getPadding();
+ if (!padding.getType().isInteger()) {
+ padding = rewriter.create<arith::BitcastOp>(
+ loc,
+ IntegerType::get(rewriter.getContext(),
+ padding.getType().getIntOrFloatBitWidth()),
+ padding);
+ }
+ auto newPadding =
+ rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 3378d329e8205..0cce8c18a40bc 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
// -----
+func.func @memref_load_f4(%arg0: index) -> f4E2M1FN {
+ %0 = memref.alloc() : memref<5xf4E2M1FN>
+ %1 = memref.load %0[%arg0] : memref<5xf4E2M1FN>
+ return %1 : f4E2M1FN
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_load_f4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+// CHECK: return %[[BC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_load_f4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+// CHECK32: return %[[BC]]
+
+// -----
+
func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
%0 = memref.alloc() : memref<3x125xi4>
%align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
@@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
// -----
+func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
+ %0 = memref.alloc() : memref<f4E2M1FN>
+ memref.store %arg0, %0[] : memref<f4E2M1FN>
+ return
+}
+// CHECK-LABEL: func @rank_zero_memref
+// CHECK-SAME: %[[ARG0:.+]]: f4E2M1FN
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+// CHECK: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+// CHECK: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+// CHECK: return
+
+// CHECK32-LABEL: func @rank_zero_memref
+// CHECK32-SAME: %[[ARG0:.+]]: f4E2M1FN
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+// CHECK32: %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+// CHECK32: return
+
+// -----
+
func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
%arr = memref.alloc() : memref<32x8x128xi4>
%collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6c924492b513e..98b1f07ef5fb0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
// -----
+func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> {
+ %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+ %cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN>
+ %1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ %2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN>
+ return %2 : vector<3x8xf4E2M1FN>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_load_f4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_load_f4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
@@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
// -----
+func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> {
+ %c0 = arith.constant 0.0 : f4E2M1FN
+ %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+ memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ return %1 : vector<8xf4E2M1FN>
+}
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_transfer_read_f4
+// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+// CHECK: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+// CHECK: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+// CHECK: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_transfer_read_f4
+// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+// CHECK32: %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+// CHECK32: %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+// CHECK32: %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+// CHECK32: %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
///----------------------------------------------------------------------------------------
/// vector.maskedload
///----------------------------------------------------------------------------------------
@@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
// -----
+func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) {
+ %0 = memref.alloc() : memref<4x8xf4E2M1FN>
+ vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN>
+ return
+}
+
+// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+// CHECK: func @vector_store_f4
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8>
+// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+// CHECK32: func @vector_store_f4
+// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32>
+// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
// FIXME: This example assumes that the store happens at a byte boundary, but
// that's not guaranteed. Below is a counter-example with specific dimensions:
// vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Could you elaborate on why that's the case? (*)For example, any kind of arithmetic on floats would be incorrect if we were interpreting them as ints - though I understand this patch only touches loads and stores. I’m generally a bit concerned about the state of I'm not opposed to this patch - the change itself makes sense. But if there are more changes coming in this area, it would be helpful to get a better sense of the broader direction. In particular, would you mind chiming in on this discussion? Thanks! (*) Preferably in the summary :) |
Right it's because we only touch loads and stores. Even in the integer case if we chose to use float typed containers that would be fine. The fact that we load packed sub-byte types as a single larger int makes arithmetic on the immediate load in the integer case just as nonsensical as in the floating point case. Your confusion might be because there's a naming overload with narrow type emulation. The only thing this narrow type emulation pattern set should care about is loads/stores. Arithmetic is the responsibility of sibling patterns such as ArithNarrowTypeEmulation or ArithExpandOps, depending on the flavor of emulation you want per the underlying data type. |
There are no further changes planned at the moment. |
This enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter for loads/stores, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations.
There is no direct change needed for vector.load/store support. The tests added for them are to verify that float types are
supported as well.