diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h index 7f44c29a98707..e14cef51185e0 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.h @@ -30,7 +30,7 @@ namespace polynomial { /// would want to specify 128-bit polynomials statically in the source code. constexpr unsigned apintBitWidth = 64; -template +template class MonomialBase { public: MonomialBase(const CoefficientType &coeff, const APInt &expo) @@ -55,12 +55,21 @@ class MonomialBase { return (exponent.ult(other.exponent)); } + Derived add(const Derived &other) { + assert(exponent == other.exponent); + CoefficientType newCoeff = coefficient + other.coefficient; + Derived result; + result.setCoefficient(newCoeff); + result.setExponent(exponent); + return result; + } + virtual bool isMonic() const = 0; virtual void coefficientToString(llvm::SmallString<16> &coeffString) const = 0; - template - friend ::llvm::hash_code hash_value(const MonomialBase &arg); + template + friend ::llvm::hash_code hash_value(const MonomialBase &arg); protected: CoefficientType coefficient; @@ -69,7 +78,7 @@ class MonomialBase { /// A class representing a monomial of a single-variable polynomial with integer /// coefficients. -class IntMonomial : public MonomialBase { +class IntMonomial : public MonomialBase { public: IntMonomial(int64_t coeff, uint64_t expo) : MonomialBase(APInt(apintBitWidth, coeff), APInt(apintBitWidth, expo)) {} @@ -77,7 +86,7 @@ class IntMonomial : public MonomialBase { IntMonomial() : MonomialBase(APInt(apintBitWidth, 0), APInt(apintBitWidth, 0)) {} - ~IntMonomial() = default; + ~IntMonomial() override = default; bool isMonic() const override { return coefficient == 1; } @@ -88,14 +97,14 @@ class IntMonomial : public MonomialBase { /// A class representing a monomial of a single-variable polynomial with integer /// coefficients. -class FloatMonomial : public MonomialBase { +class FloatMonomial : public MonomialBase { public: FloatMonomial(double coeff, uint64_t expo) : MonomialBase(APFloat(coeff), APInt(apintBitWidth, expo)) {} FloatMonomial() : MonomialBase(APFloat((double)0), APInt(apintBitWidth, 0)) {} - ~FloatMonomial() = default; + ~FloatMonomial() override = default; bool isMonic() const override { return coefficient == APFloat(1.0); } @@ -104,7 +113,7 @@ class FloatMonomial : public MonomialBase { } }; -template +template class PolynomialBase { public: PolynomialBase() = delete; @@ -149,6 +158,44 @@ class PolynomialBase { } } + Derived add(const Derived &other) { + SmallVector newTerms; + auto it1 = terms.begin(); + auto it2 = other.terms.begin(); + while (it1 != terms.end() || it2 != other.terms.end()) { + if (it1 == terms.end()) { + newTerms.emplace_back(*it2); + it2++; + continue; + } + + if (it2 == other.terms.end()) { + newTerms.emplace_back(*it1); + it1++; + continue; + } + + while (it1->getExponent().ult(it2->getExponent())) { + newTerms.emplace_back(*it1); + it1++; + if (it1 == terms.end()) + break; + } + + while (it2->getExponent().ult(it1->getExponent())) { + newTerms.emplace_back(*it2); + it2++; + if (it2 == terms.end()) + break; + } + + newTerms.emplace_back(it1->add(*it2)); + it1++; + it2++; + } + return Derived(newTerms); + } + // Prints polynomial to 'os'. void print(raw_ostream &os) const { print(os, " + ", "**"); } @@ -168,8 +215,8 @@ class PolynomialBase { ArrayRef getTerms() const { return terms; } - template - friend ::llvm::hash_code hash_value(const PolynomialBase &arg); + template + friend ::llvm::hash_code hash_value(const PolynomialBase &arg); private: // The monomial terms for this polynomial. @@ -179,7 +226,7 @@ class PolynomialBase { /// A single-variable polynomial with integer coefficients. /// /// Eg: x^1024 + x + 1 -class IntPolynomial : public PolynomialBase { +class IntPolynomial : public PolynomialBase { public: explicit IntPolynomial(ArrayRef terms) : PolynomialBase(terms) {} @@ -196,7 +243,7 @@ class IntPolynomial : public PolynomialBase { /// A single-variable polynomial with double coefficients. /// /// Eg: 1.0 x^1024 + 3.5 x + 1e-05 -class FloatPolynomial : public PolynomialBase { +class FloatPolynomial : public PolynomialBase { public: explicit FloatPolynomial(ArrayRef terms) : PolynomialBase(terms) {} @@ -212,20 +259,20 @@ class FloatPolynomial : public PolynomialBase { }; // Make Polynomials hashable. -template -inline ::llvm::hash_code hash_value(const PolynomialBase &arg) { +template +inline ::llvm::hash_code hash_value(const PolynomialBase &arg) { return ::llvm::hash_combine_range(arg.terms.begin(), arg.terms.end()); } -template -inline ::llvm::hash_code hash_value(const MonomialBase &arg) { +template +inline ::llvm::hash_code hash_value(const MonomialBase &arg) { return llvm::hash_combine(::llvm::hash_value(arg.coefficient), ::llvm::hash_value(arg.exponent)); } -template +template inline raw_ostream &operator<<(raw_ostream &os, - const PolynomialBase &polynomial) { + const PolynomialBase &polynomial) { polynomial.print(os); return os; } diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td index ae8484501a50d..14186c563beb8 100644 --- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td +++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td @@ -53,6 +53,7 @@ def Polynomial_Dialect : Dialect { let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; + let hasConstantMaterializer = 1; } class Polynomial_Attr traits = []> @@ -61,7 +62,7 @@ class Polynomial_Attr traits = []> } def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynomial"> { - let summary = "An attribute containing a single-variable polynomial with integer coefficients."; + let summary = "an attribute containing a single-variable polynomial with integer coefficients"; let description = [{ A polynomial attribute represents a single-variable polynomial with integer coefficients, which is used to define the modulus of a `RingAttr`, as well @@ -83,8 +84,32 @@ def Polynomial_IntPolynomialAttr : Polynomial_Attr<"IntPolynomial", "int_polynom let hasCustomAssemblyFormat = 1; } +def Polynomial_TypedIntPolynomialAttr : Polynomial_Attr< + "TypedIntPolynomial", "typed_int_polynomial", [TypedAttrInterface]> { + let summary = "A typed variant of int_polynomial for constant folding."; + let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::IntPolynomialAttr":$value); + let assemblyFormat = "`<` struct(params) `>`"; + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const IntPolynomial &":$value), [{ + return $_get( + type.getContext(), + type, + IntPolynomialAttr::get(type.getContext(), value)); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, + "const Attribute &":$value), [{ + return $_get(type.getContext(), type, ::llvm::cast(value)); + }]> + ]; + let extraClassDeclaration = [{ + // used for constFoldBinaryOp + using ValueType = ::mlir::Attribute; + }]; +} + def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_polynomial"> { - let summary = "An attribute containing a single-variable polynomial with double precision floating point coefficients."; + let summary = "an attribute containing a single-variable polynomial with double precision floating point coefficients"; let description = [{ A polynomial attribute represents a single-variable polynomial with double precision floating point coefficients. @@ -105,6 +130,30 @@ def Polynomial_FloatPolynomialAttr : Polynomial_Attr<"FloatPolynomial", "float_p let hasCustomAssemblyFormat = 1; } +def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr< + "TypedFloatPolynomial", "typed_float_polynomial", [TypedAttrInterface]> { + let summary = "A typed variant of float_polynomial for constant folding."; + let parameters = (ins "::mlir::Type":$type, "::mlir::polynomial::FloatPolynomialAttr":$value); + let assemblyFormat = "`<` struct(params) `>`"; + let builders = [ + AttrBuilderWithInferredContext<(ins "Type":$type, + "const FloatPolynomial &":$value), [{ + return $_get( + type.getContext(), + type, + FloatPolynomialAttr::get(type.getContext(), value)); + }]>, + AttrBuilderWithInferredContext<(ins "Type":$type, + "const Attribute &":$value), [{ + return $_get(type.getContext(), type, ::llvm::cast(value)); + }]> + ]; + let extraClassDeclaration = [{ + // used for constFoldBinaryOp + using ValueType = ::mlir::Attribute; + }]; +} + def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> { let summary = "An attribute specifying a polynomial ring."; let description = [{ @@ -221,6 +270,7 @@ def Polynomial_AddOp : Polynomial_BinaryOp<"add", [Commutative]> { %2 = polynomial.add %0, %1 : !polynomial.polynomial<#ring> ``` }]; + let hasFolder = 1; } def Polynomial_SubOp : Polynomial_BinaryOp<"sub"> { @@ -439,9 +489,28 @@ def Polynomial_AnyPolynomialAttr : AnyAttrOf<[ Polynomial_FloatPolynomialAttr, Polynomial_IntPolynomialAttr ]>; +def Polynomial_PolynomialElementsAttr : + ElementsAttrBase($_self)">, + CPred<[{ + isa<::mlir::polynomial::PolynomialType>( + ::llvm::cast<::mlir::ElementsAttr>($_self) + .getShapedType() + .getElementType()) + }]>]>, + "an elements attribute containing polynomial attributes"> { + let storageType = [{ ::mlir::ElementsAttr }]; + let returnType = [{ ::mlir::ElementsAttr }]; + let convertFromStorage = "$_self"; +} + +def Polynomial_PolynomialOrElementsAttr : AnyAttrOf<[ + Polynomial_FloatPolynomialAttr, + Polynomial_IntPolynomialAttr, + Polynomial_PolynomialElementsAttr, +]>; // Not deriving from Polynomial_Op due to need for custom assembly format -def Polynomial_ConstantOp : Op { +def Polynomial_ConstantOp : Op { let summary = "Define a constant polynomial via an attribute."; let description = [{ Example: @@ -455,9 +524,10 @@ def Polynomial_ConstantOp : Op { %0 = polynomial.constant #polynomial.float_polynomial<0.5 + 1.3e06 x**2> : !polynomial.polynomial<#float_ring> ``` }]; - let arguments = (ins Polynomial_AnyPolynomialAttr:$value); - let results = (outs Polynomial_PolynomialType:$output); + let arguments = (ins Polynomial_PolynomialOrElementsAttr:$value); + let results = (outs PolynomialLike:$output); let assemblyFormat = "attr-dict `:` type($output)"; + let hasFolder = 1; } def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> { diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp index 825b80d70f803..05cc9fd8bbc58 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp @@ -48,3 +48,17 @@ void PolynomialDialect::initialize() { #include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc" >(); } + +Operation *PolynomialDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + auto intPoly = dyn_cast(value); + auto floatPoly = dyn_cast(value); + if (!intPoly && !floatPoly) + return nullptr; + + Type ty = intPoly ? intPoly.getType() : floatPoly.getType(); + Attribute valueAttr = + intPoly ? (Attribute)intPoly.getValue() : (Attribute)floatPoly.getValue(); + return builder.create(loc, ty, valueAttr); +} diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp index 12010de348237..8cbc3b4615140 100644 --- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp +++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp @@ -7,10 +7,12 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Polynomial/IR/PolynomialOps.h" +#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Polynomial/IR/Polynomial.h" #include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h" #include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/Support/LogicalResult.h" @@ -19,6 +21,41 @@ using namespace mlir; using namespace mlir::polynomial; +OpFoldResult ConstantOp::fold(ConstantOp::FoldAdaptor adaptor) { + PolynomialType ty = dyn_cast(getOutput().getType()); + + if (isa(ty.getRing().getPolynomialModulus())) + return TypedFloatPolynomialAttr::get( + ty, cast(getValue()).getPolynomial()); + + assert(isa(ty.getRing().getPolynomialModulus()) && + "expected float or integer polynomial"); + return TypedIntPolynomialAttr::get( + ty, cast(getValue()).getPolynomial()); +} + +OpFoldResult AddOp::fold(AddOp::FoldAdaptor adaptor) { + auto lhsElements = dyn_cast(getLhs().getType()); + PolynomialType elementType = cast( + lhsElements ? lhsElements.getElementType() : getLhs().getType()); + MLIRContext *context = getContext(); + + if (isa(elementType.getRing().getCoefficientType())) + return constFoldBinaryOp( + adaptor.getOperands(), elementType, [&](Attribute a, const Attribute &b) { + return FloatPolynomialAttr::get( + context, cast(a).getPolynomial().add( + cast(b).getPolynomial())); + }); + + return constFoldBinaryOp( + adaptor.getOperands(), elementType, [&](Attribute a, const Attribute &b) { + return IntPolynomialAttr::get( + context, cast(a).getPolynomial().add( + cast(b).getPolynomial())); + }); +} + void FromTensorOp::build(OpBuilder &builder, OperationState &result, Value input, RingAttr ring) { TensorType tensorType = dyn_cast(input.getType()); diff --git a/mlir/test/Dialect/Polynomial/folding.mlir b/mlir/test/Dialect/Polynomial/folding.mlir new file mode 100644 index 0000000000000..3e52a108644ae --- /dev/null +++ b/mlir/test/Dialect/Polynomial/folding.mlir @@ -0,0 +1,54 @@ +// RUN: mlir-opt --sccp --canonicalize %s | FileCheck %s + +// Tests for folding + +#poly_3t = #polynomial.int_polynomial<3t> +#poly_t3_plus_4t_plus_2 = #polynomial.int_polynomial +#ring = #polynomial.ring +!poly_ty = !polynomial.polynomial + +// CHECK-LABEL: test_fold_add +// CHECK-NEXT: polynomial.constant {value = #polynomial.int_polynomial<2 + 7x + x**3>} +// CHECK-NEXT: return +func.func @test_fold_add() -> !poly_ty { + %0 = polynomial.constant {value=#poly_3t} : !poly_ty + %1 = polynomial.constant {value=#poly_t3_plus_4t_plus_2} : !poly_ty + %2 = polynomial.add %0, %1 : !poly_ty + return %2 : !poly_ty +} + +// CHECK-LABEL: test_fold_add_elementwise +// CHECK-NEXT: polynomial.constant {value = dense< +// CHECK-SAME: #polynomial.typed_int_polynomial>, +// CHECK-SAME: #polynomial.typed_int_polynomial>, +// CHECK-SAME: ]>} +// CHECK-NEXT: return +#typed_poly1 = #polynomial.typed_int_polynomial> +#typed_poly2 = #polynomial.typed_int_polynomial> +!tensor_ty = tensor<2x!poly_ty> +func.func @test_fold_add_elementwise() -> !tensor_ty { + %0 = polynomial.constant {value=[#typed_poly1, #typed_poly2]} : !tensor_ty + %1 = polynomial.constant {value=[#typed_poly2, #typed_poly1]} : !tensor_ty + %2 = polynomial.add %0, %1 : !tensor_ty + return %2 : !tensor_ty +} + + +#fpoly_1 = #polynomial.float_polynomial<3.5t> +#fpoly_2 = #polynomial.float_polynomial<1.0t**3 + 1.25t + 2.0> +#fring = #polynomial.ring +!fpoly_ty = !polynomial.polynomial + +// CHECK-LABEL: test_fold_add_float +// CHECK-NEXT: polynomial.constant {value = #polynomial.float_polynomial<2 + 4.75x + x**3>} +// CHECK-NEXT: return +func.func @test_fold_add_float() -> !fpoly_ty { + %0 = polynomial.constant {value=#fpoly_1} : !fpoly_ty + %1 = polynomial.constant {value=#fpoly_2} : !fpoly_ty + %2 = polynomial.add %0, %1 : !fpoly_ty + return %2 : !fpoly_ty +} + +// Test elementwise folding of add diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 13393569f36fe..90a75d5a46ad9 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(Index) add_subdirectory(LLVMIR) add_subdirectory(MemRef) add_subdirectory(OpenACC) +add_subdirectory(Polynomial) add_subdirectory(SCF) add_subdirectory(SparseTensor) add_subdirectory(SPIRV) diff --git a/mlir/unittests/Dialect/Polynomial/CMakeLists.txt b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt new file mode 100644 index 0000000000000..807deeca41c06 --- /dev/null +++ b/mlir/unittests/Dialect/Polynomial/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_unittest(MLIRPolynomialTests + PolynomialMathTest.cpp +) +target_link_libraries(MLIRPolynomialTests + PRIVATE + MLIRIR + MLIRPolynomialDialect +) diff --git a/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp new file mode 100644 index 0000000000000..95906ad42588e --- /dev/null +++ b/mlir/unittests/Dialect/Polynomial/PolynomialMathTest.cpp @@ -0,0 +1,44 @@ +//===- PolynomialMathTest.cpp - Polynomial math Tests ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Polynomial/IR/Polynomial.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::polynomial; + +TEST(AddTest, checkSameDegreeAdditionOfIntPolynomial) { + IntPolynomial x = IntPolynomial::fromCoefficients({1, 2, 3}); + IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4}); + IntPolynomial expected = IntPolynomial::fromCoefficients({3, 5, 7}); + EXPECT_EQ(expected, x.add(y)); +} + +TEST(AddTest, checkDifferentDegreeAdditionOfIntPolynomial) { + IntMonomial term2t = IntMonomial(2, 1); + IntPolynomial x = IntPolynomial::fromMonomials({term2t}).value(); + IntPolynomial y = IntPolynomial::fromCoefficients({2, 3, 4}); + IntPolynomial expected = IntPolynomial::fromCoefficients({2, 5, 4}); + EXPECT_EQ(expected, x.add(y)); + EXPECT_EQ(expected, y.add(x)); +} + +TEST(AddTest, checkSameDegreeAdditionOfFloatPolynomial) { + FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5, 3.5}); + FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5}); + FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 8}); + EXPECT_EQ(expected, x.add(y)); +} + +TEST(AddTest, checkDifferentDegreeAdditionOfFloatPolynomial) { + FloatPolynomial x = FloatPolynomial::fromCoefficients({1.5, 2.5}); + FloatPolynomial y = FloatPolynomial::fromCoefficients({2.5, 3.5, 4.5}); + FloatPolynomial expected = FloatPolynomial::fromCoefficients({4, 6, 4.5}); + EXPECT_EQ(expected, x.add(y)); + EXPECT_EQ(expected, y.add(x)); +}