2020#include " mlir/Dialect/Linalg/Utils/Utils.h"
2121#include " mlir/IR/AffineExpr.h"
2222#include " mlir/IR/AffineMap.h"
23+ #include " mlir/IR/BuiltinTypes.h"
2324#include " mlir/Transforms/FoldUtils.h"
2425#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2526#include " llvm/Support/CommandLine.h"
@@ -256,7 +257,7 @@ struct UnitExtentReplacementInfo {
256257} // namespace
257258
258259// / Utility function for replacing operands/results to a linalg generic
259- // / operation on tensors with unit-extent dimensions. These can be replaced with
260+ // / operation with unit-extent dimensions. These can be replaced with
260261// / an operand/result with the unit-extent dimension removed. This is only done
261262// / if the indexing map used to access that didimensionmension has a
262263// / AffineConstantExpr of value 0. Given the `type` of an result/operand of a
@@ -301,10 +302,19 @@ static UnitExtentReplacementInfo replaceUnitExtents(GenericOp genericOp,
301302 ++dim;
302303 }
303304 // Compute the tensor or scalar replacement type.
305+ Type actualType = opOperand->get ().getType ();
304306 Type elementType = getElementTypeOrSelf (opOperand->get ());
305- Type replacementType = elementType == opOperand->get ().getType ()
306- ? elementType
307- : RankedTensorType::get (newShape, elementType);
307+ Type replacementType;
308+ if (elementType == opOperand->get ().getType ()) {
309+ replacementType = elementType;
310+ } else if (actualType.isa <RankedTensorType>()) {
311+ replacementType = RankedTensorType::get (newShape, elementType);
312+ } else if (actualType.isa <MemRefType>()) {
313+ assert (actualType.cast <MemRefType>().getAffineMaps ().empty () &&
314+ " unsupported strided memrefs" );
315+ replacementType = MemRefType::get (newShape, elementType);
316+ }
317+ assert (replacementType && " unsupported shaped type" );
308318 UnitExtentReplacementInfo info = {replacementType,
309319 AffineMap::get (indexingMap.getNumDims (),
310320 indexingMap.getNumSymbols (),
@@ -324,22 +334,60 @@ convertAffineMapArrayToExprs(ArrayAttr affineMapArrayAttr) {
324334 return reassociationExprs;
325335}
326336
327- // / Pattern to replace tensors operands/results that are unit extents.
328- struct ReplaceUnitExtentTensors : public OpRewritePattern <GenericOp> {
337+ // / Pattern to replace tensor/buffer operands/results that are unit extents.
338+ struct ReplaceUnitExtents : public OpRewritePattern <GenericOp> {
329339 using OpRewritePattern<GenericOp>::OpRewritePattern;
340+
341+ // Return the original value if the type is unchanged, or reshape it. Return a
342+ // nullptr if this is an unsupported type.
343+ Value maybeExpand (Value result, Type origResultType,
344+ ArrayAttr reassociationMap, Location loc,
345+ PatternRewriter &rewriter) const {
346+ if (origResultType == result.getType ())
347+ return result;
348+ if (origResultType.isa <RankedTensorType>()) {
349+ return rewriter.create <linalg::TensorExpandShapeOp>(
350+ loc, origResultType, result,
351+ convertAffineMapArrayToExprs (reassociationMap));
352+ }
353+ if (origResultType.isa <MemRefType>()) {
354+ return rewriter.create <linalg::ExpandShapeOp>(
355+ loc, origResultType, result,
356+ convertAffineMapArrayToExprs (reassociationMap));
357+ }
358+ return nullptr ;
359+ };
360+
361+ // Return the original value if the type is unchanged, or reshape it. Return a
362+ // nullptr if this is an unsupported type.
363+ Value maybeCollapse (Value operand, Type newInputOutputType,
364+ ArrayAttr reassociationMap, Location loc,
365+ PatternRewriter &rewriter) const {
366+ auto operandType = operand.getType ();
367+ if (operandType == newInputOutputType)
368+ return operand;
369+ if (operandType.isa <MemRefType>()) {
370+ return rewriter.create <linalg::CollapseShapeOp>(
371+ loc, newInputOutputType, operand,
372+ convertAffineMapArrayToExprs (reassociationMap));
373+ }
374+ if (operandType.isa <RankedTensorType>()) {
375+ return rewriter.create <linalg::TensorCollapseShapeOp>(
376+ loc, newInputOutputType, operand,
377+ convertAffineMapArrayToExprs (reassociationMap));
378+ }
379+ return nullptr ;
380+ };
381+
330382 LogicalResult matchAndRewrite (GenericOp genericOp,
331383 PatternRewriter &rewriter) const override {
332- if (!genericOp.hasTensorSemantics ())
333- return failure ();
334-
335384 MLIRContext *context = rewriter.getContext ();
336385 Location loc = genericOp.getLoc ();
337386
338387 SmallVector<AffineMap> newIndexingMaps;
339388 SmallVector<ArrayAttr> reassociationMaps;
340389 SmallVector<Type> newInputOutputTypes;
341390 bool doCanonicalization = false ;
342-
343391 for (OpOperand *opOperand : genericOp.getInputAndOutputOperands ()) {
344392 UnitExtentReplacementInfo replacementInfo =
345393 replaceUnitExtents (genericOp, opOperand, context);
@@ -362,14 +410,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
362410 auto insertReshapes = [&](ValueRange values) {
363411 SmallVector<Value, 4 > res;
364412 res.reserve (values.size ());
365- for (auto operand : llvm::enumerate (values)) {
366- if (operand.value ().getType () == newInputOutputTypes[flattenedIdx])
367- res.push_back (operand.value ());
368- else {
369- res.push_back (rewriter.create <TensorCollapseShapeOp>(
370- loc, newInputOutputTypes[flattenedIdx], operand.value (),
371- convertAffineMapArrayToExprs (reassociationMaps[flattenedIdx])));
372- }
413+ for (auto operand : values) {
414+ auto reshapedValue =
415+ maybeCollapse (operand, newInputOutputTypes[flattenedIdx],
416+ reassociationMaps[flattenedIdx], loc, rewriter);
417+ assert (reshapedValue &&
418+ " expected ranked MemRef or Tensor operand type" );
419+ res.push_back (reshapedValue);
373420 ++flattenedIdx;
374421 }
375422 return res;
@@ -396,15 +443,13 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
396443 SmallVector<Value, 4 > resultReplacements;
397444 for (auto result : llvm::enumerate (replacementOp.getResults ())) {
398445 unsigned index = result.index () + replacementOp.getNumInputs ();
399- RankedTensorType origResultType = genericOp.getResult (result.index ())
400- .getType ()
401- .template cast <RankedTensorType>();
402- if (origResultType != result.value ().getType ()) {
403- resultReplacements.push_back (rewriter.create <TensorExpandShapeOp>(
404- loc, origResultType, result.value (),
405- convertAffineMapArrayToExprs (reassociationMaps[index])));
406- } else
407- resultReplacements.push_back (result.value ());
446+ auto origResultType = genericOp.getResult (result.index ()).getType ();
447+
448+ auto newResult = maybeExpand (result.value (), origResultType,
449+ reassociationMaps[index], loc, rewriter);
450+ assert (newResult &&
451+ " unexpected output type other than ranked MemRef or Tensor" );
452+ resultReplacements.push_back (newResult);
408453 }
409454 rewriter.replaceOp (genericOp, resultReplacements);
410455 return success ();
@@ -501,9 +546,8 @@ struct UseRankReducedSubTensorInsertOp
501546void mlir::linalg::populateFoldUnitExtentDimsPatterns (
502547 RewritePatternSet &patterns) {
503548 auto *context = patterns.getContext ();
504- patterns.add <FoldUnitDimLoops, ReplaceUnitExtentTensors,
505- UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(
506- context);
549+ patterns.add <FoldUnitDimLoops, ReplaceUnitExtents, UseRankReducedSubTensorOp,
550+ UseRankReducedSubTensorInsertOp>(context);
507551 TensorCollapseShapeOp::getCanonicalizationPatterns (patterns, context);
508552 TensorExpandShapeOp::getCanonicalizationPatterns (patterns, context);
509553}
0 commit comments