@@ -43,108 +43,6 @@ template <typename ConcreteType>
4343class ValueSemantics
4444 : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
4545
46- // ===----------------------------------------------------------------------===//
47- // TensorType
48- // ===----------------------------------------------------------------------===//
49-
50- // / Tensor types represent multi-dimensional arrays, and have two variants:
51- // / RankedTensorType and UnrankedTensorType.
52- // / Note: This class attaches the ShapedType trait to act as a mixin to
53- // / provide many useful utility functions. This inheritance has no effect
54- // / on derived tensor types.
55- class TensorType : public Type , public ShapedType ::Trait<TensorType> {
56- public:
57- using Type::Type;
58-
59- // / Returns the element type of this tensor type.
60- Type getElementType () const ;
61-
62- // / Returns if this type is ranked, i.e. it has a known number of dimensions.
63- bool hasRank () const ;
64-
65- // / Returns the shape of this tensor type.
66- ArrayRef<int64_t > getShape () const ;
67-
68- // / Clone this type with the given shape and element type. If the
69- // / provided shape is `std::nullopt`, the current shape of the type is used.
70- TensorType cloneWith (std::optional<ArrayRef<int64_t >> shape,
71- Type elementType) const ;
72-
73- // Make sure that base class overloads are visible.
74- using ShapedType::Trait<TensorType>::clone;
75-
76- // / Return a clone of this type with the given new shape and element type.
77- // / The returned type is ranked, even if this type is unranked.
78- RankedTensorType clone (ArrayRef<int64_t > shape, Type elementType) const ;
79-
80- // / Return a clone of this type with the given new shape. The returned type
81- // / is ranked, even if this type is unranked.
82- RankedTensorType clone (ArrayRef<int64_t > shape) const ;
83-
84- // / Return true if the specified element type is ok in a tensor.
85- static bool isValidElementType (Type type);
86-
87- // / Methods for support type inquiry through isa, cast, and dyn_cast.
88- static bool classof (Type type);
89-
90- // / Allow implicit conversion to ShapedType.
91- operator ShapedType () const { return llvm::cast<ShapedType>(*this ); }
92- };
93-
94- // ===----------------------------------------------------------------------===//
95- // BaseMemRefType
96- // ===----------------------------------------------------------------------===//
97-
98- // / This class provides a shared interface for ranked and unranked memref types.
99- // / Note: This class attaches the ShapedType trait to act as a mixin to
100- // / provide many useful utility functions. This inheritance has no effect
101- // / on derived memref types.
102- class BaseMemRefType : public Type , public ShapedType ::Trait<BaseMemRefType> {
103- public:
104- using Type::Type;
105-
106- // / Returns the element type of this memref type.
107- Type getElementType () const ;
108-
109- // / Returns if this type is ranked, i.e. it has a known number of dimensions.
110- bool hasRank () const ;
111-
112- // / Returns the shape of this memref type.
113- ArrayRef<int64_t > getShape () const ;
114-
115- // / Clone this type with the given shape and element type. If the
116- // / provided shape is `std::nullopt`, the current shape of the type is used.
117- BaseMemRefType cloneWith (std::optional<ArrayRef<int64_t >> shape,
118- Type elementType) const ;
119-
120- // Make sure that base class overloads are visible.
121- using ShapedType::Trait<BaseMemRefType>::clone;
122-
123- // / Return a clone of this type with the given new shape and element type.
124- // / The returned type is ranked, even if this type is unranked.
125- MemRefType clone (ArrayRef<int64_t > shape, Type elementType) const ;
126-
127- // / Return a clone of this type with the given new shape. The returned type
128- // / is ranked, even if this type is unranked.
129- MemRefType clone (ArrayRef<int64_t > shape) const ;
130-
131- // / Return true if the specified element type is ok in a memref.
132- static bool isValidElementType (Type type);
133-
134- // / Methods for support type inquiry through isa, cast, and dyn_cast.
135- static bool classof (Type type);
136-
137- // / Returns the memory space in which data referred to by this memref resides.
138- Attribute getMemorySpace () const ;
139-
140- // / [deprecated] Returns the memory space in old raw integer representation.
141- // / New `Attribute getMemorySpace()` method should be used instead.
142- unsigned getMemorySpaceAsInt () const ;
143-
144- // / Allow implicit conversion to ShapedType.
145- operator ShapedType () const { return llvm::cast<ShapedType>(*this ); }
146- };
147-
14846} // namespace mlir
14947
15048// ===----------------------------------------------------------------------===//
@@ -155,7 +53,6 @@ class BaseMemRefType : public Type, public ShapedType::Trait<BaseMemRefType> {
15553#include " mlir/IR/BuiltinTypes.h.inc"
15654
15755namespace mlir {
158- #include " mlir/IR/BuiltinTypeConstraints.h.inc"
15956
16057// ===----------------------------------------------------------------------===//
16158// MemRefType
@@ -353,7 +250,7 @@ enum class SliceVerificationResult {
353250// / code.
354251SliceVerificationResult isRankReducedType (ShapedType originalType,
355252 ShapedType candidateReducedType);
356-
253+
357254// ===----------------------------------------------------------------------===//
358255// Convenience wrappers for VectorType
359256//
@@ -390,25 +287,44 @@ class FixedVectorType : public VectorType {
390287// Deferred Method Definitions
391288// ===----------------------------------------------------------------------===//
392289
393- inline bool BaseMemRefType::classof (Type type) {
394- return llvm::isa<MemRefType, UnrankedMemRefType>(type);
395- }
396-
397290inline bool BaseMemRefType::isValidElementType (Type type) {
398291 return type.isIntOrIndexOrFloat () ||
399292 llvm::isa<ComplexType, MemRefType, VectorType, UnrankedMemRefType>(
400293 type) ||
401294 llvm::isa<MemRefElementTypeInterface>(type);
402295}
403296
404- inline bool TensorType::classof (Type type) {
405- return llvm::isa<RankedTensorType, UnrankedTensorType>(type);
406- }
407-
408297// ===----------------------------------------------------------------------===//
409298// Type Utilities
410299// ===----------------------------------------------------------------------===//
411300
301+ // / Returns the strides of the MemRef if the layout map is in strided form.
302+ // / MemRefs with a layout map in strided form include:
303+ // / 1. empty or identity layout map, in which case the stride information is
304+ // / the canonical form computed from sizes;
305+ // / 2. a StridedLayoutAttr layout;
306+ // / 3. any other layout that be converted into a single affine map layout of
307+ // / the form `K + k0 * d0 + ... kn * dn`, where K and ki's are constants or
308+ // / symbols.
309+ // /
310+ // / A stride specification is a list of integer values that are either static
311+ // / or dynamic (encoded with ShapedType::kDynamic). Strides encode
312+ // / the distance in the number of elements between successive entries along a
313+ // / particular dimension.
314+ LogicalResult getStridesAndOffset (MemRefType t,
315+ SmallVectorImpl<int64_t > &strides,
316+ int64_t &offset);
317+
318+ // / Wrapper around getStridesAndOffset(MemRefType, SmallVectorImpl<int64_t>,
319+ // / int64_t) that will assert if the logical result is not succeeded.
320+ std::pair<SmallVector<int64_t >, int64_t > getStridesAndOffset (MemRefType t);
321+
322+ // / Return a version of `t` with identity layout if it can be determined
323+ // / statically that the layout is the canonical contiguous strided layout.
324+ // / Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
325+ // / `t` with simplified layout.
326+ MemRefType canonicalizeStridedLayout (MemRefType t);
327+
412328// / Given MemRef `sizes` that are either static or dynamic, returns the
413329// / canonical "contiguous" strides AffineExpr. Strides are multiplicative and
414330// / once a dynamic dimension is encountered, all canonical strides become
@@ -431,6 +347,24 @@ AffineExpr makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
431347// / where `exprs` is {d0, d1, .., d_(sizes.size()-1)}
432348AffineExpr makeCanonicalStridedLayoutExpr (ArrayRef<int64_t > sizes,
433349 MLIRContext *context);
350+
351+ // / Return "true" if the layout for `t` is compatible with strided semantics.
352+ bool isStrided (MemRefType t);
353+
354+ // / Return "true" if the last dimension of the given type has a static unit
355+ // / stride. Also return "true" for types with no strides.
356+ bool isLastMemrefDimUnitStride (MemRefType type);
357+
358+ // / Return "true" if the last N dimensions of the given type are contiguous.
359+ // /
360+ // / Examples:
361+ // / - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
362+ // / considering both _all_ and _only_ the trailing 3 dims,
363+ // / - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
364+ // / considering the trailing 3 dims.
365+ // /
366+ bool trailingNDimsContiguous (MemRefType type, int64_t n);
367+
434368} // namespace mlir
435369
436370#endif // MLIR_IR_BUILTINTYPES_H
0 commit comments