@@ -2755,93 +2755,19 @@ LogicalResult WinogradFilterTransformOp::verify() {
27552755 return success ();
27562756}
27572757
2758- SmallVector<Range>
2759- WinogradFilterTransformOp::getIterationDomain (OpBuilder &builder) {
2760- Location loc = getLoc ();
2761- Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
2762- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
2763- Value output = getOutput ();
2764- SmallVector<Range> loopBounds (6 );
2765- for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2766- loopBounds[dim].offset = zero;
2767- loopBounds[dim].size = getDimValue (builder, loc, output, dim);
2768- loopBounds[dim].stride = one;
2769- }
2770- return loopBounds;
2771- }
2772-
2773- SmallVector<utils::IteratorType>
2774- WinogradFilterTransformOp::getLoopIteratorTypes () {
2775- SmallVector<utils::IteratorType> iteratorTypes (6 ,
2776- utils::IteratorType::parallel);
2777- return iteratorTypes;
2778- }
2758+ // ===----------------------------------------------------------------------===//
2759+ // WinogradInputTransformOp
2760+ // ===----------------------------------------------------------------------===//
27792761
27802762Value getValueFromOpFoldResult (OpFoldResult opFoldResult, OpBuilder &builder,
27812763 Location loc) {
2782- if (auto val = opFoldResult.dyn_cast <Value>()) {
2783- return val;
2784- } else if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2764+ if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
27852765 auto intAttr = cast<IntegerAttr>(attr);
27862766 return builder.create <arith::ConstantOp>(loc, intAttr);
27872767 }
2788- // This should never happen if OpFoldResult is correctly formed.
2789- return nullptr ;
2768+ return opFoldResult.get <Value>();
27902769}
27912770
2792- LogicalResult WinogradFilterTransformOp::getResultTilePosition (
2793- OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2794- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2795- SmallVector<OpFoldResult> &resultSizes) {
2796- auto zeroAttr = builder.getI64IntegerAttr (0 );
2797- auto oneAttr = builder.getI64IntegerAttr (1 );
2798-
2799- resultOffsets.push_back (offsets[0 ]);
2800- resultOffsets.push_back (offsets[1 ]);
2801- resultOffsets.push_back (zeroAttr);
2802- resultOffsets.push_back (zeroAttr);
2803- resultOffsets.push_back (zeroAttr);
2804- resultOffsets.push_back (zeroAttr);
2805- resultSizes.push_back (oneAttr);
2806- resultSizes.push_back (oneAttr);
2807- resultSizes.push_back (sizes[2 ]);
2808- resultSizes.push_back (sizes[3 ]);
2809- resultSizes.push_back (sizes[4 ]);
2810- resultSizes.push_back (sizes[5 ]);
2811-
2812- return success ();
2813- }
2814-
2815- FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation (
2816- OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
2817- ArrayRef<OpFoldResult> sizes) {
2818- auto oneAttr = builder.getI64IntegerAttr (1 );
2819-
2820- Location loc = getLoc ();
2821- SmallVector<OpFoldResult> strides (6 , oneAttr);
2822- SmallVector<Value> tiledOperands;
2823- tiledOperands.emplace_back (getFilter ());
2824-
2825- SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2826- if (failed (getResultTilePosition (builder, 1 , offsets, sizes, sliceOffsets,
2827- sliceSizes)))
2828- return failure ();
2829-
2830- tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2831- loc, getOutput (), sliceOffsets, sliceSizes, strides));
2832-
2833- SmallVector<Type, 4 > resultTypes;
2834- resultTypes.push_back (tiledOperands[1 ].getType ());
2835- Operation *tiledOp =
2836- mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
2837-
2838- return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
2839- }
2840-
2841- // ===----------------------------------------------------------------------===//
2842- // WinogradInputTransformOp
2843- // ===----------------------------------------------------------------------===//
2844-
28452771LogicalResult WinogradInputTransformOp::verify () {
28462772 auto inputType = cast<ShapedType>(getInput ().getType ());
28472773 ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -2887,14 +2813,15 @@ LogicalResult WinogradInputTransformOp::verify() {
28872813SmallVector<Range>
28882814WinogradInputTransformOp::getIterationDomain (OpBuilder &builder) {
28892815 Location loc = getLoc ();
2890- Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
2891- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
2816+ auto indexType = builder.getIndexType ();
2817+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2818+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
28922819 Value output = getOutput ();
28932820 SmallVector<Range> loopBounds (6 );
28942821 for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2895- loopBounds[dim].offset = zero ;
2822+ loopBounds[dim].offset = zeroAttr ;
28962823 loopBounds[dim].size = getDimValue (builder, loc, output, dim);
2897- loopBounds[dim].stride = one ;
2824+ loopBounds[dim].stride = oneAttr ;
28982825 }
28992826 return loopBounds;
29002827}
@@ -2913,16 +2840,16 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
29132840 auto zeroAttr = builder.getI64IntegerAttr (0 );
29142841 auto oneAttr = builder.getI64IntegerAttr (1 );
29152842
2916- resultOffsets.push_back (offsets[0 ]);
2917- resultOffsets.push_back (offsets[1 ]);
29182843 resultOffsets.push_back (zeroAttr);
29192844 resultOffsets.push_back (zeroAttr);
2845+ resultOffsets.push_back (offsets[2 ]);
2846+ resultOffsets.push_back (offsets[3 ]);
29202847 resultOffsets.push_back (zeroAttr);
29212848 resultOffsets.push_back (zeroAttr);
2849+ resultSizes.push_back (sizes[0 ]);
2850+ resultSizes.push_back (sizes[1 ]);
29222851 resultSizes.push_back (oneAttr);
29232852 resultSizes.push_back (oneAttr);
2924- resultSizes.push_back (sizes[2 ]);
2925- resultSizes.push_back (sizes[3 ]);
29262853 resultSizes.push_back (sizes[4 ]);
29272854 resultSizes.push_back (sizes[5 ]);
29282855
@@ -2956,9 +2883,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
29562883 auto affineMap =
29572884 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
29582885 Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
2959- loc, affineMap, getValueFromOpFoldResult (offsets[0 ], builder, loc));
2886+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
29602887 Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
2961- loc, affineMap, getValueFromOpFoldResult (offsets[1 ], builder, loc));
2888+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
29622889
29632890 sliceOffsets.push_back (zeroAttr);
29642891 sliceOffsets.push_back (mappedOffset1);
@@ -3033,14 +2960,15 @@ LogicalResult WinogradOutputTransformOp::verify() {
30332960SmallVector<Range>
30342961WinogradOutputTransformOp::getIterationDomain (OpBuilder &builder) {
30352962 Location loc = getLoc ();
3036- Value zero = builder.create <arith::ConstantIndexOp>(loc, 0 );
3037- Value one = builder.create <arith::ConstantIndexOp>(loc, 1 );
2963+ auto indexType = builder.getIndexType ();
2964+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2965+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
30382966 Value value = getValue ();
30392967 SmallVector<Range> loopBounds (6 );
30402968 for (unsigned dim = 0 ; dim < 6 ; ++dim) {
3041- loopBounds[dim].offset = zero ;
2969+ loopBounds[dim].offset = zeroAttr ;
30422970 loopBounds[dim].size = getDimValue (builder, loc, value, dim);
3043- loopBounds[dim].stride = one ;
2971+ loopBounds[dim].stride = oneAttr ;
30442972 }
30452973 return loopBounds;
30462974}
@@ -3071,9 +2999,9 @@ LogicalResult WinogradOutputTransformOp::getResultTilePosition(
30712999 auto affineMap =
30723000 AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
30733001 Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3074- loc, affineMap, getValueFromOpFoldResult (offsets[0 ], builder, loc));
3002+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
30753003 Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3076- loc, affineMap, getValueFromOpFoldResult (offsets[1 ], builder, loc));
3004+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
30773005
30783006 resultOffsets.push_back (zeroAttr);
30793007 resultOffsets.push_back (mappedOffset1);
@@ -3095,16 +3023,16 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation(
30953023 SmallVector<Value> tiledOperands;
30963024 SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
30973025
3098- sliceOffsets.push_back (offsets[0 ]);
3099- sliceOffsets.push_back (offsets[1 ]);
31003026 sliceOffsets.push_back (zeroAttr);
31013027 sliceOffsets.push_back (zeroAttr);
3028+ sliceOffsets.push_back (offsets[2 ]);
3029+ sliceOffsets.push_back (offsets[3 ]);
31023030 sliceOffsets.push_back (zeroAttr);
31033031 sliceOffsets.push_back (zeroAttr);
3032+ sliceSizes.push_back (sizes[0 ]);
3033+ sliceSizes.push_back (sizes[1 ]);
31043034 sliceSizes.push_back (oneAttr);
31053035 sliceSizes.push_back (oneAttr);
3106- sliceSizes.push_back (sizes[2 ]);
3107- sliceSizes.push_back (sizes[3 ]);
31083036 sliceSizes.push_back (sizes[4 ]);
31093037 sliceSizes.push_back (sizes[5 ]);
31103038 SmallVector<OpFoldResult> sliceStrides (6 , oneAttr);
0 commit comments