@@ -2776,6 +2776,15 @@ LogicalResult WinogradFilterTransformOp::verify() {
27762776// WinogradInputTransformOp
27772777// ===----------------------------------------------------------------------===//
27782778
2779+ Value getValueFromOpFoldResult (OpFoldResult opFoldResult, OpBuilder &builder,
2780+ Location loc) {
2781+ if (auto attr = opFoldResult.dyn_cast <Attribute>()) {
2782+ auto intAttr = cast<IntegerAttr>(attr);
2783+ return builder.create <arith::ConstantOp>(loc, intAttr);
2784+ }
2785+ return opFoldResult.get <Value>();
2786+ }
2787+
27792788LogicalResult WinogradInputTransformOp::verify () {
27802789 auto inputType = cast<ShapedType>(getInput ().getType ());
27812790 ArrayRef<int64_t > inputShape = inputType.getShape ();
@@ -2813,6 +2822,113 @@ LogicalResult WinogradInputTransformOp::verify() {
28132822 return success ();
28142823}
28152824
2825+ SmallVector<Range>
2826+ WinogradInputTransformOp::getIterationDomain (OpBuilder &builder) {
2827+ Location loc = getLoc ();
2828+ auto indexType = builder.getIndexType ();
2829+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2830+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2831+ Value output = getOutput ();
2832+ SmallVector<Range> loopBounds (6 );
2833+ for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2834+ loopBounds[dim].offset = zeroAttr;
2835+ loopBounds[dim].size = getDimValue (builder, loc, output, dim);
2836+ loopBounds[dim].stride = oneAttr;
2837+ }
2838+ return loopBounds;
2839+ }
2840+
2841+ SmallVector<utils::IteratorType>
2842+ WinogradInputTransformOp::getLoopIteratorTypes () {
2843+ SmallVector<utils::IteratorType> iteratorTypes (6 ,
2844+ utils::IteratorType::parallel);
2845+ return iteratorTypes;
2846+ }
2847+
2848+ LogicalResult WinogradInputTransformOp::getResultTilePosition (
2849+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2850+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2851+ SmallVector<OpFoldResult> &resultSizes) {
2852+ auto zeroAttr = builder.getI64IntegerAttr (0 );
2853+ auto oneAttr = builder.getI64IntegerAttr (1 );
2854+
2855+ resultOffsets.push_back (zeroAttr);
2856+ resultOffsets.push_back (zeroAttr);
2857+ resultOffsets.push_back (offsets[2 ]);
2858+ resultOffsets.push_back (offsets[3 ]);
2859+ resultOffsets.push_back (zeroAttr);
2860+ resultOffsets.push_back (zeroAttr);
2861+ resultSizes.push_back (sizes[0 ]);
2862+ resultSizes.push_back (sizes[1 ]);
2863+ resultSizes.push_back (oneAttr);
2864+ resultSizes.push_back (oneAttr);
2865+ resultSizes.push_back (sizes[4 ]);
2866+ resultSizes.push_back (sizes[5 ]);
2867+
2868+ return success ();
2869+ }
2870+
2871+ FailureOr<TilingResult>
2872+ WinogradInputTransformOp::getTiledImplementation (OpBuilder &builder,
2873+ ArrayRef<OpFoldResult> offsets,
2874+ ArrayRef<OpFoldResult> sizes) {
2875+ auto oneAttr = builder.getI64IntegerAttr (1 );
2876+ auto zeroAttr = builder.getI64IntegerAttr (0 );
2877+ Value input = getInput ();
2878+ auto inputType = cast<ShapedType>(input.getType ());
2879+ auto inputShape = inputType.getShape ();
2880+ int64_t inputH = inputShape[1 ];
2881+ int64_t inputW = inputShape[2 ];
2882+ int64_t m = getM ();
2883+ int64_t r = getR ();
2884+ int64_t alpha = m + r - 1 ;
2885+ int64_t alphaH = inputH != 1 ? alpha : 1 ;
2886+ int64_t alphaW = inputW != 1 ? alpha : 1 ;
2887+ auto alphaHAttr = builder.getI64IntegerAttr (alphaH);
2888+ auto alphaWAttr = builder.getI64IntegerAttr (alphaW);
2889+
2890+ Location loc = getLoc ();
2891+ SmallVector<Value> tiledOperands;
2892+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
2893+
2894+ auto context = builder.getContext ();
2895+ auto affineMap =
2896+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
2897+ Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
2898+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
2899+ Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
2900+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
2901+
2902+ sliceOffsets.push_back (zeroAttr);
2903+ sliceOffsets.push_back (mappedOffset1);
2904+ sliceOffsets.push_back (mappedOffset2);
2905+ sliceOffsets.push_back (zeroAttr);
2906+ sliceSizes.push_back (sizes[4 ]);
2907+ sliceSizes.push_back (alphaHAttr);
2908+ sliceSizes.push_back (alphaWAttr);
2909+ sliceSizes.push_back (sizes[5 ]);
2910+ SmallVector<OpFoldResult> inputStrides (4 , oneAttr);
2911+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2912+ loc, getInput (), sliceOffsets, sliceSizes, inputStrides));
2913+
2914+ sliceOffsets.clear ();
2915+ sliceSizes.clear ();
2916+ if (failed (getResultTilePosition (builder, 1 , offsets, sizes, sliceOffsets,
2917+ sliceSizes)))
2918+ return failure ();
2919+
2920+ SmallVector<OpFoldResult> outputStrides (6 , oneAttr);
2921+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
2922+ loc, getOutput (), sliceOffsets, sliceSizes, outputStrides));
2923+
2924+ SmallVector<Type, 4 > resultTypes;
2925+ resultTypes.push_back (tiledOperands[1 ].getType ());
2926+ Operation *tiledOp =
2927+ mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
2928+
2929+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
2930+ }
2931+
28162932// ===----------------------------------------------------------------------===//
28172933// WinogradOutputTransformOp
28182934// ===----------------------------------------------------------------------===//
@@ -2855,6 +2971,106 @@ LogicalResult WinogradOutputTransformOp::verify() {
28552971 return success ();
28562972}
28572973
2974+ SmallVector<Range>
2975+ WinogradOutputTransformOp::getIterationDomain (OpBuilder &builder) {
2976+ Location loc = getLoc ();
2977+ auto indexType = builder.getIndexType ();
2978+ auto zeroAttr = builder.getIntegerAttr (indexType, 0 );
2979+ auto oneAttr = builder.getIntegerAttr (indexType, 1 );
2980+ Value value = getValue ();
2981+ SmallVector<Range> loopBounds (6 );
2982+ for (unsigned dim = 0 ; dim < 6 ; ++dim) {
2983+ loopBounds[dim].offset = zeroAttr;
2984+ loopBounds[dim].size = getDimValue (builder, loc, value, dim);
2985+ loopBounds[dim].stride = oneAttr;
2986+ }
2987+ return loopBounds;
2988+ }
2989+
2990+ SmallVector<utils::IteratorType>
2991+ WinogradOutputTransformOp::getLoopIteratorTypes () {
2992+ SmallVector<utils::IteratorType> iteratorTypes (6 ,
2993+ utils::IteratorType::parallel);
2994+ return iteratorTypes;
2995+ }
2996+
2997+ LogicalResult WinogradOutputTransformOp::getResultTilePosition (
2998+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2999+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
3000+ SmallVector<OpFoldResult> &resultSizes) {
3001+ auto zeroAttr = builder.getI64IntegerAttr (0 );
3002+ Value output = getOutput ();
3003+ auto outputType = cast<ShapedType>(output.getType ());
3004+ auto outputShape = outputType.getShape ();
3005+ int64_t outputH = outputShape[1 ];
3006+ int64_t outputW = outputShape[2 ];
3007+ int64_t m = getM ();
3008+ auto heightM = builder.getI64IntegerAttr (outputH != 1 ? m : 1 );
3009+ auto widthM = builder.getI64IntegerAttr (outputW != 1 ? m : 1 );
3010+
3011+ Location loc = getLoc ();
3012+ auto context = builder.getContext ();
3013+ auto affineMap =
3014+ AffineMap::get (1 , 0 , {builder.getAffineDimExpr (0 ) * m}, context);
3015+ Value mappedOffset1 = builder.create <affine::AffineApplyOp>(
3016+ loc, affineMap, getValueFromOpFoldResult (offsets[2 ], builder, loc));
3017+ Value mappedOffset2 = builder.create <affine::AffineApplyOp>(
3018+ loc, affineMap, getValueFromOpFoldResult (offsets[3 ], builder, loc));
3019+
3020+ resultOffsets.push_back (zeroAttr);
3021+ resultOffsets.push_back (mappedOffset1);
3022+ resultOffsets.push_back (mappedOffset2);
3023+ resultOffsets.push_back (zeroAttr);
3024+ resultSizes.push_back (sizes[4 ]);
3025+ resultSizes.push_back (heightM);
3026+ resultSizes.push_back (widthM);
3027+ resultSizes.push_back (sizes[5 ]);
3028+ return success ();
3029+ }
3030+
3031+ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation (
3032+ OpBuilder &builder, ArrayRef<OpFoldResult> offsets,
3033+ ArrayRef<OpFoldResult> sizes) {
3034+ auto oneAttr = builder.getI64IntegerAttr (1 );
3035+ auto zeroAttr = builder.getI64IntegerAttr (0 );
3036+ Location loc = getLoc ();
3037+ SmallVector<Value> tiledOperands;
3038+ SmallVector<OpFoldResult> sliceOffsets, sliceSizes;
3039+
3040+ sliceOffsets.push_back (zeroAttr);
3041+ sliceOffsets.push_back (zeroAttr);
3042+ sliceOffsets.push_back (offsets[2 ]);
3043+ sliceOffsets.push_back (offsets[3 ]);
3044+ sliceOffsets.push_back (zeroAttr);
3045+ sliceOffsets.push_back (zeroAttr);
3046+ sliceSizes.push_back (sizes[0 ]);
3047+ sliceSizes.push_back (sizes[1 ]);
3048+ sliceSizes.push_back (oneAttr);
3049+ sliceSizes.push_back (oneAttr);
3050+ sliceSizes.push_back (sizes[4 ]);
3051+ sliceSizes.push_back (sizes[5 ]);
3052+ SmallVector<OpFoldResult> sliceStrides (6 , oneAttr);
3053+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3054+ loc, getValue (), sliceOffsets, sliceSizes, sliceStrides));
3055+
3056+ sliceOffsets.clear ();
3057+ sliceSizes.clear ();
3058+ if (failed (getResultTilePosition (builder, 1 , offsets, sizes, sliceOffsets,
3059+ sliceSizes)))
3060+ return failure ();
3061+
3062+ SmallVector<OpFoldResult> strides (4 , oneAttr);
3063+ tiledOperands.emplace_back (builder.create <tensor::ExtractSliceOp>(
3064+ loc, getOutput (), sliceOffsets, sliceSizes, strides));
3065+
3066+ SmallVector<Type, 4 > resultTypes;
3067+ resultTypes.push_back (tiledOperands[1 ].getType ());
3068+ Operation *tiledOp =
3069+ mlir::clone (builder, getOperation (), resultTypes, tiledOperands);
3070+
3071+ return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults ())};
3072+ }
3073+
28583074// ===----------------------------------------------------------------------===//
28593075// LinalgDialect
28603076// ===----------------------------------------------------------------------===//
0 commit comments