@@ -82,6 +82,75 @@ class IfOpConverter : public OpRewritePattern<tosa::IfOp> {
8282 }
8383};
8484
85+ class ScatterOpConverter : public OpRewritePattern <tosa::ScatterOp> {
86+ static Value createTensorDim (OpBuilder &builder, Location loc, Value tensor,
87+ int64_t dim) {
88+ return builder.createOrFold <tensor::DimOp>(loc, tensor, dim);
89+ }
90+
91+ static Value createIndexConst (OpBuilder &builder, Location loc,
92+ int64_t value) {
93+ return builder.create <arith::ConstantIndexOp>(loc, value);
94+ }
95+
96+ public:
97+ using OpRewritePattern<tosa::ScatterOp>::OpRewritePattern;
98+
99+ LogicalResult matchAndRewrite (tosa::ScatterOp scatter,
100+ PatternRewriter &rewriter) const final {
101+ auto valuesIn = scatter.getValuesIn ();
102+ auto indices = scatter.getIndices ();
103+ auto input = scatter.getInput ();
104+ auto loc = scatter.getLoc ();
105+
106+ // N, W, C are chosen to match the TOSA spec
107+ auto dimN = createTensorDim (rewriter, loc, input, 0 );
108+ auto dimW = createTensorDim (rewriter, loc, input, 1 );
109+ auto dimC = createTensorDim (rewriter, loc, input, 2 );
110+
111+ auto zero = createIndexConst (rewriter, loc, 0 );
112+ auto one = createIndexConst (rewriter, loc, 1 );
113+
114+ // Loop bounds
115+ auto lbs = llvm::SmallVector<Value>(2 , zero);
116+ auto steps = llvm::SmallVector<Value>(2 , one);
117+ auto ubs = llvm::SmallVector<Value>{{dimN, dimW}};
118+
119+ auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs,
120+ ValueRange args) -> scf::ValueVector {
121+ auto n = ivs[0 ];
122+
123+ // Read the index and cast it to index type
124+ auto index = builder.create <tensor::ExtractOp>(loc, indices, ivs);
125+ auto castIndex = builder.create <arith::IndexCastOp>(
126+ loc, builder.getIndexType (), index);
127+
128+ // Offset, sizes, and strides for the input tensor
129+ auto inputOffset = llvm::to_vector (ivs);
130+ inputOffset.push_back (zero);
131+
132+ llvm::SmallVector<Value> sizes = {one, one, dimC};
133+ llvm::SmallVector<Value> strides = {one, one, one};
134+
135+ auto slice = builder.create <tensor::ExtractSliceOp>(
136+ loc, input, inputOffset, sizes, strides);
137+
138+ // Insert the slice into the output accumulator tensor.
139+ llvm::SmallVector<Value> outputOffset = {n, castIndex, zero};
140+ auto updated = builder.create <tensor::InsertSliceOp>(
141+ loc, slice, args[0 ], outputOffset, sizes, strides);
142+
143+ return {updated};
144+ };
145+
146+ auto loops = scf::buildLoopNest (rewriter, loc, lbs, ubs, steps,
147+ ValueRange{valuesIn}, buildBody);
148+ rewriter.replaceOp (scatter, loops.results );
149+
150+ return success ();
151+ }
152+ };
153+
85154class WhileOpConverter : public OpRewritePattern <tosa::WhileOp> {
86155public:
87156 using OpRewritePattern<tosa::WhileOp>::OpRewritePattern;
@@ -106,6 +175,6 @@ class WhileOpConverter : public OpRewritePattern<tosa::WhileOp> {
106175
107176void mlir::tosa::populateTosaToSCFConversionPatterns (
108177 RewritePatternSet *patterns) {
109- patterns->add <IfOpConverter>(patterns-> getContext ());
110- patterns-> add <WhileOpConverter>( patterns->getContext ());
178+ patterns->add <IfOpConverter, ScatterOpConverter, WhileOpConverter>(
179+ patterns->getContext ());
111180}
0 commit comments