Skip to content

Commit 5dcb849

Browse files
author
Peiming Liu
committed
address comments.
1 parent f2f55d4 commit 5dcb849

File tree

5 files changed

+38
-38
lines changed

5 files changed

+38
-38
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,9 +1150,9 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
11501150

11511151
Operation &last = rewriter.getBlock()->back();
11521152
if (llvm::isa<scf::YieldOp>(last)) {
1153-
// scf.for inserts a implicit yield op when there is no reduction
1154-
// variable upon creation, in this case we need to merge the block
1155-
// *before* the yield op.
1153+
// Because `scf.for` inserts an implicit yield op when there is no
1154+
// reduction variable upon creation, we reset the insertion point such
1155+
// that the block is inlined before *before* the yield op.
11561156
rewriter.setInsertionPoint(&last);
11571157
}
11581158

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,9 +1032,9 @@ static bool getAllTidLvlsInLatPoints(
10321032

10331033
if (isDenseLT(env.lt(outTid, curr))) {
10341034
auto stt = getSparseTensorType(env.op().getOutputs().front());
1035-
// Note that we generate dense indices of the output tensor
1036-
// unconditionally, since they may not appear in the lattice, but may be
1037-
// needed for linearized env.
1035+
// Note that we generate dense indices of the output tensor unconditionally,
1036+
// since they may not appear in the lattice, but may be needed for
1037+
// linearized env.
10381038
// TODO: we should avoid introducing corner cases for all-dense sparse
10391039
// tensors.
10401040
if (stt.hasEncoding() && stt.isAllDense())
@@ -1067,8 +1067,9 @@ static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
10671067

10681068
SmallVector<TensorLevel> tidLvls;
10691069
getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
1070-
// TODO: remove this! Duplication can be introduced due to the speical
1071-
// handling for all-dense "sparse" output tensor.
1070+
// TODO: remove this! The same tensor level might be added for multiple
1071+
// times due to the special handling for all-dense "sparse" output tensor
1072+
// (see L1038).
10721073
if (llvm::find(tidLvls, tl) != tidLvls.end())
10731074
return;
10741075
tidLvls.emplace_back(tl);

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ class LoopEmitter {
408408
/// alive.
409409
std::vector<LoopInfo> loopStack;
410410

411-
// Loop Sequence Stack, stores the unversial index for the current loop
411+
// Loop Sequence Stack, stores the universal index for the current loop
412412
// sequence. and a list of tid level that the loop sequence traverse.
413413
std::vector<std::pair<Value, std::vector<TensorLevel>>> loopSeqStack;
414414
};

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.cpp

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,17 +164,6 @@ static scf::ValueVector genWhenInBound(
164164
OpBuilder &b, Location l, SparseIterator &it, ValueRange elseRet,
165165
llvm::function_ref<scf::ValueVector(OpBuilder &, Location, Value)>
166166
builder) {
167-
// Value isNotEnd = it.genNotEnd(b, l);
168-
// Value crd = it.deref(b, l);
169-
// scf::ValueVector ret = builder(b, l, crd);
170-
171-
// scf::ValueVector res;
172-
// for (auto [notEnd, end] : llvm::zip_equal(ret, elseRet)) {
173-
// res.push_back(SELECT(isNotEnd, notEnd, end));
174-
// };
175-
// return res;
176-
177-
// !it.end() ? callback(*crd) : resOOB;
178167
TypeRange ifRetTypes = elseRet.getTypes();
179168
auto ifOp = b.create<scf::IfOp>(l, ifRetTypes, it.genNotEnd(b, l), true);
180169

@@ -204,7 +193,7 @@ static scf::ValueVector genWhenInBound(
204193
static Value offsetFromMinCrd(OpBuilder &b, Location l, Value minCrd,
205194
Value size) {
206195
Value geSize = CMPI(uge, minCrd, size);
207-
// Computes minCrd - size + 1
196+
// Compute minCrd - size + 1.
208197
Value mms = SUBI(ADDI(minCrd, C_IDX(1)), size);
209198
// This is the absolute offset related to the actual tensor.
210199
return SELECT(geSize, mms, C_IDX(0));
@@ -627,7 +616,7 @@ class NonEmptySubSectIterator : public SparseIterator {
627616

628617
class SubSectIterator;
629618

630-
// A simple helper that helps generating code to traverse a subsection, used
619+
// A wrapper that helps generating code to traverse a subsection, used
631620
// by both `NonEmptySubSectIterator`and `SubSectIterator`.
632621
struct SubSectIterHelper {
633622
explicit SubSectIterHelper(const SubSectIterator &iter);
@@ -778,7 +767,7 @@ class SubSectIterator : public SparseIterator {
778767
} // namespace
779768

780769
//===----------------------------------------------------------------------===//
781-
// Complex SparseIterator derived classes impl.
770+
// SparseIterator derived classes implementation.
782771
//===----------------------------------------------------------------------===//
783772

784773
ValueRange SparseIterator::forwardIf(OpBuilder &b, Location l, Value cond) {
@@ -819,7 +808,6 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
819808
},
820809
/*afterBuilder=*/
821810
[](OpBuilder &b, Location l, ValueRange ivs) {
822-
// pos ++
823811
Value nxPos = ADDI(ivs[0], C_IDX(1));
824812
YIELD(nxPos);
825813
});
@@ -830,11 +818,11 @@ Value DedupIterator::genSegmentHigh(OpBuilder &b, Location l, Value pos) {
830818
Value FilterIterator::genCrdNotLegitPredicate(OpBuilder &b, Location l,
831819
Value wrapCrd) {
832820
Value crd = fromWrapCrd(b, l, wrapCrd);
833-
// not on stride
821+
// Test whether the coordinate is on stride.
834822
Value notlegit = CMPI(ne, toWrapCrd(b, l, crd), wrapCrd);
835-
// wrapCrd < offset
823+
// Test wrapCrd < offset
836824
notlegit = ORI(CMPI(ult, wrapCrd, offset), notlegit);
837-
// crd >= length
825+
// Test crd >= length
838826
notlegit = ORI(CMPI(uge, crd, size), notlegit);
839827
return notlegit;
840828
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorLevel.h

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ class SparseTensorLevel {
2929
/// the given position `p` that the immediate parent level is current at.
3030
/// Returns a pair of values for *posLo* and *loopHi* respectively.
3131
///
32-
/// For dense level, the *posLo* is the linearized position at beginning,
32+
/// For a dense level, the *posLo* is the linearized position at beginning,
3333
/// while *loopHi* is the largest *coordinate*, it also implies that the
3434
/// smallest *coordinate* to start the loop is 0.
3535
///
36-
/// For sparse level, [posLo, loopHi) specifies the range of index pointer to
37-
/// load coordinate from the coordinate buffer.
36+
/// For a sparse level, [posLo, loopHi) specifies the range of index pointer
37+
/// to load coordinate from the coordinate buffer.
3838
///
3939
/// `bound` is only used when the level is `non-unique` and deduplication is
4040
/// required. It specifies the max upper bound of the non-unique segment.
@@ -68,7 +68,7 @@ enum class IterKind : uint8_t {
6868
kFilter,
6969
};
7070

71-
/// Helper class that helps generating loop conditions, etc, to traverse a
71+
/// Helper class that generates loop conditions, etc, to traverse a
7272
/// sparse tensor level.
7373
class SparseIterator {
7474
SparseIterator(SparseIterator &&) = delete;
@@ -103,17 +103,18 @@ class SparseIterator {
103103
//
104104

105105
// Whether the iterator support random access (i.e., support look up by
106-
// *coordinate*).
107-
// A random access iterator also traverses a dense space.
106+
// *coordinate*). A random access iterator must also traverses a dense space.
108107
virtual bool randomAccessible() const = 0;
108+
109109
// Whether the iterator can simply traversed by a for loop.
110110
virtual bool iteratableByFor() const { return false; };
111+
111112
// Get the upper bound of the sparse space that the iterator might visited. A
112113
// sparse space is a subset of a dense space [0, bound), this function returns
113114
// *bound*.
114115
virtual Value upperBound(OpBuilder &b, Location l) const = 0;
115116

116-
// Serialize and deserialize the current status to/from a set of values. The
117+
// Serializes and deserializes the current status to/from a set of values. The
117118
// ValueRange should contain values that specifies the current postion and
118119
// loop bound.
119120
//
@@ -131,7 +132,7 @@ class SparseIterator {
131132
// Core functions.
132133
//
133134

134-
// Get the current position and the optional *position high* (for non-unique
135+
// Gets the current position and the optional *position high* (for non-unique
135136
// iterators), the value is essentially the number of sparse coordinate that
136137
// the iterator is current visiting. It should be able to uniquely identify
137138
// the sparse range for the next level. See SparseTensorLevel::peekRangeAt();
@@ -143,16 +144,17 @@ class SparseIterator {
143144
llvm_unreachable("unsupported");
144145
};
145146

146-
// Initialize the iterator according to the parent iterator's state.
147+
// Initializes the iterator according to the parent iterator's state.
147148
virtual void genInit(OpBuilder &, Location, const SparseIterator *) = 0;
148149

149-
// Return a pair of values for *upper*, *lower* bound respectively.
150+
// Returns a pair of values for *upper*, *lower* bound respectively.
150151
virtual std::pair<Value, Value> genForCond(OpBuilder &b, Location l) {
151152
assert(randomAccessible());
152153
// Random-access iterator is traversed by coordinate, i.e., [curCrd, UB).
153154
return {getCrd(), upperBound(b, l)};
154155
}
155156

157+
// Returns a boolean value that equals `!it.end()`
156158
virtual Value genNotEnd(OpBuilder &b, Location l) = 0;
157159
std::pair<Value, ValueRange> genWhileCond(OpBuilder &b, Location l,
158160
ValueRange vs) {
@@ -221,21 +223,30 @@ std::unique_ptr<SparseTensorLevel> makeSparseTensorLevel(OpBuilder &builder,
221223
Location loc, Value t,
222224
unsigned tid, Level l);
223225

224-
/// Helper function to create a SparseIterator object.
226+
/// Helper function to create a simple SparseIterator object that iterate over
227+
/// the SparseTensorLevel.
225228
std::unique_ptr<SparseIterator>
226229
makeSimpleIterator(const SparseTensorLevel &stl);
227230

231+
/// Helper function to create a synthetic SparseIterator object that iterate
232+
/// over a dense space specified by [0,`sz`).
228233
std::pair<std::unique_ptr<SparseTensorLevel>, std::unique_ptr<SparseIterator>>
229234
makeSynLevelAndIterator(Value sz, unsigned tid, unsigned lvl);
230235

236+
/// Helper function to create a SparseIterator object that iterate over a
237+
/// sliced space, the orignal space (before slicing) is traversed by `sit`.
231238
std::unique_ptr<SparseIterator>
232239
makeSlicedLevelIterator(std::unique_ptr<SparseIterator> &&sit, Value offset,
233240
Value stride, Value size);
234241

242+
/// Helper function to create a SparseIterator object that iterate over the
243+
/// non-empty subsections set.
235244
std::unique_ptr<SparseIterator> makeNonEmptySubSectIterator(
236245
OpBuilder &b, Location l, const SparseIterator *parent,
237246
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);
238247

248+
/// Helper function to create a SparseIterator object that iterate over a
249+
/// non-empty subsection created by NonEmptySubSectIterator.
239250
std::unique_ptr<SparseIterator> makeTraverseSubSectIterator(
240251
const SparseIterator &subsectIter, const SparseIterator &parent,
241252
std::unique_ptr<SparseIterator> &&delegate, Value size, unsigned stride);

0 commit comments

Comments
 (0)