Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 29 additions & 18 deletions mlir/docs/DialectConversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
type. Type conversions are specified via the `addConversion` method described
below.

There are two kind of conversion functions: context-aware and context-unaware
conversions. A context-unaware conversion function converts a `Type` into a
`Type`. A context-aware conversion function converts a `Value` into a type. The
latter allows users to customize type conversion rules based on the IR.

Note: When there is at least one context-aware type conversion function, the
result of type conversions can no longer be cached, which can increase
compilation time. Use this feature with caution!

A `materialization` describes how a list of values should be converted to a
list of values with specific types. An important distinction from a
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
Expand Down Expand Up @@ -332,29 +341,31 @@ Several of the available hooks are detailed below:
```c++
class TypeConverter {
public:
/// Register a conversion function. A conversion function defines how a given
/// source type should be converted. A conversion function must be convertible
/// to any of the following forms(where `T` is a class derived from `Type`:
/// * Optional<Type>(T)
/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms (where `T` is `Value` or a class derived
/// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
/// converter is allowed to try another conversion function to perform
/// the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
/// the converter is allowed to try another conversion function to
/// perform the conversion.
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
/// - This form represents a 1-N type conversion. It should return
/// `failure` or `std::nullopt` to signify a failed conversion. If the new
/// set of types is empty, the type is removed and any usages of the
/// `failure` or `std::nullopt` to signify a failed conversion. If the
/// new set of types is empty, the type is removed and any usages of the
/// existing value are expected to be removed during conversion. If
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
/// - This form represents a 1-N type conversion supporting recursive
/// types. The first two arguments and the return value are the same as
/// for the regular 1-N form. The third argument is contains is the
/// "call stack" of the recursive conversion: it contains the list of
/// types currently being converted, with the current type being the
/// last one. If it is present more than once in the list, the
/// conversion concerns a recursive type.
///
/// Conversion functions that accept `Value` as the first argument are
/// context-aware. I.e., they can take into account IR when converting the
/// type of the given value. Context-unaware conversion functions accept
/// `Type` or a derived class as the first argument.
///
/// Note: Context-unaware conversions are cached, but context-aware
/// conversions are not.
///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT,
Expand Down
110 changes: 92 additions & 18 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ class TypeConverter {
};

/// Register a conversion function. A conversion function must be convertible
/// to any of the following forms (where `T` is a class derived from `Type`):
/// to any of the following forms (where `T` is `Value` or a class derived
/// from `Type`, including `Type` itself):
///
/// * std::optional<Type>(T)
/// - This form represents a 1-1 type conversion. It should return nullptr
Expand All @@ -154,6 +155,14 @@ class TypeConverter {
/// `std::nullopt` is returned, the converter is allowed to try another
/// conversion function to perform the conversion.
///
/// Conversion functions that accept `Value` as the first argument are
/// context-aware. I.e., they can take into account IR when converting the
/// type of the given value. Context-unaware conversion functions accept
/// `Type` or a derived class as the first argument.
///
/// Note: Context-unaware conversions are cached, but context-aware
/// conversions are not.
///
/// Note: When attempting to convert a type, e.g. via 'convertType', the
/// mostly recently added conversions will be invoked first.
template <typename FnT, typename T = typename llvm::function_traits<
Expand Down Expand Up @@ -242,15 +251,28 @@ class TypeConverter {
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
}

/// Convert the given type. This function should return failure if no valid
/// Convert the given type. This function returns failure if no valid
/// conversion exists, success otherwise. If the new set of types is empty,
/// the type is removed and any usages of the existing value are expected to
/// be removed during conversion.
///
/// Note: This overload invokes only context-unaware type conversion
/// functions. Users should call the other overload if possible.
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;

/// Convert the type of the given value. This function returns failure if no
/// valid conversion exists, success otherwise. If the new set of types is
/// empty, the type is removed and any usages of the existing value are
/// expected to be removed during conversion.
///
/// Note: This overload invokes both context-aware and context-unaware type
/// conversion functions.
LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;

/// This hook simplifies defining 1-1 type conversions. This function returns
/// the type to convert to on success, and a null type on failure.
Type convertType(Type t) const;
Type convertType(Value v) const;

/// Attempts a 1-1 type conversion, expecting the result type to be
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
Expand All @@ -259,25 +281,36 @@ class TypeConverter {
TargetType convertType(Type t) const {
return dyn_cast_or_null<TargetType>(convertType(t));
}
template <typename TargetType>
TargetType convertType(Value v) const {
return dyn_cast_or_null<TargetType>(convertType(v));
}

/// Convert the given set of types, filling 'results' as necessary. This
/// returns failure if the conversion of any of the types fails, success
/// Convert the given types, filling 'results' as necessary. This returns
/// "failure" if the conversion of any of the types fails, "success"
/// otherwise.
LogicalResult convertTypes(TypeRange types,
SmallVectorImpl<Type> &results) const;

/// Convert the types of the given values, filling 'results' as necessary.
/// This returns "failure" if the conversion of any of the types fails,
/// "success" otherwise.
LogicalResult convertTypes(ValueRange values,
SmallVectorImpl<Type> &results) const;

/// Return true if the given type is legal for this type converter, i.e. the
/// type converts to itself.
bool isLegal(Type type) const;
bool isLegal(Value value) const;

/// Return true if all of the given types are legal for this type converter.
template <typename RangeT>
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
!std::is_convertible<RangeT, Operation *>::value,
bool>
isLegal(RangeT &&range) const {
bool isLegal(TypeRange range) const {
return llvm::all_of(range, [this](Type type) { return isLegal(type); });
}
bool isLegal(ValueRange range) const {
return llvm::all_of(range, [this](Value value) { return isLegal(value); });
}

/// Return true if the given operation has legal operand and result types.
bool isLegal(Operation *op) const;

Expand All @@ -296,6 +329,11 @@ class TypeConverter {
LogicalResult convertSignatureArgs(TypeRange types,
SignatureConversion &result,
unsigned origInputOffset = 0) const;
LogicalResult convertSignatureArg(unsigned inputNo, Value value,
SignatureConversion &result) const;
LogicalResult convertSignatureArgs(ValueRange values,
SignatureConversion &result,
unsigned origInputOffset = 0) const;

/// This function converts the type signature of the given block, by invoking
/// 'convertSignatureArg' for each argument. This function should return a
Expand Down Expand Up @@ -329,7 +367,7 @@ class TypeConverter {
/// types is empty, the type is removed and any usages of the existing value
/// are expected to be removed during conversion.
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
Type, SmallVectorImpl<Type> &)>;
PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;

/// The signature of the callback used to materialize a source conversion.
///
Expand All @@ -349,13 +387,14 @@ class TypeConverter {

/// Generate a wrapper for the given callback. This allows for accepting
/// different callback forms, that all compose into a single version.
/// With callback of form: `std::optional<Type>(T)`
/// With callback of form: `std::optional<Type>(T)`, where `T` can be a
/// `Value` or a `Type` (or a class derived from `Type`).
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
wrapCallback(FnT &&callback) {
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
T type, SmallVectorImpl<Type> &results) {
if (std::optional<Type> resultOpt = callback(type)) {
T typeOrValue, SmallVectorImpl<Type> &results) {
if (std::optional<Type> resultOpt = callback(typeOrValue)) {
bool wasSuccess = static_cast<bool>(*resultOpt);
if (wasSuccess)
results.push_back(*resultOpt);
Expand All @@ -365,20 +404,49 @@ class TypeConverter {
});
}
/// With callback of form: `std::optional<LogicalResult>(
/// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
/// T, SmallVectorImpl<Type> &)`, where `T` is a type.
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
std::is_base_of_v<Type, T>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) const {
return [callback = std::forward<FnT>(callback)](
Type type,
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
T derivedType = dyn_cast<T>(type);
T derivedType;
if (Type t = dyn_cast<Type>(typeOrValue)) {
derivedType = dyn_cast<T>(t);
} else if (Value v = dyn_cast<Value>(typeOrValue)) {
derivedType = dyn_cast<T>(v.getType());
} else {
llvm_unreachable("unexpected variant");
}
if (!derivedType)
return std::nullopt;
return callback(derivedType, results);
};
}
/// With callback of form: `std::optional<LogicalResult>(
/// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
template <typename T, typename FnT>
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
std::is_same_v<T, Value>,
ConversionCallbackFn>
wrapCallback(FnT &&callback) {
hasContextAwareTypeConversions = true;
return [callback = std::forward<FnT>(callback)](
PointerUnion<Type, Value> typeOrValue,
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
if (Type t = dyn_cast<Type>(typeOrValue)) {
// Context-aware type conversion was called with a type.
return std::nullopt;
} else if (Value v = dyn_cast<Value>(typeOrValue)) {
return callback(v, results);
}
llvm_unreachable("unexpected variant");
return std::nullopt;
};
}

/// Register a type conversion.
void registerConversion(ConversionCallbackFn callback) {
Expand Down Expand Up @@ -505,6 +573,12 @@ class TypeConverter {
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
/// A mutex used for cache access
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
/// Whether the type converter has context-aware type conversions. I.e.,
/// conversion rules that depend on the SSA value instead of just the type.
/// Type conversion caching is deactivated when there are context-aware
/// conversions because the type converter may return different results for
/// the same input type.
bool hasContextAwareTypeConversions = false;
};

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 5 additions & 7 deletions mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
SmallVector<unsigned> offsets;
offsets.push_back(0);
// Do the type conversion and record the offsets.
for (Type type : op.getResultTypes()) {
if (failed(typeConverter->convertTypes(type, dstTypes)))
for (Value v : op.getResults()) {
if (failed(typeConverter->convertType(v, dstTypes)))
return rewriter.notifyMatchFailure(op, "could not convert result type");
offsets.push_back(dstTypes.size());
}
Expand Down Expand Up @@ -127,7 +127,6 @@ class ConvertForOpTypes
// Inline the type converted region from the original operation.
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());

return newOp;
}
};
Expand Down Expand Up @@ -226,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions(

void mlir::scf::populateSCFStructuralTypeConversionTarget(
const TypeConverter &typeConverter, ConversionTarget &target) {
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
});
target.addDynamicallyLegalOp<ForOp, IfOp>(
[&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
// We only have conversions for a subset of ops that use scf.yield
// terminators.
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
return true;
return typeConverter.isLegal(op.getOperandTypes());
return typeConverter.isLegal(op.getOperands());
});
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
Expand Down
Loading
Loading