-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[mlir][spirv] Drop support for SPV_NV_cooperative_matrix #76782
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This extension has been superseeded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD, and Intel. Given that the KHR version has been supported for nearly half a year, drop the NV-specific extension to reduce the maintanance burden and code duplication.
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir-gpu Author: Jakub Kuderski (kuhar) ChangesThis extension has been superseded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD, and Intel. Given that the KHR version has been supported for nearly half a year, drop the NV-specific extension to reduce the maintenance burden and code duplication. Patch is 93.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/76782.diff 26 Files Affected:
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index cd650345f1daa2..d34549432161db 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -31,16 +31,10 @@ void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
void populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
-/// Collect a set of patterns to convert WMMA ops from GPU dialect to SPIRV,
-/// using the NV Cooperative Matrix extension.
-void populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(
- SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns);
-
-/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix type
-/// conversion to the type converter. Defaults to KHR cooperative matrix types.
-/// When `useNVTypes` is `true`, uses the NV cooperative matrix types.
+/// Adds `MMAMatrixType` conversions to SPIR-V cooperative matrix KHR type
+/// conversion to the type converter.
void populateMMAToSPIRVCoopMatrixTypeConversion(
- SPIRVTypeConverter &typeConverter, bool useNVTypes = false);
+ SPIRVTypeConverter &typeConverter);
} // namespace mlir
#endif // MLIR_CONVERSION_GPUTOSPIRV_GPUTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 6193aeb545bc6b..71be8841ca7c03 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -564,10 +564,6 @@ def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
Option<"use64bitIndex", "use-64bit-index",
"bool", /*default=*/"false",
"Use 64-bit integers to convert index types">,
- Option<"useCoopMatrixNV", "use-coop-matrix-nv",
- "bool", /*default=*/"false",
- "Use the NV cooperative matrix extension insted of the KHR extension"
- " to lower GPU WMMA ops">,
];
}
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index ee1fbba1e2844e..6ec97e17c5dcc8 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -1253,12 +1253,6 @@ def SPIRV_C_RayTracingProvisionalKHR : I32EnumAttrCase<"RayTr
Extension<[SPV_KHR_ray_tracing]>
];
}
-def SPIRV_C_CooperativeMatrixNV : I32EnumAttrCase<"CooperativeMatrixNV", 5357> {
- list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
- list<Availability> availability = [
- Extension<[SPV_NV_cooperative_matrix]>
- ];
-}
def SPIRV_C_FragmentShaderSampleInterlockEXT : I32EnumAttrCase<"FragmentShaderSampleInterlockEXT", 5363> {
list<I32EnumAttrCase> implies = [SPIRV_C_Shader];
list<Availability> availability = [
@@ -1501,7 +1495,7 @@ def SPIRV_CapabilityAttr :
SPIRV_C_ShaderNonUniform, SPIRV_C_RuntimeDescriptorArray,
SPIRV_C_StorageTexelBufferArrayDynamicIndexing, SPIRV_C_RayTracingNV,
SPIRV_C_RayTracingMotionBlurNV, SPIRV_C_PhysicalStorageBufferAddresses,
- SPIRV_C_RayTracingProvisionalKHR, SPIRV_C_CooperativeMatrixNV,
+ SPIRV_C_RayTracingProvisionalKHR,
SPIRV_C_FragmentShaderSampleInterlockEXT,
SPIRV_C_FragmentShaderShadingRateInterlockEXT, SPIRV_C_ShaderSMBuiltinsNV,
SPIRV_C_FragmentShaderPixelInterlockEXT, SPIRV_C_DemoteToHelperInvocation,
@@ -4123,8 +4117,6 @@ class SignlessOrUnsignedIntOfWidths<list<int> widths> :
def SPIRV_IsArrayType : CPred<"::llvm::isa<::mlir::spirv::ArrayType>($_self)">;
def SPIRV_IsCooperativeMatrixType :
CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixType>($_self)">;
-def SPIRV_IsCooperativeMatrixNVType :
- CPred<"::llvm::isa<::mlir::spirv::CooperativeMatrixNVType>($_self)">;
def SPIRV_IsImageType : CPred<"::llvm::isa<::mlir::spirv::ImageType>($_self)">;
def SPIRV_IsJointMatrixType :
CPred<"::llvm::isa<::mlir::spirv::JointMatrixINTELType>($_self)">;
@@ -4157,9 +4149,6 @@ def SPIRV_AnyArray : DialectType<SPIRV_Dialect, SPIRV_IsArrayType,
def SPIRV_AnyCooperativeMatrix : DialectType<SPIRV_Dialect,
SPIRV_IsCooperativeMatrixType,
"any SPIR-V cooperative matrix type">;
-def SPIRV_AnyCooperativeMatrixNV : DialectType<SPIRV_Dialect,
- SPIRV_IsCooperativeMatrixNVType,
- "any SPIR-V NV cooperative matrix type">;
def SPIRV_AnyImage : DialectType<SPIRV_Dialect, SPIRV_IsImageType,
"any SPIR-V image type">;
def SPIRV_AnyJointMatrix : DialectType<SPIRV_Dialect, SPIRV_IsJointMatrixType,
@@ -4178,13 +4167,12 @@ def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
def SPIRV_Aggregate : AnyTypeOf<[SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct]>;
def SPIRV_Composite :
AnyTypeOf<[SPIRV_Vector, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
- SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix]>;
def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_Float, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
- SPIRV_AnyCooperativeMatrix, SPIRV_AnyCooperativeMatrixNV,
- SPIRV_AnyJointMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage
+ SPIRV_AnyCooperativeMatrix, SPIRV_AnyJointMatrix, SPIRV_AnyMatrix,
+ SPIRV_AnySampledImage
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4195,11 +4183,6 @@ class SPIRV_CoopMatrixOfType<list<Type> allowedTypes> :
"::llvm::cast<::mlir::spirv::CooperativeMatrixType>($_self).getElementType()",
"Cooperative Matrix">;
-class SPIRV_CoopMatrixNVOfType<list<Type> allowedTypes> :
- ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsCooperativeMatrixNVType,
- "::llvm::cast<::mlir::spirv::CooperativeMatrixNVType>($_self).getElementType()",
- "Cooperative Matrix NV">;
-
class SPIRV_JointMatrixOfType<list<Type> allowedTypes> :
ContainerType<AnyTypeOf<allowedTypes>, SPIRV_IsJointMatrixType,
"::llvm::cast<::mlir::spirv::JointMatrixINTELType>($_self).getElementType()",
@@ -4213,12 +4196,11 @@ class SPIRV_ScalarOrVectorOf<Type type> :
class SPIRV_ScalarOrVectorOrCoopMatrixOf<Type type> :
AnyTypeOf<[type, SPIRV_VectorOf<type>,
- SPIRV_CoopMatrixOfType<[type]>, SPIRV_CoopMatrixNVOfType<[type]>]>;
+ SPIRV_CoopMatrixOfType<[type]>]>;
class SPIRV_MatrixOrCoopMatrixOf<Type type> :
AnyTypeOf<[SPIRV_AnyMatrix,
- SPIRV_CoopMatrixOfType<[type]>,
- SPIRV_CoopMatrixNVOfType<[type]>]>;
+ SPIRV_CoopMatrixOfType<[type]>]>;
def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>;
def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>;
@@ -4480,11 +4462,6 @@ def SPIRV_OC_OpCooperativeMatrixLoadKHR : I32EnumAttrCase<"OpCooperativeMatrix
def SPIRV_OC_OpCooperativeMatrixStoreKHR : I32EnumAttrCase<"OpCooperativeMatrixStoreKHR", 4458>;
def SPIRV_OC_OpCooperativeMatrixMulAddKHR : I32EnumAttrCase<"OpCooperativeMatrixMulAddKHR", 4459>;
def SPIRV_OC_OpCooperativeMatrixLengthKHR : I32EnumAttrCase<"OpCooperativeMatrixLengthKHR", 4460>;
-def SPIRV_OC_OpTypeCooperativeMatrixNV : I32EnumAttrCase<"OpTypeCooperativeMatrixNV", 5358>;
-def SPIRV_OC_OpCooperativeMatrixLoadNV : I32EnumAttrCase<"OpCooperativeMatrixLoadNV", 5359>;
-def SPIRV_OC_OpCooperativeMatrixStoreNV : I32EnumAttrCase<"OpCooperativeMatrixStoreNV", 5360>;
-def SPIRV_OC_OpCooperativeMatrixMulAddNV : I32EnumAttrCase<"OpCooperativeMatrixMulAddNV", 5361>;
-def SPIRV_OC_OpCooperativeMatrixLengthNV : I32EnumAttrCase<"OpCooperativeMatrixLengthNV", 5362>;
def SPIRV_OC_OpSubgroupBlockReadINTEL : I32EnumAttrCase<"OpSubgroupBlockReadINTEL", 5575>;
def SPIRV_OC_OpSubgroupBlockWriteINTEL : I32EnumAttrCase<"OpSubgroupBlockWriteINTEL", 5576>;
def SPIRV_OC_OpAssumeTrueKHR : I32EnumAttrCase<"OpAssumeTrueKHR", 5630>;
@@ -4585,9 +4562,6 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpTypeCooperativeMatrixKHR, SPIRV_OC_OpCooperativeMatrixLoadKHR,
SPIRV_OC_OpCooperativeMatrixStoreKHR, SPIRV_OC_OpCooperativeMatrixMulAddKHR,
SPIRV_OC_OpCooperativeMatrixLengthKHR,
- SPIRV_OC_OpTypeCooperativeMatrixNV, SPIRV_OC_OpCooperativeMatrixLoadNV,
- SPIRV_OC_OpCooperativeMatrixStoreNV, SPIRV_OC_OpCooperativeMatrixMulAddNV,
- SPIRV_OC_OpCooperativeMatrixLengthNV,
SPIRV_OC_OpSubgroupBlockReadINTEL, SPIRV_OC_OpSubgroupBlockWriteINTEL,
SPIRV_OC_OpAssumeTrueKHR, SPIRV_OC_OpAtomicFAddEXT, SPIRV_OC_OpGroupIMulKHR,
SPIRV_OC_OpGroupFMulKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
index 29ad45bddd5529..46732ba19afed5 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td
@@ -338,253 +338,6 @@ def SPIRV_KHRCooperativeMatrixMulAddOp : SPIRV_KhrVendorOp<"CooperativeMatrixMul
];
}
-//===----------------------------------------------------------------------===//
-// SPV_NV_cooperative_matrix extension ops.
-//===----------------------------------------------------------------------===//
-
-// -----
-
-def SPIRV_NVCooperativeMatrixLengthOp : SPIRV_NvVendorOp<"CooperativeMatrixLength",
- [Pure]> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Number of components of a cooperative matrix type accessible to each
- invocation when treated as a composite.
-
- Result Type must be an OpTypeInt with 32-bit Width and 0 Signedness.
-
- Type is a cooperative matrix type.
-
- #### Example:
-
- ```
- %0 = spirv.NV.CooperativeMatrixLength : !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- ```
- }];
-
- let assemblyFormat = "attr-dict `:` $cooperative_matrix_type";
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- TypeAttr:$cooperative_matrix_type
- );
-
- let results = (outs
- SPIRV_Int32:$result
- );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixLoadOp : SPIRV_NvVendorOp<"CooperativeMatrixLoad", []> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Load a cooperative matrix through a pointer.
-
- Result Type is the type of the loaded object. It must be a cooperative
- matrix type.
-
- Pointer is a pointer into an array. Its type must be an OpTypePointer whose
- Type operand is a scalar or vector type. The storage class of Pointer must
- be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
- supported) PhysicalStorageBufferEXT.
-
- Stride is the number of elements in the array in memory between the first
- component of consecutive rows (or columns) in the result. It must be a
- scalar integer type.
-
- ColumnMajor indicates whether the values loaded from memory are arranged in
- column-major or row-major order. It must be a boolean constant instruction,
- with false indicating row major and true indicating column major.
-
- Memory Access must be a Memory Access literal. If not present, it is the
- same as specifying None.
-
- If ColumnMajor is false, then elements (row,*) of the result are taken in
- order from contiguous locations starting at Pointer[row*Stride]. If
- ColumnMajor is true, then elements (*,col) of the result are taken in order
- from contiguous locations starting from Pointer[col*Stride]. Any ArrayStride
- decoration on Pointer is ignored.
-
- For a given dynamic instance of this instruction, all operands of this
- instruction must be the same for all invocations in a given scope instance
- (where the scope is the scope the cooperative matrix type was created with).
- All invocations in a given scope instance must be active or all must be
- inactive.
-
- ### Custom assembly form
-
- ``` {.ebnf}
- cooperative-matrixload-op ::= ssa-id `=` `spirv.NV.CooperativeMatrixLoad`
- ssa-use `,` ssa-use `,` ssa-use
- (`[` memory-access `]`)? ` : `
- pointer-type `as`
- cooperative-matrix-type
- ```
-
- #### Example:
-
- ```
- %0 = spirv.NV.CooperativeMatrixLoad %ptr, %stride, %colMajor
- : !spirv.ptr<i32, StorageBuffer> as !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- ```
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- SPIRV_AnyPtr:$pointer,
- SPIRV_Integer:$stride,
- SPIRV_Bool:$columnmajor,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
- );
-
- let results = (outs
- SPIRV_AnyCooperativeMatrixNV:$result
- );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixMulAddOp : SPIRV_NvVendorOp<"CooperativeMatrixMulAdd",
- [Pure, AllTypesMatch<["c", "result"]>]> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Linear-algebraic matrix multiply of A by B and then component-wise add C.
- The order of the operations is implementation-dependent. The internal
- precision of floating-point operations is defined by the client API.
- Integer operations are performed at the precision of the Result Type and are
- exact unless there is overflow or underflow, in which case the result is
- undefined.
-
- Result Type must be a cooperative matrix type with M rows and N columns.
-
- A is a cooperative matrix with M rows and K columns.
-
- B is a cooperative matrix with K rows and N columns.
-
- C is a cooperative matrix with M rows and N columns.
-
- The values of M, N, and K must be consistent across the result and operands.
- This is referred to as an MxNxK matrix multiply.
-
- A, B, C, and Result Type must have the same scope, and this defines the
- scope of the operation. A, B, C, and Result Type need not necessarily have
- the same component type, this is defined by the client API.
-
- If the Component Type of any matrix operand is an integer type, then its
- components are treated as signed if its Component Type has Signedness of 1
- and are treated as unsigned otherwise.
-
- For a given dynamic instance of this instruction, all invocations in a given
- scope instance must be active or all must be inactive (where the scope is
- the scope of the operation).
-
- #### Example:
-
- ```
- %0 = spirv.NV.CooperativeMatrixMulAdd %arg0, %arg1, %arg2, :
- !spirv.NV.coopmatrix<8x16xi32, Subgroup>
- ```
- }];
-
- let assemblyFormat = [{
- operands attr-dict `:` type($a) `,` type($b) `->` type($c)
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- SPIRV_AnyCooperativeMatrixNV:$a,
- SPIRV_AnyCooperativeMatrixNV:$b,
- SPIRV_AnyCooperativeMatrixNV:$c
- );
-
- let results = (outs
- SPIRV_AnyCooperativeMatrixNV:$result
- );
-}
-
-// -----
-
-def SPIRV_NVCooperativeMatrixStoreOp : SPIRV_NvVendorOp<"CooperativeMatrixStore", []> {
- let summary = "See extension SPV_NV_cooperative_matrix";
-
- let description = [{
- Store a cooperative matrix through a pointer.
-
- Pointer is a pointer into an array. Its type must be an OpTypePointer whose
- Type operand is a scalar or vector type. The storage class of Pointer must
- be Workgroup, StorageBuffer, or (if SPV_EXT_physical_storage_buffer is
- supported) PhysicalStorageBufferEXT.
-
- Object is the object to store. Its type must be an
- OpTypeCooperativeMatrixNV.
-
- Stride is the number of elements in the array in memory between the first
- component of consecutive rows (or columns) in the result. It must be a
- scalar integer type.
-
- ColumnMajor indicates whether the values stored to memory are arranged in
- column-major or row-major order. It must be a boolean constant instruction,
- with false indicating row major and true indicating column major.
-
- Memory Access must be a Memory Access literal. If not present, it is the
- same as specifying None.
-
- ``` {.ebnf}
- coop-matrix-store-op ::= `spirv.NV.CooperativeMatrixStore `
- ssa-use `, ` ssa-use `, `
- ssa-use `, ` ssa-use `, `
- (`[` memory-access `]`)? `:`
- pointer-type `,` coop-matrix-type
- ```
-
- #### Example:
-
- ```
- spirv.NV.CooperativeMatrixStore %arg0, %arg2, %arg1, %arg3 :
- !spirv.ptr<i32, StorageBuffer>, !spirv.NV.coopmatrix<16x8xi32, Workgroup>
- ```
- }];
-
- let availability = [
- MinVersion<SPIRV_V_1_0>,
- MaxVersion<SPIRV_V_1_6>,
- Extension<[SPV_NV_cooperative_matrix]>,
- Capability<[SPIRV_C_CooperativeMatrixNV]>
- ];
-
- let arguments = (ins
- SPIRV_AnyPtr:$pointer,
- SPIRV_AnyCooperativeMatrixNV:$object,
- SPIRV_Integer:$stride,
- SPIRV_Bool:$columnmajor,
- OptionalAttr<SPIRV_MemoryAccessAttr>:$memory_access
- );
-
- let results = (outs);
-}
-
// -----
#endif // MLIR_DIALECT_SPIRV_IR_COOPERATIVE_MATRIX_OPS
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index d946d936d4e6cf..55f0c787b44403 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,7 +29,6 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
-struct CooperativeMatrixNVTypeStorage;
struct ImageTypeStorage;
struct JointMatrixTypeStorage;
struct MatrixTypeStorage;
@@ -421,32 +420,6 @@ class CooperativeMatrixType
std::optional<StorageClass> storage = std::nullopt);
};
-// SPIR-V NV cooperative matrix type
-class CooperativeMatrixNVType
- : public Type::TypeBase<CooperativeMatrixNVType, CompositeType,
- detail::CooperativeMatrixNVTypeStorage> {
-public:
- using Base::Base;
-
- static constexpr StringLiteral name = "spirv.NV.coopmatrix";
-
- static CooperativeMatrixNVType get(Type elementType, Scope scope,
- unsigned rows, unsigned columns);
- Type getElementType() const;
-
- /// Returns the scope of the matrix.
- Scope getScope() const;
- /// Returns the number of rows of the matrix.
- unsigned getRows() const;
- /// Returns the number of columns of the matrix.
- unsigned getColumns() const;
-
- void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
- std::optional<StorageClass> storage = std::nullopt);
- void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
- std::optional<StorageClass> storage = std::nullopt);
-};
-
// SPIR-V joint matrix type
class JointMatrixINTELType
: public Type::TypeBase<JointMatrixINTELType, CompositeType,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 8279b3408a6e66..0dd0e7e21b0553 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -112,18 +112,12 @@ void GPUToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.use64bitIndex = this->use64bitIndex;
SPIRVTypeConverter typeConverter(targetAttr, options);
- populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter,
- this->useCoopMatrixNV);
+ populateMMAToSPIRVCoopMatrixTypeConversion(typeConverter);
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
- if (this->useCoopMatrixNV) {
- populateGpuWMMAToSPIRVCoopMatrixNVConversionPatterns(typ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
I'm going to leave this open until Monday to get folks time to comment, considering some may be still on vacation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This extension has been superseded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD, and Intel. Given that the KHR version has been supported for nearly half a year, drop the NV-specific extension to reduce the maintenance burden and code duplication.
This extension has been superseded by SPV_KHR_cooperative_matrix which is supported across major vendors GPU like Nvidia, AMD, and Intel.
Given that the KHR version has been supported for nearly half a year, drop the NV-specific extension to reduce the maintenance burden and code duplication.