From e112855a28f220b24b465c6c1b78a526f56a3e46 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 15 Dec 2024 17:36:49 +0100 Subject: [PATCH 1/8] ex --- .../lib/Optimizer/CodeGen/BoxedProcedure.cpp | 1 - mlir/docs/DialectConversion.md | 35 +- .../mlir/Transforms/DialectConversion.h | 18 +- .../Conversion/LLVMCommon/TypeConverter.cpp | 16 +- .../EmitC/Transforms/TypeConversions.cpp | 1 - .../Dialect/Linalg/Transforms/Detensorize.cpp | 1 - .../Quant/Transforms/StripFuncQuantTypes.cpp | 1 - .../Utils/SparseTensorDescriptor.cpp | 3 - .../Vector/Transforms/VectorLinearize.cpp | 1 - .../Transforms/Utils/DialectConversion.cpp | 432 +++++++++--------- mlir/test/Transforms/test-legalizer.mlir | 7 +- .../Func/TestDecomposeCallGraphTypes.cpp | 2 +- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 1 - .../lib/Transforms/TestDialectConversion.cpp | 1 - 14 files changed, 224 insertions(+), 296 deletions(-) diff --git a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp index 1bb91d252529f..104ae7408b80c 100644 --- a/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp +++ b/flang/lib/Optimizer/CodeGen/BoxedProcedure.cpp @@ -172,7 +172,6 @@ class BoxprocTypeRewriter : public mlir::TypeConverter { addConversion([&](TypeDescType ty) { return TypeDescType::get(convertType(ty.getOfTy())); }); - addArgumentMaterialization(materializeProcedure); addSourceMaterialization(materializeProcedure); addTargetMaterialization(materializeProcedure); } diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md index 3168f5e13c751..abacd5a82c61e 100644 --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -242,19 +242,6 @@ cannot. These materializations are used by the conversion framework to ensure type safety during the conversion process. There are several types of materializations depending on the situation. -* Argument Materialization - - - An argument materialization is used when converting the type of a block - argument during a [signature conversion](#region-signature-conversion). - The new block argument types are specified in a `SignatureConversion` - object. An original block argument can be converted into multiple - block arguments, which is not supported everywhere in the dialect - conversion. (E.g., adaptors support only a single replacement value for - each original value.) Therefore, an argument materialization is used to - convert potentially multiple new block arguments back into a single SSA - value. An argument materialization is also used when replacing an op - result with multiple values. - * Source Materialization - A source materialization is used when a value was replaced with a value @@ -343,17 +330,6 @@ class TypeConverter { /// Materialization functions must be provided when a type conversion may /// persist after the conversion has finished. - /// This method registers a materialization that will be called when - /// converting (potentially multiple) block arguments that were the result of - /// a signature conversion of a single block argument, to a single SSA value - /// with the old argument type. - template ::template arg_t<1>> - void addArgumentMaterialization(FnT &&callback) { - argumentMaterializations.emplace_back( - wrapMaterialization(std::forward(callback))); - } - /// This method registers a materialization that will be called when /// converting a replacement value back to its original source type. /// This is used when some uses of the original value persist beyond the main @@ -406,12 +382,11 @@ done explicitly via a conversion pattern. To convert the types of block arguments within a Region, a custom hook on the `ConversionPatternRewriter` must be invoked; `convertRegionTypes`. This hook uses a provided type converter to apply type conversions to all blocks of a -given region. As noted above, the conversions performed by this method use the -argument materialization hook on the `TypeConverter`. This hook also takes an -optional `TypeConverter::SignatureConversion` parameter that applies a custom -conversion to the entry block of the region. The types of the entry block -arguments are often tied semantically to the operation, e.g., -`func::FuncOp`, `AffineForOp`, etc. +given region. This hook also takes an optional +`TypeConverter::SignatureConversion` parameter that applies a custom conversion +to the entry block of the region. The types of the entry block arguments are +often tied semantically to the operation, e.g., `func::FuncOp`, `AffineForOp`, +etc. To convert the signature of just one given block, the `applySignatureConversion` hook can be used. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 28150e886913e..9a6975dcf8dfa 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -181,6 +181,10 @@ class TypeConverter { /// converting (potentially multiple) block arguments that were the result of /// a signature conversion of a single block argument, to a single SSA value /// with the old block argument type. + /// + /// Note: Argument materializations are used only with the 1:N dialect + /// conversion driver. The 1:N dialect conversion driver will be removed soon + /// and so will be argument materializations. template >::template arg_t<1>> void addArgumentMaterialization(FnT &&callback) { @@ -880,15 +884,7 @@ class ConversionPatternRewriter final : public PatternRewriter { void replaceOp(Operation *op, Operation *newOp) override; /// Replace the given operation with the new value ranges. The number of op - /// results and value ranges must match. If an original SSA value is replaced - /// by multiple SSA values (i.e., a value range has more than 1 element), the - /// conversion driver will insert an argument materialization to convert the - /// N SSA values back into 1 SSA value of the original type. The given - /// operation is erased. - /// - /// Note: The argument materialization is a workaround until we have full 1:N - /// support in the dialect conversion. (It is going to disappear from both - /// `replaceOpWithMultiple` and `applySignatureConversion`.) + /// results and value ranges must match. The given operation is erased. void replaceOpWithMultiple(Operation *op, ArrayRef newValues); /// PatternRewriter hook for erasing a dead operation. The uses of this @@ -1285,8 +1281,8 @@ struct ConversionConfig { // represented at the moment. RewriterBase::Listener *listener = nullptr; - /// If set to "true", the dialect conversion attempts to build source/target/ - /// argument materializations through the type converter API in lieu of + /// If set to "true", the dialect conversion attempts to build source/target + /// materializations through the type converter API in lieu of /// "builtin.unrealized_conversion_cast ops". The conversion process fails if /// at least one materialization could not be built. /// diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 49e2d94328664..72799e42cf3fd 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -85,7 +85,7 @@ static Value unrankedMemRefMaterialization(OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter) { - // An argument materialization must return a value of type + // A source materialization must return a value of type // `resultType`, so insert a cast from the memref descriptor type // (!llvm.struct) to the original memref type. Value packed = @@ -101,7 +101,7 @@ static Value rankedMemRefMaterialization(OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc, const LLVMTypeConverter &converter) { - // An argument materialization must return a value of type `resultType`, + // A source materialization must return a value of type `resultType`, // so insert a cast from the memref descriptor type (!llvm.struct) to the // original memref type. Value packed = @@ -234,19 +234,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, .getResult(0); }); - // Argument materializations convert from the new block argument types + // Source materializations convert from the new block argument types // (multiple SSA values that make up a memref descriptor) back to the // original block argument type. - addArgumentMaterialization([&](OpBuilder &builder, - UnrankedMemRefType resultType, - ValueRange inputs, Location loc) { - return unrankedMemRefMaterialization(builder, resultType, inputs, loc, - *this); - }); - addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, - ValueRange inputs, Location loc) { - return rankedMemRefMaterialization(builder, resultType, inputs, loc, *this); - }); addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) { diff --git a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp index 0b3a494794f3f..72c8fd0f32485 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp @@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) { converter.addSourceMaterialization(materializeAsUnrealizedCast); converter.addTargetMaterialization(materializeAsUnrealizedCast); - converter.addArgumentMaterialization(materializeAsUnrealizedCast); } /// Get an unsigned integer or size data type corresponding to \p ty. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 0e651f4cee4c3..fc6671ef81175 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter { }); addSourceMaterialization(sourceMaterializationCallback); - addArgumentMaterialization(sourceMaterializationCallback); } }; diff --git a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp index 6191272266283..71b88d1be1b05 100644 --- a/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp @@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter { addConversion(convertQuantizedType); addConversion(convertTensorType); - addArgumentMaterialization(materializeConversion); addSourceMaterialization(materializeConversion); addTargetMaterialization(materializeConversion); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp index 834e3634cc130..8bbb2cac5efdf 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp @@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() { // Required by scf.for 1:N type conversion. addSourceMaterialization(materializeTuple); - - // Required as a workaround until we have full 1:N support. - addArgumentMaterialization(materializeTuple); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 757631944f224..68535ae5a7a5c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return builder.create(loc, type, inputs.front()); }; - typeConverter.addArgumentMaterialization(materializeCast); typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); target.markUnknownOpDynamicallyLegal( diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 2b006430d3817..96cbe07f0f12f 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Block.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -53,6 +54,55 @@ static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) { }); } +/// Given two insertion points in the same block, choose the later one. +static OpBuilder::InsertPoint +chooseLaterInsertPointInBlock(OpBuilder::InsertPoint a, + OpBuilder::InsertPoint b) { + assert(a.getBlock() == b.getBlock() && "expected same block"); + Block *block = a.getBlock(); + if (a.getPoint() == block->begin()) + return b; + if (b.getPoint() == block->begin()) + return a; + if (a.getPoint()->isBeforeInBlock(&*b.getPoint())) + return b; + return a; +} + +/// Helper function that chooses the insertion point among the two given ones +/// that is later. +// TODO: Extend DominanceInfo API to work with block iterators. +static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a, + OpBuilder::InsertPoint b) { + // Case 1: Same block. + if (a.getBlock() == b.getBlock()) + return chooseLaterInsertPointInBlock(a, b); + + // Case 2: Different block, but same region. + if (a.getBlock()->getParent() == b.getBlock()->getParent()) { + DominanceInfo domInfo; + if (domInfo.properlyDominates(a.getBlock(), b.getBlock())) + return b; + if (domInfo.properlyDominates(b.getBlock(), a.getBlock())) + return a; + // Neither of the two blocks dominante each other. + llvm_unreachable("unable to find valid insertion point"); + } + + // Case 3: b's region contains a: choose a. + if (Operation *aParent = b.getBlock()->getParent()->findAncestorOpInRegion( + *a.getPoint()->getParentOp())) + return a; + + // Case 4: a's region contains b: choose b. + if (Operation *bParent = a.getBlock()->getParent()->findAncestorOpInRegion( + *b.getPoint()->getParentOp())) + return b; + + // Neither of the two operations contain each other. + llvm_unreachable("unable to find valid insertion point"); +} + /// Helper function that computes an insertion point where the given value is /// defined and can be used without a dominance violation. static OpBuilder::InsertPoint computeInsertPoint(Value value) { @@ -63,11 +113,36 @@ static OpBuilder::InsertPoint computeInsertPoint(Value value) { return OpBuilder::InsertPoint(insertBlock, insertPt); } +/// Helper function that computes an insertion point where the given values are +/// defined and can be used without a dominance violation. +static OpBuilder::InsertPoint computeInsertPoint(ArrayRef vals) { + assert(!vals.empty() && "expected at least one value"); + OpBuilder::InsertPoint pt = computeInsertPoint(vals.front()); + for (Value v : vals.drop_front()) + pt = chooseLaterInsertPoint(pt, computeInsertPoint(v)); + return pt; +} + //===----------------------------------------------------------------------===// // ConversionValueMapping //===----------------------------------------------------------------------===// +/// A vector of SSA values, optimized for the most common case of a single +/// value. +using ValueVector = SmallVector; + namespace { + +/// Helper class to make it possible to use `ValueVector` as a key in DenseMap. +struct ValueVectorMapInfo { + static ValueVector getEmptyKey() { return ValueVector{}; } + static ValueVector getTombstoneKey() { return ValueVector{}; } + static ::llvm::hash_code getHashValue(ValueVector val) { + return ::llvm::hash_combine_range(val.begin(), val.end()); + } + static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; } +}; + /// This class wraps a IRMapping to provide recursive lookup /// functionality, i.e. we will traverse if the mapped value also has a mapping. struct ConversionValueMapping { @@ -75,68 +150,103 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup the most recently mapped value with the desired type in the + /// Lookup the most recently mapped values with the desired types in the /// mapping. /// /// Special cases: - /// - If the desired type is "null", simply return the most recently mapped - /// value. - /// - If there is no mapping to the desired type, also return the most - /// recently mapped value. - /// - If there is no mapping for the given value at all, return the given - /// value. - Value lookupOrDefault(Value from, Type desiredType = nullptr) const; - - /// Lookup a mapped value within the map, or return null if a mapping does not - /// exist. If a mapping exists, this follows the same behavior of - /// `lookupOrDefault`. - Value lookupOrNull(Value from, Type desiredType = nullptr) const; + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given + /// values. + ValueVector lookupOrDefault(ValueVector from, + TypeRange desiredTypes = {}) const; + + /// Lookup the given values within the map, or return an empty vector if the + /// values are not mapped. If they are mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(const ValueVector &from, + TypeRange desiredTypes = {}) const; /// Map a value to the one provided. - void map(Value oldVal, Value newVal) { + void map(const ValueVector &oldVal, const ValueVector &newVal) { LLVM_DEBUG({ - for (Value it = newVal; it; it = mapping.lookupOrNull(it)) - assert(it != oldVal && "inserting cyclic mapping"); + ValueVector next = newVal; + while (true) { + assert(next != oldVal && "inserting cyclic mapping"); + auto it = mapping.find(next); + if (it == mapping.end()) + break; + next = it->second; + } }); - mapping.map(oldVal, newVal); - mappedTo.insert(newVal); + mapping[oldVal] = newVal; + for (Value v : newVal) + mappedTo.insert(v); } - /// Drop the last mapping for the given value. - void erase(Value value) { mapping.erase(value); } + /// Drop the last mapping for the given values. + void erase(ValueVector value) { mapping.erase(value); } private: /// Current value mappings. - IRMapping mapping; + DenseMap mapping; /// All SSA values that are mapped to. May contain false positives. DenseSet mappedTo; }; } // namespace -Value ConversionValueMapping::lookupOrDefault(Value from, - Type desiredType) const { - // Try to find the deepest value that has the desired type. If there is no - // such value, simply return the deepest value. - Value desiredValue; +ValueVector +ConversionValueMapping::lookupOrDefault(ValueVector from, + TypeRange desiredTypes) const { + // Try to find the deepest values that have the desired types. If there is no + // such mapping, simply return the deepest values. + ValueVector desiredValue; do { - if (!desiredType || from.getType() == desiredType) + // Store the current value if the types match. + if (desiredTypes.empty() || TypeRange(from) == desiredTypes) desiredValue = from; - Value mappedValue = mapping.lookupOrNull(from); - if (!mappedValue) + // If possible, Replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : from) { + auto it = mapping.find({v}); + if (it != mapping.end()) { + llvm::append_range(next, it->second); + } else { + next.push_back(v); + } + } + if (next != from) { + // If at least one value was replaced, continue the lookup from there. + from = next; + continue; + } + + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + auto it = mapping.find(from); + if (it == mapping.end()) { + // No mapping found: The lookup stops here. break; - from = mappedValue; + } + from = it->second; } while (true); - // If the desired value was found use it, otherwise default to the leaf value. - return desiredValue ? desiredValue : from; + // If the desired values were found use them, otherwise default to the leaf + // values. + return !desiredValue.empty() ? desiredValue : from; } -Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { - Value result = lookupOrDefault(from, desiredType); - if (result == from || (desiredType && result.getType() != desiredType)) - return nullptr; +ValueVector ConversionValueMapping::lookupOrNull(const ValueVector &from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + TypeRange resultTypes(result); + if (result == from || (!desiredTypes.empty() && resultTypes != desiredTypes)) + return {}; return result; } @@ -651,10 +761,6 @@ class CreateOperationRewrite : public OperationRewrite { /// The type of materialization. enum MaterializationKind { - /// This materialization materializes a conversion for an illegal block - /// argument type, to the original one. - Argument, - /// This materialization materializes a conversion from an illegal type to a /// legal one. Target, @@ -673,7 +779,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { UnrealizedConversionCastOp op, const TypeConverter *converter, MaterializationKind kind, Type originalType, - Value mappedValue); + ValueVector mappedValues); static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::UnresolvedMaterialization; @@ -708,9 +814,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite { /// materializations. Type originalType; - /// The value in the conversion value mapping that is being replaced by the + /// The values in the conversion value mapping that are being replaced by the /// results of this unresolved materialization. - Value mappedValue; + ValueVector mappedValues; }; } // namespace @@ -779,7 +885,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { LogicalResult remapValues(StringRef valueDiagTag, std::optional inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector> &remapped); + SmallVector &remapped); /// Return "true" if the given operation is ignored, and does not need to be /// converted. @@ -820,39 +926,14 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// If a cast op was built, it can optionally be returned with the `castOp` /// output argument. /// - /// If `valueToMap` is set to a non-null Value, then that value is mapped to + /// If `valuesToMap` is set to a non-null Value, then that value is mapped to /// the results of the unresolved materialization in the conversion value /// mapping. ValueRange buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, TypeRange outputTypes, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp = nullptr); - Value buildUnresolvedMaterialization( - MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, Type outputType, Type originalType, - const TypeConverter *converter, - UnrealizedConversionCastOp *castOp = nullptr) { - return buildUnresolvedMaterialization(kind, ip, loc, valueToMap, inputs, - TypeRange(outputType), originalType, - converter, castOp) - .front(); - } - - /// Build an N:1 materialization for the given original value that was - /// replaced with the given replacement values. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. The conversion mapping can store only 1:1 replacements - /// and the conversion patterns only support single Value replacements in the - /// adaptor, so N values must be converted back to a single value. This - /// function will be deleted when full 1:N support has been added. - /// - /// This function inserts an argument materialization back to the original - /// type. - void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc, - ValueRange replacements, Value originalValue, - const TypeConverter *converter); /// Find a replacement value for the given SSA value in the conversion value /// mapping. The replacement value must have the same type as the given SSA @@ -862,16 +943,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { Value findOrBuildReplacementValue(Value value, const TypeConverter *converter); - /// Unpack an N:1 materialization and return the inputs of the - /// materialization. This function unpacks only those materializations that - /// were built with `insertNTo1Materialization`. - /// - /// This is a workaround around incomplete 1:N support in the dialect - /// conversion driver. It allows us to write 1:N conversion patterns while - /// 1:N support is still missing in the conversion value mapping. This - /// function will be deleted when full 1:N support has been added. - SmallVector unpackNTo1Materialization(Value value); - //===--------------------------------------------------------------------===// // Rewriter Notification Hooks //===--------------------------------------------------------------------===// @@ -1041,7 +1112,7 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } -void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); } +void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { auto *listener = @@ -1082,7 +1153,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { void ReplaceOperationRewrite::rollback() { for (auto result : op->getResults()) - rewriterImpl.mapping.erase(result); + rewriterImpl.mapping.erase({result}); } void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { @@ -1101,18 +1172,18 @@ void CreateOperationRewrite::rollback() { UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op, const TypeConverter *converter, MaterializationKind kind, Type originalType, - Value mappedValue) + ValueVector mappedValues) : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), converterAndKind(converter, kind), originalType(originalType), - mappedValue(mappedValue) { + mappedValues(mappedValues) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); rewriterImpl.unresolvedMaterializations[op] = this; } void UnresolvedMaterializationRewrite::rollback() { - if (mappedValue) - rewriterImpl.mapping.erase(mappedValue); + if (!mappedValues.empty()) + rewriterImpl.mapping.erase(mappedValues); rewriterImpl.unresolvedMaterializations.erase(getOperation()); rewriterImpl.nTo1TempMaterializations.erase(getOperation()); op->erase(); @@ -1160,7 +1231,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) { LogicalResult ConversionPatternRewriterImpl::remapValues( StringRef valueDiagTag, std::optional inputLoc, PatternRewriter &rewriter, ValueRange values, - SmallVector> &remapped) { + SmallVector &remapped) { remapped.reserve(llvm::size(values)); for (const auto &it : llvm::enumerate(values)) { @@ -1168,18 +1239,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( Type origType = operand.getType(); Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); - // Find the most recently mapped value. Unpack all temporary N:1 - // materializations. Such conversions are a workaround around missing - // 1:N support in the ConversionValueMapping. (The conversion patterns - // already support 1:N replacements.) - Value repl = mapping.lookupOrDefault(operand); - SmallVector unpacked = unpackNTo1Materialization(repl); - if (!currentTypeConverter) { // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply - // pass through the most recently mapped value. - remapped.push_back(std::move(unpacked)); + // pass through the most recently mapped values. + remapped.push_back(mapping.lookupOrDefault({operand})); continue; } @@ -1192,51 +1256,28 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( }); return failure(); } - // If a type is converted to 0 types, there is nothing to do. if (legalTypes.empty()) { remapped.push_back({}); continue; } - if (legalTypes.size() != 1) { - // TODO: This is a 1:N conversion. The conversion value mapping does not - // store such materializations yet. If the types of the most recently - // mapped values do not match, build a target materialization. - ValueRange unpackedRange(unpacked); - if (TypeRange(unpackedRange) == legalTypes) { - remapped.push_back(std::move(unpacked)); - continue; - } - - // Insert a target materialization if the current pattern expects - // different legalized types. - ValueRange targetMat = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(repl), operandLoc, - /*valueToMap=*/Value(), /*inputs=*/unpacked, - /*outputTypes=*/legalTypes, /*originalType=*/origType, - currentTypeConverter); - remapped.push_back(targetMat); + ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes); + if (!repl.empty() && TypeRange(repl) == legalTypes) { + // Mapped values have the correct type or there is an existing + // materialization. Or the opreand is not mapped at all and has the + // correct type. + remapped.push_back(repl); continue; } - // Handle 1->1 type conversions. - Type desiredType = legalTypes.front(); - // Try to find a mapped value with the desired type. (Or the operand itself - // if the value is not mapped at all.) - Value newOperand = mapping.lookupOrDefault(operand, desiredType); - if (newOperand.getType() != desiredType) { - // If the looked up value's type does not have the desired type, it means - // that the value was replaced with a value of different type and no - // target materialization was created yet. - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Target, computeInsertPoint(newOperand), - operandLoc, /*valueToMap=*/newOperand, /*inputs=*/unpacked, - /*outputType=*/desiredType, /*originalType=*/origType, - currentTypeConverter); - newOperand = castValue; - } - remapped.push_back({newOperand}); + // Create a materialization for the most recently mapped values. + repl = mapping.lookupOrDefault({operand}); + ValueRange castValues = buildUnresolvedMaterialization( + MaterializationKind::Target, computeInsertPoint(repl), operandLoc, + /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, + /*originalType=*/origType, currentTypeConverter); + remapped.push_back(castValues); } return success(); } @@ -1353,7 +1394,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( buildUnresolvedMaterialization( MaterializationKind::Source, OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*valueToMap=*/origArg, /*inputs=*/ValueRange(), + /*valuesToMap=*/{origArg}, /*inputs=*/ValueRange(), /*outputType=*/origArgType, /*originalType=*/Type(), converter); appendRewrite(block, origArg, converter); continue; @@ -1364,7 +1405,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - mapping.map(origArg, repl); + mapping.map({origArg}, {repl}); appendRewrite(block, origArg, converter); continue; } @@ -1375,13 +1416,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // used as a replacement. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - if (replArgs.size() == 1) { - mapping.map(origArg, replArgs.front()); - } else { - insertNTo1Materialization( - OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(), - /*replacements=*/replArgs, /*outputValue=*/origArg, converter); - } + ValueVector replArgVals = llvm::map_to_vector<1>( + replArgs, [](BlockArgument arg) -> Value { return arg; }); + mapping.map({origArg}, replArgVals); appendRewrite(block, origArg, converter); } @@ -1402,7 +1439,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /// of input operands. ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, - Value valueToMap, ValueRange inputs, TypeRange outputTypes, + ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, UnrealizedConversionCastOp *castOp) { assert((!originalType || kind == MaterializationKind::Target) && @@ -1410,10 +1447,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // Avoid materializing an unnecessary cast. if (TypeRange(inputs) == outputTypes) { - if (valueToMap) { - assert(inputs.size() == 1 && "1:N mapping is not supported"); - mapping.map(valueToMap, inputs.front()); - } + if (!valuesToMap.empty()) + mapping.map(valuesToMap, inputs); return inputs; } @@ -1423,37 +1458,21 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); auto convertOp = builder.create(loc, outputTypes, inputs); - if (valueToMap) { - assert(outputTypes.size() == 1 && "1:N mapping is not supported"); - mapping.map(valueToMap, convertOp.getResult(0)); - } + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); if (castOp) *castOp = convertOp; appendRewrite(convertOp, converter, kind, - originalType, valueToMap); + originalType, valuesToMap); return convertOp.getResults(); } -void ConversionPatternRewriterImpl::insertNTo1Materialization( - OpBuilder::InsertPoint ip, Location loc, ValueRange replacements, - Value originalValue, const TypeConverter *converter) { - // Insert argument materialization back to the original type. - Type originalType = originalValue.getType(); - UnrealizedConversionCastOp argCastOp; - buildUnresolvedMaterialization( - MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue, - /*inputs=*/replacements, originalType, - /*originalType=*/Type(), converter, &argCastOp); - if (argCastOp) - nTo1TempMaterializations.insert(argCastOp); -} - Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { // Find a replacement value with the same type. - Value repl = mapping.lookupOrNull(value, value.getType()); - if (repl) - return repl; + ValueVector repl = mapping.lookupOrNull({value}, value.getType()); + if (!repl.empty()) + return repl.front(); // Check if the value is dead. No replacement value is needed in that case. // This is an approximate check that may have false negatives but does not @@ -1467,8 +1486,8 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // No replacement value was found. Get the latest replacement value // (regardless of the type) and build a source materialization to the // original type. - repl = mapping.lookupOrNull(value); - if (!repl) { + repl = mapping.lookupOrNull({value}); + if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, // a source materialization producing a replacement value "out of thin air" @@ -1478,34 +1497,12 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( } Value castValue = buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), - /*valueToMap=*/value, /*inputs=*/repl, /*outputType=*/value.getType(), - /*originalType=*/Type(), converter); - mapping.map(value, castValue); + /*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(), + /*originalType=*/Type(), converter)[0]; + mapping.map({value}, {castValue}); return castValue; } -SmallVector -ConversionPatternRewriterImpl::unpackNTo1Materialization(Value value) { - // Unpack unrealized_conversion_cast ops that were inserted as a N:1 - // workaround. - auto castOp = value.getDefiningOp(); - if (!castOp) - return {value}; - if (!nTo1TempMaterializations.contains(castOp)) - return {value}; - assert(castOp->getNumResults() == 1 && "expected single result"); - - SmallVector result; - for (Value v : castOp.getOperands()) { - // Keep unpacking if possible. This is needed because during block - // signature conversions and 1:N op replacements, the driver may have - // inserted two materializations back-to-back: first an argument - // materialization, then a target materialization. - llvm::append_range(result, unpackNTo1Materialization(v)); - } - return result; -} - //===----------------------------------------------------------------------===// // Rewriter Notification Hooks @@ -1554,7 +1551,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Materialize a replacement value "out of thin air". buildUnresolvedMaterialization( MaterializationKind::Source, computeInsertPoint(result), - result.getLoc(), /*valueToMap=*/result, /*inputs=*/ValueRange(), + result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputType=*/result.getType(), /*originalType=*/Type(), currentTypeConverter); continue; @@ -1572,16 +1569,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Remap result to replacement value. if (repl.empty()) continue; - - if (repl.size() == 1) { - // Single replacement value: replace directly. - mapping.map(result, repl.front()); - } else { - // Multiple replacement values: insert N:1 materialization. - insertNTo1Materialization(computeInsertPoint(result), result.getLoc(), - /*replacements=*/repl, /*outputValue=*/result, - currentTypeConverter); - } + mapping.map({result}, repl); } appendRewrite(op, currentTypeConverter); @@ -1660,8 +1648,13 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); SmallVector newVals; - for (size_t i = 0; i < newValues.size(); ++i) - newVals.push_back(newValues.slice(i, 1)); + for (size_t i = 0; i < newValues.size(); ++i) { + if (newValues[i]) { + newVals.push_back(newValues.slice(i, 1)); + } else { + newVals.push_back(ValueRange()); + } + } impl->notifyOpReplaced(op, newVals); } @@ -1729,11 +1722,11 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, }); impl->appendRewrite(from.getOwner(), from, impl->currentTypeConverter); - impl->mapping.map(impl->mapping.lookupOrDefault(from), to); + impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to}); } Value ConversionPatternRewriter::getRemappedValue(Value key) { - SmallVector> remappedValues; + SmallVector remappedValues; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, remappedValues))) return nullptr; @@ -1746,7 +1739,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, SmallVectorImpl &results) { if (keys.empty()) return success(); - SmallVector> remapped; + SmallVector remapped; if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, remapped))) return failure(); @@ -1872,7 +1865,7 @@ ConversionPattern::matchAndRewrite(Operation *op, getTypeConverter()); // Remap the operands of the operation. - SmallVector> remapped; + SmallVector remapped; if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, op->getOperands(), remapped))) { return failure(); @@ -2625,19 +2618,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, rewriter.setInsertionPoint(op); SmallVector newMaterialization; switch (rewrite->getMaterializationKind()) { - case MaterializationKind::Argument: { - // Try to materialize an argument conversion. - assert(op->getNumResults() == 1 && "expected single result"); - Value argMat = converter->materializeArgumentConversion( - rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands); - if (argMat) { - newMaterialization.push_back(argMat); - break; - } - } - // If an argument materialization failed, fallback to trying a target - // materialization. - [[fallthrough]]; case MaterializationKind::Target: newMaterialization = converter->materializeTargetConversion( rewriter, op->getLoc(), op.getResultTypes(), inputOperands, diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 297eb5acef21b..4cd196c5b44b3 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -64,9 +64,6 @@ func.func @remap_call_1_to_1(%arg0: i64) { // Contents of the old block are moved to the new block. // CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown -// The new block arguments are used in "test.return". -// CHECK-NEXT: notifyOperationModified: test.return - // The old block is erased. // CHECK-NEXT: notifyBlockErased @@ -390,8 +387,8 @@ func.func @caller() { // CHECK: %[[call:.*]]:2 = call @callee() : () -> (f16, f16) %0:2 = func.call @callee() : () -> (f32, i24) - // CHECK: %[[cast1:.*]] = "test.cast"() : () -> i24 - // CHECK: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32 + // CHECK-DAG: %[[cast1:.*]] = "test.cast"() : () -> i24 + // CHECK-DAG: %[[cast0:.*]] = "test.cast"(%[[call]]#0, %[[call]]#1) : (f16, f16) -> f32 // CHECK: "test.some_user"(%[[cast0]], %[[cast1]]) : (f32, i24) -> () // expected-remark @below{{'test.some_user' is not legalizable}} "test.some_user"(%0#0, %0#1) : (f32, i24) -> () diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index 09c5b4b2a0ad5..d0b62e71ab0cf 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -139,7 +139,7 @@ struct TestDecomposeCallGraphTypes tupleType.getFlattenedTypes(types); return success(); }); - typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); typeConverter.addTargetMaterialization(buildDecomposeTuple); populateFunctionOpInterfaceTypeConversionPattern( diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 826c222990be4..eae9b887e9d49 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1284,7 +1284,6 @@ struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; TestTypeConverter() { addConversion(convertType); - addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); } diff --git a/mlir/test/lib/Transforms/TestDialectConversion.cpp b/mlir/test/lib/Transforms/TestDialectConversion.cpp index 2cc1fb5d39d78..a03bf0a1023d5 100644 --- a/mlir/test/lib/Transforms/TestDialectConversion.cpp +++ b/mlir/test/lib/Transforms/TestDialectConversion.cpp @@ -28,7 +28,6 @@ namespace { struct PDLLTypeConverter : public TypeConverter { PDLLTypeConverter() { addConversion(convertType); - addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); } From 6db4cdf705222cc00cb72f3ce5bdefbfc7507fc8 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 23 Dec 2024 14:03:02 +0100 Subject: [PATCH 2/8] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Markus Böck --- .../Transforms/Utils/DialectConversion.cpp | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 96cbe07f0f12f..8d6291f0f4f0d 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -187,7 +187,7 @@ struct ConversionValueMapping { } /// Drop the last mapping for the given values. - void erase(ValueVector value) { mapping.erase(value); } + void erase(const ValueVector &value) { mapping.erase(value); } private: /// Current value mappings. @@ -221,7 +221,7 @@ ConversionValueMapping::lookupOrDefault(ValueVector from, } if (next != from) { // If at least one value was replaced, continue the lookup from there. - from = next; + from = std::move(next); continue; } @@ -1175,7 +1175,7 @@ UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite( ValueVector mappedValues) : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op), converterAndKind(converter, kind), originalType(originalType), - mappedValues(mappedValues) { + mappedValues(std::move(mappedValues)) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); rewriterImpl.unresolvedMaterializations[op] = this; @@ -1265,9 +1265,9 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes); if (!repl.empty() && TypeRange(repl) == legalTypes) { // Mapped values have the correct type or there is an existing - // materialization. Or the opreand is not mapped at all and has the + // materialization. Or the operand is not mapped at all and has the // correct type. - remapped.push_back(repl); + remapped.push_back(std::move(repl)); continue; } @@ -1416,8 +1416,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // used as a replacement. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - ValueVector replArgVals = llvm::map_to_vector<1>( - replArgs, [](BlockArgument arg) -> Value { return arg; }); + ValueVector replArgVals = llvm::to_vector_of(replArgs); mapping.map({origArg}, replArgVals); appendRewrite(block, origArg, converter); } @@ -1462,8 +1461,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( mapping.map(valuesToMap, convertOp.getResults()); if (castOp) *castOp = convertOp; - appendRewrite(convertOp, converter, kind, - originalType, valuesToMap); + appendRewrite( + convertOp, converter, kind, originalType, std::move(valuesToMap)); return convertOp.getResults(); } @@ -1495,10 +1494,13 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // `applySignatureConversion`.) return Value(); } - Value castValue = buildUnresolvedMaterialization( - MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(), - /*valuesToMap=*/{value}, /*inputs=*/repl, /*outputType=*/value.getType(), - /*originalType=*/Type(), converter)[0]; + Value castValue = + buildUnresolvedMaterialization(MaterializationKind::Source, + computeInsertPoint(repl), value.getLoc(), + /*valuesToMap=*/{value}, /*inputs=*/repl, + /*outputType=*/value.getType(), + /*originalType=*/Type(), converter) + .front(); mapping.map({value}, {castValue}); return castValue; } From e96dbaaddf546eea31bc15932337479c9d063c0f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sun, 29 Dec 2024 13:25:20 +0100 Subject: [PATCH 3/8] rebase fixes --- mlir/test/Transforms/test-legalizer.mlir | 9 ++------- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 8 -------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 4cd196c5b44b3..ae7d344b7167f 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -491,13 +491,8 @@ func.func @test_1_to_n_block_signature_conversion() { // CHECK-LABEL: func @test_multiple_1_to_n_replacement() // CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16) -// TODO: There should be a single cast (i.e., a single target materialization). -// This is currently not possible due to 1:N limitations of the conversion -// mapping. Instead, we have 3 argument materializations. -// CHECK: %[[cast1:.*]] = "test.cast"(%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16) -> f16 -// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1) : (f16, f16) -> f16 -// CHECK: %[[cast3:.*]] = "test.cast"(%[[cast2]], %[[cast1]]) : (f16, f16) -> f16 -// CHECK: "test.valid"(%[[cast3]]) : (f16) -> () +// CHECK: %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16 +// CHECK: "test.valid"(%[[cast]]) : (f16) -> () func.func @test_multiple_1_to_n_replacement() { %0 = "test.multiple_1_to_n_replacement"() : () -> (f16) "test.invalid"(%0) : (f16) -> () diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index eae9b887e9d49..5b7c36c9b97bf 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1264,14 +1264,6 @@ class TestMultiple1ToNReplacement : public ConversionPattern { // Replace test.multiple_1_to_n_replacement with test.step_1. Operation *repl1 = replaceWithDoubleResults(op, "test.step_1"); // Now replace test.step_1 with test.legal_op. - // TODO: Ideally, it should not be necessary to reset the insertion point - // here. Based on the API calls, it looks like test.step_1 is entirely - // erased. But that's not the case: an argument materialization will - // survive. And that argument materialization will be used by the users of - // `op`. If we don't reset the insertion point here, we get dominance - // errors. This will be fixed when we have 1:N support in the conversion - // value mapping. - rewriter.setInsertionPoint(repl1); replaceWithDoubleResults(repl1, "test.legal_op"); return success(); } From cf1c9a99e9fdf6a1d5c194b3442f9b0ccc3e08af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Markus=20B=C3=B6ck?= Date: Sat, 21 Dec 2024 19:01:19 +0100 Subject: [PATCH 4/8] use universal references for `map` --- .../Transforms/Utils/DialectConversion.cpp | 36 ++++++++++++++----- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 8d6291f0f4f0d..2a5c11c3d32ec 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -169,10 +169,15 @@ struct ConversionValueMapping { ValueVector lookupOrNull(const ValueVector &from, TypeRange desiredTypes = {}) const; + template + struct IsValueVector : std::is_same, ValueVector> {}; + /// Map a value to the one provided. - void map(const ValueVector &oldVal, const ValueVector &newVal) { + template + std::enable_if_t{} && IsValueVector{}> + map(OldVal &&oldVal, NewVal &&newVal) { LLVM_DEBUG({ - ValueVector next = newVal; + ValueVector next(newVal); while (true) { assert(next != oldVal && "inserting cyclic mapping"); auto it = mapping.find(next); @@ -181,9 +186,22 @@ struct ConversionValueMapping { next = it->second; } }); - mapping[oldVal] = newVal; for (Value v : newVal) mappedTo.insert(v); + + mapping[std::forward(oldVal)] = std::forward(newVal); + } + + template + std::enable_if_t{} || !IsValueVector{}> + map(OldVal &&oldVal, NewVal &&newVal) { + if constexpr (IsValueVector{}) { + map(std::forward(oldVal), ValueVector{newVal}); + } else if constexpr (IsValueVector{}) { + map(ValueVector{oldVal}, std::forward(newVal)); + } else { + map(ValueVector{oldVal}, ValueVector{newVal}); + } } /// Drop the last mapping for the given values. @@ -1405,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - mapping.map({origArg}, {repl}); + mapping.map(origArg, repl); appendRewrite(block, origArg, converter); continue; } @@ -1417,7 +1435,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); ValueVector replArgVals = llvm::to_vector_of(replArgs); - mapping.map({origArg}, replArgVals); + mapping.map(origArg, std::move(replArgVals)); appendRewrite(block, origArg, converter); } @@ -1447,7 +1465,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // Avoid materializing an unnecessary cast. if (TypeRange(inputs) == outputTypes) { if (!valuesToMap.empty()) - mapping.map(valuesToMap, inputs); + mapping.map(std::move(valuesToMap), inputs); return inputs; } @@ -1501,7 +1519,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( /*outputType=*/value.getType(), /*originalType=*/Type(), converter) .front(); - mapping.map({value}, {castValue}); + mapping.map(value, castValue); return castValue; } @@ -1571,7 +1589,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced( // Remap result to replacement value. if (repl.empty()) continue; - mapping.map({result}, repl); + mapping.map(result, repl); } appendRewrite(op, currentTypeConverter); @@ -1724,7 +1742,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, }); impl->appendRewrite(from.getOwner(), from, impl->currentTypeConverter); - impl->mapping.map(impl->mapping.lookupOrDefault({from}), {to}); + impl->mapping.map(impl->mapping.lookupOrDefault({from}), to); } Value ConversionPatternRewriter::getRemappedValue(Value key) { From 3ad5421dbc100dcdb19d763e84eb0ae367ef8fc0 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 30 Dec 2024 19:35:01 +0100 Subject: [PATCH 5/8] address comments --- .../Transforms/Utils/DialectConversion.cpp | 63 ++++++++++--------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 2a5c11c3d32ec..3571e017158be 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -137,10 +137,12 @@ namespace { struct ValueVectorMapInfo { static ValueVector getEmptyKey() { return ValueVector{}; } static ValueVector getTombstoneKey() { return ValueVector{}; } - static ::llvm::hash_code getHashValue(ValueVector val) { + static ::llvm::hash_code getHashValue(const ValueVector &val) { return ::llvm::hash_combine_range(val.begin(), val.end()); } - static bool isEqual(ValueVector LHS, ValueVector RHS) { return LHS == RHS; } + static bool isEqual(const ValueVector &LHS, const ValueVector &RHS) { + return LHS == RHS; + } }; /// This class wraps a IRMapping to provide recursive lookup @@ -159,20 +161,18 @@ struct ConversionValueMapping { /// - If there is no mapping to the desired types, also return the most /// recently mapped values. /// - If there is no mapping for the given values at all, return the given - /// values. - ValueVector lookupOrDefault(ValueVector from, - TypeRange desiredTypes = {}) const; + /// value. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; - /// Lookup the given values within the map, or return an empty vector if the - /// values are not mapped. If they are mapped, this follows the same behavior + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior /// as `lookupOrDefault`. - ValueVector lookupOrNull(const ValueVector &from, - TypeRange desiredTypes = {}) const; + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; template struct IsValueVector : std::is_same, ValueVector> {}; - /// Map a value to the one provided. + /// Map a value vector to the one provided. template std::enable_if_t{} && IsValueVector{}> map(OldVal &&oldVal, NewVal &&newVal) { @@ -192,6 +192,7 @@ struct ConversionValueMapping { mapping[std::forward(oldVal)] = std::forward(newVal); } + /// Map a value vector or single value to the one provided. template std::enable_if_t{} || !IsValueVector{}> map(OldVal &&oldVal, NewVal &&newVal) { @@ -217,19 +218,20 @@ struct ConversionValueMapping { } // namespace ValueVector -ConversionValueMapping::lookupOrDefault(ValueVector from, +ConversionValueMapping::lookupOrDefault(Value from, TypeRange desiredTypes) const { // Try to find the deepest values that have the desired types. If there is no // such mapping, simply return the deepest values. ValueVector desiredValue; + ValueVector current{from}; do { // Store the current value if the types match. - if (desiredTypes.empty() || TypeRange(from) == desiredTypes) - desiredValue = from; + if (TypeRange(current) == desiredTypes) + desiredValue = current; // If possible, Replace each value with (one or multiple) mapped values. ValueVector next; - for (Value v : from) { + for (Value v : current) { auto it = mapping.find({v}); if (it != mapping.end()) { llvm::append_range(next, it->second); @@ -237,33 +239,35 @@ ConversionValueMapping::lookupOrDefault(ValueVector from, next.push_back(v); } } - if (next != from) { + if (next != current) { // If at least one value was replaced, continue the lookup from there. - from = std::move(next); + current = std::move(next); continue; } // Otherwise: Check if there is a mapping for the entire vector. Such // mappings are materializations. (N:M mapping are not supported for value // replacements.) - auto it = mapping.find(from); + auto it = mapping.find(current); if (it == mapping.end()) { // No mapping found: The lookup stops here. break; } - from = it->second; + current = it->second; } while (true); // If the desired values were found use them, otherwise default to the leaf // values. - return !desiredValue.empty() ? desiredValue : from; + // Note: If `desiredTypes` is empty, this function always returns `current`. + return !desiredValue.empty() ? desiredValue : current; } -ValueVector ConversionValueMapping::lookupOrNull(const ValueVector &from, +ValueVector ConversionValueMapping::lookupOrNull(Value from, TypeRange desiredTypes) const { ValueVector result = lookupOrDefault(from, desiredTypes); TypeRange resultTypes(result); - if (result == from || (!desiredTypes.empty() && resultTypes != desiredTypes)) + if (result == ValueVector{from} || + (!desiredTypes.empty() && resultTypes != desiredTypes)) return {}; return result; } @@ -1261,7 +1265,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped values. - remapped.push_back(mapping.lookupOrDefault({operand})); + remapped.push_back(mapping.lookupOrDefault(operand)); continue; } @@ -1280,7 +1284,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( continue; } - ValueVector repl = mapping.lookupOrDefault({operand}, legalTypes); + ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); if (!repl.empty() && TypeRange(repl) == legalTypes) { // Mapped values have the correct type or there is an existing // materialization. Or the operand is not mapped at all and has the @@ -1290,7 +1294,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = mapping.lookupOrDefault({operand}); + repl = mapping.lookupOrDefault(operand); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1428,10 +1432,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( continue; } - // This is a 1->1+ mapping. 1->N mappings are not fully supported in the - // dialect conversion. Therefore, we need an argument materialization to - // turn the replacement block arguments into a single SSA value that can be - // used as a replacement. + // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); ValueVector replArgVals = llvm::to_vector_of(replArgs); @@ -1487,7 +1488,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { // Find a replacement value with the same type. - ValueVector repl = mapping.lookupOrNull({value}, value.getType()); + ValueVector repl = mapping.lookupOrNull(value, value.getType()); if (!repl.empty()) return repl.front(); @@ -1503,7 +1504,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // No replacement value was found. Get the latest replacement value // (regardless of the type) and build a source materialization to the // original type. - repl = mapping.lookupOrNull({value}); + repl = mapping.lookupOrNull(value); if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, @@ -1742,7 +1743,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, }); impl->appendRewrite(from.getOwner(), from, impl->currentTypeConverter); - impl->mapping.map(impl->mapping.lookupOrDefault({from}), to); + impl->mapping.map(impl->mapping.lookupOrDefault(from), to); } Value ConversionPatternRewriter::getRemappedValue(Value key) { From c411bd0947be37a69d4d2a41dac70eebb5c1c749 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 3 Jan 2025 13:33:38 +0100 Subject: [PATCH 6/8] fix windows build --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 3571e017158be..470a5b843135e 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -174,7 +174,7 @@ struct ConversionValueMapping { /// Map a value vector to the one provided. template - std::enable_if_t{} && IsValueVector{}> + std::enable_if_t::value && IsValueVector::value> map(OldVal &&oldVal, NewVal &&newVal) { LLVM_DEBUG({ ValueVector next(newVal); @@ -194,7 +194,8 @@ struct ConversionValueMapping { /// Map a value vector or single value to the one provided. template - std::enable_if_t{} || !IsValueVector{}> + std::enable_if_t::value || + !IsValueVector::value> map(OldVal &&oldVal, NewVal &&newVal) { if constexpr (IsValueVector{}) { map(std::forward(oldVal), ValueVector{newVal}); From 209d9226dec8a41a7194bb663fbad181e47666ca Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 3 Jan 2025 13:36:51 +0100 Subject: [PATCH 7/8] address comments --- mlir/lib/Transforms/Utils/DialectConversion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 470a5b843135e..23aa5473be63d 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -260,7 +260,7 @@ ConversionValueMapping::lookupOrDefault(Value from, // If the desired values were found use them, otherwise default to the leaf // values. // Note: If `desiredTypes` is empty, this function always returns `current`. - return !desiredValue.empty() ? desiredValue : current; + return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); } ValueVector ConversionValueMapping::lookupOrNull(Value from, From f43e564b7bb625d0fb67aa8e16d9a3259555b270 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 3 Jan 2025 15:51:55 +0100 Subject: [PATCH 8/8] address comments --- .../Transforms/Utils/DialectConversion.cpp | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 23aa5473be63d..0c5520988eff3 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -74,8 +74,8 @@ chooseLaterInsertPointInBlock(OpBuilder::InsertPoint a, // TODO: Extend DominanceInfo API to work with block iterators. static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a, OpBuilder::InsertPoint b) { - // Case 1: Same block. - if (a.getBlock() == b.getBlock()) + // Case 1: Fast path: Same block. This is the most common case. + if (LLVM_LIKELY(a.getBlock() == b.getBlock())) return chooseLaterInsertPointInBlock(a, b); // Case 2: Different block, but same region. @@ -90,12 +90,12 @@ static OpBuilder::InsertPoint chooseLaterInsertPoint(OpBuilder::InsertPoint a, } // Case 3: b's region contains a: choose a. - if (Operation *aParent = b.getBlock()->getParent()->findAncestorOpInRegion( + if (b.getBlock()->getParent()->findAncestorOpInRegion( *a.getPoint()->getParentOp())) return a; // Case 4: a's region contains b: choose b. - if (Operation *bParent = a.getBlock()->getParent()->findAncestorOpInRegion( + if (a.getBlock()->getParent()->findAncestorOpInRegion( *b.getPoint()->getParentOp())) return b; @@ -249,6 +249,11 @@ ConversionValueMapping::lookupOrDefault(Value from, // Otherwise: Check if there is a mapping for the entire vector. Such // mappings are materializations. (N:M mapping are not supported for value // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. auto it = mapping.find(current); if (it == mapping.end()) { // No mapping found: The lookup stops here. @@ -1514,6 +1519,18 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // `applySignatureConversion`.) return Value(); } + + // Note: `computeInsertPoint` computes the "earliest" insertion point at + // which all values in `repl` are defined. It is important to emit the + // materialization at that location because the same materialization may be + // reused in a different context. (That's because materializations are cached + // in the conversion value mapping.) The insertion point of the + // materialization must be valid for all future users that may be created + // later in the conversion process. + // + // Note: Instead of creating new IR, `buildUnresolvedMaterialization` may + // return an already existing, cached materialization from the conversion + // value mapping. Value castValue = buildUnresolvedMaterialization(MaterializationKind::Source, computeInsertPoint(repl), value.getLoc(),