Skip to content

Commit 2e2eabb

Browse files
xin-zhang-intelrayngun
authored andcommitted
PR intel#38: Replace BaseMemRef/TensorType class with TypeInterface
1 parent a196b6e commit 2e2eabb

File tree

8 files changed

+156
-255
lines changed

8 files changed

+156
-255
lines changed

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,24 +104,20 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
104104

105105
/// Return the number of elements present in the given shape.
106106
static int64_t getNumElements(ArrayRef<int64_t> shape);
107+
}];
107108

109+
let extraSharedClassDeclaration = [{
108110
/// Return a clone of this type with the given new shape and element type.
109-
/// The returned type is ranked, even if this type is unranked.
110111
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
111-
return cloneWith(shape, elementType);
112+
return $_type.cloneWith(shape, elementType);
112113
}
113114

114-
/// Return a clone of this type with the given new shape. The returned type
115-
/// is ranked, even if this type is unranked.
115+
/// Return a clone of this type with the given new shape.
116116
auto clone(::llvm::ArrayRef<int64_t> shape) {
117-
return cloneWith(shape, getElementType());
117+
return $_type.cloneWith(shape, $_type.getElementType());
118118
}
119-
}];
120119

121-
let extraSharedClassDeclaration = [{
122-
/// Return a clone of this type with the given new element type. The
123-
/// returned type is ranked if and only if this type is ranked. In that
124-
/// case, the returned type has the same shape as this type.
120+
/// Return a clone of this type with the given new element type.
125121
auto clone(::mlir::Type elementType) {
126122
return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
127123
}
@@ -188,4 +184,66 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
188184
}];
189185
}
190186

187+
//===----------------------------------------------------------------------===//
188+
// TensorTypeInterface
189+
//===----------------------------------------------------------------------===//
190+
191+
def TensorTypeInterface : TypeInterface<"TensorType", [
192+
ShapedTypeInterface
193+
]
194+
> {
195+
let cppNamespace = "::mlir";
196+
let description = [{
197+
This interface provides a shared interface for ranked and unranked type
198+
and customized tensor types.
199+
200+
This class attaches the ShapedTypeInterface to act as a mixin to
201+
provide many useful utility functions.
202+
}];
203+
204+
let extraClassDeclaration = [{
205+
// Return true if the specified element type is ok in a tensor.
206+
static bool isValidElementType(::mlir::Type type);
207+
}];
208+
209+
let extraClassOf = [{
210+
return $_type.hasTrait<::mlir::TensorType::Trait>();
211+
}];
212+
213+
}
214+
215+
//===----------------------------------------------------------------------===//
216+
// BaseMemRefTypeInterface
217+
//===----------------------------------------------------------------------===//
218+
219+
def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [
220+
ShapedTypeInterface
221+
]
222+
> {
223+
let cppNamespace = "::mlir";
224+
let description = [{
225+
This interface provides a shared interface for ranked and unranked memref and
226+
customized memref types.
227+
228+
This interface attaches the ShapedTypeInterface to act as a mixin to
229+
provide many useful utility functions.
230+
}];
231+
232+
let methods = [
233+
InterfaceMethod<[{
234+
Returns the memory space in which data referred to by this memref resides.
235+
}],
236+
"::mlir::Attribute", "getMemorySpace">,
237+
];
238+
239+
let extraClassDeclaration = [{
240+
// Return true if the specified element type is ok in a memref.
241+
static bool isValidElementType(::mlir::Type type);
242+
}];
243+
244+
let extraClassOf = [{
245+
return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
246+
}];
247+
}
248+
191249
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -85,108 +85,6 @@ class FloatType : public Type {
8585
const llvm::fltSemantics &getFloatSemantics();
8686
};
8787

88-
//===----------------------------------------------------------------------===//
89-
// TensorType
90-
//===----------------------------------------------------------------------===//
91-
92-
/// Tensor types represent multi-dimensional arrays, and have two variants:
93-
/// RankedTensorType and UnrankedTensorType.
94-
/// Note: This class attaches the ShapedType trait to act as a mixin to
95-
/// provide many useful utility functions. This inheritance has no effect
96-
/// on derived tensor types.
97-
class TensorType : public Type, public ShapedType::Trait<TensorType> {
98-
public:
99-
using Type::Type;
100-
101-
/// Returns the element type of this tensor type.
102-
Type getElementType() const;
103-
104-
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
105-
bool hasRank() const;
106-
107-
/// Returns the shape of this tensor type.
108-
ArrayRef<int64_t> getShape() const;
109-
110-
/// Clone this type with the given shape and element type. If the
111-
/// provided shape is `std::nullopt`, the current shape of the type is used.
112-
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
113-
Type elementType) const;
114-
115-
// Make sure that base class overloads are visible.
116-
using ShapedType::Trait<TensorType>::clone;
117-
118-
/// Return a clone of this type with the given new shape and element type.
119-
/// The returned type is ranked, even if this type is unranked.
120-
RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
121-
122-
/// Return a clone of this type with the given new shape. The returned type
123-
/// is ranked, even if this type is unranked.
124-
RankedTensorType clone(ArrayRef<int64_t> shape) const;
125-
126-
/// Return true if the specified element type is ok in a tensor.
127-
static bool isValidElementType(Type type);
128-
129-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
130-
static bool classof(Type type);
131-
132-
/// Allow implicit conversion to ShapedType.
133-
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
134-
};
135-
136-
//===----------------------------------------------------------------------===//
137-
// BaseMemRefType
138-
//===----------------------------------------------------------------------===//
139-
140-
/// This class provides a shared interface for ranked and unranked memref types.
141-
/// Note: This class attaches the ShapedType trait to act as a mixin to
142-
/// provide many useful utility functions. This inheritance has no effect
143-
/// on derived memref types.
144-
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
145-
public:
146-
using Type::Type;
147-
148-
/// Returns the element type of this memref type.
149-
Type getElementType() const;
150-
151-
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
152-
bool hasRank() const;
153-
154-
/// Returns the shape of this memref type.
155-
ArrayRef<int64_t> getShape() const;
156-
157-
/// Clone this type with the given shape and element type. If the
158-
/// provided shape is `std::nullopt`, the current shape of the type is used.
159-
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
160-
Type elementType) const;
161-
162-
// Make sure that base class overloads are visible.
163-
using ShapedType::Trait<BaseMemRefType>::clone;
164-
165-
/// Return a clone of this type with the given new shape and element type.
166-
/// The returned type is ranked, even if this type is unranked.
167-
MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
168-
169-
/// Return a clone of this type with the given new shape. The returned type
170-
/// is ranked, even if this type is unranked.
171-
MemRefType clone(ArrayRef<int64_t> shape) const;
172-
173-
/// Return true if the specified element type is ok in a memref.
174-
static bool isValidElementType(Type type);
175-
176-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
177-
static bool classof(Type type);
178-
179-
/// Returns the memory space in which data referred to by this memref resides.
180-
Attribute getMemorySpace() const;
181-
182-
/// [deprecated] Returns the memory space in old raw integer representation.
183-
/// New `Attribute getMemorySpace()` method should be used instead.
184-
unsigned getMemorySpaceAsInt() const;
185-
186-
/// Allow implicit conversion to ShapedType.
187-
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
188-
};
189-
19088
} // namespace mlir
19189

19290
//===----------------------------------------------------------------------===//
@@ -399,10 +297,6 @@ SliceVerificationResult isRankReducedType(ShapedType originalType,
399297
// Deferred Method Definitions
400298
//===----------------------------------------------------------------------===//
401299

402-
inline bool BaseMemRefType::classof(Type type) {
403-
return llvm::isa<MemRefType, UnrankedMemRefType>(type);
404-
}
405-
406300
inline bool BaseMemRefType::isValidElementType(Type type) {
407301
return type.isIntOrIndexOrFloat() ||
408302
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
@@ -469,10 +363,6 @@ inline FloatType FloatType::getF128(MLIRContext *ctx) {
469363
return Float128Type::get(ctx);
470364
}
471365

472-
inline bool TensorType::classof(Type type) {
473-
return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
474-
}
475-
476366
//===----------------------------------------------------------------------===//
477367
// Type Utilities
478368
//===----------------------------------------------------------------------===//

mlir/include/mlir/IR/BuiltinTypes.td

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,8 @@ def Builtin_Integer : Builtin_Type<"Integer", "integer"> {
423423
//===----------------------------------------------------------------------===//
424424

425425
def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
426-
ShapedTypeInterface
427-
], "BaseMemRefType"> {
426+
BaseMemRefTypeInterface
427+
]> {
428428
let summary = "Shaped reference to a region of memory";
429429
let description = [{
430430
Syntax:
@@ -675,7 +675,7 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
675675
"unsigned":$memorySpaceInd)>
676676
];
677677
let extraClassDeclaration = [{
678-
using BaseMemRefType::clone;
678+
using ShapedType::Trait<MemRefType>::clone;
679679
using ShapedType::Trait<MemRefType>::getElementTypeBitWidth;
680680
using ShapedType::Trait<MemRefType>::getRank;
681681
using ShapedType::Trait<MemRefType>::getNumElements;
@@ -693,6 +693,13 @@ def Builtin_MemRef : Builtin_Type<"MemRef", "memref", [
693693
/// New `Attribute getMemorySpace()` method should be used instead.
694694
unsigned getMemorySpaceAsInt() const;
695695

696+
/// Returns if this type is ranked (always true).
697+
bool hasRank() const { return true; }
698+
699+
/// Returns a clone of this type with the given shape and element
700+
/// type. If a shape is not provided, the current shape of the type is used.
701+
MemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
702+
Type elementType) const;
696703
}];
697704
let skipDefaultBuilders = 1;
698705
let genVerifyDecl = 1;
@@ -773,8 +780,8 @@ def Builtin_Opaque : Builtin_Type<"Opaque", "opaque"> {
773780
//===----------------------------------------------------------------------===//
774781

775782
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
776-
ShapedTypeInterface, ValueSemantics
777-
], "TensorType"> {
783+
TensorTypeInterface, ValueSemantics
784+
]> {
778785
let summary = "Multi-dimensional array with a fixed number of dimensions";
779786
let description = [{
780787
Syntax:
@@ -855,7 +862,7 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
855862
}]>
856863
];
857864
let extraClassDeclaration = [{
858-
using TensorType::clone;
865+
using ShapedType::Trait<RankedTensorType>::clone;
859866
using ShapedType::Trait<RankedTensorType>::getElementTypeBitWidth;
860867
using ShapedType::Trait<RankedTensorType>::getRank;
861868
using ShapedType::Trait<RankedTensorType>::getNumElements;
@@ -869,11 +876,12 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "tensor", [
869876
/// Arguments that are passed into the builder must outlive the builder.
870877
class Builder;
871878

872-
/// Return a clone of this type with the given new element type and the same
873-
/// shape as this type.
874-
RankedTensorType clone(::mlir::Type elementType) {
875-
return ::llvm::cast<RankedTensorType>(cloneWith(getShape(), elementType));
876-
}
879+
/// Returns if this type is ranked (always true).
880+
bool hasRank() const { return true; }
881+
882+
/// Returns a clone of this type with the given shape and element type.
883+
RankedTensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
884+
Type elementType) const;
877885
}];
878886
let skipDefaultBuilders = 1;
879887
let genVerifyDecl = 1;
@@ -951,8 +959,8 @@ def Builtin_Tuple : Builtin_Type<"Tuple", "tuple"> {
951959
//===----------------------------------------------------------------------===//
952960

953961
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
954-
ShapedTypeInterface
955-
], "BaseMemRefType"> {
962+
BaseMemRefTypeInterface
963+
]> {
956964
let summary = "Shaped reference, with unknown rank, to a region of memory";
957965
let description = [{
958966
Syntax:
@@ -998,7 +1006,7 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
9981006
}]>
9991007
];
10001008
let extraClassDeclaration = [{
1001-
using BaseMemRefType::clone;
1009+
using ShapedType::Trait<UnrankedMemRefType>::clone;
10021010
using ShapedType::Trait<UnrankedMemRefType>::getElementTypeBitWidth;
10031011
using ShapedType::Trait<UnrankedMemRefType>::getRank;
10041012
using ShapedType::Trait<UnrankedMemRefType>::getNumElements;
@@ -1014,11 +1022,12 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
10141022
/// New `Attribute getMemorySpace()` method should be used instead.
10151023
unsigned getMemorySpaceAsInt() const;
10161024

1017-
/// Return a clone of this type with the given new element type and the same
1018-
/// shape as this type.
1019-
MemRefType clone(::mlir::Type elementType) {
1020-
return ::llvm::cast<MemRefType>(cloneWith(getShape(), elementType));
1021-
}
1025+
/// Returns if this type is ranked (always false).
1026+
bool hasRank() const { return false; }
1027+
1028+
/// Returns a clone of this type with the given shape and element type.
1029+
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
1030+
Type elementType) const;
10221031
}];
10231032
let skipDefaultBuilders = 1;
10241033
let genVerifyDecl = 1;
@@ -1029,8 +1038,8 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "unranked_memref", [
10291038
//===----------------------------------------------------------------------===//
10301039

10311040
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
1032-
ShapedTypeInterface, ValueSemantics
1033-
], "TensorType"> {
1041+
TensorTypeInterface, ValueSemantics
1042+
]> {
10341043
let summary = "Multi-dimensional array with unknown dimensions";
10351044
let description = [{
10361045
Syntax:
@@ -1057,7 +1066,7 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10571066
}]>
10581067
];
10591068
let extraClassDeclaration = [{
1060-
using TensorType::clone;
1069+
using ShapedType::Trait<UnrankedTensorType>::clone;
10611070
using ShapedType::Trait<UnrankedTensorType>::getElementTypeBitWidth;
10621071
using ShapedType::Trait<UnrankedTensorType>::getRank;
10631072
using ShapedType::Trait<UnrankedTensorType>::getNumElements;
@@ -1068,6 +1077,13 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "unranked_tensor", [
10681077
using ShapedType::Trait<UnrankedTensorType>::getDynamicDimIndex;
10691078

10701079
ArrayRef<int64_t> getShape() const { return std::nullopt; }
1080+
1081+
/// Returns if this type is ranked (always false).
1082+
bool hasRank() const { return false; }
1083+
1084+
/// Returns a clone of this type with the given shape and element type.
1085+
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
1086+
Type elementType) const;
10711087
}];
10721088
let skipDefaultBuilders = 1;
10731089
let genVerifyDecl = 1;

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -711,8 +711,7 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
711711
if (llvm::isa<TensorType>(opResult.getType())) {
712712
// The OpResult is a tensor. Such values are replaced with memrefs during
713713
// bufferization.
714-
assert((llvm::isa<MemRefType>(replacement.getType()) ||
715-
llvm::isa<UnrankedMemRefType>(replacement.getType())) &&
714+
assert(llvm::isa<BaseMemRefType>(replacement.getType()) &&
716715
"tensor op result should be replaced with a memref value");
717716
// The existing uses of the OpResult still expect a tensor. Insert a
718717
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1515
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1616
#include "mlir/IR/Matchers.h"
17+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1718
#include <optional>
1819

1920
using namespace mlir;

mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
153153

154154
/// Returns true if the given type has the default memory space.
155155
static bool hasDefaultMemorySpace(BaseMemRefType type) {
156-
return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
156+
return !type.getMemorySpace();
157157
}
158158

159159
/// Returns true if the given type has the shared (workgroup) memory space.

0 commit comments

Comments
 (0)