99#include " mlir/Dialect/Linalg/EDSC/Builders.h"
1010#include " mlir/Dialect/Linalg/EDSC/Intrinsics.h"
1111#include " mlir/Dialect/Linalg/IR/LinalgOps.h"
12+ #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
1213#include " mlir/EDSC/Builders.h"
1314#include " mlir/EDSC/Intrinsics.h"
1415#include " mlir/IR/AffineExpr.h"
@@ -144,7 +145,7 @@ static void getMaxDimIndex(ArrayRef<StructuredIndexed> structuredIndices,
144145}
145146
146147Operation *mlir::edsc::makeGenericLinalgOp (
147- ArrayRef<IterType > iteratorTypes, ArrayRef<StructuredIndexed> inputs,
148+ ArrayRef<IteratorType > iteratorTypes, ArrayRef<StructuredIndexed> inputs,
148149 ArrayRef<StructuredIndexed> outputs,
149150 function_ref<void (ArrayRef<BlockArgument>)> regionBuilder,
150151 ArrayRef<Value> otherValues, ArrayRef<Attribute> otherAttributes) {
@@ -240,8 +241,8 @@ void mlir::edsc::ops::macRegionBuilder(ArrayRef<BlockArgument> args) {
240241Operation *mlir::edsc::ops::linalg_pointwise (UnaryPointwiseOpBuilder unaryOp,
241242 StructuredIndexed I,
242243 StructuredIndexed O) {
243- SmallVector<edsc::IterType , 4 > iterTypes (O.getExprs ().size (),
244- edsc::IterType ::Parallel);
244+ SmallVector<IteratorType , 4 > iterTypes (O.getExprs ().size (),
245+ IteratorType ::Parallel);
245246 if (O.getType ().isa <RankedTensorType>()) {
246247 auto fun = [&unaryOp](ArrayRef<BlockArgument> args) {
247248 assert (args.size () == 1 && " expected 1 block arguments" );
@@ -270,8 +271,8 @@ Operation *mlir::edsc::ops::linalg_pointwise(BinaryPointwiseOpBuilder binaryOp,
270271 StructuredIndexed I1,
271272 StructuredIndexed I2,
272273 StructuredIndexed O) {
273- SmallVector<edsc::IterType , 4 > iterTypes (O.getExprs ().size (),
274- edsc::IterType ::Parallel);
274+ SmallVector<IteratorType , 4 > iterTypes (O.getExprs ().size (),
275+ IteratorType ::Parallel);
275276 if (O.getType ().isa <RankedTensorType>()) {
276277 auto fun = [&binaryOp](ArrayRef<BlockArgument> args) {
277278 assert (args.size () == 2 && " expected 2 block arguments" );
@@ -315,7 +316,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
315316 bindDims (ScopedContext::getContext (), m, n, k);
316317 StructuredIndexed A (vA), B (vB), C (vC);
317318 return makeGenericLinalgOp (
318- {IterType ::Parallel, IterType ::Parallel, IterType ::Reduction},
319+ {IteratorType ::Parallel, IteratorType ::Parallel, IteratorType ::Reduction},
319320 {A ({m, k}), B ({k, n})},
320321 {C ({m, n})},
321322 macRegionBuilder);
@@ -329,7 +330,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
329330 bindDims (ScopedContext::getContext (), m, n, k);
330331 StructuredIndexed A (vA), B (vB), C (tC);
331332 return makeGenericLinalgOp (
332- {IterType ::Parallel, IterType ::Parallel, IterType ::Reduction},
333+ {IteratorType ::Parallel, IteratorType ::Parallel, IteratorType ::Reduction},
333334 {A ({m, k}), B ({k, n})},
334335 {C ({m, n})},
335336 mulRegionBuilder);
@@ -343,7 +344,7 @@ Operation *mlir::edsc::ops::linalg_matmul(ValueHandle vA, ValueHandle vB,
343344 bindDims (ScopedContext::getContext (), m, n, k);
344345 StructuredIndexed A (vA), B (vB), C (vC), D (tD);
345346 return makeGenericLinalgOp (
346- {IterType ::Parallel, IterType ::Parallel, IterType ::Reduction},
347+ {IteratorType ::Parallel, IteratorType ::Parallel, IteratorType ::Reduction},
347348 {A ({m, k}), B ({k, n}), C ({m, n})},
348349 {D ({m, n})},
349350 macRegionBuilder);
@@ -360,8 +361,8 @@ Operation *mlir::edsc::ops::linalg_conv_nhwc(ValueHandle vI, ValueHandle vW,
360361 assert ((strides.empty () || strides.size () == 2 ) && " only 2-D conv atm" );
361362
362363 // Some short names.
363- auto par = IterType ::Parallel;
364- auto red = IterType ::Reduction;
364+ auto par = IteratorType ::Parallel;
365+ auto red = IteratorType ::Reduction;
365366 auto s = strides;
366367 auto d = dilations;
367368
@@ -393,8 +394,8 @@ Operation *mlir::edsc::ops::linalg_dilated_conv_nhwc(
393394 assert ((strides.empty () || strides.size () == 2 ) && " only 2-D conv atm" );
394395
395396 // Some short names.
396- auto par = IterType ::Parallel;
397- auto red = IterType ::Reduction;
397+ auto par = IteratorType ::Parallel;
398+ auto red = IteratorType ::Reduction;
398399 auto s = strides;
399400 auto d = dilations;
400401
0 commit comments