|
13 | 13 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
14 | 14 | #include "mlir/Dialect/Arith/IR/Arith.h" |
15 | 15 | #include "mlir/Dialect/Complex/IR/Complex.h" |
| 16 | +#include "mlir/Dialect/Utils/StaticValueUtils.h" |
16 | 17 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
17 | 18 | #include "llvm/ADT/SmallBitVector.h" |
18 | 19 | #include <numeric> |
19 | 20 |
|
20 | 21 | using namespace mlir; |
21 | 22 |
|
| 23 | +std::optional<SmallVector<OpFoldResult>> |
| 24 | +mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, |
| 25 | + ShapedType expandedType, |
| 26 | + ArrayRef<ReassociationIndices> reassociation, |
| 27 | + ArrayRef<OpFoldResult> inputShape) { |
| 28 | + |
| 29 | + SmallVector<Value> outputShapeValues; |
| 30 | + SmallVector<int64_t> outputShapeInts; |
| 31 | + // For zero-rank inputs, all dims in result shape are unit extent. |
| 32 | + if (inputShape.empty()) { |
| 33 | + outputShapeInts.resize(expandedType.getRank(), 1); |
| 34 | + return getMixedValues(outputShapeInts, outputShapeValues, b); |
| 35 | + } |
| 36 | + |
| 37 | + // Check for all static shapes. |
| 38 | + if (expandedType.hasStaticShape()) { |
| 39 | + ArrayRef<int64_t> staticShape = expandedType.getShape(); |
| 40 | + outputShapeInts.assign(staticShape.begin(), staticShape.end()); |
| 41 | + return getMixedValues(outputShapeInts, outputShapeValues, b); |
| 42 | + } |
| 43 | + |
| 44 | + outputShapeInts.resize(expandedType.getRank(), ShapedType::kDynamic); |
| 45 | + for (const auto &it : llvm::enumerate(reassociation)) { |
| 46 | + ReassociationIndices indexGroup = it.value(); |
| 47 | + |
| 48 | + int64_t indexGroupStaticSizesProductInt = 1; |
| 49 | + bool foundDynamicShape = false; |
| 50 | + for (int64_t index : indexGroup) { |
| 51 | + int64_t outputDimSize = expandedType.getDimSize(index); |
| 52 | + // Cannot infer expanded shape with multiple dynamic dims in the |
| 53 | + // same reassociation group! |
| 54 | + if (ShapedType::isDynamic(outputDimSize)) { |
| 55 | + if (foundDynamicShape) |
| 56 | + return std::nullopt; |
| 57 | + foundDynamicShape = true; |
| 58 | + } else { |
| 59 | + outputShapeInts[index] = outputDimSize; |
| 60 | + indexGroupStaticSizesProductInt *= outputDimSize; |
| 61 | + } |
| 62 | + } |
| 63 | + if (!foundDynamicShape) |
| 64 | + continue; |
| 65 | + |
| 66 | + int64_t inputIndex = it.index(); |
| 67 | + // Call get<Value>() under the assumption that we're not casting |
| 68 | + // dynamism. |
| 69 | + Value indexGroupSize = inputShape[inputIndex].get<Value>(); |
| 70 | + Value indexGroupStaticSizesProduct = |
| 71 | + b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt); |
| 72 | + Value dynamicDimSize = b.createOrFold<arith::DivUIOp>( |
| 73 | + loc, indexGroupSize, indexGroupStaticSizesProduct); |
| 74 | + outputShapeValues.push_back(dynamicDimSize); |
| 75 | + } |
| 76 | + |
| 77 | + if ((int64_t)outputShapeValues.size() != |
| 78 | + llvm::count(outputShapeInts, ShapedType::kDynamic)) |
| 79 | + return std::nullopt; |
| 80 | + |
| 81 | + return getMixedValues(outputShapeInts, outputShapeValues, b); |
| 82 | +} |
| 83 | + |
22 | 84 | /// Matches a ConstantIndexOp. |
23 | 85 | /// TODO: This should probably just be a general matcher that uses matchConstant |
24 | 86 | /// and checks the operation for an index type. |
|
0 commit comments