From 92bc49cf94ccd6d0ab6fdc1ffb4be90613a3935a Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 30 Jul 2025 09:12:27 +0000 Subject: [PATCH] [mlir][tosa] Relax constraint on matmul verifier requiring equal operand types Removes the verifier constraint allowing support for matmul with different operand types such as fp8e5m2xfp8e4m3. Support for specific operand types strictly adhering to the TOSA specification will still be caught in the validation pass. Change-Id: I1453ded48326ea0460fa6caf52651c02b7d8c055 --- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 6 ------ mlir/test/Dialect/Tosa/ops.mlir | 9 +++++++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 3cafb199d2db3..d58acf6a53d5e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1605,12 +1605,6 @@ LogicalResult MatMulOp::verify() { return emitOpError("expect quantized operands to have same widths, got ") << aQuantWidth << " and " << bQuantWidth; } - } else { - // non-quantized element types - if (aElementType != bElementType) { - return emitOpError("expect same element type for inputs a and b, got ") - << aElementType << " and " << bElementType; - } } // check a_zp and b_zp diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 30361a882afe5..6061c04662de8 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -934,6 +934,15 @@ func.func @test_matmul_f8E5M2(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x return %0 : tensor<1x14x28xf16> } +// ----- +// CHECK-LABEL: test_matmul_f8E5M2_f8E4M3 +func.func @test_matmul_f8E5M2_f8E4M3(%arg0: tensor<1x14x19xf8E5M2>, %arg1: tensor<1x19x28xf8E4M3FN>) -> tensor<1x14x28xf16> { + %azp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E5M2>}> : () -> tensor<1xf8E5M2> + %bzp0 = "tosa.const"() <{values = dense<0.0> : tensor<1xf8E4M3FN>}> : () -> tensor<1xf8E4M3FN> + %0 = tosa.matmul %arg0, %arg1, %azp0, %bzp0 : (tensor<1x14x19xf8E5M2>, tensor<1x19x28xf8E4M3FN>, tensor<1xf8E5M2>, tensor<1xf8E4M3FN>) -> tensor<1x14x28xf16> + return %0 : tensor<1x14x28xf16> +} + // ----- // CHECK-LABEL: max_pool2d_f8E5M2 func.func @test_max_pool2d_f8E5M2(%arg0: tensor<1x32x32x8xf8E5M2>) -> tensor<1x32x32x8xf8E5M2> {