Skip to content

[mlir][MemRef] Add support for emulating narrow floats #148036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025

Conversation

qedawkins
Copy link
Contributor

@qedawkins qedawkins commented Jul 10, 2025

This enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter for loads/stores, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations.

There is no direct change needed for vector.load/store support. The tests added for them are to verify that float types are
supported as well.

This enables memref.load/store + vector.load/store support for sub-byte float
types. Since the memref types don't matter, we still use the same types
as integers with equivalent widths, with a few extra bitcasts needed
around certain operations.
@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: Quinn Dawkins (qedawkins)

Changes

This enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations.


Full diff: https://github.com/llvm/llvm-project/pull/148036.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+33-11)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+12-2)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+58)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+78)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index d2a032688fb6d..ec2bc95291455 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
     // It is not clear if this case actually happens in practice, but we keep
     // the operations just in case. Otherwise, if the arith computation bitwidth
     // is different from the emulated bitwidth we truncate the result.
-    Operation *result;
+    Value result;
     auto resultTy = getTypeConverter()->convertType(oldElementType);
-    if (resultTy == convertedElementType) {
+    auto conversionTy =
+        resultTy.isInteger()
+            ? resultTy
+            : IntegerType::get(rewriter.getContext(),
+                               resultTy.getIntOrFloatBitWidth());
+    if (conversionTy == convertedElementType) {
       auto mask = rewriter.create<arith::ConstantOp>(
           loc, convertedElementType,
           rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
 
       result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
     } else {
-      result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+      result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
     }
 
-    rewriter.replaceOp(op, result->getResult(0));
+    if (conversionTy != resultTy) {
+      result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
     }
 
     Location loc = op.getLoc();
-    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
-                                                          adaptor.getValue());
+
+    // Pad the input value with 0s on the left.
+    Value input = adaptor.getValue();
+    if (!input.getType().isInteger()) {
+      input = rewriter.create<arith::BitcastOp>(
+          loc,
+          IntegerType::get(rewriter.getContext(),
+                           input.getType().getIntOrFloatBitWidth()),
+          input);
+    }
+    Value extendedInput =
+        rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
 
     // Special case 0-rank memref stores. No need for masking.
     if (convertedType.getRank() == 0) {
@@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
     arith::NarrowTypeEmulationConverter &typeConverter) {
   typeConverter.addConversion(
       [&typeConverter](MemRefType ty) -> std::optional<Type> {
-        auto intTy = dyn_cast<IntegerType>(ty.getElementType());
-        if (!intTy)
+        Type elementType = ty.getElementType();
+        if (!elementType.isIntOrFloat())
           return ty;
 
-        unsigned width = intTy.getWidth();
+        unsigned width = elementType.getIntOrFloatBitWidth();
         unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
         if (width >= loadStoreWidth)
           return ty;
@@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
         if (!strides.empty() && strides.back() != 1)
           return nullptr;
 
-        auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
-                                          intTy.getSignedness());
+        auto newElemTy = IntegerType::get(
+            ty.getContext(), loadStoreWidth,
+            elementType.isInteger()
+                ? cast<IntegerType>(elementType).getSignedness()
+                : IntegerType::SignednessSemantics::Signless);
         if (!newElemTy)
           return nullptr;
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 004beadc9ec7d..0fe08417f818f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final
     bool isDivisibleInSize =
         fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
 
-    auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
-                                                      adaptor.getPadding());
+    // Pad the padding value with 0s on the left. These bits are discarded and
+    // thus their values don't matter.
+    Value padding = adaptor.getPadding();
+    if (!padding.getType().isInteger()) {
+      padding = rewriter.create<arith::BitcastOp>(
+          loc,
+          IntegerType::get(rewriter.getContext(),
+                           padding.getType().getIntOrFloatBitWidth()),
+          padding);
+    }
+    auto newPadding =
+        rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 3378d329e8205..0cce8c18a40bc 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
 
 // -----
 
+func.func @memref_load_f4(%arg0: index) -> f4E2M1FN {
+    %0 = memref.alloc() : memref<5xf4E2M1FN>
+    %1 = memref.load %0[%arg0] : memref<5xf4E2M1FN>
+    return %1 : f4E2M1FN
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_load_f4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+//      CHECK:   return %[[BC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_load_f4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+//      CHECK32:   return %[[BC]]
+
+// -----
+
 func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
     %0 = memref.alloc() : memref<3x125xi4>
     %align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
@@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
 
 // -----
 
+func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
+  %0 = memref.alloc() : memref<f4E2M1FN>
+  memref.store %arg0, %0[] : memref<f4E2M1FN>
+  return
+}
+// CHECK-LABEL: func @rank_zero_memref
+//  CHECK-SAME:     %[[ARG0:.+]]: f4E2M1FN
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+//       CHECK:   %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+//       CHECK:   %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
+//       CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+//       CHECK:   return
+
+// CHECK32-LABEL: func @rank_zero_memref
+//  CHECK32-SAME:     %[[ARG0:.+]]: f4E2M1FN
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+//       CHECK32:   %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+//       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
+//       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+//       CHECK32:   return
+
+// -----
+
 func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
   %arr = memref.alloc() : memref<32x8x128xi4>
   %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6c924492b513e..98b1f07ef5fb0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
 
 // -----
 
+func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> {
+    %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+    %cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    %2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN>
+    return %2 : vector<3x8xf4E2M1FN>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_load_f4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_load_f4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
 func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
   %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
   %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
@@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 
 // -----
 
+func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> {
+    %c0 = arith.constant 0.0 : f4E2M1FN
+    %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+    %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+      memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    return %1 : vector<8xf4E2M1FN>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_transfer_read_f4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+//      CHECK:   %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_transfer_read_f4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+//      CHECK32:   %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.maskedload
 ///----------------------------------------------------------------------------------------
@@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
 
 // -----
 
+func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xf4E2M1FN>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    return
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_store_f4
+//      CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+//      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8>
+//      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_store_f4
+//      CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+//      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32>
+//      CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
 // FIXME: This example assumes that the store happens at a byte boundary, but
 // that's not guaranteed. Below is a counter-example with specific dimensions:
 //    vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

Changes

This enables memref.load/store + vector.load/store support for sub-byte float types. Since the memref types don't matter, we still use the same types as integers with equivalent widths, with a few extra bitcasts needed around certain operations.


Full diff: https://github.com/llvm/llvm-project/pull/148036.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+33-11)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+12-2)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+58)
  • (modified) mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir (+78)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index d2a032688fb6d..ec2bc95291455 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -323,19 +323,28 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
     // It is not clear if this case actually happens in practice, but we keep
     // the operations just in case. Otherwise, if the arith computation bitwidth
     // is different from the emulated bitwidth we truncate the result.
-    Operation *result;
+    Value result;
     auto resultTy = getTypeConverter()->convertType(oldElementType);
-    if (resultTy == convertedElementType) {
+    auto conversionTy =
+        resultTy.isInteger()
+            ? resultTy
+            : IntegerType::get(rewriter.getContext(),
+                               resultTy.getIntOrFloatBitWidth());
+    if (conversionTy == convertedElementType) {
       auto mask = rewriter.create<arith::ConstantOp>(
           loc, convertedElementType,
           rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
 
       result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
     } else {
-      result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
+      result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad);
     }
 
-    rewriter.replaceOp(op, result->getResult(0));
+    if (conversionTy != resultTy) {
+      result = rewriter.create<arith::BitcastOp>(loc, resultTy, result);
+    }
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -415,8 +424,18 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
     }
 
     Location loc = op.getLoc();
-    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
-                                                          adaptor.getValue());
+
+    // Pad the input value with 0s on the left.
+    Value input = adaptor.getValue();
+    if (!input.getType().isInteger()) {
+      input = rewriter.create<arith::BitcastOp>(
+          loc,
+          IntegerType::get(rewriter.getContext(),
+                           input.getType().getIntOrFloatBitWidth()),
+          input);
+    }
+    Value extendedInput =
+        rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input);
 
     // Special case 0-rank memref stores. No need for masking.
     if (convertedType.getRank() == 0) {
@@ -619,11 +638,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
     arith::NarrowTypeEmulationConverter &typeConverter) {
   typeConverter.addConversion(
       [&typeConverter](MemRefType ty) -> std::optional<Type> {
-        auto intTy = dyn_cast<IntegerType>(ty.getElementType());
-        if (!intTy)
+        Type elementType = ty.getElementType();
+        if (!elementType.isIntOrFloat())
           return ty;
 
-        unsigned width = intTy.getWidth();
+        unsigned width = elementType.getIntOrFloatBitWidth();
         unsigned loadStoreWidth = typeConverter.getLoadStoreBitwidth();
         if (width >= loadStoreWidth)
           return ty;
@@ -636,8 +655,11 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
         if (!strides.empty() && strides.back() != 1)
           return nullptr;
 
-        auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
-                                          intTy.getSignedness());
+        auto newElemTy = IntegerType::get(
+            ty.getContext(), loadStoreWidth,
+            elementType.isInteger()
+                ? cast<IntegerType>(elementType).getSignedness()
+                : IntegerType::SignednessSemantics::Signless);
         if (!newElemTy)
           return nullptr;
 
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 004beadc9ec7d..0fe08417f818f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1268,8 +1268,18 @@ struct ConvertVectorTransferRead final
     bool isDivisibleInSize =
         fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy);
 
-    auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
-                                                      adaptor.getPadding());
+    // Pad the padding value with 0s on the left. These bits are discarded and
+    // thus their values don't matter.
+    Value padding = adaptor.getPadding();
+    if (!padding.getType().isInteger()) {
+      padding = rewriter.create<arith::BitcastOp>(
+          loc,
+          IntegerType::get(rewriter.getContext(),
+                           padding.getType().getIntOrFloatBitWidth()),
+          padding);
+    }
+    auto newPadding =
+        rewriter.create<arith::ExtUIOp>(loc, containerElemTy, padding);
 
     auto stridedMetadata =
         rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 3378d329e8205..0cce8c18a40bc 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -61,6 +61,41 @@ func.func @memref_load_i4(%arg0: index) -> i4 {
 
 // -----
 
+func.func @memref_load_f4(%arg0: index) -> f4E2M1FN {
+    %0 = memref.alloc() : memref<5xf4E2M1FN>
+    %1 = memref.load %0[%arg0] : memref<5xf4E2M1FN>
+    return %1 : f4E2M1FN
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_load_f4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+//      CHECK:   return %[[BC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_load_f4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   %[[BC:.+]] = arith.bitcast %[[TRUNC]] : i4 to f4E2M1FN
+//      CHECK32:   return %[[BC]]
+
+// -----
+
 func.func @memref_load_i4_rank2(%arg0: index, %arg1: index) -> i4 {
     %0 = memref.alloc() : memref<3x125xi4>
     %align0 = memref.assume_alignment %0, 64 : memref<3x125xi4>
@@ -470,6 +505,29 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () {
 
 // -----
 
+func.func @rank_zero_memref_store_f4(%arg0: f4E2M1FN) -> () {
+  %0 = memref.alloc() : memref<f4E2M1FN>
+  memref.store %arg0, %0[] : memref<f4E2M1FN>
+  return
+}
+// CHECK-LABEL: func @rank_zero_memref
+//  CHECK-SAME:     %[[ARG0:.+]]: f4E2M1FN
+//       CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<i8>
+//       CHECK:   %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+//       CHECK:   %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i8
+//       CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
+//       CHECK:   return
+
+// CHECK32-LABEL: func @rank_zero_memref
+//  CHECK32-SAME:     %[[ARG0:.+]]: f4E2M1FN
+//       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<i32>
+//       CHECK32:   %[[BC:.+]] = arith.bitcast %[[ARG0]] : f4E2M1FN to i4
+//       CHECK32:   %[[EXTUI:.+]] = arith.extui %[[BC]] : i4 to i32
+//       CHECK32:   %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
+//       CHECK32:   return
+
+// -----
+
 func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 {
   %arr = memref.alloc() : memref<32x8x128xi4>
   %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4>
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
index 6c924492b513e..98b1f07ef5fb0 100644
--- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir
@@ -53,6 +53,31 @@ func.func @vector_load_i4(%arg1: index, %arg2: index) -> vector<3x8xi4> {
 
 // -----
 
+func.func @vector_load_f4(%arg1: index, %arg2: index) -> vector<3x8xf4E2M1FN> {
+    %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+    %cst = arith.constant dense<0.0> : vector<3x8xf4E2M1FN>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    %2 = vector.insert %1, %cst [0] : vector<8xf4E2M1FN> into vector<3x8xf4E2M1FN>
+    return %2 : vector<3x8xf4E2M1FN>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_load_f4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_load_f4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
 func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index) -> vector<8xi4> {
   %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
   %1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
@@ -119,6 +144,37 @@ func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> {
 
 // -----
 
+func.func @vector_transfer_read_f4(%arg1: index, %arg2: index) -> vector<8xf4E2M1FN> {
+    %c0 = arith.constant 0.0 : f4E2M1FN
+    %0 = memref.alloc() : memref<3x8xf4E2M1FN>
+    %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} :
+      memref<3x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    return %1 : vector<8xf4E2M1FN>
+}
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_transfer_read_f4
+// CHECK-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK:   %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+//      CHECK:   %[[ALLOC:.+]] = memref.alloc() : memref<12xi8>
+//      CHECK:   %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+//      CHECK:   %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i8
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8>
+//      CHECK:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xf4E2M1FN>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_transfer_read_f4
+// CHECK32-SAME:     (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index)
+//      CHECK32:   %[[CONST:.+]] = arith.constant 0.{{0+}}e+00 : f4E2M1FN
+//      CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi32>
+//      CHECK32:   %[[BC:.+]] = arith.bitcast %[[CONST]] : f4E2M1FN to i4
+//      CHECK32:   %[[PAD:.+]] = arith.extui %[[BC]] : i4 to i32
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]]
+//      CHECK32:   %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32>
+//      CHECK32:   %[[VEC_F4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xf4E2M1FN>
+
+// -----
+
 ///----------------------------------------------------------------------------------------
 /// vector.maskedload
 ///----------------------------------------------------------------------------------------
@@ -439,6 +495,28 @@ func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
 
 // -----
 
+func.func @vector_store_f4(%arg0: vector<8xf4E2M1FN>, %arg1: index, %arg2: index) {
+    %0 = memref.alloc() : memref<4x8xf4E2M1FN>
+    vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xf4E2M1FN>, vector<8xf4E2M1FN>
+    return
+}
+
+//  CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
+//      CHECK: func @vector_store_f4
+//      CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
+//      CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<4xi8>
+//      CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
+
+//  CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
+//      CHECK32: func @vector_store_f4
+//      CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
+//      CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
+//      CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xf4E2M1FN> to vector<1xi32>
+//      CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
+
+// -----
+
 // FIXME: This example assumes that the store happens at a byte boundary, but
 // that's not guaranteed. Below is a counter-example with specific dimensions:
 //    vector.store %arg0, %0[0, 3] : memref<2x13xi4>, vector<8xi4>

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@banach-space
Copy link
Contributor

Since the memref types don't matter

Could you elaborate on why that's the case? (*)For example, any kind of arithmetic on floats would be incorrect if we were interpreting them as ints - though I understand this patch only touches loads and stores.

I’m generally a bit concerned about the state of NarrowTypeEmulation in Vector. We should be cautious about how it's being extended. In my opinion, it currently lacks some of the rigor or design clarity we'd ideally want. Here's one case in point:

I'm not opposed to this patch - the change itself makes sense. But if there are more changes coming in this area, it would be helpful to get a better sense of the broader direction. In particular, would you mind chiming in on this discussion?

Thanks!

(*) Preferably in the summary :)

@qedawkins
Copy link
Contributor Author

Could you elaborate on why that's the case? (*)For example, any kind of arithmetic on floats would be incorrect if we were interpreting them as ints - though I understand this patch only touches loads and stores.

Right it's because we only touch loads and stores. Even in the integer case if we chose to use float typed containers that would be fine. The fact that we load packed sub-byte types as a single larger int makes arithmetic on the immediate load in the integer case just as nonsensical as in the floating point case.

Your confusion might be because there's a naming overload with narrow type emulation. The only thing this narrow type emulation pattern set should care about is loads/stores. Arithmetic is the responsibility of sibling patterns such as ArithNarrowTypeEmulation or ArithExpandOps, depending on the flavor of emulation you want per the underlying data type.

@qedawkins
Copy link
Contributor Author

But if there are more changes coming in this area

There are no further changes planned at the moment.

@qedawkins qedawkins merged commit b1ef5a8 into llvm:main Jul 14, 2025
14 checks passed
@qedawkins qedawkins deleted the emulate_narrow_floats branch July 14, 2025 15:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants