@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
691691 // This code deals with permutations as well as non-permutations that
692692 // arise from rank changing blocking.
693693 const auto dimToLvl = stt.getDimToLvl ();
694+ const auto lvlToDim = stt.getLvlToDim ();
694695 SmallVector<Value> dim2lvlValues (lvlRank); // for each lvl, expr in dim vars
695696 SmallVector<Value> lvl2dimValues (dimRank); // for each dim, expr in lvl vars
696697 SmallVector<Value> lvlSizesValues (lvlRank);
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
705706 Dimension d = 0 ;
706707 uint64_t cf = 0 , cm = 0 ;
707708 switch (exp.getKind ()) {
708- case AffineExprKind::DimId:
709+ case AffineExprKind::DimId: {
709710 d = exp.cast <AffineDimExpr>().getPosition ();
710711 break ;
711- case AffineExprKind::FloorDiv:
712- d = exp.cast <AffineBinaryOpExpr>()
713- .getLHS ()
714- .cast <AffineDimExpr>()
715- .getPosition ();
716- cf = exp.cast <AffineBinaryOpExpr>()
717- .getRHS ()
718- .cast <AffineConstantExpr>()
719- .getValue ();
712+ }
713+ case AffineExprKind::FloorDiv: {
714+ auto floor = exp.cast <AffineBinaryOpExpr>();
715+ d = floor.getLHS ().cast <AffineDimExpr>().getPosition ();
716+ cf = floor.getRHS ().cast <AffineConstantExpr>().getValue ();
720717 break ;
721- case AffineExprKind::Mod:
722- d = exp.cast <AffineBinaryOpExpr>()
723- .getLHS ()
724- .cast <AffineDimExpr>()
725- .getPosition ();
726- cm = exp.cast <AffineBinaryOpExpr>()
727- .getRHS ()
728- .cast <AffineConstantExpr>()
729- .getValue ();
718+ }
719+ case AffineExprKind::Mod: {
720+ auto mod = exp.cast <AffineBinaryOpExpr>();
721+ d = mod.getLHS ().cast <AffineDimExpr>().getPosition ();
722+ cm = mod.getRHS ().cast <AffineConstantExpr>().getValue ();
730723 break ;
724+ }
731725 default :
732726 llvm::report_fatal_error (" unsupported dim2lvl in sparse tensor type" );
733727 }
734728 dim2lvlValues[l] = constantIndex (builder, loc, encodeDim (d, cf, cm));
735- lvl2dimValues[d] = constantIndex (builder, loc, l); // FIXME, use lvlToDim
736729 // Compute the level sizes.
737730 // (1) l = d : size(d)
738731 // (2) l = d / c : size(d) / c
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
751744 }
752745 lvlSizesValues[l] = lvlSz;
753746 }
747+ // Generate lvl2dim.
748+ assert (dimRank == lvlToDim.getNumResults ());
749+ for (Dimension d = 0 ; d < dimRank; d++) {
750+ AffineExpr exp = lvlToDim.getResult (d);
751+ // We expect:
752+ // (1) d = l
753+ // (2) d = l' * c + l
754+ Level l = 0 , ll = 0 ;
755+ uint64_t c = 0 ;
756+ switch (exp.getKind ()) {
757+ case AffineExprKind::DimId: {
758+ l = exp.cast <AffineDimExpr>().getPosition ();
759+ break ;
760+ }
761+ case AffineExprKind::Add: {
762+ // Always mul on lhs, symbol/constant on rhs.
763+ auto add = exp.cast <AffineBinaryOpExpr>();
764+ assert (add.getLHS ().getKind () == AffineExprKind::Mul);
765+ auto mul = add.getLHS ().cast <AffineBinaryOpExpr>();
766+ ll = mul.getLHS ().cast <AffineDimExpr>().getPosition ();
767+ c = mul.getRHS ().cast <AffineConstantExpr>().getValue ();
768+ l = add.getRHS ().cast <AffineDimExpr>().getPosition ();
769+ break ;
770+ }
771+ default :
772+ llvm::report_fatal_error (" unsupported lvl2dim in sparse tensor type" );
773+ }
774+ lvl2dimValues[d] = constantIndex (builder, loc, encodeLvl (l, c, ll));
775+ }
754776 // Return buffers.
755777 dim2lvlBuffer = allocaBuffer (builder, loc, dim2lvlValues);
756778 lvl2dimBuffer = allocaBuffer (builder, loc, lvl2dimValues);
0 commit comments