From 48a53784b5380d1e235e01c5064074cee69fe9a2 Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 14 Dec 2015 15:02:33 -0800 Subject: [PATCH 1/2] Implement the fmod function --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/arithmetic.scala | 44 +++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 70f8777bb7b75..6ffd147247882 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -147,6 +147,7 @@ object FunctionRegistry { expression[Pow]("pow"), expression[Pow]("power"), expression[Pmod]("pmod"), + expression[Fmod]("fmod"), expression[UnaryPositive]("positive"), expression[Rint]("rint"), expression[Round]("round"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 98464edf4d390..541d634de0587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -514,3 +514,47 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } } + +case class Fmod(left: Expression, right: Expression) + extends BinaryExpression with Serializable with ImplicitCastInputTypes { + + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + + override def toString: String = s"fmod($left, $right)" + + override def dataType: DataType = DoubleType + + override def eval(input: InternalRow): Any = { + val input2 = right.eval(input) + if (input2 == null || input2 == 0) { + null + } else { + val input1 = left.eval(input) + if (input1 == null) { + null + } else { + input1.asInstanceOf[Double] % input2.asInstanceOf[Double] + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + s""" + ${eval2.code} + boolean ${ev.isNull} = ${eval2.isNull} || ${eval2.primitive} == 0; + + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval1.code} + if (!${eval1.isNull}) { + ${ev.primitive} = ${eval1.primitive} % ${eval2.primitive}; + } else { + ${ev.isNull} = true; + } + } + """ + } + +} From dd82703ee22a6e084e8f9d1798103c306c56399b Mon Sep 17 00:00:00 2001 From: Mikhail Bautin Date: Mon, 14 Dec 2015 15:56:03 -0800 Subject: [PATCH 2/2] Add fmod to DataFrame API --- .../src/main/scala/org/apache/spark/sql/functions.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 355787eb02147..82d5fb6c53e05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1533,6 +1533,13 @@ object functions { */ def toRadians(columnName: String): Column = toRadians(Column(columnName)) + /** + * Computes a floating-point remaineder value. The result has the same sign as the denominator. + * + * @group math_funcs + */ + def fmod(numerator: Column, denominator: Column): Column = Fmod(numerator.expr, denominator.expr) + ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions //////////////////////////////////////////////////////////////////////////////////////////////