From 54aa7d04c5a8c483ccfa18e5acb5b5756caa60fc Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Wed, 21 Aug 2024 10:53:09 +0100 Subject: [PATCH 1/2] Refactor tblgen-to-irdl script and support more types --- mlir/include/mlir/IR/CommonTypeConstraints.td | 5 +- mlir/test/tblgen-to-irdl/CMathDialect.td | 1 - mlir/test/tblgen-to-irdl/TestDialect.td | 58 +++++- .../tools/tblgen-to-irdl/OpDefinitionsGen.cpp | 179 +++++++++++++++++- 4 files changed, 226 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 4536d781ef674..0e076413d0d9f 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -198,7 +198,10 @@ class AllOfType allowedTypeList, string summary = "", class ConfinedType predicates, string summary = "", string cppType = type.cppType> : Type< And, - summary, cppType>; + summary, cppType> { + Type baseType = type; + list predicateList = predicates; +} // Integer types. diff --git a/mlir/test/tblgen-to-irdl/CMathDialect.td b/mlir/test/tblgen-to-irdl/CMathDialect.td index 5b9e756727cb3..454543e074c48 100644 --- a/mlir/test/tblgen-to-irdl/CMathDialect.td +++ b/mlir/test/tblgen-to-irdl/CMathDialect.td @@ -25,7 +25,6 @@ def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> { // CHECK: irdl.operation @identity { // CHECK-NEXT: %0 = irdl.base "!cmath.complex" -// CHECK-NEXT: irdl.operands() // CHECK-NEXT: irdl.results(%0) // CHECK-NEXT: } def CMath_IdentityOp : CMath_Op<"identity"> { diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td index fc40da527db00..e67ff1e8f7d6b 100644 --- a/mlir/test/tblgen-to-irdl/TestDialect.td +++ b/mlir/test/tblgen-to-irdl/TestDialect.td @@ -28,9 +28,8 @@ def Test_AndOp : Test_Op<"and"> { // CHECK-LABEL: irdl.operation @and { // CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a" // CHECK-NEXT: %[[v1:[^ ]*]] = irdl.any -// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) +// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) // CHECK-NEXT: irdl.operands(%[[v2]]) -// CHECK-NEXT: irdl.results() // CHECK-NEXT: } @@ -41,9 +40,38 @@ def Test_AnyOp : Test_Op<"any"> { // CHECK-LABEL: irdl.operation @any { // CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any // CHECK-NEXT: irdl.operands(%[[v0]]) -// CHECK-NEXT: irdl.results() // CHECK-NEXT: } +// Check confined types are converted correctly. +def Test_ConfinedOp : Test_Op<"confined"> { + let arguments = (ins ConfinedType:$confined, + ConfinedType.predicate, IntMaxValue<2>.predicate]>]>:$bounded); +} +// CHECK-LABEL: irdl.operation @confined { +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i32 +// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "{{.*}}" +// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) +// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.is i8 +// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "{{.*}}" +// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "{{.*}}" +// CHECK-NEXT: %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]]) +// CHECK-NEXT: %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]]) +// CHECK-NEXT: irdl.operands(%[[v2]], %[[v7]]) +// CHECK-NEXT: } + +// Check generic integer types are converted correctly. +def Test_Integers : Test_Op<"integers"> { + let arguments = (ins AnyI8:$any_int, + AnyInteger:$any_integer); +} +// CHECK-LABEL: irdl.operation @integers { +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i8 +// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.is si8 +// CHECK-NEXT: %[[v2:[^ ]*]] = irdl.is ui8 +// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) +// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.base "!builtin.integer" +// CHECK-NEXT: irdl.operands(%[[v3]], %[[v4]]) +// CHECK-NEXT: } // Check that AnyTypeOf is converted correctly. def Test_OrOp : Test_Op<"or"> { @@ -53,11 +81,30 @@ def Test_OrOp : Test_Op<"or"> { // CHECK-NEXT: %[[v0:[^ ]*]] = irdl.base "!test.singleton_a" // CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b" // CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c" -// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) +// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any_of(%[[v0]], %[[v1]], %[[v2]]) // CHECK-NEXT: irdl.operands(%[[v3]]) -// CHECK-NEXT: irdl.results() // CHECK-NEXT: } +// Check that various types are converted correctly. +def Test_TypesOp : Test_Op<"types"> { + let arguments = (ins I32:$a, + SI64:$b, + UI8:$c, + Index:$d, + F32:$e, + NoneType:$f, + Complex); +} +// CHECK-LABEL: irdl.operation @types { +// CHECK-NEXT: %{{.*}} = irdl.is i32 +// CHECK-NEXT: %{{.*}} = irdl.is si64 +// CHECK-NEXT: %{{.*}} = irdl.is ui8 +// CHECK-NEXT: %{{.*}} = irdl.is index +// CHECK-NEXT: %{{.*}} = irdl.is f32 +// CHECK-NEXT: %{{.*}} = irdl.is none +// CHECK-NEXT: %{{.*}} = irdl.is complex +// CHECK-NEXT: irdl.operands({{.*}}) +// CHECK-NEXT: } // Check that variadics and optionals are converted correctly. def Test_VariadicityOp : Test_Op<"variadicity"> { @@ -70,5 +117,4 @@ def Test_VariadicityOp : Test_Op<"variadicity"> { // CHECK-NEXT: %[[v1:[^ ]*]] = irdl.base "!test.singleton_b" // CHECK-NEXT: %[[v2:[^ ]*]] = irdl.base "!test.singleton_c" // CHECK-NEXT: irdl.operands(variadic %[[v0]], optional %[[v1]], %[[v2]]) -// CHECK-NEXT: irdl.results() // CHECK-NEXT: } diff --git a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp index a55f3539f31db..e066a73d9c458 100644 --- a/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp +++ b/mlir/tools/tblgen-to-irdl/OpDefinitionsGen.cpp @@ -39,6 +39,131 @@ llvm::cl::opt selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), llvm::cl::cat(dialectGenCat), llvm::cl::Required); +Value createPredicate(OpBuilder &builder, tblgen::Pred pred) { + MLIRContext *ctx = builder.getContext(); + + if (pred.isCombined()) { + auto combiner = pred.getDef().getValueAsDef("kind")->getName(); + if (combiner == "PredCombinerAnd" || combiner == "PredCombinerOr") { + std::vector constraints; + for (auto *child : pred.getDef().getValueAsListOfDefs("children")) { + constraints.push_back(createPredicate(builder, tblgen::Pred(child))); + } + if (combiner == "PredCombinerAnd") { + auto op = + builder.create(UnknownLoc::get(ctx), constraints); + return op.getOutput(); + } + auto op = + builder.create(UnknownLoc::get(ctx), constraints); + return op.getOutput(); + } + } + + std::string condition = pred.getCondition(); + // Build a CPredOp to match the C constraint built. + irdl::CPredOp op = builder.create( + UnknownLoc::get(ctx), StringAttr::get(ctx, condition)); + return op; +} + +Value typeToConstraint(OpBuilder &builder, Type type) { + MLIRContext *ctx = builder.getContext(); + auto op = + builder.create(UnknownLoc::get(ctx), TypeAttr::get(type)); + return op.getOutput(); +} + +std::optional recordToType(MLIRContext *ctx, const Record &predRec) { + + if (predRec.isSubClassOf("I")) { + auto width = predRec.getValueAsInt("bitwidth"); + return IntegerType::get(ctx, width, IntegerType::Signless); + } + + if (predRec.isSubClassOf("SI")) { + auto width = predRec.getValueAsInt("bitwidth"); + return IntegerType::get(ctx, width, IntegerType::Signed); + } + + if (predRec.isSubClassOf("UI")) { + auto width = predRec.getValueAsInt("bitwidth"); + return IntegerType::get(ctx, width, IntegerType::Unsigned); + } + + // Index type + if (predRec.getName() == "Index") { + return IndexType::get(ctx); + } + + // Float types + if (predRec.isSubClassOf("F")) { + auto width = predRec.getValueAsInt("bitwidth"); + switch (width) { + case 16: + return FloatType::getF16(ctx); + case 32: + return FloatType::getF32(ctx); + case 64: + return FloatType::getF64(ctx); + case 80: + return FloatType::getF80(ctx); + case 128: + return FloatType::getF128(ctx); + } + } + + if (predRec.getName() == "NoneType") { + return NoneType::get(ctx); + } + + if (predRec.getName() == "BF16") { + return FloatType::getBF16(ctx); + } + + if (predRec.getName() == "TF32") { + return FloatType::getTF32(ctx); + } + + if (predRec.getName() == "F8E4M3FN") { + return FloatType::getFloat8E4M3FN(ctx); + } + + if (predRec.getName() == "F8E5M2") { + return FloatType::getFloat8E5M2(ctx); + } + + if (predRec.getName() == "F8E4M3") { + return FloatType::getFloat8E4M3(ctx); + } + + if (predRec.getName() == "F8E4M3FNUZ") { + return FloatType::getFloat8E4M3FNUZ(ctx); + } + + if (predRec.getName() == "F8E4M3B11FNUZ") { + return FloatType::getFloat8E4M3B11FNUZ(ctx); + } + + if (predRec.getName() == "F8E5M2FNUZ") { + return FloatType::getFloat8E5M2FNUZ(ctx); + } + + if (predRec.getName() == "F8E3M4") { + return FloatType::getFloat8E3M4(ctx); + } + + if (predRec.isSubClassOf("Complex")) { + const Record *elementRec = predRec.getValueAsDef("elementType"); + auto elementType = recordToType(ctx, *elementRec); + if (elementType.has_value()) { + return ComplexType::get(elementType.value()); + } + } + + return std::nullopt; +} + Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) { MLIRContext *ctx = builder.getContext(); const Record &predRec = constraint.getDef(); @@ -78,11 +203,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) { return op.getOutput(); } - std::string condition = constraint.getPredicate().getCondition(); - // Build a CPredOp to match the C constraint built. - irdl::CPredOp op = builder.create( - UnknownLoc::get(ctx), StringAttr::get(ctx, condition)); - return op; + // Integer types + if (predRec.getName() == "AnyInteger") { + auto op = builder.create( + UnknownLoc::get(ctx), StringAttr::get(ctx, "!builtin.integer")); + return op.getOutput(); + } + + if (predRec.isSubClassOf("AnyI")) { + auto width = predRec.getValueAsInt("bitwidth"); + std::vector types = { + typeToConstraint(builder, + IntegerType::get(ctx, width, IntegerType::Signless)), + typeToConstraint(builder, + IntegerType::get(ctx, width, IntegerType::Signed)), + typeToConstraint(builder, + IntegerType::get(ctx, width, IntegerType::Unsigned))}; + auto op = builder.create(UnknownLoc::get(ctx), types); + return op.getOutput(); + } + + auto type = recordToType(ctx, predRec); + + if (type.has_value()) { + return typeToConstraint(builder, type.value()); + } + + // Confined type + if (predRec.isSubClassOf("ConfinedType")) { + std::vector constraints; + constraints.push_back(createConstraint( + builder, tblgen::Constraint(predRec.getValueAsDef("baseType")))); + for (Record *child : predRec.getValueAsListOfDefs("predicateList")) { + constraints.push_back(createPredicate(builder, tblgen::Pred(child))); + } + auto op = builder.create(UnknownLoc::get(ctx), constraints); + return op.getOutput(); + } + + return createPredicate(builder, constraint.getPredicate()); } /// Returns the name of the operation without the dialect prefix. @@ -131,10 +290,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder, auto [results, resultVariadicity] = getValues(tblgenOp.getResults()); // Create the operands and results operations. - consBuilder.create(UnknownLoc::get(ctx), operands, - operandVariadicity); - consBuilder.create(UnknownLoc::get(ctx), results, - resultVariadicity); + if (!operands.empty()) + consBuilder.create(UnknownLoc::get(ctx), operands, + operandVariadicity); + if (!results.empty()) + consBuilder.create(UnknownLoc::get(ctx), results, + resultVariadicity); return op; } From d59956c1ec1f0bcf9d817491b1b18cf50e2d5de4 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Tue, 3 Sep 2024 10:50:28 +0100 Subject: [PATCH 2/2] Use CPred in test --- mlir/test/tblgen-to-irdl/TestDialect.td | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/mlir/test/tblgen-to-irdl/TestDialect.td b/mlir/test/tblgen-to-irdl/TestDialect.td index e67ff1e8f7d6b..2622c81776076 100644 --- a/mlir/test/tblgen-to-irdl/TestDialect.td +++ b/mlir/test/tblgen-to-irdl/TestDialect.td @@ -44,16 +44,17 @@ def Test_AnyOp : Test_Op<"any"> { // Check confined types are converted correctly. def Test_ConfinedOp : Test_Op<"confined"> { - let arguments = (ins ConfinedType:$confined, - ConfinedType.predicate, IntMaxValue<2>.predicate]>]>:$bounded); + let arguments = (ins ConfinedType($_self)">]>:$tensor, + ConfinedType($_self)"> + , CPred<"::llvm::cast<::mlir::VectorType>($_self).getRank() > 0">]>]>:$vector); } // CHECK-LABEL: irdl.operation @confined { -// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.is i32 -// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "{{.*}}" +// CHECK-NEXT: %[[v0:[^ ]*]] = irdl.any +// CHECK-NEXT: %[[v1:[^ ]*]] = irdl.c_pred "(::llvm::isa<::mlir::TensorType>($_self))" // CHECK-NEXT: %[[v2:[^ ]*]] = irdl.all_of(%[[v0]], %[[v1]]) -// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.is i8 -// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "{{.*}}" -// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "{{.*}}" +// CHECK-NEXT: %[[v3:[^ ]*]] = irdl.any +// CHECK-NEXT: %[[v4:[^ ]*]] = irdl.c_pred "(::llvm::isa<::mlir::VectorType>($_self))" +// CHECK-NEXT: %[[v5:[^ ]*]] = irdl.c_pred "(::llvm::cast<::mlir::VectorType>($_self).getRank() > 0)" // CHECK-NEXT: %[[v6:[^ ]*]] = irdl.all_of(%[[v4]], %[[v5]]) // CHECK-NEXT: %[[v7:[^ ]*]] = irdl.all_of(%[[v3]], %[[v6]]) // CHECK-NEXT: irdl.operands(%[[v2]], %[[v7]])