@@ -39,6 +39,131 @@ llvm::cl::opt<std::string>
3939 selectedDialect (" dialect" , llvm::cl::desc(" The dialect to gen for" ),
4040 llvm::cl::cat(dialectGenCat), llvm::cl::Required);
4141
42+ Value createPredicate (OpBuilder &builder, tblgen::Pred pred) {
43+ MLIRContext *ctx = builder.getContext ();
44+
45+ if (pred.isCombined ()) {
46+ auto combiner = pred.getDef ().getValueAsDef (" kind" )->getName ();
47+ if (combiner == " PredCombinerAnd" || combiner == " PredCombinerOr" ) {
48+ std::vector<Value> constraints;
49+ for (auto *child : pred.getDef ().getValueAsListOfDefs (" children" )) {
50+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
51+ }
52+ if (combiner == " PredCombinerAnd" ) {
53+ auto op =
54+ builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
55+ return op.getOutput ();
56+ }
57+ auto op =
58+ builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), constraints);
59+ return op.getOutput ();
60+ }
61+ }
62+
63+ std::string condition = pred.getCondition ();
64+ // Build a CPredOp to match the C constraint built.
65+ irdl::CPredOp op = builder.create <irdl::CPredOp>(
66+ UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
67+ return op;
68+ }
69+
70+ Value typeToConstraint (OpBuilder &builder, Type type) {
71+ MLIRContext *ctx = builder.getContext ();
72+ auto op =
73+ builder.create <irdl::IsOp>(UnknownLoc::get (ctx), TypeAttr::get (type));
74+ return op.getOutput ();
75+ }
76+
77+ std::optional<Type> recordToType (MLIRContext *ctx, const Record &predRec) {
78+
79+ if (predRec.isSubClassOf (" I" )) {
80+ auto width = predRec.getValueAsInt (" bitwidth" );
81+ return IntegerType::get (ctx, width, IntegerType::Signless);
82+ }
83+
84+ if (predRec.isSubClassOf (" SI" )) {
85+ auto width = predRec.getValueAsInt (" bitwidth" );
86+ return IntegerType::get (ctx, width, IntegerType::Signed);
87+ }
88+
89+ if (predRec.isSubClassOf (" UI" )) {
90+ auto width = predRec.getValueAsInt (" bitwidth" );
91+ return IntegerType::get (ctx, width, IntegerType::Unsigned);
92+ }
93+
94+ // Index type
95+ if (predRec.getName () == " Index" ) {
96+ return IndexType::get (ctx);
97+ }
98+
99+ // Float types
100+ if (predRec.isSubClassOf (" F" )) {
101+ auto width = predRec.getValueAsInt (" bitwidth" );
102+ switch (width) {
103+ case 16 :
104+ return FloatType::getF16 (ctx);
105+ case 32 :
106+ return FloatType::getF32 (ctx);
107+ case 64 :
108+ return FloatType::getF64 (ctx);
109+ case 80 :
110+ return FloatType::getF80 (ctx);
111+ case 128 :
112+ return FloatType::getF128 (ctx);
113+ }
114+ }
115+
116+ if (predRec.getName () == " NoneType" ) {
117+ return NoneType::get (ctx);
118+ }
119+
120+ if (predRec.getName () == " BF16" ) {
121+ return FloatType::getBF16 (ctx);
122+ }
123+
124+ if (predRec.getName () == " TF32" ) {
125+ return FloatType::getTF32 (ctx);
126+ }
127+
128+ if (predRec.getName () == " F8E4M3FN" ) {
129+ return FloatType::getFloat8E4M3FN (ctx);
130+ }
131+
132+ if (predRec.getName () == " F8E5M2" ) {
133+ return FloatType::getFloat8E5M2 (ctx);
134+ }
135+
136+ if (predRec.getName () == " F8E4M3" ) {
137+ return FloatType::getFloat8E4M3 (ctx);
138+ }
139+
140+ if (predRec.getName () == " F8E4M3FNUZ" ) {
141+ return FloatType::getFloat8E4M3FNUZ (ctx);
142+ }
143+
144+ if (predRec.getName () == " F8E4M3B11FNUZ" ) {
145+ return FloatType::getFloat8E4M3B11FNUZ (ctx);
146+ }
147+
148+ if (predRec.getName () == " F8E5M2FNUZ" ) {
149+ return FloatType::getFloat8E5M2FNUZ (ctx);
150+ }
151+
152+ if (predRec.getName () == " F8E3M4" ) {
153+ return FloatType::getFloat8E3M4 (ctx);
154+ }
155+
156+ if (predRec.isSubClassOf (" Complex" )) {
157+ const Record *elementRec = predRec.getValueAsDef (" elementType" );
158+ auto elementType = recordToType (ctx, *elementRec);
159+ if (elementType.has_value ()) {
160+ return ComplexType::get (elementType.value ());
161+ }
162+ }
163+
164+ return std::nullopt ;
165+ }
166+
42167Value createConstraint (OpBuilder &builder, tblgen::Constraint constraint) {
43168 MLIRContext *ctx = builder.getContext ();
44169 const Record &predRec = constraint.getDef ();
@@ -78,11 +203,45 @@ Value createConstraint(OpBuilder &builder, tblgen::Constraint constraint) {
78203 return op.getOutput ();
79204 }
80205
81- std::string condition = constraint.getPredicate ().getCondition ();
82- // Build a CPredOp to match the C constraint built.
83- irdl::CPredOp op = builder.create <irdl::CPredOp>(
84- UnknownLoc::get (ctx), StringAttr::get (ctx, condition));
85- return op;
206+ // Integer types
207+ if (predRec.getName () == " AnyInteger" ) {
208+ auto op = builder.create <irdl::BaseOp>(
209+ UnknownLoc::get (ctx), StringAttr::get (ctx, " !builtin.integer" ));
210+ return op.getOutput ();
211+ }
212+
213+ if (predRec.isSubClassOf (" AnyI" )) {
214+ auto width = predRec.getValueAsInt (" bitwidth" );
215+ std::vector<Value> types = {
216+ typeToConstraint (builder,
217+ IntegerType::get (ctx, width, IntegerType::Signless)),
218+ typeToConstraint (builder,
219+ IntegerType::get (ctx, width, IntegerType::Signed)),
220+ typeToConstraint (builder,
221+ IntegerType::get (ctx, width, IntegerType::Unsigned))};
222+ auto op = builder.create <irdl::AnyOfOp>(UnknownLoc::get (ctx), types);
223+ return op.getOutput ();
224+ }
225+
226+ auto type = recordToType (ctx, predRec);
227+
228+ if (type.has_value ()) {
229+ return typeToConstraint (builder, type.value ());
230+ }
231+
232+ // Confined type
233+ if (predRec.isSubClassOf (" ConfinedType" )) {
234+ std::vector<Value> constraints;
235+ constraints.push_back (createConstraint (
236+ builder, tblgen::Constraint (predRec.getValueAsDef (" baseType" ))));
237+ for (Record *child : predRec.getValueAsListOfDefs (" predicateList" )) {
238+ constraints.push_back (createPredicate (builder, tblgen::Pred (child)));
239+ }
240+ auto op = builder.create <irdl::AllOfOp>(UnknownLoc::get (ctx), constraints);
241+ return op.getOutput ();
242+ }
243+
244+ return createPredicate (builder, constraint.getPredicate ());
86245}
87246
88247// / Returns the name of the operation without the dialect prefix.
@@ -131,10 +290,12 @@ irdl::OperationOp createIRDLOperation(OpBuilder &builder,
131290 auto [results, resultVariadicity] = getValues (tblgenOp.getResults ());
132291
133292 // Create the operands and results operations.
134- consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
135- operandVariadicity);
136- consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
137- resultVariadicity);
293+ if (!operands.empty ())
294+ consBuilder.create <irdl::OperandsOp>(UnknownLoc::get (ctx), operands,
295+ operandVariadicity);
296+ if (!results.empty ())
297+ consBuilder.create <irdl::ResultsOp>(UnknownLoc::get (ctx), results,
298+ resultVariadicity);
138299
139300 return op;
140301}
0 commit comments