Skip to content

Commit a938527

Browse files
zhichao-lirxin
authored andcommitted
[SPARK-8221][SQL]Add pmod function
https://issues.apache.org/jira/browse/SPARK-8221 One concern is the result would be negative if the divisor is not positive( i.e pmod(7, -3) ), but the behavior is the same as hive. Author: zhichao.li <[email protected]> Closes #6783 from zhichao-li/pmod2 and squashes the following commits: 7083eb9 [zhichao.li] update to the latest type checking d26dba7 [zhichao.li] add pmod
1 parent fa4ec36 commit a938527

File tree

6 files changed

+170
-1
lines changed

6 files changed

+170
-1
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ object FunctionRegistry {
115115
expression[Log2]("log2"),
116116
expression[Pow]("pow"),
117117
expression[Pow]("power"),
118+
expression[Pmod]("pmod"),
118119
expression[UnaryPositive]("positive"),
119120
expression[Rint]("rint"),
120121
expression[Round]("round"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,12 @@ object HiveTypeCoercion {
426426
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
427427
)
428428

429+
case Pmod(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
430+
Cast(
431+
Pmod(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited)),
432+
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
433+
)
434+
429435
// When we compare 2 decimal types with different precisions, cast them to the smallest
430436
// common precision.
431437
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,97 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
377377
override def symbol: String = "min"
378378
override def prettyName: String = symbol
379379
}
380+
381+
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
382+
383+
override def toString: String = s"pmod($left, $right)"
384+
385+
override def symbol: String = "pmod"
386+
387+
protected def checkTypesInternal(t: DataType) =
388+
TypeUtils.checkForNumericExpr(t, "pmod")
389+
390+
override def inputType: AbstractDataType = NumericType
391+
392+
protected override def nullSafeEval(left: Any, right: Any) =
393+
dataType match {
394+
case IntegerType => pmod(left.asInstanceOf[Int], right.asInstanceOf[Int])
395+
case LongType => pmod(left.asInstanceOf[Long], right.asInstanceOf[Long])
396+
case ShortType => pmod(left.asInstanceOf[Short], right.asInstanceOf[Short])
397+
case ByteType => pmod(left.asInstanceOf[Byte], right.asInstanceOf[Byte])
398+
case FloatType => pmod(left.asInstanceOf[Float], right.asInstanceOf[Float])
399+
case DoubleType => pmod(left.asInstanceOf[Double], right.asInstanceOf[Double])
400+
case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal])
401+
}
402+
403+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
404+
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
405+
dataType match {
406+
case dt: DecimalType =>
407+
val decimalAdd = "$plus"
408+
s"""
409+
${ctx.javaType(dataType)} r = $eval1.remainder($eval2);
410+
if (r.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
411+
${ev.primitive} = (r.$decimalAdd($eval2)).remainder($eval2);
412+
} else {
413+
${ev.primitive} = r;
414+
}
415+
"""
416+
// byte and short are casted into int when add, minus, times or divide
417+
case ByteType | ShortType =>
418+
s"""
419+
${ctx.javaType(dataType)} r = (${ctx.javaType(dataType)})($eval1 % $eval2);
420+
if (r < 0) {
421+
${ev.primitive} = (${ctx.javaType(dataType)})((r + $eval2) % $eval2);
422+
} else {
423+
${ev.primitive} = r;
424+
}
425+
"""
426+
case _ =>
427+
s"""
428+
${ctx.javaType(dataType)} r = $eval1 % $eval2;
429+
if (r < 0) {
430+
${ev.primitive} = (r + $eval2) % $eval2;
431+
} else {
432+
${ev.primitive} = r;
433+
}
434+
"""
435+
}
436+
})
437+
}
438+
439+
private def pmod(a: Int, n: Int): Int = {
440+
val r = a % n
441+
if (r < 0) {(r + n) % n} else r
442+
}
443+
444+
private def pmod(a: Long, n: Long): Long = {
445+
val r = a % n
446+
if (r < 0) {(r + n) % n} else r
447+
}
448+
449+
private def pmod(a: Byte, n: Byte): Byte = {
450+
val r = a % n
451+
if (r < 0) {((r + n) % n).toByte} else r.toByte
452+
}
453+
454+
private def pmod(a: Double, n: Double): Double = {
455+
val r = a % n
456+
if (r < 0) {(r + n) % n} else r
457+
}
458+
459+
private def pmod(a: Short, n: Short): Short = {
460+
val r = a % n
461+
if (r < 0) {((r + n) % n).toShort} else r.toShort
462+
}
463+
464+
private def pmod(a: Float, n: Float): Float = {
465+
val r = a % n
466+
if (r < 0) {(r + n) % n} else r
467+
}
468+
469+
private def pmod(a: Decimal, n: Decimal): Decimal = {
470+
val r = a % n
471+
if (r.compare(Decimal(0)) < 0) {(r + n) % n} else r
472+
}
473+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.types.Decimal
2323

24-
2524
class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
2625

2726
/**
@@ -158,4 +157,19 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
158157
checkEvaluation(MinOf(Array(1.toByte, 2.toByte), Array(1.toByte, 3.toByte)),
159158
Array(1.toByte, 2.toByte))
160159
}
160+
161+
test("pmod") {
162+
testNumericDataTypes { convert =>
163+
val left = Literal(convert(7))
164+
val right = Literal(convert(3))
165+
checkEvaluation(Pmod(left, right), convert(1))
166+
checkEvaluation(Pmod(Literal.create(null, left.dataType), right), null)
167+
checkEvaluation(Pmod(left, Literal.create(null, right.dataType)), null)
168+
checkEvaluation(Remainder(left, Literal(convert(0))), null) // mod by 0
169+
}
170+
checkEvaluation(Pmod(-7, 3), 2)
171+
checkEvaluation(Pmod(7.2D, 4.1D), 3.1000000000000005)
172+
checkEvaluation(Pmod(Decimal(0.7), Decimal(0.2)), Decimal(0.1))
173+
checkEvaluation(Pmod(2L, Long.MaxValue), 2)
174+
}
161175
}

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,23 @@ object functions {
13711371
*/
13721372
def pow(l: Double, rightName: String): Column = pow(l, Column(rightName))
13731373

1374+
/**
1375+
* Returns the positive value of dividend mod divisor.
1376+
*
1377+
* @group math_funcs
1378+
* @since 1.5.0
1379+
*/
1380+
def pmod(dividend: Column, divisor: Column): Column = Pmod(dividend.expr, divisor.expr)
1381+
1382+
/**
1383+
* Returns the positive value of dividend mod divisor.
1384+
*
1385+
* @group math_funcs
1386+
* @since 1.5.0
1387+
*/
1388+
def pmod(dividendColName: String, divisorColName: String): Column =
1389+
pmod(Column(dividendColName), Column(divisorColName))
1390+
13741391
/**
13751392
* Returns the double value that is closest in value to the argument and
13761393
* is equal to a mathematical integer.

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,4 +403,41 @@ class DataFrameFunctionsSuite extends QueryTest {
403403
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
404404
)
405405
}
406+
407+
test("pmod") {
408+
val intData = Seq((7, 3), (-7, 3)).toDF("a", "b")
409+
checkAnswer(
410+
intData.select(pmod('a, 'b)),
411+
Seq(Row(1), Row(2))
412+
)
413+
checkAnswer(
414+
intData.select(pmod('a, lit(3))),
415+
Seq(Row(1), Row(2))
416+
)
417+
checkAnswer(
418+
intData.select(pmod(lit(-7), 'b)),
419+
Seq(Row(2), Row(2))
420+
)
421+
checkAnswer(
422+
intData.selectExpr("pmod(a, b)"),
423+
Seq(Row(1), Row(2))
424+
)
425+
checkAnswer(
426+
intData.selectExpr("pmod(a, 3)"),
427+
Seq(Row(1), Row(2))
428+
)
429+
checkAnswer(
430+
intData.selectExpr("pmod(-7, b)"),
431+
Seq(Row(2), Row(2))
432+
)
433+
val doubleData = Seq((7.2, 4.1)).toDF("a", "b")
434+
checkAnswer(
435+
doubleData.select(pmod('a, 'b)),
436+
Seq(Row(3.1000000000000005)) // same as hive
437+
)
438+
checkAnswer(
439+
doubleData.select(pmod(lit(2), lit(Int.MaxValue))),
440+
Seq(Row(2))
441+
)
442+
}
406443
}

0 commit comments

Comments
 (0)