@@ -2120,6 +2120,120 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
21202120 const bool aggressiveReduceConstant;
21212121};
21222122
2123+ template <typename ElementStorageType>
2124+ DenseElementsAttr
2125+ concatenateAttrs (const ShapedType outputType, ArrayRef<ElementsAttr> inputAttrs,
2126+ const uint32_t concatAxis, PatternRewriter &rewriter,
2127+ const Type elementType) {
2128+
2129+ static_assert (std::is_same<ElementStorageType, APInt>::value ||
2130+ std::is_same<ElementStorageType, APFloat>::value,
2131+ " ElementStorageType must be either APInt or APFloat" );
2132+
2133+ SmallVector<ElementStorageType> resultValues;
2134+ if constexpr (std::is_same<ElementStorageType, APInt>::value) {
2135+ resultValues.resize_for_overwrite (outputType.getNumElements ());
2136+ } else {
2137+ resultValues.resize (
2138+ outputType.getNumElements (),
2139+ APFloat::getZero (cast<FloatType>(elementType).getFloatSemantics ()));
2140+ }
2141+ const auto outputShape = outputType.getShape ();
2142+
2143+ int64_t concatDimOffset = 0 ;
2144+ for (const auto &inputAttr : inputAttrs) {
2145+ const auto inputShape = cast<ShapedType>(inputAttr.getType ()).getShape ();
2146+ const auto inputValues = inputAttr.getValues <ElementStorageType>();
2147+
2148+ for (const auto &[inputLinearIdx, val] : llvm::enumerate (inputValues)) {
2149+ // TODO: Could be optimized to work on slices instead of single value
2150+ SmallVector<int64_t > multiDimIndex =
2151+ offsetToIndex (inputShape, inputLinearIdx);
2152+ multiDimIndex[concatAxis] += concatDimOffset;
2153+
2154+ const int64_t outputLinearIndex =
2155+ indexToOffset (outputShape, multiDimIndex);
2156+ resultValues[outputLinearIndex] = val;
2157+ }
2158+ concatDimOffset += inputShape[concatAxis];
2159+ }
2160+ return DenseElementsAttr::get (outputType, resultValues);
2161+ }
2162+
2163+ struct TosaFoldConstantConcat : public TosaFoldConstantBase <tosa::ConcatOp> {
2164+ using TosaFoldConstantBase::TosaFoldConstantBase;
2165+
2166+ LogicalResult matchAndRewrite (tosa::ConcatOp op,
2167+ PatternRewriter &rewriter) const override {
2168+ auto inputs = op->getOperands ();
2169+ const uint32_t concatAxis = op.getAxis ();
2170+ const auto outputType = cast<ShapedType>(op.getType ());
2171+ if (!outputType.hasStaticShape ()) {
2172+ return rewriter.notifyMatchFailure (
2173+ op, " Output type must have static shape for concat folding." );
2174+ }
2175+ if (llvm::any_of (inputs, [](Value v) {
2176+ return !cast<ShapedType>(v.getType ()).hasStaticShape ();
2177+ })) {
2178+ return rewriter.notifyMatchFailure (
2179+ op, " All inputs to ConcatOp must have static shape for folding." );
2180+ }
2181+
2182+ const Type elementType = outputType.getElementType ();
2183+ if (!elementType.isIntOrIndexOrFloat ()) {
2184+ // Sanity check, this should always be the case
2185+ return rewriter.notifyMatchFailure (
2186+ op, " Output element type must be int, index, or float for folding." );
2187+ }
2188+
2189+ SmallVector<ElementsAttr> inputAttrs;
2190+ inputAttrs.reserve (inputs.size ());
2191+
2192+ for (Value inputVal : inputs) {
2193+ ElementsAttr inputAsAttr;
2194+ if (!matchPattern (inputVal, m_Constant (&inputAsAttr))) {
2195+ // TODO: This could be extended to handle partial non-const inputs
2196+ return rewriter.notifyMatchFailure (
2197+ op, " All inputs to ConcatOp must be constant for folding." );
2198+ }
2199+
2200+ if (inputAsAttr.isSplat ()) {
2201+ const ShapedType inputType = cast<ShapedType>(inputAsAttr.getType ());
2202+ if (isa<IntegerType>(elementType)) {
2203+ inputAsAttr = DenseElementsAttr::get (
2204+ inputType, inputAsAttr.getSplatValue <APInt>());
2205+ } else {
2206+ inputAsAttr = DenseElementsAttr::get (
2207+ inputType, inputAsAttr.getSplatValue <APFloat>());
2208+ }
2209+ }
2210+ if (foldSplatOrSingleUseOnly && !inputVal.hasOneUse () &&
2211+ !inputAsAttr.isSplat ()) {
2212+ return rewriter.notifyMatchFailure (
2213+ op, " Concat folding heuristic: non-splat constant inputs must have "
2214+ " only a single use." );
2215+ }
2216+ inputAttrs.push_back (inputAsAttr);
2217+ }
2218+
2219+ DenseElementsAttr resultAttr;
2220+ if (auto intType = dyn_cast<IntegerType>(elementType)) {
2221+ // TODO: This could be optimized to not go to APInt if the int size
2222+ // matches c++ native types
2223+ resultAttr = concatenateAttrs<APInt>(outputType, inputAttrs, concatAxis,
2224+ rewriter, elementType);
2225+ } else {
2226+ resultAttr = concatenateAttrs<APFloat>(outputType, inputAttrs, concatAxis,
2227+ rewriter, elementType);
2228+ }
2229+
2230+ assert (resultAttr && " Result attribute should not be null." );
2231+
2232+ rewriter.replaceOpWithNewOp <tosa::ConstOp>(op, outputType, resultAttr);
2233+ return success ();
2234+ }
2235+ };
2236+
21232237} // namespace
21242238
21252239void mlir::tosa::populateTosaFoldConstantPatterns (
@@ -2167,6 +2281,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
21672281 patterns.add <TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly );
21682282 patterns.add <TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly );
21692283 patterns.add <TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly );
2284+ patterns.add <TosaFoldConstantConcat>(ctx, options.foldSplatOrSingleUseOnly );
21702285 if (options.enableTileFolding )
21712286 patterns.add <TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly );
21722287}
0 commit comments