Skip to content

Commit 51f2bc8

Browse files
xin-zhang-intelsramasit
authored andcommitted
PR #38: Replace BaseMemRef/TensorType class with TypeInterface
1 parent 92f43ec commit 51f2bc8

File tree

9 files changed

+208
-259
lines changed

9 files changed

+208
-259
lines changed

mlir/include/mlir/IR/BuiltinTypeInterfaces.td

Lines changed: 70 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -143,24 +143,20 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
143143

144144
/// Return the number of elements present in the given shape.
145145
static int64_t getNumElements(ArrayRef<int64_t> shape);
146+
}];
146147

148+
let extraSharedClassDeclaration = [{
147149
/// Return a clone of this type with the given new shape and element type.
148-
/// The returned type is ranked, even if this type is unranked.
149150
auto clone(::llvm::ArrayRef<int64_t> shape, Type elementType) {
150-
return cloneWith(shape, elementType);
151+
return $_type.cloneWith(shape, elementType);
151152
}
152153

153-
/// Return a clone of this type with the given new shape. The returned type
154-
/// is ranked, even if this type is unranked.
154+
/// Return a clone of this type with the given new shape.
155155
auto clone(::llvm::ArrayRef<int64_t> shape) {
156-
return cloneWith(shape, getElementType());
156+
return $_type.cloneWith(shape, $_type.getElementType());
157157
}
158-
}];
159158

160-
let extraSharedClassDeclaration = [{
161-
/// Return a clone of this type with the given new element type. The
162-
/// returned type is ranked if and only if this type is ranked. In that
163-
/// case, the returned type has the same shape as this type.
159+
/// Return a clone of this type with the given new element type.
164160
auto clone(::mlir::Type elementType) {
165161
return $_type.cloneWith(/*shape=*/std::nullopt, elementType);
166162
}
@@ -227,4 +223,68 @@ def ShapedTypeInterface : TypeInterface<"ShapedType"> {
227223
}];
228224
}
229225

226+
//===----------------------------------------------------------------------===//
227+
// TensorTypeInterface
228+
//===----------------------------------------------------------------------===//
229+
230+
def TensorTypeInterface : TypeInterface<"TensorType", [
231+
ShapedTypeInterface
232+
]
233+
> {
234+
let cppNamespace = "::mlir";
235+
let description = [{
236+
This interface provides a shared interface for ranked and unranked type
237+
and customized tensor types.
238+
239+
This class attaches the ShapedTypeInterface to act as a mixin to
240+
provide many useful utility functions.
241+
}];
242+
243+
let extraClassDeclaration = [{
244+
// Return true if the specified element type is ok in a tensor.
245+
static bool isValidElementType(::mlir::Type type);
246+
}];
247+
248+
let extraClassOf = [{
249+
return $_type.hasTrait<::mlir::TensorType::Trait>();
250+
}];
251+
252+
}
253+
254+
//===----------------------------------------------------------------------===//
255+
// BaseMemRefTypeInterface
256+
//===----------------------------------------------------------------------===//
257+
258+
def BaseMemRefTypeInterface : TypeInterface<"BaseMemRefType", [
259+
ShapedTypeInterface
260+
]
261+
> {
262+
let cppNamespace = "::mlir";
263+
let description = [{
264+
This interface provides a shared interface for ranked and unranked memref and
265+
customized memref types.
266+
267+
This interface attaches the ShapedTypeInterface to act as a mixin to
268+
provide many useful utility functions.
269+
}];
270+
271+
let methods = [
272+
InterfaceMethod<[{
273+
Returns the memory space in which data referred to by this memref resides.
274+
}],
275+
"::mlir::Attribute", "getMemorySpace">,
276+
];
277+
278+
let extraClassDeclaration = [{
279+
// Return true if the specified element type is ok in a memref.
280+
static bool isValidElementType(::mlir::Type type);
281+
282+
unsigned getMemorySpaceAsInt() const;
283+
}];
284+
285+
let extraClassOf = [{
286+
return $_type.hasTrait<::mlir::BaseMemRefType::Trait>();
287+
}];
288+
}
289+
230290
#endif // MLIR_IR_BUILTINTYPEINTERFACES_TD_

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 46 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -43,108 +43,6 @@ template <typename ConcreteType>
4343
class ValueSemantics
4444
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4545

46-
//===----------------------------------------------------------------------===//
47-
// TensorType
48-
//===----------------------------------------------------------------------===//
49-
50-
/// Tensor types represent multi-dimensional arrays, and have two variants:
51-
/// RankedTensorType and UnrankedTensorType.
52-
/// Note: This class attaches the ShapedType trait to act as a mixin to
53-
/// provide many useful utility functions. This inheritance has no effect
54-
/// on derived tensor types.
55-
class TensorType : public Type, public ShapedType::Trait<TensorType> {
56-
public:
57-
using Type::Type;
58-
59-
/// Returns the element type of this tensor type.
60-
Type getElementType() const;
61-
62-
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
63-
bool hasRank() const;
64-
65-
/// Returns the shape of this tensor type.
66-
ArrayRef<int64_t> getShape() const;
67-
68-
/// Clone this type with the given shape and element type. If the
69-
/// provided shape is `std::nullopt`, the current shape of the type is used.
70-
TensorType cloneWith(std::optional<ArrayRef<int64_t>> shape,
71-
Type elementType) const;
72-
73-
// Make sure that base class overloads are visible.
74-
using ShapedType::Trait<TensorType>::clone;
75-
76-
/// Return a clone of this type with the given new shape and element type.
77-
/// The returned type is ranked, even if this type is unranked.
78-
RankedTensorType clone(ArrayRef<int64_t> shape, Type elementType) const;
79-
80-
/// Return a clone of this type with the given new shape. The returned type
81-
/// is ranked, even if this type is unranked.
82-
RankedTensorType clone(ArrayRef<int64_t> shape) const;
83-
84-
/// Return true if the specified element type is ok in a tensor.
85-
static bool isValidElementType(Type type);
86-
87-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
88-
static bool classof(Type type);
89-
90-
/// Allow implicit conversion to ShapedType.
91-
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
92-
};
93-
94-
//===----------------------------------------------------------------------===//
95-
// BaseMemRefType
96-
//===----------------------------------------------------------------------===//
97-
98-
/// This class provides a shared interface for ranked and unranked memref types.
99-
/// Note: This class attaches the ShapedType trait to act as a mixin to
100-
/// provide many useful utility functions. This inheritance has no effect
101-
/// on derived memref types.
102-
class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
103-
public:
104-
using Type::Type;
105-
106-
/// Returns the element type of this memref type.
107-
Type getElementType() const;
108-
109-
/// Returns if this type is ranked, i.e. it has a known number of dimensions.
110-
bool hasRank() const;
111-
112-
/// Returns the shape of this memref type.
113-
ArrayRef<int64_t> getShape() const;
114-
115-
/// Clone this type with the given shape and element type. If the
116-
/// provided shape is `std::nullopt`, the current shape of the type is used.
117-
BaseMemRefType cloneWith(std::optional<ArrayRef<int64_t>> shape,
118-
Type elementType) const;
119-
120-
// Make sure that base class overloads are visible.
121-
using ShapedType::Trait<BaseMemRefType>::clone;
122-
123-
/// Return a clone of this type with the given new shape and element type.
124-
/// The returned type is ranked, even if this type is unranked.
125-
MemRefType clone(ArrayRef<int64_t> shape, Type elementType) const;
126-
127-
/// Return a clone of this type with the given new shape. The returned type
128-
/// is ranked, even if this type is unranked.
129-
MemRefType clone(ArrayRef<int64_t> shape) const;
130-
131-
/// Return true if the specified element type is ok in a memref.
132-
static bool isValidElementType(Type type);
133-
134-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
135-
static bool classof(Type type);
136-
137-
/// Returns the memory space in which data referred to by this memref resides.
138-
Attribute getMemorySpace() const;
139-
140-
/// [deprecated] Returns the memory space in old raw integer representation.
141-
/// New `Attribute getMemorySpace()` method should be used instead.
142-
unsigned getMemorySpaceAsInt() const;
143-
144-
/// Allow implicit conversion to ShapedType.
145-
operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
146-
};
147-
14846
} // namespace mlir
14947

15048
//===----------------------------------------------------------------------===//
@@ -155,7 +53,6 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
15553
#include "mlir/IR/BuiltinTypes.h.inc"
15654

15755
namespace mlir {
158-
#include "mlir/IR/BuiltinTypeConstraints.h.inc"
15956

16057
//===----------------------------------------------------------------------===//
16158
// MemRefType
@@ -353,7 +250,7 @@ enum class SliceVerificationResult {
353250
/// code.
354251
SliceVerificationResult isRankReducedType(ShapedType originalType,
355252
ShapedType candidateReducedType);
356-
253+
357254
//===----------------------------------------------------------------------===//
358255
// Convenience wrappers for VectorType
359256
//
@@ -390,25 +287,44 @@ class FixedVectorType : public VectorType {
390287
// Deferred Method Definitions
391288
//===----------------------------------------------------------------------===//
392289

393-
inline bool BaseMemRefType::classof(Type type) {
394-
return llvm::isa<MemRefType, UnrankedMemRefType>(type);
395-
}
396-
397290
inline bool BaseMemRefType::isValidElementType(Type type) {
398291
return type.isIntOrIndexOrFloat() ||
399292
llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
400293
type) ||
401294
llvm::isa<MemRefElementTypeInterface>(type);
402295
}
403296

404-
inline bool TensorType::classof(Type type) {
405-
return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
406-
}
407-
408297
//===----------------------------------------------------------------------===//
409298
// Type Utilities
410299
//===----------------------------------------------------------------------===//
411300

301+
/// Returns the strides of the MemRef if the layout map is in strided form.
302+
/// MemRefs with a layout map in strided form include:
303+
/// 1. empty or identity layout map, in which case the stride information is
304+
/// the canonical form computed from sizes;
305+
/// 2. a StridedLayoutAttr layout;
306+
/// 3. any other layout that be converted into a single affine map layout of
307+
/// the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
308+
/// symbols.
309+
///
310+
/// A stride specification is a list of integer values that are either static
311+
/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
312+
/// the distance in the number of elements between successive entries along a
313+
/// particular dimension.
314+
LogicalResult getStridesAndOffset(MemRefType t,
315+
SmallVectorImpl<int64_t> &strides,
316+
int64_t &offset);
317+
318+
/// Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
319+
/// int64_t) that will assert if the logical result is not succeeded.
320+
std::pair<SmallVector<int64_t>, int64_t> getStridesAndOffset(MemRefType t);
321+
322+
/// Return a version of `t` with identity layout if it can be determined
323+
/// statically that the layout is the canonical contiguous strided layout.
324+
/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
325+
/// `t` with simplified layout.
326+
MemRefType canonicalizeStridedLayout(MemRefType t);
327+
412328
/// Given MemRef `sizes` that are either static or dynamic, returns the
413329
/// canonical "contiguous" strides AffineExpr. Strides are multiplicative and
414330
/// once a dynamic dimension is encountered, all canonical strides become
@@ -431,6 +347,24 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
431347
/// where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
432348
AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
433349
MLIRContext *context);
350+
351+
/// Return "true" if the layout for `t` is compatible with strided semantics.
352+
bool isStrided(MemRefType t);
353+
354+
/// Return "true" if the last dimension of the given type has a static unit
355+
/// stride. Also return "true" for types with no strides.
356+
bool isLastMemrefDimUnitStride(MemRefType type);
357+
358+
/// Return "true" if the last N dimensions of the given type are contiguous.
359+
///
360+
/// Examples:
361+
/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
362+
/// considering both _all_ and _only_ the trailing 3 dims,
363+
/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
364+
/// considering the trailing 3 dims.
365+
///
366+
bool trailingNDimsContiguous(MemRefType type, int64_t n);
367+
434368
} // namespace mlir
435369

436370
#endif // MLIR_IR_BUILTINTYPES_H

0 commit comments

Comments
 (0)