Skip to content

Commit 371d0f5

Browse files
Revert "PR #38: Replace BaseMemRef/TensorType class with TypeInterface"
This reverts commit 51f2bc8.
1 parent 68190c6 commit 371d0f5

File tree

9 files changed

+259
-208
lines changed

9 files changed

+259
-208
lines changed

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,24 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
143143

144144
/// Return the number of elements present in the given shape.
145145
static int64_t getNumElements(ArrayRef<int64_t> shape);
146-
}];
147146

148-
let extraSharedClassDeclaration = [{
149147
/// Return a clone of this type with the given new shape and element type.
148+
/// The returned type is ranked, even if this type is unranked.
150149
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
151-
return $_type.cloneWith(shape, elementType);
150+
return cloneWith(shape, elementType);
152151
}
153152

154-
/// Return a clone of this type with the given new shape.
153+
/// Return a clone of this type with the given new shape. The returned type
154+
/// is ranked, even if this type is unranked.
155155
auto clone(::llvm::ArrayRef<int64_t> shape) {
156-
return $_type.cloneWith(shape, $_type.getElementType());
156+
return cloneWith(shape, getElementType());
157157
}
158+
}];
158159

159-
/// Return a clone of this type with the given new element type.
160+
let extraSharedClassDeclaration = [{
161+
/// Return a clone of this type with the given new element type. The
162+
/// returned type is ranked if and only if this type is ranked. In that
163+
/// case, the returned type has the same shape as this type.
160164
auto clone(::mlir::Type elementType) {
161165
return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
162166
}
@@ -223,68 +227,4 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
223227
}];
224228
}
225229

226-
//===----------------------------------------------------------------------===//
227-
// TensorTypeInterface
228-
//===----------------------------------------------------------------------===//
229-
230-
def TensorTypeInterface : TypeInterface<"TensorType", [
231-
ShapedTypeInterface
232-
]
233-
> {
234-
let cppNamespace = "::mlir";
235-
let description = [{
236-
This interface provides a shared interface for ranked and unranked type
237-
and customized tensor types.
238-
239-
This class attaches the ShapedTypeInterface to act as a mixin to
240-
provide many useful utility functions.
241-
}];
242-
243-
let extraClassDeclaration = [{
244-
// Return true if the specified element type is ok in a tensor.
245-
static bool isValidElementType(::mlir::Type type);
246-
}];
247-
248-
let extraClassOf = [{
249-
return $_type.hasTrait<::mlir::TensorType::Trait>();
250-
}];
251-
252-
}
253-
254-
//===----------------------------------------------------------------------===//
255-
// BaseMemRefTypeInterface
256-
//===----------------------------------------------------------------------===//
257-
258-
def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [
259-
ShapedTypeInterface
260-
]
261-
> {
262-
let cppNamespace = "::mlir";
263-
let description = [{
264-
This interface provides a shared interface for ranked and unranked memref and
265-
customized memref types.
266-
267-
This interface attaches the ShapedTypeInterface to act as a mixin to
268-
provide many useful utility functions.
269-
}];
270-
271-
let methods = [
272-
InterfaceMethod<[{
273-
Returns the memory space in which data referred to by this memref resides.
274-
}],
275-
"::mlir::Attribute", "getMemorySpace">,
276-
];
277-
278-
let extraClassDeclaration = [{
279-
// Return true if the specified element type is ok in a memref.
280-
static bool isValidElementType(::mlir::Type type);
281-
282-
unsigned getMemorySpaceAsInt() const;
283-
}];
284-
285-
let extraClassOf = [{
286-
return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
287-
}];
288-
}
289-
290230
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 112 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,108 @@ template <typename ConcreteType>
4343
class ValueSemantics
4444
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4545

46+
//===----------------------------------------------------------------------===//
47+
// TensorType
48+
//===----------------------------------------------------------------------===//
49+
50+
/// Tensor types represent multi-dimensional arrays, and have two variants:
51+
/// RankedTensorType and UnrankedTensorType.
52+
/// Note: This class attaches the ShapedType trait to act as a mixin to
53+
/// provide many useful utility functions. This inheritance has no effect
54+
/// on derived tensor types.
55+
class TensorType : public Type, public ShapedType::Trait<TensorType> {
56+
public:
57+
using Type::Type;
58+
59+
/// Returns the element type of this tensor type.
60+
Type getElementType() const;
61+
62+
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
63+
bool hasRank() const;
64+
65+
/// Returns the shape of this tensor type.
66+
ArrayRef<int64_t> getShape() const;
67+
68+
/// Clone this type with the given shape and element type. If the
69+
/// provided shape is `std::nullopt`, the current shape of the type is used.
70+
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
71+
Type elementType) const;
72+
73+
// Make sure that base class overloads are visible.
74+
using ShapedType::Trait<TensorType>::clone;
75+
76+
/// Return a clone of this type with the given new shape and element type.
77+
/// The returned type is ranked, even if this type is unranked.
78+
RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
79+
80+
/// Return a clone of this type with the given new shape. The returned type
81+
/// is ranked, even if this type is unranked.
82+
RankedTensorType clone(ArrayRef<int64_t> shape) const;
83+
84+
/// Return true if the specified element type is ok in a tensor.
85+
static bool isValidElementType(Type type);
86+
87+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
88+
static bool classof(Type type);
89+
90+
/// Allow implicit conversion to ShapedType.
91+
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
92+
};
93+
94+
//===----------------------------------------------------------------------===//
95+
// BaseMemRefType
96+
//===----------------------------------------------------------------------===//
97+
98+
/// This class provides a shared interface for ranked and unranked memref types.
99+
/// Note: This class attaches the ShapedType trait to act as a mixin to
100+
/// provide many useful utility functions. This inheritance has no effect
101+
/// on derived memref types.
102+
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
103+
public:
104+
using Type::Type;
105+
106+
/// Returns the element type of this memref type.
107+
Type getElementType() const;
108+
109+
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
110+
bool hasRank() const;
111+
112+
/// Returns the shape of this memref type.
113+
ArrayRef<int64_t> getShape() const;
114+
115+
/// Clone this type with the given shape and element type. If the
116+
/// provided shape is `std::nullopt`, the current shape of the type is used.
117+
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
118+
Type elementType) const;
119+
120+
// Make sure that base class overloads are visible.
121+
using ShapedType::Trait<BaseMemRefType>::clone;
122+
123+
/// Return a clone of this type with the given new shape and element type.
124+
/// The returned type is ranked, even if this type is unranked.
125+
MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
126+
127+
/// Return a clone of this type with the given new shape. The returned type
128+
/// is ranked, even if this type is unranked.
129+
MemRefType clone(ArrayRef<int64_t> shape) const;
130+
131+
/// Return true if the specified element type is ok in a memref.
132+
static bool isValidElementType(Type type);
133+
134+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
135+
static bool classof(Type type);
136+
137+
/// Returns the memory space in which data referred to by this memref resides.
138+
Attribute getMemorySpace() const;
139+
140+
/// [deprecated] Returns the memory space in old raw integer representation.
141+
/// New `Attribute getMemorySpace()` method should be used instead.
142+
unsigned getMemorySpaceAsInt() const;
143+
144+
/// Allow implicit conversion to ShapedType.
145+
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
146+
};
147+
46148
} // namespace mlir
47149

48150
//===----------------------------------------------------------------------===//
@@ -53,6 +155,7 @@ class ValueSemantics
53155
#include "mlir/IR/BuiltinTypes.h.inc"
54156

55157
namespace mlir {
158+
#include "mlir/IR/BuiltinTypeConstraints.h.inc"
56159

57160
//===----------------------------------------------------------------------===//
58161
// MemRefType
@@ -250,7 +353,7 @@ enum class SliceVerificationResult {
250353
/// code.
251354
SliceVerificationResult isRankReducedType(ShapedType originalType,
252355
ShapedType candidateReducedType);
253-
356+
254357
//===----------------------------------------------------------------------===//
255358
// Convenience wrappers for VectorType
256359
//
@@ -287,44 +390,25 @@ class FixedVectorType : public VectorType {
287390
// Deferred Method Definitions
288391
//===----------------------------------------------------------------------===//
289392

393+
inline bool BaseMemRefType::classof(Type type) {
394+
return llvm::isa<MemRefType, UnrankedMemRefType>(type);
395+
}
396+
290397
inline bool BaseMemRefType::isValidElementType(Type type) {
291398
return type.isIntOrIndexOrFloat() ||
292399
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
293400
type) ||
294401
llvm::isa<MemRefElementTypeInterface>(type);
295402
}
296403

404+
inline bool TensorType::classof(Type type) {
405+
return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
406+
}
407+
297408
//===----------------------------------------------------------------------===//
298409
// Type Utilities
299410
//===----------------------------------------------------------------------===//
300411

301-
/// Returns the strides of the MemRef if the layout map is in strided form.
302-
/// MemRefs with a layout map in strided form include:
303-
/// 1. empty or identity layout map, in which case the stride information is
304-
/// the canonical form computed from sizes;
305-
/// 2. a StridedLayoutAttr layout;
306-
/// 3. any other layout that be converted into a single affine map layout of
307-
/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
308-
/// symbols.
309-
///
310-
/// A stride specification is a list of integer values that are either static
311-
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
312-
/// the distance in the number of elements between successive entries along a
313-
/// particular dimension.
314-
LogicalResult getStridesAndOffset(MemRefType t,
315-
SmallVectorImpl<int64_t> &strides,
316-
int64_t &offset);
317-
318-
/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
319-
/// int64_t) that will assert if the logical result is not succeeded.
320-
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
321-
322-
/// Return a version of `t` with identity layout if it can be determined
323-
/// statically that the layout is the canonical contiguous strided layout.
324-
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
325-
/// `t` with simplified layout.
326-
MemRefType canonicalizeStridedLayout(MemRefType t);
327-
328412
/// Given MemRef `sizes` that are either static or dynamic, returns the
329413
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
330414
/// once a dynamic dimension is encountered, all canonical strides become
@@ -347,24 +431,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
347431
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
348432
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
349433
MLIRContext *context);
350-
351-
/// Return "true" if the layout for `t` is compatible with strided semantics.
352-
bool isStrided(MemRefType t);
353-
354-
/// Return "true" if the last dimension of the given type has a static unit
355-
/// stride. Also return "true" for types with no strides.
356-
bool isLastMemrefDimUnitStride(MemRefType type);
357-
358-
/// Return "true" if the last N dimensions of the given type are contiguous.
359-
///
360-
/// Examples:
361-
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
362-
/// considering both _all_ and _only_ the trailing 3 dims,
363-
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
364-
/// considering the trailing 3 dims.
365-
///
366-
bool trailingNDimsContiguous(MemRefType type, int64_t n);
367-
368434
} // namespace mlir
369435

370436
#endif // MLIR_IR_BUILTINTYPES_H

0 commit comments

Comments
 (0)