diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 0046d977c68f4..65e6d4f64e36c 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1322,7 +1322,14 @@ struct OperationEquivalence { // When provided, the location attached to the operation are ignored. IgnoreLocations = 1, - LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreLocations) + // When provided, the discardable attributes attached to the operation are + // ignored. + IgnoreDiscardableAttrs = 2, + + // When provided, the properties attached to the operation are ignored. + IgnoreProperties = 4, + + LLVM_MARK_AS_BITMASK_ENUM(/* LargestValue = */ IgnoreProperties) }; /// Compute a hash for the given operation. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 7c9e6c89d4d8e..b591b50f2d0dc 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -680,9 +680,14 @@ llvm::hash_code OperationEquivalence::computeHash( // - Operation Name // - Attributes // - Result Types - llvm::hash_code hash = - llvm::hash_combine(op->getName(), op->getRawDictionaryAttrs(), - op->getResultTypes(), op->hashProperties()); + DictionaryAttr dictAttrs; + if (!(flags & Flags::IgnoreDiscardableAttrs)) + dictAttrs = op->getRawDictionaryAttrs(); + llvm::hash_code hashProperties; + if (!(flags & Flags::IgnoreProperties)) + hashProperties = op->hashProperties(); + llvm::hash_code hash = llvm::hash_combine( + op->getName(), dictAttrs, op->getResultTypes(), hashProperties); // - Location if required if (!(flags & Flags::IgnoreLocations)) @@ -836,14 +841,19 @@ OperationEquivalence::isRegionEquivalentTo(Region *lhs, Region *rhs, return true; // 1. Compare the operation properties. + if (!(flags & IgnoreDiscardableAttrs) && + lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs()) + return false; + if (lhs->getName() != rhs->getName() || - lhs->getRawDictionaryAttrs() != rhs->getRawDictionaryAttrs() || lhs->getNumRegions() != rhs->getNumRegions() || lhs->getNumSuccessors() != rhs->getNumSuccessors() || lhs->getNumOperands() != rhs->getNumOperands() || - lhs->getNumResults() != rhs->getNumResults() || - !lhs->getName().compareOpProperties(lhs->getPropertiesStorage(), - rhs->getPropertiesStorage())) + lhs->getNumResults() != rhs->getNumResults()) + return false; + if (!(flags & IgnoreProperties) && + !(lhs->getName().compareOpProperties(lhs->getPropertiesStorage(), + rhs->getPropertiesStorage()))) return false; if (!(flags & IgnoreLocations) && lhs->getLoc() != rhs->getLoc()) return false; diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index bac2b72b68deb..b18512817969e 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -315,6 +315,7 @@ TEST(OperandStorageTest, PopulateDefaultAttrs) { TEST(OperationEquivalenceTest, HashWorksWithFlags) { MLIRContext context; context.getOrLoadDialect(); + OpBuilder b(&context); auto *op1 = createOp(&context); // `op1` has an unknown loc. @@ -325,12 +326,36 @@ TEST(OperationEquivalenceTest, HashWorksWithFlags) { op, OperationEquivalence::ignoreHashValue, OperationEquivalence::ignoreHashValue, flags); }; + // Check ignore location. EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreLocations), getHash(op2, OperationEquivalence::IgnoreLocations)); EXPECT_NE(getHash(op1, OperationEquivalence::None), getHash(op2, OperationEquivalence::None)); + op1->setLoc(NameLoc::get(StringAttr::get(&context, "foo"))); + // Check ignore discardable dictionary attributes. + SmallVector newAttrs = { + b.getNamedAttr("foo", b.getStringAttr("f"))}; + op1->setAttrs(newAttrs); + EXPECT_EQ(getHash(op1, OperationEquivalence::IgnoreDiscardableAttrs), + getHash(op2, OperationEquivalence::IgnoreDiscardableAttrs)); + EXPECT_NE(getHash(op1, OperationEquivalence::None), + getHash(op2, OperationEquivalence::None)); op1->destroy(); op2->destroy(); + + // Check ignore properties. + auto req1 = b.getI32IntegerAttr(10); + Operation *opWithProperty1 = b.create( + b.getUnknownLoc(), req1, nullptr, nullptr, req1); + auto req2 = b.getI32IntegerAttr(60); + Operation *opWithProperty2 = b.create( + b.getUnknownLoc(), req2, nullptr, nullptr, req2); + EXPECT_NE(getHash(op1, OperationEquivalence::None), + getHash(op2, OperationEquivalence::None)); + EXPECT_EQ(getHash(opWithProperty1, OperationEquivalence::IgnoreProperties), + getHash(opWithProperty2, OperationEquivalence::IgnoreProperties)); + opWithProperty1->destroy(); + opWithProperty2->destroy(); } } // namespace