@@ -122,19 +122,10 @@ struct TensorExp final {
122122// /
123123// / The `kLoopVar` leaf kind is for representing `linalg::IndexOp`.
124124// / That is, its argument is a `LoopId` identifying the loop-variable
125- // / in question, and its value will be the current iteration's value
126- // / of that loop-variable. See the `LoopId` documentation for more details.
127- // /
128- // / The `kSynZero` leaf kind is for representing a synthetic zero value, which
129- // / can be introduced when sparsifying operations like `arith::cmp` to generate
130- // / `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
131- //
132- // TODO: Modify this definition so that the numeric values already encode
133- // the `ExpArity` (while extending the notion of "arity" to include not
134- // just the number of `ExprId` children the node has, but also whether the
135- // node has a `Value` and/or `Operation*`). Doing this will avoid needing
136- // to enumerate all the kinds in `getExpArity` and in the `TensorExp` ctor,
137- // and should help clean up a few other places as well.
125+ // / in question, and its value will be the current iteration's value.
126+ // / The `kSynZero` leaf kind is for representing a synthetic zero value,
127+ // / which can be introduced when sparsifying operations like `arith::cmp`
128+ // / to generate `arith::cmp %lhs, %syn_zero` when the rhs operand is absent.
138129enum class TensorExp ::Kind {
139130 // Leaf.
140131 kTensor = 0 ,
@@ -253,15 +244,6 @@ class Merger {
253244 // /
254245 // / The maxLvlRank specifies the max level rank of all inputs/output tensors.
255246 // / It is used to pre-allocate sufficient memory for internal storage.
256- //
257- // TODO: we want to make the filter loop more efficient in the future,
258- // e.g., by avoiding scanning the full list of stored coordinates (keeping
259- // the last position in ordered list) or even apply binary search to find
260- // the coordinate.
261- //
262- // TODO: would be cleaner to understand/document if the first argument
263- // gave the number of input tensors, instead of the current number of
264- // input+output tensors.
265247 Merger (unsigned numInputOutputTensors, unsigned numNativeLoops,
266248 unsigned numFilterLoops, unsigned maxLvlRank);
267249
@@ -383,12 +365,15 @@ class Merger {
383365
384366 // / Gets the total number of loops (native loops + filter loops).
385367 constexpr unsigned getNumLoops () const { return numLoops; }
368+
386369 // / Gets the number of native loops.
387370 constexpr unsigned getNumNativeLoops () const { return numNativeLoops; }
371+
388372 // / Gets the number of filter loops.
389373 constexpr unsigned getNumFilterLoops () const {
390374 return numLoops - numNativeLoops;
391375 }
376+
392377 // / Gets the identifier of the first filter-loop.
393378 constexpr LoopId getStartingFilterLoopId () const {
394379 return getNumNativeLoops ();
@@ -473,8 +458,7 @@ class Merger {
473458 lvlTypes[t][i] = dlt;
474459 loopToLvl[t][i] = lvl;
475460 lvlToLoop[t][lvl] = i;
476- // TODO: Maybe we should favor a constant loop bound when there are multiple
477- // choices.
461+ // TODO: favor a constant loop bound when there are multiple choices.
478462 loopBounds[i] = std::make_pair (t, lvl);
479463 }
480464
@@ -600,43 +584,19 @@ class Merger {
600584 // / Checks whether the given expression has an associated value.
601585 bool hasExprValue (ExprId e) const { return static_cast <bool >(exp (e).val ); }
602586
603- // / Sets the expression to have the associated value. Asserts that
604- // / the new value is defined, and that the expression does not already
605- // / have a value. If you want to overwrite a previous associated value,
606- // / use `updateExprValue` instead.
587+ // / Sets the expression to have the associated value. Asserts that the new
588+ // / value is defined, and that the expression does not already have a value.
607589 void setExprValue (ExprId e, Value v) {
608- assert (isValidExprId (e));
609- assert (v && " Got an undefined value" );
610- auto &val = tensorExps[e].val ;
611- assert (!val && " Expression already has an associated value" );
612- val = v;
590+ assert (!exp (e).val && " Expression already has an associated value" );
591+ assert (v && " Trying to assign an undefined value" );
592+ tensorExps[e].val = v;
613593 }
614594
615- // / Clears the value associated with the expression. Asserts that the
595+ // / Clears the value associated with the expression. Asserts that the
616596 // / expression does indeed have an associated value before clearing it.
617- // / If you don't want to check for a previous associated value first,
618- // / then use `updateExprValue` instead.
619597 void clearExprValue (ExprId e) {
620- assert (isValidExprId (e));
621- auto &val = tensorExps[e].val ;
622- assert (val && " Expression does not have an associated value to clear" );
623- val = Value ();
624- }
625-
626- // / Unilaterally updates the expression to have the associated value.
627- // / That is, unlike `setExprValue` and `clearExprValue`, this method
628- // / does not perform any checks on whether the expression had a
629- // / previously associated value nor whether the new value is defined.
630- //
631- // TODO: The unilateral update semantics are required by the
632- // current implementation of `CodegenEnv::genLoopBoundary`; however,
633- // that implementation seems a bit dubious. We would much rather have
634- // the semantics `{ clearExprValue(e); setExprValue(e, v); }` or
635- // `{ clearExprValue(e); if (v) setExprValue(e, v); }` since those
636- // provide better invariants.
637- void updateExprValue (ExprId e, Value v) {
638- assert (isValidExprId (e));
639- tensorExps[e].val = v;
598+ assert (exp (e).val && " Expression does not have an associated value" );
599+ tensorExps[e].val = Value ();
640600 }
641601
642602#ifndef NDEBUG
@@ -706,12 +666,10 @@ class Merger {
706666 // `operator[]`: `SmallVector` performs OOB checks, whereas `std::vector`
707667 // does not.
708668
709- // / Map that converts pair<TensorId, LoopId> to the corresponding
710- // / level-type.
669+ // / Map that converts pair<TensorId, LoopId> to the corresponding lvl-type.
711670 std::vector<std::vector<DimLevelType>> lvlTypes;
712671
713- // / Map that converts pair<TensorId, LoopId> to the corresponding
714- // / level.
672+ // / Map that converts pair<TensorId, LoopId> to the corresponding lvl.
715673 std::vector<std::vector<std::optional<Level>>> loopToLvl;
716674
717675 // / Map that converts pair<TensorId, Level> to the corresponding LoopId.
0 commit comments