Skip to content

Commit 4c6c10d

Browse files
committed
try constraints
1 parent 528778e commit 4c6c10d

File tree

3 files changed

+58
-9
lines changed

3 files changed

+58
-9
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
6262
}
6363

6464
def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> {
65-
let summary = "An attribute containing a single-variable polynomial with integer coefficients.";
65+
let summary = "an attribute containing a single-variable polynomial with integer coefficients";
6666
let description = [{
6767
A polynomial attribute represents a single-variable polynomial with integer
6868
coefficients, which is used to define the modulus of a `RingAttr`, as well
@@ -109,7 +109,7 @@ def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr<
109109
}
110110

111111
def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> {
112-
let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients.";
112+
let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients";
113113
let description = [{
114114
A polynomial attribute represents a single-variable polynomial with double
115115
precision floating point coefficients.
@@ -489,6 +489,25 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[
489489
Polynomial_FloatPolynomialAttr,
490490
Polynomial_IntPolynomialAttr
491491
]>;
492+
def Polynomial_PolynomialElementsAttr :
493+
ElementsAttrBase<And<[//CPred<"::llvm::isa<::mlir::ElementsAttr>($_self)">,
494+
CPred<[{
495+
isa<::mlir::polynomial::PolynomialType>(
496+
::llvm::cast<::mlir::ElementsAttr>($_self)
497+
.getShapedType()
498+
.getElementType())
499+
}]>]>,
500+
"an elements attribute containing polynomial attributes"> {
501+
let storageType = [{ ::mlir::ElementsAttr }];
502+
let returnType = [{ ::mlir::ElementsAttr }];
503+
let convertFromStorage = "$_self";
504+
}
505+
506+
def Polynomial_PolynomialOrElementsAttr : AnyAttrOf<[
507+
Polynomial_FloatPolynomialAttr,
508+
Polynomial_IntPolynomialAttr,
509+
Polynomial_PolynomialElementsAttr,
510+
]>;
492511

493512
// Not deriving from Polynomial_Op due to need for custom assembly format
494513
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLike]> {
@@ -505,8 +524,8 @@ def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant", [Pure, ConstantLi
505524
%0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring>
506525
```
507526
}];
508-
let arguments = (ins Polynomial_AnyPolynomialAttr:$value);
509-
let results = (outs Polynomial_PolynomialType:$output);
527+
let arguments = (ins Polynomial_PolynomialOrElementsAttr:$value);
528+
let results = (outs PolynomialLike:$output);
510529
let assemblyFormat = "attr-dict `:` type($output)";
511530
let hasFolder = 1;
512531
}

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
10-
#include "mlir/Dialect/Arith/IR/Arith.h"
1110
#include "mlir/Dialect/CommonFolders.h"
1211
#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
1312
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"

mlir/test/Dialect/Polynomial/folding.mlir

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22

33
// Tests for folding
44

5-
#my_poly = #polynomial.int_polynomial<1 + x**1024>
65
#poly_3t = #polynomial.int_polynomial<3t>
76
#poly_t3_plus_4t_plus_2 = #polynomial.int_polynomial<t**3 + 4t + 2>
8-
#modulus = #polynomial.int_polynomial<-1 + x**1024>
9-
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#modulus, primitiveRoot=193>
7+
#ring = #polynomial.ring<coefficientType=i32>
108
!poly_ty = !polynomial.polynomial<ring=#ring>
119

1210
// CHECK-LABEL: test_fold_add
@@ -19,5 +17,38 @@ func.func @test_fold_add() -> !poly_ty {
1917
return %2 : !poly_ty
2018
}
2119

20+
// CHECK-LABEL: test_fold_add_elementwise
21+
// CHECK-NEXT: polynomial.constant {value = dense<
22+
// CHECK-SAME: #polynomial.typed_int_polynomial<type=
23+
// CHECK-SAME: value = <2 + 7x + x**3>>,
24+
// CHECK-SAME: #polynomial.typed_int_polynomial<type=
25+
// CHECK-SAME: value = <2 + 7x + x**3>>,
26+
// CHECK-SAME: ]>}
27+
// CHECK-NEXT: return
28+
#typed_poly1 = #polynomial.typed_int_polynomial<type=!poly_ty, value=<3t>>
29+
#typed_poly2 = #polynomial.typed_int_polynomial<type=!poly_ty, value=<t**3 + 4t + 2>>
30+
!tensor_ty = tensor<2x!poly_ty>
31+
func.func @test_fold_add_elementwise() -> !tensor_ty {
32+
%0 = polynomial.constant {value=[#typed_poly1, #typed_poly2]} : !tensor_ty
33+
%1 = polynomial.constant {value=[#typed_poly2, #typed_poly1]} : !tensor_ty
34+
%2 = polynomial.add %0, %1 : !tensor_ty
35+
return %2 : !tensor_ty
36+
}
37+
38+
39+
#fpoly_1 = #polynomial.float_polynomial<3.5t>
40+
#fpoly_2 = #polynomial.float_polynomial<1.0t**3 + 1.25t + 2.0>
41+
#fring = #polynomial.ring<coefficientType=f32>
42+
!fpoly_ty = !polynomial.polynomial<ring=#fring>
43+
44+
// CHECK-LABEL: test_fold_add_float
45+
// CHECK-NEXT: polynomial.constant {value = #polynomial.float_polynomial<2 + 4.75x + x**3>}
46+
// CHECK-NEXT: return
47+
func.func @test_fold_add_float() -> !fpoly_ty {
48+
%0 = polynomial.constant {value=#fpoly_1} : !fpoly_ty
49+
%1 = polynomial.constant {value=#fpoly_2} : !fpoly_ty
50+
%2 = polynomial.add %0, %1 : !fpoly_ty
51+
return %2 : !fpoly_ty
52+
}
53+
2254
// Test elementwise folding of add
23-
// Test float folding of add

0 commit comments

Comments
 (0)