Skip to content

Commit 7e5d8a3

Browse files
committed
[MLIR] Support memrefs with complex element types.
Differential Revision: https://reviews.llvm.org/D74307
1 parent 42a16da commit 7e5d8a3

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

mlir/lib/IR/StandardTypes.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,8 @@ MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
333333
auto *context = elementType.getContext();
334334

335335
// Check that memref is formed from allowed types.
336-
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
336+
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
337+
!elementType.isa<ComplexType>())
337338
return emitOptionalError(location, "invalid memref element type"),
338339
MemRefType();
339340

@@ -411,7 +412,8 @@ LogicalResult UnrankedMemRefType::verifyConstructionInvariants(
411412
Optional<Location> loc, MLIRContext *context, Type elementType,
412413
unsigned memorySpace) {
413414
// Check that memref is formed from allowed types.
414-
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>())
415+
if (!elementType.isIntOrFloat() && !elementType.isa<VectorType>() &&
416+
!elementType.isa<ComplexType>())
415417
return emitOptionalError(*loc, "invalid memref element type");
416418
return success();
417419
}

mlir/test/IR/parser.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ func @memrefs_compose_with_id(memref<2x2xi8, affine_map<(d0, d1) -> (d0, d1)>,
133133
// CHECK: func @complex_types(complex<i1>) -> complex<f32>
134134
func @complex_types(complex<i1>) -> complex<f32>
135135

136+
137+
// CHECK: func @memref_with_complex_elems(memref<1x?xcomplex<f32>>)
138+
func @memref_with_complex_elems(memref<1x?xcomplex<f32>>)
139+
140+
// CHECK: func @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
141+
func @memref_with_vector_elems(memref<1x?xvector<10xf32>>)
142+
143+
// CHECK: func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
144+
func @unranked_memref_with_complex_elems(memref<*xcomplex<f32>>)
145+
136146
// CHECK: func @functions((memref<1x?x4x?x?xi32, #map0>, memref<8xi8>) -> (), () -> ())
137147
func @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<8xi8, #map1, 0>) -> (), ()->())
138148

0 commit comments

Comments
 (0)