diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h index 7f44c29a98707..e14cef51185e0 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h @@ -30,7 +30,7 @@ namespace polynomial { /// would want to specify 128-bit polynomials statically in the source code. constexpr unsigned apintBitWidth = 64; -template +template class MonomialBase { public: MonomialBase(const CoefficientType &coeff, const APInt &expo) @@ -55,12 +55,21 @@ class MonomialBase { return (exponent.ult(other.exponent)); } + Derived add(const Derived &other) { + assert(exponent == other.exponent); + CoefficientType newCoeff = coefficient + other.coefficient; + Derived result; + result.setCoefficient(newCoeff); + result.setExponent(exponent); + return result; + } + virtual bool isMonic() const = 0; virtual void coefficientToString(llvm::SmallString<16> &coeffString) const = 0; - template - friend ::llvm::hash_code hash_value(const MonomialBase &arg); + template + friend ::llvm::hash_code hash_value(const MonomialBase &arg); protected: CoefficientType coefficient; @@ -69,7 +78,7 @@ class MonomialBase { /// A class representing a monomial of a single-variable polynomial with integer /// coefficients. -class IntMonomial : public MonomialBase { +class IntMonomial : public MonomialBase { public: IntMonomial(int64_t coeff, uint64_t expo) : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {} @@ -77,7 +86,7 @@ class IntMonomial : public MonomialBase { IntMonomial() : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {} - ~IntMonomial() = default; + ~IntMonomial() override = default; bool isMonic() const override { return coefficient == 1; } @@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase { /// A class representing a monomial of a single-variable polynomial with integer /// coefficients. -class FloatMonomial : public MonomialBase { +class FloatMonomial : public MonomialBase { public: FloatMonomial(double coeff, uint64_t expo) : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {} FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {} - ~FloatMonomial() = default; + ~FloatMonomial() override = default; bool isMonic() const override { return coefficient == APFloat(1.0); } @@ -104,7 +113,7 @@ class FloatMonomial : public MonomialBase { } }; -template +template class PolynomialBase { public: PolynomialBase() = delete; @@ -149,6 +158,44 @@ class PolynomialBase { } } + Derived add(const Derived &other) { + SmallVector newTerms; + auto it1 = terms.begin(); + auto it2 = other.terms.begin(); + while (it1 != terms.end() || it2 != other.terms.end()) { + if (it1 == terms.end()) { + newTerms.emplace_back(*it2); + it2++; + continue; + } + + if (it2 == other.terms.end()) { + newTerms.emplace_back(*it1); + it1++; + continue; + } + + while (it1->getExponent().ult(it2->getExponent())) { + newTerms.emplace_back(*it1); + it1++; + if (it1 == terms.end()) + break; + } + + while (it2->getExponent().ult(it1->getExponent())) { + newTerms.emplace_back(*it2); + it2++; + if (it2 == terms.end()) + break; + } + + newTerms.emplace_back(it1->add(*it2)); + it1++; + it2++; + } + return Derived(newTerms); + } + // Prints polynomial to 'os'. void print(raw_ostream &os) const { print(os, " + ", "**"); } @@ -168,8 +215,8 @@ class PolynomialBase { ArrayRef getTerms() const { return terms; } - template - friend ::llvm::hash_code hash_value(const PolynomialBase &arg); + template + friend ::llvm::hash_code hash_value(const PolynomialBase &arg); private: // The monomial terms for this polynomial. @@ -179,7 +226,7 @@ class PolynomialBase { /// A single-variable polynomial with integer coefficients. /// /// Eg: x^1024 + x + 1 -class IntPolynomial : public PolynomialBase { +class IntPolynomial : public PolynomialBase { public: explicit IntPolynomial(ArrayRef terms) : PolynomialBase(terms) {} @@ -196,7 +243,7 @@ class IntPolynomial : public PolynomialBase { /// A single-variable polynomial with double coefficients. /// /// Eg: 1.0 x^1024 + 3.5 x + 1e-05 -class FloatPolynomial : public PolynomialBase { +class FloatPolynomial : public PolynomialBase { public: explicit FloatPolynomial(ArrayRef terms) : PolynomialBase(terms) {} @@ -212,20 +259,20 @@ class FloatPolynomial : public PolynomialBase { }; // Make Polynomials hashable. -template -inline ::llvm::hash_code hash_value(const PolynomialBase &arg) { +template +inline ::llvm::hash_code hash_value(const PolynomialBase &arg) { return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end()); } -template -inline ::llvm::hash_code hash_value(const MonomialBase &arg) { +template +inline ::llvm::hash_code hash_value(const MonomialBase &arg) { return llvm::hash_combine(::llvm::hash_value(arg.coefficient), ::llvm::hash_value(arg.exponent)); } -template +template inline raw_ostream &operator<<(raw_ostream &os, - const PolynomialBase &polynomial) { + const PolynomialBase &polynomial) { polynomial.print(os); return os; } diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 13393569f36fe..90a75d5a46ad9 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(Index) add_subdirectory(LLVMIR) add_subdirectory(MemRef) add_subdirectory(OpenACC) +add_subdirectory(Polynomial) add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/Polynomial/CMakeLists.txt b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt new file mode 100644 index 0000000000000..807deeca41c06 --- /dev/null +++ b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRPolynomialTests + PolynomialMathTest.cpp +) +target_link_libraries(MLIRPolynomialTests + PRIVATE + MLIRIR + MLIRPolynomialDialect +) diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp new file mode 100644 index 0000000000000..95906ad42588e --- /dev/null +++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp @@ -0,0 +1,44 @@ +//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::polynomial; + +TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) { + IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3}); + IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4}); + IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7}); + EXPECT_EQ(expected, x.add(y)); +} + +TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) { + IntMonomial term2t = IntMonomial(2, 1); + IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value(); + IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4}); + IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4}); + EXPECT_EQ(expected, x.add(y)); + EXPECT_EQ(expected, y.add(x)); +} + +TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) { + FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5}); + FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5}); + FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8}); + EXPECT_EQ(expected, x.add(y)); +} + +TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) { + FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5}); + FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5}); + FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5}); + EXPECT_EQ(expected, x.add(y)); + EXPECT_EQ(expected, y.add(x)); +}