@@ -43,6 +43,108 @@ template <typename ConcreteType>
4343class ValueSemantics
4444 : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4545
46+ // ===----------------------------------------------------------------------===//
47+ // TensorType
48+ // ===----------------------------------------------------------------------===//
49+
50+ // / Tensor types represent multi-dimensional arrays, and have two variants:
51+ // / RankedTensorType and UnrankedTensorType.
52+ // / Note: This class attaches the ShapedType trait to act as a mixin to
53+ // / provide many useful utility functions. This inheritance has no effect
54+ // / on derived tensor types.
55+ class TensorType : public Type , public ShapedType ::Trait<TensorType> {
56+ public:
57+ using Type::Type;
58+
59+ // / Returns the element type of this tensor type.
60+ Type getElementType () const ;
61+
62+ // / Returns if this type is ranked, i.e. it has a known number of dimensions.
63+ bool hasRank () const ;
64+
65+ // / Returns the shape of this tensor type.
66+ ArrayRef<int64_t > getShape () const ;
67+
68+ // / Clone this type with the given shape and element type. If the
69+ // / provided shape is `std::nullopt`, the current shape of the type is used.
70+ TensorType cloneWith (std::optional<ArrayRef<int64_t >> shape,
71+ Type elementType) const ;
72+
73+ // Make sure that base class overloads are visible.
74+ using ShapedType::Trait<TensorType>::clone;
75+
76+ // / Return a clone of this type with the given new shape and element type.
77+ // / The returned type is ranked, even if this type is unranked.
78+ RankedTensorType clone (ArrayRef<int64_t > shape, Type elementType) const ;
79+
80+ // / Return a clone of this type with the given new shape. The returned type
81+ // / is ranked, even if this type is unranked.
82+ RankedTensorType clone (ArrayRef<int64_t > shape) const ;
83+
84+ // / Return true if the specified element type is ok in a tensor.
85+ static bool isValidElementType (Type type);
86+
87+ // / Methods for support type inquiry through isa, cast, and dyn_cast.
88+ static bool classof (Type type);
89+
90+ // / Allow implicit conversion to ShapedType.
91+ operator ShapedType () const { return llvm::cast<ShapedType>(*this ); }
92+ };
93+
94+ // ===----------------------------------------------------------------------===//
95+ // BaseMemRefType
96+ // ===----------------------------------------------------------------------===//
97+
98+ // / This class provides a shared interface for ranked and unranked memref types.
99+ // / Note: This class attaches the ShapedType trait to act as a mixin to
100+ // / provide many useful utility functions. This inheritance has no effect
101+ // / on derived memref types.
102+ class BaseMemRefType : public Type , public ShapedType ::Trait<BaseMemRefType> {
103+ public:
104+ using Type::Type;
105+
106+ // / Returns the element type of this memref type.
107+ Type getElementType () const ;
108+
109+ // / Returns if this type is ranked, i.e. it has a known number of dimensions.
110+ bool hasRank () const ;
111+
112+ // / Returns the shape of this memref type.
113+ ArrayRef<int64_t > getShape () const ;
114+
115+ // / Clone this type with the given shape and element type. If the
116+ // / provided shape is `std::nullopt`, the current shape of the type is used.
117+ BaseMemRefType cloneWith (std::optional<ArrayRef<int64_t >> shape,
118+ Type elementType) const ;
119+
120+ // Make sure that base class overloads are visible.
121+ using ShapedType::Trait<BaseMemRefType>::clone;
122+
123+ // / Return a clone of this type with the given new shape and element type.
124+ // / The returned type is ranked, even if this type is unranked.
125+ MemRefType clone (ArrayRef<int64_t > shape, Type elementType) const ;
126+
127+ // / Return a clone of this type with the given new shape. The returned type
128+ // / is ranked, even if this type is unranked.
129+ MemRefType clone (ArrayRef<int64_t > shape) const ;
130+
131+ // / Return true if the specified element type is ok in a memref.
132+ static bool isValidElementType (Type type);
133+
134+ // / Methods for support type inquiry through isa, cast, and dyn_cast.
135+ static bool classof (Type type);
136+
137+ // / Returns the memory space in which data referred to by this memref resides.
138+ Attribute getMemorySpace () const ;
139+
140+ // / [deprecated] Returns the memory space in old raw integer representation.
141+ // / New `Attribute getMemorySpace()` method should be used instead.
142+ unsigned getMemorySpaceAsInt () const ;
143+
144+ // / Allow implicit conversion to ShapedType.
145+ operator ShapedType () const { return llvm::cast<ShapedType>(*this ); }
146+ };
147+
46148} // namespace mlir
47149
48150// ===----------------------------------------------------------------------===//
@@ -53,6 +155,7 @@ class ValueSemantics
53155#include " mlir/IR/BuiltinTypes.h.inc"
54156
55157namespace mlir {
158+ #include " mlir/IR/BuiltinTypeConstraints.h.inc"
56159
57160// ===----------------------------------------------------------------------===//
58161// MemRefType
@@ -250,7 +353,7 @@ enum class SliceVerificationResult {
250353// / code.
251354SliceVerificationResult isRankReducedType (ShapedType originalType,
252355 ShapedType candidateReducedType);
253-
356+
254357// ===----------------------------------------------------------------------===//
255358// Convenience wrappers for VectorType
256359//
@@ -287,44 +390,25 @@ class FixedVectorType : public VectorType {
287390// Deferred Method Definitions
288391// ===----------------------------------------------------------------------===//
289392
393+ inline bool BaseMemRefType::classof (Type type) {
394+ return llvm::isa<MemRefType, UnrankedMemRefType>(type);
395+ }
396+
290397inline bool BaseMemRefType::isValidElementType (Type type) {
291398 return type.isIntOrIndexOrFloat () ||
292399 llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
293400 type) ||
294401 llvm::isa<MemRefElementTypeInterface>(type);
295402}
296403
404+ inline bool TensorType::classof (Type type) {
405+ return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
406+ }
407+
297408// ===----------------------------------------------------------------------===//
298409// Type Utilities
299410// ===----------------------------------------------------------------------===//
300411
301- // / Returns the strides of the MemRef if the layout map is in strided form.
302- // / MemRefs with a layout map in strided form include:
303- // / 1. empty or identity layout map, in which case the stride information is
304- // / the canonical form computed from sizes;
305- // / 2. a StridedLayoutAttr layout;
306- // / 3. any other layout that be converted into a single affine map layout of
307- // / the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
308- // / symbols.
309- // /
310- // / A stride specification is a list of integer values that are either static
311- // / or dynamic (encoded with ShapedType::kDynamic). Strides encode
312- // / the distance in the number of elements between successive entries along a
313- // / particular dimension.
314- LogicalResult getStridesAndOffset (MemRefType t,
315- SmallVectorImpl<int64_t > &strides,
316- int64_t &offset);
317-
318- // / Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
319- // / int64_t) that will assert if the logical result is not succeeded.
320- std::pair<SmallVector<int64_t >, int64_t > getStridesAndOffset (MemRefType t);
321-
322- // / Return a version of `t` with identity layout if it can be determined
323- // / statically that the layout is the canonical contiguous strided layout.
324- // / Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
325- // / `t` with simplified layout.
326- MemRefType canonicalizeStridedLayout (MemRefType t);
327-
328412// / Given MemRef `sizes` that are either static or dynamic, returns the
329413// / canonical "contiguous" strides AffineExpr. Strides are multiplicative and
330414// / once a dynamic dimension is encountered, all canonical strides become
@@ -347,24 +431,6 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
347431// / where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
348432AffineExpr makeCanonicalStridedLayoutExpr (ArrayRef<int64_t > sizes,
349433 MLIRContext *context);
350-
351- // / Return "true" if the layout for `t` is compatible with strided semantics.
352- bool isStrided (MemRefType t);
353-
354- // / Return "true" if the last dimension of the given type has a static unit
355- // / stride. Also return "true" for types with no strides.
356- bool isLastMemrefDimUnitStride (MemRefType type);
357-
358- // / Return "true" if the last N dimensions of the given type are contiguous.
359- // /
360- // / Examples:
361- // / - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
362- // / considering both _all_ and _only_ the trailing 3 dims,
363- // / - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
364- // / considering the trailing 3 dims.
365- // /
366- bool trailingNDimsContiguous (MemRefType type, int64_t n);
367-
368434} // namespace mlir
369435
370436#endif // MLIR_IR_BUILTINTYPES_H
0 commit comments