1111
1212#include " mlir/Support/LLVM.h"
1313#include " mlir/Support/LogicalResult.h"
14+ #include " llvm/ADT/APFloat.h"
1415#include " llvm/ADT/APInt.h"
1516#include " llvm/ADT/ArrayRef.h"
1617#include " llvm/ADT/Hashing.h"
17- #include " llvm/ADT/SmallVector.h"
18+ #include " llvm/ADT/SmallString.h"
19+ #include " llvm/ADT/Twine.h"
20+ #include " llvm/Support/raw_ostream.h"
1821
1922namespace mlir {
2023
@@ -27,98 +30,202 @@ namespace polynomial {
2730// / would want to specify 128-bit polynomials statically in the source code.
2831constexpr unsigned apintBitWidth = 64 ;
2932
30- // / A class representing a monomial of a single-variable polynomial with integer
31- // / coefficients.
32- class Monomial {
33+ template <typename CoefficientType>
34+ class MonomialBase {
3335public:
34- Monomial (int64_t coeff, uint64_t expo)
35- : coefficient(apintBitWidth, coeff), exponent(apintBitWidth, expo) {}
36-
37- Monomial (const APInt &coeff, const APInt &expo)
36+ MonomialBase (const CoefficientType &coeff, const APInt &expo)
3837 : coefficient(coeff), exponent(expo) {}
38+ virtual ~MonomialBase () = 0 ;
3939
40- Monomial () : coefficient(apintBitWidth, 0 ), exponent(apintBitWidth, 0 ) {}
40+ const CoefficientType &getCoefficient () const { return coefficient; }
41+ CoefficientType &getMutableCoefficient () { return coefficient; }
42+ const APInt &getExponent () const { return exponent; }
43+ void setCoefficient (const CoefficientType &coeff) { coefficient = coeff; }
44+ void setExponent (const APInt &exp) { exponent = exp; }
4145
42- bool operator ==(const Monomial &other) const {
46+ bool operator ==(const MonomialBase &other) const {
4347 return other.coefficient == coefficient && other.exponent == exponent;
4448 }
45- bool operator !=(const Monomial &other) const {
49+ bool operator !=(const MonomialBase &other) const {
4650 return other.coefficient != coefficient || other.exponent != exponent;
4751 }
4852
4953 // / Monomials are ordered by exponent.
50- bool operator <(const Monomial &other) const {
54+ bool operator <(const MonomialBase &other) const {
5155 return (exponent.ult (other.exponent ));
5256 }
5357
54- friend ::llvm::hash_code hash_value (const Monomial &arg);
58+ virtual bool isMonic () const = 0;
59+ virtual void
60+ coefficientToString (llvm::SmallString<16 > &coeffString) const = 0 ;
5561
56- public:
57- APInt coefficient ;
62+ template < typename T>
63+ friend ::llvm::hash_code hash_value ( const MonomialBase<T> &arg) ;
5864
59- // Always unsigned
65+ protected:
66+ CoefficientType coefficient;
6067 APInt exponent;
6168};
6269
63- // / A single-variable polynomial with integer coefficients.
64- // /
65- // / Eg: x^1024 + x + 1
66- // /
67- // / The symbols used as the polynomial's indeterminate don't matter, so long as
68- // / it is used consistently throughout the polynomial.
69- class Polynomial {
70+ // / A class representing a monomial of a single-variable polynomial with integer
71+ // / coefficients.
72+ class IntMonomial : public MonomialBase <APInt> {
7073public:
71- Polynomial () = delete ;
74+ IntMonomial (int64_t coeff, uint64_t expo)
75+ : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {}
7276
73- explicit Polynomial (ArrayRef<Monomial> terms) : terms(terms){};
77+ IntMonomial ()
78+ : MonomialBase(APInt(apintBitWidth, 0 ), APInt(apintBitWidth, 0 )) {}
7479
75- // Returns a Polynomial from a list of monomials.
76- // Fails if two monomials have the same exponent.
77- static FailureOr<Polynomial> fromMonomials (ArrayRef<Monomial> monomials);
80+ ~IntMonomial () = default ;
7881
79- // / Returns a polynomial with coefficients given by `coeffs`. The value
80- // / coeffs[i] is converted to a monomial with exponent i.
81- static Polynomial fromCoefficients (ArrayRef<int64_t > coeffs);
82+ bool isMonic () const override { return coefficient == 1 ; }
83+
84+ void coefficientToString (llvm::SmallString<16 > &coeffString) const override {
85+ coefficient.toStringSigned (coeffString);
86+ }
87+ };
88+
89+ // / A class representing a monomial of a single-variable polynomial with integer
90+ // / coefficients.
91+ class FloatMonomial : public MonomialBase <APFloat> {
92+ public:
93+ FloatMonomial (double coeff, uint64_t expo)
94+ : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {}
95+
96+ FloatMonomial () : MonomialBase(APFloat((double )0 ), APInt(apintBitWidth, 0 )) {}
97+
98+ ~FloatMonomial () = default ;
99+
100+ bool isMonic () const override { return coefficient == APFloat (1.0 ); }
101+
102+ void coefficientToString (llvm::SmallString<16 > &coeffString) const override {
103+ coefficient.toString (coeffString);
104+ }
105+ };
106+
107+ template <typename Monomial>
108+ class PolynomialBase {
109+ public:
110+ PolynomialBase () = delete ;
111+
112+ explicit PolynomialBase (ArrayRef<Monomial> terms) : terms(terms){};
82113
83114 explicit operator bool () const { return !terms.empty (); }
84- bool operator ==(const Polynomial &other) const {
115+ bool operator ==(const PolynomialBase &other) const {
85116 return other.terms == terms;
86117 }
87- bool operator !=(const Polynomial &other) const {
118+ bool operator !=(const PolynomialBase &other) const {
88119 return !(other.terms == terms);
89120 }
90121
91- // Prints polynomial to 'os'.
92- void print (raw_ostream &os) const ;
93122 void print (raw_ostream &os, ::llvm::StringRef separator,
94- ::llvm::StringRef exponentiation) const ;
123+ ::llvm::StringRef exponentiation) const {
124+ bool first = true ;
125+ for (const Monomial &term : getTerms ()) {
126+ if (first) {
127+ first = false ;
128+ } else {
129+ os << separator;
130+ }
131+ std::string coeffToPrint;
132+ if (term.isMonic () && term.getExponent ().uge (1 )) {
133+ coeffToPrint = " " ;
134+ } else {
135+ llvm::SmallString<16 > coeffString;
136+ term.coefficientToString (coeffString);
137+ coeffToPrint = coeffString.str ();
138+ }
139+
140+ if (term.getExponent () == 0 ) {
141+ os << coeffToPrint;
142+ } else if (term.getExponent () == 1 ) {
143+ os << coeffToPrint << " x" ;
144+ } else {
145+ llvm::SmallString<16 > expString;
146+ term.getExponent ().toStringSigned (expString);
147+ os << coeffToPrint << " x" << exponentiation << expString;
148+ }
149+ }
150+ }
151+
152+ // Prints polynomial to 'os'.
153+ void print (raw_ostream &os) const { print (os, " + " , " **" ); }
154+
95155 void dump () const ;
96156
97157 // Prints polynomial so that it can be used as a valid identifier
98- std::string toIdentifier () const ;
158+ std::string toIdentifier () const {
159+ std::string result;
160+ llvm::raw_string_ostream os (result);
161+ print (os, " _" , " " );
162+ return os.str ();
163+ }
99164
100- unsigned getDegree () const ;
165+ unsigned getDegree () const {
166+ return terms.back ().getExponent ().getZExtValue ();
167+ }
101168
102169 ArrayRef<Monomial> getTerms () const { return terms; }
103170
104- friend ::llvm::hash_code hash_value (const Polynomial &arg);
171+ template <typename T>
172+ friend ::llvm::hash_code hash_value (const PolynomialBase<T> &arg);
105173
106174private:
107175 // The monomial terms for this polynomial.
108176 SmallVector<Monomial> terms;
109177};
110178
111- // Make Polynomial hashable.
112- inline ::llvm::hash_code hash_value (const Polynomial &arg) {
179+ // / A single-variable polynomial with integer coefficients.
180+ // /
181+ // / Eg: x^1024 + x + 1
182+ class IntPolynomial : public PolynomialBase <IntMonomial> {
183+ public:
184+ explicit IntPolynomial (ArrayRef<IntMonomial> terms) : PolynomialBase(terms) {}
185+
186+ // Returns a Polynomial from a list of monomials.
187+ // Fails if two monomials have the same exponent.
188+ static FailureOr<IntPolynomial>
189+ fromMonomials (ArrayRef<IntMonomial> monomials);
190+
191+ // / Returns a polynomial with coefficients given by `coeffs`. The value
192+ // / coeffs[i] is converted to a monomial with exponent i.
193+ static IntPolynomial fromCoefficients (ArrayRef<int64_t > coeffs);
194+ };
195+
196+ // / A single-variable polynomial with double coefficients.
197+ // /
198+ // / Eg: 1.0 x^1024 + 3.5 x + 1e-05
199+ class FloatPolynomial : public PolynomialBase <FloatMonomial> {
200+ public:
201+ explicit FloatPolynomial (ArrayRef<FloatMonomial> terms)
202+ : PolynomialBase(terms) {}
203+
204+ // Returns a Polynomial from a list of monomials.
205+ // Fails if two monomials have the same exponent.
206+ static FailureOr<FloatPolynomial>
207+ fromMonomials (ArrayRef<FloatMonomial> monomials);
208+
209+ // / Returns a polynomial with coefficients given by `coeffs`. The value
210+ // / coeffs[i] is converted to a monomial with exponent i.
211+ static FloatPolynomial fromCoefficients (ArrayRef<double > coeffs);
212+ };
213+
214+ // Make Polynomials hashable.
215+ template <typename T>
216+ inline ::llvm::hash_code hash_value (const PolynomialBase<T> &arg) {
113217 return ::llvm::hash_combine_range (arg.terms .begin (), arg.terms .end ());
114218}
115219
116- inline ::llvm::hash_code hash_value (const Monomial &arg) {
220+ template <typename T>
221+ inline ::llvm::hash_code hash_value (const MonomialBase<T> &arg) {
117222 return llvm::hash_combine (::llvm::hash_value (arg.coefficient ),
118223 ::llvm::hash_value (arg.exponent));
119224}
120225
121- inline raw_ostream &operator <<(raw_ostream &os, const Polynomial &polynomial) {
226+ template <typename T>
227+ inline raw_ostream &operator <<(raw_ostream &os,
228+ const PolynomialBase<T> &polynomial) {
122229 polynomial.print (os);
123230 return os;
124231}
0 commit comments