Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6aa9f0d
Add support for transform.param values in `PadOp`s pad_to_multiple_of
srcarroll May 1, 2024
a1a7b17
fix for param pad_to_multiple_of
srcarroll May 1, 2024
5159eeb
format
srcarroll May 1, 2024
23ef22a
cleanup diagnostic messages
srcarroll May 1, 2024
7912431
Merge branch 'main' into pad-op-parameter
srcarroll May 1, 2024
2dade05
refactor paramhandle reification
srcarroll May 2, 2024
124957b
Add python test for new functionality
srcarroll May 2, 2024
f69f052
fix typo
srcarroll May 2, 2024
a9c511e
use tablegen assembly format
srcarroll May 2, 2024
3008e5d
address some comments
srcarroll May 2, 2024
6549268
make transform ops with param/handle inputs have consistent assembly
srcarroll May 2, 2024
ce17605
address review comments
srcarroll May 3, 2024
f65cc75
address review comments
srcarroll May 3, 2024
d45b1b0
modify PackedOrDyanmicList and change assembly for tile_using_forall
srcarroll May 3, 2024
e31bd9d
Merge branch 'main' into pad-op-parameter
srcarroll May 3, 2024
9b7b215
Merge branch 'pad-op-parameter' into consistent-transform-syntax
srcarroll May 3, 2024
bc5fd5a
fix formatting
srcarroll May 3, 2024
27f671d
Merge branch 'main' into pad-op-parameter
srcarroll May 4, 2024
eaf0e38
Update mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
srcarroll May 4, 2024
1c4ee55
Merge branch 'main' into pad-op-parameter
srcarroll May 4, 2024
66f9d4d
fix typo
srcarroll May 4, 2024
85fcf89
Merge branch 'pad-op-parameter' into consistent-transform-syntax
srcarroll May 4, 2024
79631a3
Merge branch 'main' into consistent-transform-syntax
srcarroll May 4, 2024
12b2194
change pack and pack_greedily
srcarroll May 4, 2024
df455ee
address review comments
srcarroll May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -783,10 +783,9 @@ def PackOp : Op<Transform_Dialect, "structured.pack", [
let assemblyFormat = [{
$target
`packed_sizes` `=` custom<DynamicIndexList>($packed_sizes,
$static_packed_sizes,
type($packed_sizes))
$static_packed_sizes)
attr-dict
`:` functional-type($target, results)
`:` functional-type(operands, results)
}];

let builders = [
Expand Down Expand Up @@ -890,14 +889,13 @@ def PackGreedilyOp : Op<Transform_Dialect, "structured.pack_greedily", [
$target
oilist(
`matmul_packed_sizes` `=` custom<DynamicIndexList>($matmul_packed_sizes,
$static_matmul_packed_sizes,
type($matmul_packed_sizes))
$static_matmul_packed_sizes)
(`matmul_padded_sizes_next_multiple_of` `=`
$matmul_padded_sizes_next_multiple_of^)?
`matmul_inner_dims_order` `=` $matmul_inner_dims_order
)
attr-dict
`:` functional-type($target, results)
`:` functional-type(operands, results)
}];
let hasVerifier = 1;

Expand Down Expand Up @@ -1899,7 +1897,17 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
$scalableSizes)>,
];

let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
$target
`tile_sizes` custom<DynamicIndexList>(
$dynamic_sizes,
$static_sizes,
$scalable_sizes)
(`interchange` `=` $interchange^)?
attr-dict
`:` functional-type(operands, results)
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
Expand Down Expand Up @@ -2017,17 +2025,13 @@ def TileUsingForallOp :
let assemblyFormat = [{
$target oilist(
`num_threads` custom<PackedOrDynamicIndexList>($packed_num_threads,
type($packed_num_threads),
$num_threads,
type($num_threads),
$static_num_threads) |
`tile_sizes` custom<PackedOrDynamicIndexList>($packed_tile_sizes,
type($packed_tile_sizes),
$tile_sizes,
type($tile_sizes),
$static_tile_sizes))
(`(` `mapping` `=` $mapping^ `)`)? attr-dict
`:` functional-type($target, results)
`:` functional-type(operands, results)
}];
let hasVerifier = 1;

Expand Down Expand Up @@ -2162,7 +2166,18 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",

let results = (outs);

let hasCustomAssemblyFormat = 1;
// We use oilist here to elide the optional `vector_sizes` when empty list
// is passed.
let assemblyFormat = [{
$target oilist(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i did remove for this as i believe the intention is that tile_sizes is never optional, nor does it make sense to allow empty list

`vector_sizes` custom<DynamicIndexList>(
$vector_sizes,
$static_vector_sizes,
$scalable_sizes))
attr-dict
`:` type($target)(`,`type($vector_sizes)^)?
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
Expand Down
16 changes: 15 additions & 1 deletion mlir/include/mlir/Dialect/Transform/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
Value packed, Type packedType,
OperandRange values, TypeRange valueTypes,
DenseI64ArrayAttr integers);
inline void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
Value packed, OperandRange values,
DenseI64ArrayAttr integers) {
printPackedOrDynamicIndexList(printer, op, packed, Type(), values,
TypeRange{}, integers);
}

/// Parser hook for custom directive in assemblyFormat.
///
Expand All @@ -47,7 +53,15 @@ void printPackedOrDynamicIndexList(OpAsmPrinter &printer, Operation *op,
ParseResult parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers);
SmallVectorImpl<Type> *valueTypes, DenseI64ArrayAttr &integers);
inline ParseResult parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers) {
Type packedType;
return parsePackedOrDynamicIndexList(parser, packed, packedType, values,
nullptr, integers);
}
} // namespace transform
} // namespace mlir

Expand Down
11 changes: 9 additions & 2 deletions mlir/include/mlir/Interfaces/ViewLikeInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,16 @@ class OpWithOffsetSizesAndStridesConstantArgumentFolder final
/// empty then assume that all indices are non-scalable.
void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
ArrayRef<bool> scalables = {},
ArrayRef<int64_t> integers, ArrayRef<bool> scalables,
TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square);
inline void printDynamicIndexList(
OpAsmPrinter &printer, Operation *op, OperandRange values,
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
return printDynamicIndexList(printer, op, values, integers, {}, valueTypes,
delimiter);
}

/// Parser hook for custom directive in assemblyFormat.
///
Expand Down
154 changes: 0 additions & 154 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2825,86 +2825,6 @@ SmallVector<OpFoldResult> transform::TileUsingForOp::getMixedSizes() {
return results;
}

// We want to parse `DenseI64ArrayAttr` using the short form without the
// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
ParseResult parseOptionalInterchange(OpAsmParser &parser,
OperationState &result) {
if (failed(parser.parseOptionalKeyword("interchange")))
return success();
if (failed(parser.parseEqual()))
return failure();
result.addAttribute(
transform::TileUsingForOp::getInterchangeAttrName(result.name),
DenseI64ArrayAttr::parse(parser, Type{}));
return success();
}

void printOptionalInterchange(OpAsmPrinter &p,
ArrayRef<int64_t> interchangeVals) {
if (!interchangeVals.empty()) {
p << " interchange = [";
llvm::interleaveComma(interchangeVals, p,
[&](int64_t integer) { p << integer; });
p << "]";
}
}

ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
DenseI64ArrayAttr staticSizes;
FunctionType functionalType;
llvm::SMLoc operandLoc;
DenseBoolArrayAttr scalableVals;

if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
parseOptionalInterchange(parser, result) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(functionalType))
return ParseResult::failure();

size_t numExpectedLoops =
staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
if (functionalType.getNumResults() != numExpectedLoops + 1) {
return parser.emitError(parser.getNameLoc())
<< "expected " << (numExpectedLoops + 1) << " result type(s)";
}
if (functionalType.getNumInputs() != dynamicSizes.size() + 1) {
return parser.emitError(operandLoc)
<< "expected " << dynamicSizes.size() + 1 << " operand type(s)";
}
if (parser.resolveOperand(target, functionalType.getInputs().front(),
result.operands) ||
parser.resolveOperands(dynamicSizes,
functionalType.getInputs().drop_front(),
operandLoc, result.operands)) {
return failure();
}

result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);

result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
result.addTypes(functionalType.getResults());
return success();
}

void TileUsingForOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget();
printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
/*valueTypes=*/{}, getScalableSizesAttr(),
OpAsmParser::Delimiter::Square);
printOptionalInterchange(p, getInterchange());
p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{getInterchangeAttrName(getOperation()->getName()),
getScalableSizesAttrName(getOperation()->getName()),
getStaticSizesAttrName(getOperation()->getName())});
p << " : ";
p.printFunctionalType(getOperands().getTypes(), getResults().getTypes());
}

void transform::TileUsingForOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
consumesHandle(getTarget(), effects);
Expand Down Expand Up @@ -3221,80 +3141,6 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
// VectorizeOp
//===----------------------------------------------------------------------===//

static const StringLiteral kVectorSizesKeyword = "vector_sizes";

ParseResult transform::VectorizeOp::parse(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::UnresolvedOperand target;
SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
DenseI64ArrayAttr staticSizes;
SmallVector<Type> operandTypes;
llvm::SMLoc operandLoc;
DenseBoolArrayAttr scalableVals;

if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc))
return ParseResult::failure();

if (succeeded(parser.parseOptionalKeyword(kVectorSizesKeyword))) {
if (failed(parseDynamicIndexList(parser, dynamicSizes, staticSizes,
scalableVals)))
return ParseResult::failure();
}

if (succeeded(parser.parseOptionalKeyword(
getVectorizeNdExtractAttrName(result.name))))
result.addAttribute(getVectorizeNdExtractAttrName(result.name),
parser.getBuilder().getUnitAttr());

if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(operandTypes))
return ParseResult::failure();

if (operandTypes.size() != dynamicSizes.size() + 1) {
return parser.emitError(operandLoc)
<< "expected " << dynamicSizes.size() + 1 << " operand type(s)";
}
if (parser.resolveOperand(target, operandTypes.front(), result.operands) ||
parser.resolveOperands(dynamicSizes, ArrayRef(operandTypes).drop_front(),
operandLoc, result.operands)) {
return failure();
}

if (scalableVals)
result.addAttribute(getScalableSizesAttrName(result.name), scalableVals);
if (staticSizes)
result.addAttribute(getStaticVectorSizesAttrName(result.name), staticSizes);

return success();
}

void transform::VectorizeOp::print(OpAsmPrinter &p) {
p << ' ' << getTarget() << ' ';
if (!getMixedVectorSizes().empty()) {
p << kVectorSizesKeyword << ' ';
printDynamicIndexList(p, getOperation(), getVectorSizes(),
getStaticVectorSizesAttr(),
/*valueTypes=*/{}, getScalableSizesAttr(),
OpAsmParser::Delimiter::Square);
}

if (getVectorizeNdExtract())
p << getVectorizeNdExtractAttrName() << ' ';

p.printOptionalAttrDict(
(*this)->getAttrs(),
/*elidedAttrs=*/{
getScalableSizesAttrName(getOperation()->getName()),
getStaticVectorSizesAttrName(getOperation()->getName())});
p << " : ";
p << getTarget().getType();
if (!getVectorSizes().empty()) {
p << ", ";
llvm::interleaveComma(getVectorSizes(), p,
[&](Value operand) { p << operand.getType(); });
}
}

DiagnosedSilenceableFailure transform::VectorizeOp::apply(
transform::TransformRewriter &rewriter,
mlir::transform::TransformResults &transformResults,
Expand Down
19 changes: 12 additions & 7 deletions mlir/lib/Dialect/Transform/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ void mlir::transform::printPackedOrDynamicIndexList(
if (packed) {
assert(values.empty() && (!integers || integers.empty()) &&
"expected no values/integers");
printer << "*(" << packed << " : " << packedType << ")";
printer << "*(" << packed;
if (packedType) {
printer << " : " << packedType;
}
printer << ")";
return;
}
printDynamicIndexList(printer, op, values, integers, valueTypes);
Expand All @@ -29,19 +33,20 @@ void mlir::transform::printPackedOrDynamicIndexList(
ParseResult mlir::transform::parsePackedOrDynamicIndexList(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &packed,
Type &packedType, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &valueTypes, DenseI64ArrayAttr &integers) {
SmallVectorImpl<Type> *valueTypes, DenseI64ArrayAttr &integers) {
OpAsmParser::UnresolvedOperand packedOperand;
if (parser.parseOptionalStar().succeeded()) {
if (parser.parseLParen().failed() ||
parser.parseOperand(packedOperand).failed() ||
parser.parseColonType(packedType).failed() ||
parser.parseRParen().failed()) {
parser.parseOperand(packedOperand).failed())
return failure();
if (packedType && (parser.parseColonType(packedType).failed()))
return failure();
if (parser.parseRParen().failed())
return failure();
}
packed.emplace(packedOperand);
integers = parser.getBuilder().getDenseI64ArrayAttr({});
return success();
}

return parseDynamicIndexList(parser, values, integers, &valueTypes);
return parseDynamicIndexList(parser, values, integers, valueTypes);
}
2 changes: 1 addition & 1 deletion mlir/lib/Interfaces/ViewLikeInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ static char getRightDelimiter(AsmParser::Delimiter delimiter) {
void mlir::printDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
TypeRange valueTypes, ArrayRef<bool> scalables,
ArrayRef<bool> scalables, TypeRange valueTypes,
AsmParser::Delimiter delimiter) {
char leftDelimiter = getLeftDelimiter(delimiter);
char rightDelimiter = getRightDelimiter(delimiter);
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVM/transform-e2e.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func.func @matmul_tensors(
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.consumed}) {
%0 = transform.structured.match ops{["linalg.matmul"]} in %module_op : (!transform.any_op) -> !transform.any_op
%1, %loops:3 = transform.structured.tile_using_for %0 [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
%b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:4 = transform.structured.tile_using_for %0 [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:4 = transform.structured.tile_using_for %0 tile_sizes [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
Expand All @@ -54,7 +54,7 @@ func.func @pad_and_pack(%arg0: tensor<13x15xf32>, %arg1: tensor<2x8x8x2xf32>, %a
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.tile_using_for %0 [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
Expand Down Expand Up @@ -85,7 +85,7 @@ func.func @KC_to_CKkc(%arg0: tensor<128x256xf32>, %arg1: tensor<32x4x32x8xf32>)
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loops:2 = transform.structured.tile_using_for %0 [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%1, %loops:2 = transform.structured.tile_using_for %0 tile_sizes [1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
Loading