@@ -2507,8 +2507,9 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
25072507
25082508 for (Operation *op : dynamicSizeProducers.back ()) {
25092509 if (op->getNumResults () == 1 &&
2510- isa<IndexType>(op->getResult (0 ).getType ()))
2510+ isa<IndexType>(op->getResult (0 ).getType ())) {
25112511 continue ;
2512+ }
25122513
25132514 DiagnosedSilenceableFailure diag =
25142515 emitSilenceableError () << " expected sizes to be produced by ops "
@@ -2525,11 +2526,10 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
25252526 auto scalableSizes = getScalableSizes ();
25262527 for (auto [i, op] : llvm::enumerate (targets)) {
25272528 auto tilingInterface = dyn_cast<TilingInterface>(op);
2528- auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(op);
2529- if (!tilingInterface || !dpsInterface) {
2529+ if (!tilingInterface) {
25302530 DiagnosedSilenceableFailure diag =
2531- emitSilenceableError () << " only ops implementing TilingInterface and "
2532- " DestinationStyleOpInterface are supported" ;
2531+ emitSilenceableError ()
2532+ << " only ops implementing TilingInterface are supported" ;
25332533 diag.attachNote (op->getLoc ()) << " target op" ;
25342534 return diag;
25352535 }
@@ -2578,10 +2578,7 @@ transform::TileOp::apply(transform::TransformRewriter &rewriter,
25782578 if (failed (maybeTilingResult))
25792579 return DiagnosedSilenceableFailure::definiteFailure ();
25802580
2581- if (dpsInterface.hasBufferSemantics ())
2582- rewriter.eraseOp (op);
2583- else
2584- rewriter.replaceOp (op, maybeTilingResult->loops .front ()->getResults ());
2581+ rewriter.replaceOp (op, maybeTilingResult->replacements );
25852582
25862583 tiled.append (maybeTilingResult->tiledOps );
25872584 for (const auto &en2 : llvm::enumerate (maybeTilingResult->loops ))
@@ -2895,204 +2892,6 @@ LogicalResult TileToForallOp::verify() {
28952892 return success ();
28962893}
28972894
2898- // ===----------------------------------------------------------------------===//
2899- // TileToScfForOp
2900- // ===----------------------------------------------------------------------===//
2901-
2902- void transform::TileToScfForOp::build (OpBuilder &builder,
2903- OperationState &result, Value target,
2904- ArrayRef<OpFoldResult> mixedTileSizes,
2905- ArrayRef<int64_t > interchange) {
2906- SmallVector<int64_t > staticTileSizes;
2907- SmallVector<Value> dynamicTileSizes;
2908- dispatchIndexOpFoldResults (mixedTileSizes, dynamicTileSizes, staticTileSizes);
2909- // Call the default builder which sets up the proper operands segment sizes
2910- // attributes for multiple variadic operands. In the absence of this,
2911- // horrible bugs ensue.
2912- auto staticTileSizesAttr = builder.getDenseI64ArrayAttr (staticTileSizes);
2913- int64_t numExpectedLoops =
2914- staticTileSizes.size () - llvm::count (staticTileSizes, 0 );
2915- SmallVector<Type> resultTypes (
2916- numExpectedLoops, transform::AnyOpType::get (builder.getContext ()));
2917- build (builder, result,
2918- /* tiled_linalg_op=*/ target.getType (),
2919- /* loops=*/ resultTypes,
2920- /* target=*/ target,
2921- /* dynamic_sizes=*/ dynamicTileSizes,
2922- /* static_sizes=*/ staticTileSizesAttr,
2923- /* interchange=*/ builder.getDenseI64ArrayAttr (interchange));
2924- }
2925-
2926- DiagnosedSilenceableFailure
2927- transform::TileToScfForOp::apply (transform::TransformRewriter &rewriter,
2928- TransformResults &transformResults,
2929- TransformState &state) {
2930- ArrayRef<int64_t > tileSizes = getStaticSizes ();
2931-
2932- SmallVector<Operation *> targets =
2933- llvm::to_vector (state.getPayloadOps (getTarget ()));
2934- SmallVector<SmallVector<Operation *>> dynamicSizeProducers;
2935- dynamicSizeProducers.reserve (getDynamicSizes ().size ());
2936- for (Value dynamicSizeProducerHandle : getDynamicSizes ()) {
2937- dynamicSizeProducers.push_back (
2938- llvm::to_vector (state.getPayloadOps (dynamicSizeProducerHandle)));
2939-
2940- if (dynamicSizeProducers.back ().size () != targets.size ()) {
2941- DiagnosedSilenceableFailure diag =
2942- emitSilenceableError ()
2943- << " expected as many dynamic size-producing operations ("
2944- << dynamicSizeProducers.back ().size () << " ) as target ops ("
2945- << targets.size () << " )" ;
2946- diag.attachNote (dynamicSizeProducerHandle.getLoc ()) << " for this handle" ;
2947- return diag;
2948- }
2949-
2950- for (Operation *op : dynamicSizeProducers.back ()) {
2951- if (op->getNumResults () == 1 &&
2952- isa<IndexType>(op->getResult (0 ).getType ()))
2953- continue ;
2954- DiagnosedSilenceableFailure diag =
2955- emitSilenceableError () << " expected sizes to be produced by ops "
2956- " with a single index-type result" ;
2957- diag.attachNote (op->getLoc ()) << " size producer op" ;
2958- diag.attachNote (dynamicSizeProducerHandle.getLoc ()) << " for this handle" ;
2959- return diag;
2960- }
2961- }
2962-
2963- SmallVector<Operation *> tiled;
2964- SmallVector<SmallVector<Operation *, 4 >, 4 > loops;
2965- loops.resize (getLoops ().size ());
2966- for (auto en : llvm::enumerate (targets)) {
2967- auto tilingInterfaceOp = dyn_cast<TilingInterface>(en.value ());
2968- if (!tilingInterfaceOp) {
2969- DiagnosedSilenceableFailure diag =
2970- emitSilenceableError () << " only TilingInterface ops are supported" ;
2971- diag.attachNote (en.value ()->getLoc ()) << " target op" ;
2972- return diag;
2973- }
2974-
2975- scf::SCFTilingOptions tilingOptions;
2976- unsigned index = en.index ();
2977- if (!tileSizes.empty ()) {
2978- tilingOptions.setTileSizeComputationFunction (
2979- [&, index](OpBuilder &b, Operation *) {
2980- SmallVector<Value, 4 > sizes;
2981- sizes.reserve (tileSizes.size ());
2982- unsigned dynamicIdx = 0 ;
2983- for (OpFoldResult ofr : getMixedSizes ()) {
2984- if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) {
2985- sizes.push_back (b.create <arith::ConstantIndexOp>(
2986- getLoc (), cast<IntegerAttr>(attr).getInt ()));
2987- } else {
2988- sizes.push_back (
2989- dynamicSizeProducers[dynamicIdx++][index]->getResult (0 ));
2990- }
2991- }
2992- return sizes;
2993- });
2994- }
2995-
2996- tilingOptions.setInterchange (getInterchange ());
2997- FailureOr<scf::SCFTilingResult> tilingResult =
2998- tileUsingSCFForOp (rewriter, tilingInterfaceOp, tilingOptions);
2999- if (failed (tilingResult))
3000- return DiagnosedSilenceableFailure::definiteFailure ();
3001-
3002- rewriter.replaceOp (tilingInterfaceOp, tilingResult->replacements );
3003-
3004- tiled.append (tilingResult->tiledOps );
3005- for (const auto &en2 : llvm::enumerate (tilingResult->loops ))
3006- loops[en2.index ()].push_back (en2.value ());
3007- }
3008-
3009- transformResults.set (cast<OpResult>(getTiledLinalgOp ()), tiled);
3010- for (const auto &en : llvm::enumerate (loops))
3011- transformResults.set (cast<OpResult>(getLoops ()[en.index ()]), en.value ());
3012-
3013- return DiagnosedSilenceableFailure::success ();
3014- }
3015-
3016- SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes () {
3017- ValueRange dynamic = getDynamicSizes ();
3018- ArrayRef<int64_t > tileSizes = getStaticSizes ();
3019- SmallVector<OpFoldResult> results;
3020- results.reserve (tileSizes.size ());
3021- unsigned dynamicPos = 0 ;
3022- Builder builder (getContext ());
3023- for (int64_t size : tileSizes) {
3024- if (size == ShapedType::kDynamic ) {
3025- results.push_back (dynamic[dynamicPos++]);
3026- } else {
3027- results.push_back (builder.getIndexAttr (size));
3028- }
3029- }
3030- return results;
3031- }
3032-
3033- ParseResult transform::TileToScfForOp::parse (OpAsmParser &parser,
3034- OperationState &result) {
3035- OpAsmParser::UnresolvedOperand target;
3036- SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
3037- DenseI64ArrayAttr staticSizes;
3038- FunctionType trailingType;
3039- llvm::SMLoc typeLoc;
3040- if (parser.parseOperand (target) ||
3041- parseDynamicIndexList (parser, dynamicSizes, staticSizes) ||
3042- parseOptionalInterchange (parser, result) ||
3043- parser.parseOptionalAttrDict (result.attributes ) ||
3044- parser.getCurrentLocation (&typeLoc) ||
3045- parser.parseColonType (trailingType)) {
3046- return ParseResult::failure ();
3047- }
3048-
3049- result.addAttribute (getStaticSizesAttrName (result.name ), staticSizes);
3050- size_t numExpectedLoops =
3051- staticSizes.size () - llvm::count (staticSizes.asArrayRef (), 0 );
3052-
3053- unsigned numExpectedInputTypes = 1 + dynamicSizes.size ();
3054- if (trailingType.getNumInputs () != numExpectedInputTypes) {
3055- return parser.emitError (typeLoc)
3056- << " expected " << numExpectedInputTypes << " operand types, got "
3057- << trailingType.getNumInputs ();
3058- }
3059-
3060- unsigned numExpectedOutputTypes = 1 + numExpectedLoops;
3061- if (trailingType.getNumResults () != numExpectedOutputTypes) {
3062- return parser.emitError (typeLoc)
3063- << " expected " << numExpectedOutputTypes << " result types, got "
3064- << trailingType.getNumResults ();
3065- }
3066-
3067- if (parser.resolveOperand (target, trailingType.getInput (0 ),
3068- result.operands ) ||
3069- parser.resolveOperands (dynamicSizes,
3070- trailingType.getInputs ().drop_front (), typeLoc,
3071- result.operands ) ||
3072- parser.addTypesToList (trailingType.getResults (), result.types )) {
3073- return failure ();
3074- }
3075- return success ();
3076- }
3077-
3078- void TileToScfForOp::print (OpAsmPrinter &p) {
3079- p << ' ' << getTarget ();
3080- printDynamicIndexList (p, getOperation (), getDynamicSizes (), getStaticSizes ());
3081- printOptionalInterchange (p, getInterchange ());
3082- p.printOptionalAttrDict (getOperation ()->getAttrs (), getAttributeNames ());
3083- p << " : " ;
3084- p.printFunctionalType (getOperation ());
3085- }
3086-
3087- void transform::TileToScfForOp::getEffects (
3088- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
3089- consumesHandle (getTarget (), effects);
3090- onlyReadsHandle (getDynamicSizes (), effects);
3091- producesHandle (getTiledLinalgOp (), effects);
3092- producesHandle (getLoops (), effects);
3093- modifiesPayload (effects);
3094- }
3095-
30962895// ===----------------------------------------------------------------------===//
30972896// VectorizeOp
30982897// ===----------------------------------------------------------------------===//
0 commit comments