@@ -39,6 +39,130 @@ llvm::cl::opt<std::string>
3939 selectedDialect (" dialect" , llvm::cl::desc(" The dialect to gen for" ),
4040 llvm::cl::cat(dialectGenCat), llvm::cl::Required);
4141
42+ Value createPredicate (OpBuilder &builder, tblgen::Pred pred) {
43+ MLIRContext *ctx = builder.getContext ();
44+
45+ if (pred.isCombined ()) {
46+ auto combiner = pred.getDef ().getValueAsDef (" kind" )->getName ();
47+ if (combiner == " PredCombinerAnd" || combiner == " PredCombinerOr" ) {
48+ std::vector<Value> constraints;
49+ for (auto *child : pred.getDef ().getValueAsListOfDefs (" children" )) {
50+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
51+ }
52+ if (combiner == " PredCombinerAnd" ) {
53+ auto op =
54+ builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
55+ return op.getOutput ();
56+ }
57+ auto op =
58+ builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), constraints);
59+ return op.getOutput ();
60+ }
61+ }
62+
63+ std::string condition = pred.getCondition ();
64+ // Build a CPredOp to match the C constraint built.
65+ irdl::CPredOp op = builder.create <irdl::CPredOp>(
66+ UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
67+ return op;
68+ }
69+
70+ Value typeToConstraint (OpBuilder &builder, MLIRContext *ctx, Type type) {
71+ auto op =
72+ builder.create <irdl::IsOp>(UnknownLoc::get (ctx), TypeAttr::get (type));
73+ return op.getOutput ();
74+ }
75+
76+ std::optional<Type> recordToType (MLIRContext *ctx, const Record &predRec) {
77+
78+ if (predRec.isSubClassOf (" I" )) {
79+ auto width = predRec.getValueAsInt (" bitwidth" );
80+ return IntegerType::get (ctx, width, IntegerType::Signless);
81+ }
82+
83+ if (predRec.isSubClassOf (" SI" )) {
84+ auto width = predRec.getValueAsInt (" bitwidth" );
85+ return IntegerType::get (ctx, width, IntegerType::Signed);
86+ }
87+
88+ if (predRec.isSubClassOf (" UI" )) {
89+ auto width = predRec.getValueAsInt (" bitwidth" );
90+ return IntegerType::get (ctx, width, IntegerType::Unsigned);
91+ }
92+
93+ // Index type
94+ if (predRec.getName () == " Index" ) {
95+ return IndexType::get (ctx);
96+ }
97+
98+ // Float types
99+ if (predRec.isSubClassOf (" F" )) {
100+ auto width = predRec.getValueAsInt (" bitwidth" );
101+ switch (width) {
102+ case 16 :
103+ return FloatType::getF16 (ctx);
104+ case 32 :
105+ return FloatType::getF32 (ctx);
106+ case 64 :
107+ return FloatType::getF64 (ctx);
108+ case 80 :
109+ return FloatType::getF80 (ctx);
110+ case 128 :
111+ return FloatType::getF128 (ctx);
112+ }
113+ }
114+
115+ if (predRec.getName () == " NoneType" ) {
116+ return NoneType::get (ctx);
117+ }
118+
119+ if (predRec.getName () == " BF16" ) {
120+ return FloatType::getBF16 (ctx);
121+ }
122+
123+ if (predRec.getName () == " TF32" ) {
124+ return FloatType::getTF32 (ctx);
125+ }
126+
127+ if (predRec.getName () == " F8E4M3FN" ) {
128+ return FloatType::getFloat8E4M3FN (ctx);
129+ }
130+
131+ if (predRec.getName () == " F8E5M2" ) {
132+ return FloatType::getFloat8E5M2 (ctx);
133+ }
134+
135+ if (predRec.getName () == " F8E4M3" ) {
136+ return FloatType::getFloat8E4M3 (ctx);
137+ }
138+
139+ if (predRec.getName () == " F8E4M3FNUZ" ) {
140+ return FloatType::getFloat8E4M3FNUZ (ctx);
141+ }
142+
143+ if (predRec.getName () == " F8E4M3B11FNUZ" ) {
144+ return FloatType::getFloat8E4M3B11FNUZ (ctx);
145+ }
146+
147+ if (predRec.getName () == " F8E5M2FNUZ" ) {
148+ return FloatType::getFloat8E5M2FNUZ (ctx);
149+ }
150+
151+ if (predRec.getName () == " F8E3M4" ) {
152+ return FloatType::getFloat8E3M4 (ctx);
153+ }
154+
155+ if (predRec.isSubClassOf (" Complex" )) {
156+ const Record *elementRec = predRec.getValueAsDef (" elementType" );
157+ auto elementType = recordToType (ctx, *elementRec);
158+ if (elementType.has_value ()) {
159+ return ComplexType::get (elementType.value ());
160+ }
161+ }
162+
163+ return std::nullopt ;
164+ }
165+
42166Value createConstraint (OpBuilder &builder, tblgen::Constraint constraint) {
43167 MLIRContext *ctx = builder.getContext ();
44168 const Record &predRec = constraint.getDef ();
@@ -78,11 +202,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
78202 return op.getOutput ();
79203 }
80204
81- std::string condition = constraint.getPredicate ().getCondition ();
82- // Build a CPredOp to match the C constraint built.
83- irdl::CPredOp op = builder.create <irdl::CPredOp>(
84- UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
85- return op;
205+ // Integer types
206+ if (predRec.getName () == " AnyInteger" ) {
207+ auto op = builder.create <irdl::BaseOp>(
208+ UnknownLoc::get (ctx), StringAttr::get (ctx, " !builtin.integer" ));
209+ return op.getOutput ();
210+ }
211+
212+ if (predRec.isSubClassOf (" AnyI" )) {
213+ auto width = predRec.getValueAsInt (" bitwidth" );
214+ std::vector<Value> types = {
215+ typeToConstraint (builder, ctx,
216+ IntegerType::get (ctx, width, IntegerType::Signless)),
217+ typeToConstraint (builder, ctx,
218+ IntegerType::get (ctx, width, IntegerType::Signed)),
219+ typeToConstraint (builder, ctx,
220+ IntegerType::get (ctx, width, IntegerType::Unsigned))};
221+ auto op = builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), types);
222+ return op.getOutput ();
223+ }
224+
225+ auto type = recordToType (ctx, predRec);
226+
227+ if (type.has_value ()) {
228+ return typeToConstraint (builder, ctx, type.value ());
229+ }
230+
231+ // Confined type
232+ if (predRec.isSubClassOf (" ConfinedType" )) {
233+ std::vector<Value> constraints;
234+ constraints.push_back (createConstraint (
235+ builder, tblgen::Constraint (predRec.getValueAsDef (" baseType" ))));
236+ for (Record *child : predRec.getValueAsListOfDefs (" predicateList" )) {
237+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
238+ }
239+ auto op = builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
240+ return op.getOutput ();
241+ }
242+
243+ return createPredicate (builder, constraint.getPredicate ());
86244}
87245
88246// / Returns the name of the operation without the dialect prefix.
@@ -131,10 +289,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
131289 auto [results, resultVariadicity] = getValues (tblgenOp.getResults ());
132290
133291 // Create the operands and results operations.
134- consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
135- operandVariadicity);
136- consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
137- resultVariadicity);
292+ if (!operands.empty ())
293+ consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
294+ operandVariadicity);
295+ if (!results.empty ())
296+ consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
297+ resultVariadicity);
138298
139299 return op;
140300}
0 commit comments