Skip to content

Conversation

@antiagainst
Copy link
Member

This commit allows inferFragType to see through all arith.ext op users before reaching contract op for figuring out the fragment type.

This commit allows `inferFragType` to see through all arith.ext
op users before reaching contract op for figuring out the fragment
type.
@llvmbot
Copy link
Member

llvmbot commented May 13, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Lei Zhang (antiagainst)

Changes

This commit allows inferFragType to see through all arith.ext op users before reaching contract op for figuring out the fragment type.


Full diff: https://github.com/llvm/llvm-project/pull/91988.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp (+10-4)
  • (modified) mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir (+27)
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 782cc92f83fee..ad7408bb06fc1 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -515,6 +515,13 @@ struct CombineTransferReadOpTranspose final
 // TODO: Change the GPU dialect to abstract the layout at the this level and
 // only care about it during lowering to NVVM.
 static const char *inferFragType(Operation *op) {
+  // We can have arith.ext ops before reaching contract ops. See through them.
+  if (op->hasOneUse()) {
+    Operation *extOp = *op->user_begin();
+    if (isa<arith::ExtFOp, arith::ExtUIOp, arith::ExtSIOp>(extOp))
+      return inferFragType(extOp);
+  }
+
   for (Operation *users : op->getUsers()) {
     auto contract = dyn_cast<vector::ContractionOp>(users);
     if (!contract)
@@ -560,13 +567,12 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
   if (op->hasOneUse()) {
     auto *user = *op->user_begin();
     // Infer the signedness of the mma type from the integer extend.
-    bool isSignedExtend = isa<arith::ExtSIOp>(user);
-    if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
+    if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
       elType = IntegerType::get(
           op.getContext(), cast<IntegerType>(elType).getWidth(),
-          isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
+          isa<arith::ExtSIOp>(user) ? IntegerType::Signed
+                                    : IntegerType::Unsigned);
       mappingResult = user->getResult(0);
-      fragType = inferFragType(user);
     }
   }
   gpu::MMAMatrixType type =
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index 962ed7de584a2..8526ff1392599 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
 }
 
 // -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
+#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
+#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>
+
+// CHECK-LABEL: func @cast_f16_to_f32_read
+//       CHECK:    %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+//       CHECK:    %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+//       CHECK:    %[[AE:.+]] = gpu.subgroup_mma_elementwise  extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
+//       CHECK:    %[[CE:.+]] = gpu.subgroup_mma_elementwise  extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
+//       CHECK:    %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+//       CHECK:    %[[BE:.+]] = gpu.subgroup_mma_elementwise  extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
+//       CHECK:    gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
+func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
+  %c0 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f16
+  %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+  %Aext = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
+  %Bext = arith.extf %B : vector<16x16xf16> to vector<16x16xf32>
+  %Cext = arith.extf %C : vector<16x16xf16> to vector<16x16xf32>
+  %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
+                        %Aext, %Bext, %Cext : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
+  vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
+  return
+}

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@antiagainst antiagainst merged commit a037d88 into llvm:main May 13, 2024
@antiagainst antiagainst deleted the mma-cast branch May 13, 2024 19:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants