@@ -186,56 +186,63 @@ static Value createLinalgBodyCalculationForElementwiseOp(
186186 if (isa<tosa::NegateOp>(op)) {
187187 auto negate = cast<tosa::NegateOp>(op);
188188
189+ int64_t inZp = 0 , outZp = 0 ;
189190 FailureOr<int64_t > maybeInZp = negate.getInput1ZeroPoint ();
190- if (failed (maybeInZp)) {
191- (void )rewriter.notifyMatchFailure (
192- op, " input1 zero point cannot be statically determined" );
193- return nullptr ;
194- }
195-
196191 FailureOr<int64_t > maybeOutZp = negate.getOutputZeroPoint ();
197- if (failed (maybeOutZp)) {
198- (void )rewriter.notifyMatchFailure (
199- op, " output zero point cannot be statically determined" );
200- return nullptr ;
201- }
202-
203- int64_t inZp = *maybeInZp;
204- int64_t outZp = *maybeOutZp;
192+ bool hasInZp = !failed (maybeInZp);
193+ bool hasOutZp = !failed (maybeOutZp);
194+ if (hasInZp)
195+ inZp = *maybeInZp;
196+ if (hasOutZp)
197+ outZp = *maybeOutZp;
205198
206199 if (isa<FloatType>(elementTy))
207200 return arith::NegFOp::create (rewriter, loc, resultTypes, args[0 ]);
208201
209202 if (isa<IntegerType>(elementTy)) {
210- if (!inZp && !outZp) {
203+ if (hasInZp && hasOutZp && !inZp && !outZp) {
211204 auto constant = arith::ConstantOp::create (
212205 rewriter, loc, IntegerAttr::get (elementTy, 0 ));
213206 return arith::SubIOp::create (rewriter, loc, resultTypes, constant,
214207 args[0 ]);
215208 }
216209
210+ Value zpAddValue;
211+ Type intermediateType;
217212 // Compute the maximum value that can occur in the intermediate buffer.
218213 const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth ();
219- const int64_t zpAdd = inZp + outZp;
220- const int64_t maxValue =
221- APInt::getSignedMaxValue (inputBitWidth).getSExtValue () +
222- std::abs (zpAdd) + 1 ;
223-
224- // Convert that maximum value into the maximum bitwidth needed to
225- // represent it. We assume 48-bit numbers may be supported further in
226- // the pipeline.
227214 int intermediateBitWidth = 64 ;
228- if (maxValue <= APInt::getSignedMaxValue (16 ).getSExtValue ()) {
229- intermediateBitWidth = 16 ;
230- } else if (maxValue <= APInt::getSignedMaxValue (32 ).getSExtValue ()) {
231- intermediateBitWidth = 32 ;
232- } else if (maxValue <= APInt::getSignedMaxValue (48 ).getSExtValue ()) {
233- intermediateBitWidth = 48 ;
234- }
235215
236- Type intermediateType = rewriter.getIntegerType (intermediateBitWidth);
237- Value zpAddValue = arith::ConstantOp::create (
238- rewriter, loc, rewriter.getIntegerAttr (intermediateType, zpAdd));
216+ if (hasInZp && hasOutZp) {
217+ // Compute the maximum value that can occur in the intermediate buffer.
218+ const int64_t zpAdd = inZp + outZp;
219+ const int64_t maxValue =
220+ APInt::getSignedMaxValue (inputBitWidth).getSExtValue () +
221+ std::abs (zpAdd) + 1 ;
222+
223+ // Convert that maximum value into the maximum bitwidth needed to
224+ // represent it. We assume 48-bit numbers may be supported further in
225+ // the pipeline.
226+ if (maxValue <= APInt::getSignedMaxValue (16 ).getSExtValue ()) {
227+ intermediateBitWidth = 16 ;
228+ } else if (maxValue <= APInt::getSignedMaxValue (32 ).getSExtValue ()) {
229+ intermediateBitWidth = 32 ;
230+ } else if (maxValue <= APInt::getSignedMaxValue (48 ).getSExtValue ()) {
231+ intermediateBitWidth = 48 ;
232+ }
233+
234+ intermediateType = rewriter.getIntegerType (intermediateBitWidth);
235+ zpAddValue = rewriter.create <arith::ConstantOp>(
236+ loc, rewriter.getIntegerAttr (intermediateType, zpAdd));
237+ } else {
238+ intermediateType = rewriter.getIntegerType (intermediateBitWidth);
239+ auto arg1 =
240+ rewriter.create <arith::ExtSIOp>(loc, intermediateType, args[1 ]);
241+ auto arg2 =
242+ rewriter.create <arith::ExtSIOp>(loc, intermediateType, args[2 ]);
243+ zpAddValue =
244+ rewriter.create <arith::AddIOp>(loc, intermediateType, arg1, arg2);
245+ }
239246
240247 // The negation can be applied by doing:
241248 // outputValue = inZp + outZp - inputValue
@@ -1013,9 +1020,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
10131020 else
10141021 return operands.take_front (3 );
10151022 }
1016- // Input1_zp and output_zp cannot broadcast
1017- if (isa<tosa::NegateOp>(operation))
1023+ if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1024+ FailureOr<int64_t > maybeInZp = negate.getInput1ZeroPoint ();
1025+ FailureOr<int64_t > maybeOutZp = negate.getOutputZeroPoint ();
1026+ if (failed (maybeOutZp) && failed (maybeInZp))
1027+ return operands;
1028+ // Input1_zp and output_zp cannot broadcast when they are constants.
10181029 return operands.take_front (1 );
1030+ }
10191031 return operands;
10201032}
10211033
0 commit comments