@@ -71,6 +71,25 @@ static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
7171 return t1.compose (fusedConsumerArgIndexMap);
7272}
7373
74+ // / Returns a set of indices of the producer's results which would
75+ // / be preserved after the fusion.
76+ llvm::SmallDenseSet<int >
77+ ElementwiseOpFusionResult::getPreservedProducerResults (GenericOp producer,
78+ GenericOp consumer) {
79+ llvm::SmallDenseSet<int > preservedProducerResults;
80+ for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
81+ auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
82+ if (producer.payloadUsesValueFromOperand (outputOperand) ||
83+ !producer.canOpOperandsBeDropped (outputOperand) ||
84+ llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
85+ return user != consumer.getOperation ();
86+ })) {
87+ preservedProducerResults.insert (producerResult.index ());
88+ }
89+ }
90+ return preservedProducerResults;
91+ }
92+
7493// / Conditions for elementwise fusion of generic operations.
7594bool mlir::linalg::areElementwiseOpsFusable (OpOperand *fusedOperand) {
7695 if (!fusedOperand)
@@ -285,17 +304,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
285304 assert (consumer.isDpsInput (fusedOperand) &&
286305 " expected producer of input operand" );
287306 // / Find the results of the producer that have uses outside of the consumer.
288- llvm::SmallDenseSet<int > preservedProducerResults;
289- for (const auto &producerResult : llvm::enumerate (producer->getResults ())) {
290- auto *outputOperand = producer.getDpsInitOperand (producerResult.index ());
291- if (producer.payloadUsesValueFromOperand (outputOperand) ||
292- !producer.canOpOperandsBeDropped (outputOperand) ||
293- llvm::any_of (producerResult.value ().getUsers (), [&](Operation *user) {
294- return user != consumer.getOperation ();
295- })) {
296- preservedProducerResults.insert (producerResult.index ());
297- }
298- }
307+ llvm::SmallDenseSet<int > preservedProducerResults =
308+ ElementwiseOpFusionResult::getPreservedProducerResults (producer,
309+ consumer);
299310
300311 // Compute the fused operands list and indexing maps.
301312 SmallVector<Value> fusedInputOperands, fusedOutputOperands;
0 commit comments