@@ -1548,7 +1548,6 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
15481548class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15491549 MemRef_Op<mnemonic, !listconcat(traits,
15501550 [Pure, ViewLikeOpInterface])>,
1551- Arguments<(ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation)>,
15521551 Results<(outs AnyStridedMemRef:$result)>{
15531552
15541553 code commonExtraClassDeclaration = [{
@@ -1573,10 +1572,6 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
15731572 Value getViewSource() { return getSrc(); }
15741573 }];
15751574
1576- let assemblyFormat = [{
1577- $src $reassociation attr-dict `:` type($src) `into` type($result)
1578- }];
1579-
15801575 let hasFolder = 1;
15811576 let hasCanonicalizer = 1;
15821577 let hasVerifier = 1;
@@ -1598,14 +1593,10 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
15981593 Example:
15991594
16001595 ```mlir
1601- %r = memref.expand_shape %0 [[0, 1], [2]]
1602- : memref<?x?xf32 > into memref<?x5x?xf32 >
1596+ %r = memref.expand_shape %0 [[0, 1], [2]] output_shape [%sz0, %sz1, 32]
1597+ : memref<?x32xf32 > into memref<?x?x32xf32 >
16031598 ```
16041599
1605- At most one dimension of a reassociation group (e.g., [0, 1] above) may be
1606- dynamic in the result type. Otherwise, the op would be ambiguous, as it
1607- would not be clear how the source dimension is extended.
1608-
16091600 If an op can be statically proven to be invalid (e.g, an expansion from
16101601 `memref<10xf32>` to `memref<2x6xf32>`), it is rejected by the verifier. If
16111602 it cannot statically be proven invalid (e.g., the full example above; it is
@@ -1622,41 +1613,80 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
16221613 there must be a dynamic result dimension in the corresponding reassociation
16231614 group. Same for strides.
16241615
1616+ The representation for the output shape supports a partially-static
1617+ specification via attributes specified through the `static_output_shape`
1618+ argument. A special sentinel value `ShapedType::kDynamic` encodes that the
1619+ corresponding entry has a dynamic value. There must be exactly as many SSA
1620+ inputs in `output_shape` as there are `ShapedType::kDynamic` entries in
1621+ `static_output_shape`.
1622+
16251623 Note: This op currently assumes that the inner strides are of the
16261624 source/result layout map are the faster-varying ones.
16271625 }];
16281626
1627+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation,
1628+ Variadic<Index>:$output_shape,
1629+ DenseI64ArrayAttr:$static_output_shape);
1630+
1631+ let assemblyFormat = [{
1632+ $src $reassociation `output_shape`
1633+ custom<DynamicIndexList>($output_shape, $static_output_shape) attr-dict `:`
1634+ type($src) `into` type($result)
1635+ }];
1636+
16291637 let builders = [
16301638 // Builders using ReassociationIndices.
16311639 OpBuilder<(ins "Type":$resultType, "Value":$src,
16321640 "ArrayRef<ReassociationIndices>":$reassociation,
1633- CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
1641+ "ArrayRef<OpFoldResult>":$outputShape)>,
1642+
1643+ // It will infer output shape using inferOutputShape() method.
1644+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1645+ "ArrayRef<ReassociationIndices>":$reassociation)>,
1646+
1647+ // Builder using ReassociationExprs.
1648+ OpBuilder<(ins "Type":$resultType, "Value":$src,
1649+ "ArrayRef<ReassociationExprs>":$reassociation),
16341650 [{
1635- build($_builder, $_state, resultType, src, attrs);
1636- $_state.addAttribute(" reassociation",
1637- getReassociationIndicesAttribute ($_builder, reassociation) );
1651+ auto reassociationIndices =
1652+ convertReassociationMapsToIndices( reassociation);
1653+ build ($_builder, $_state, resultType, src, reassociationIndices );
16381654 }]>,
16391655
1640- // Builder using ReassociationExprs.
16411656 OpBuilder<(ins "Type":$resultType, "Value":$src,
16421657 "ArrayRef<ReassociationExprs>":$reassociation,
1643- CArg< "ArrayRef<NamedAttribute>", "{}">:$attrs ),
1658+ "ArrayRef<OpFoldResult>":$outputShape ),
16441659 [{
16451660 auto reassociationMaps =
1646- convertReassociationMapsToIndices($_builder, reassociation);
1647- build($_builder, $_state, resultType, src, reassociationMaps, attrs);
1661+ convertReassociationMapsToIndices(reassociation);
1662+ build($_builder, $_state, resultType, src, reassociationMaps,
1663+ outputShape);
16481664 }]>,
16491665
1666+ // Builder that infers the result layout map. The result shape must be
1667+ // specified. Otherwise, the op may be ambiguous. The output shape for
1668+ // the op will be inferred using the inferOutputShape() method.
1669+ OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1670+ "ArrayRef<ReassociationIndices>":$reassociation)>,
1671+
16501672 // Builder that infers the result layout map. The result shape must be
16511673 // specified. Otherwise, the op may be ambiguous.
16521674 OpBuilder<(ins "ArrayRef<int64_t>":$resultShape, "Value":$src,
1653- "ArrayRef<ReassociationIndices>":$reassociation)>
1675+ "ArrayRef<ReassociationIndices>":$reassociation,
1676+ "ArrayRef<OpFoldResult>":$outputShape)>
16541677 ];
16551678
16561679 let extraClassDeclaration = commonExtraClassDeclaration # [{
16571680 static FailureOr<MemRefType> computeExpandedType(
16581681 MemRefType srcType, ArrayRef<int64_t> resultShape,
16591682 ArrayRef<ReassociationIndices> reassociation);
1683+
1684+ // Infer the output shape for a memref.expand_shape when it is possible
1685+ // to do so.
1686+ static FailureOr<SmallVector<OpFoldResult>> inferOutputShape(
1687+ OpBuilder &b, Location loc, MemRefType expandedType,
1688+ ArrayRef<ReassociationIndices> reassociation,
1689+ ArrayRef<OpFoldResult> inputShape);
16601690 }];
16611691
16621692 let hasVerifier = 1;
@@ -1707,6 +1737,12 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17071737 source/result layout map are the faster-varying ones.
17081738 }];
17091739
1740+ let arguments = (ins AnyStridedMemRef:$src, IndexListArrayAttr:$reassociation);
1741+
1742+ let assemblyFormat = [{
1743+ $src $reassociation attr-dict `:` type($src) `into` type($result)
1744+ }];
1745+
17101746 let builders = [
17111747 // Builders for a contracting reshape whose result type is computed from
17121748 // `src` and `reassociation`.
@@ -1718,7 +1754,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17181754 CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17191755 [{
17201756 auto reassociationMaps =
1721- convertReassociationMapsToIndices($_builder, reassociation);
1757+ convertReassociationMapsToIndices(reassociation);
17221758 build($_builder, $_state, src, reassociationMaps, attrs);
17231759 }]>,
17241760
@@ -1736,7 +1772,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape", [
17361772 CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs),
17371773 [{
17381774 auto reassociationMaps =
1739- convertReassociationMapsToIndices($_builder, reassociation);
1775+ convertReassociationMapsToIndices(reassociation);
17401776 build($_builder, $_state, resultType, src, reassociationMaps, attrs);
17411777 }]>
17421778 ];
0 commit comments