Skip to content

Commit fe917f5

Browse files
brkyvzrxin
authored andcommitted
[SPARK-7188] added python support for math DataFrame functions
Adds support for the math functions for DataFrames in PySpark. rxin I love Davies. Author: Burak Yavuz <[email protected]> Closes #5750 from brkyvz/python-math-udfs and squashes the following commits: 7c4f563 [Burak Yavuz] removed is_math 3c4adde [Burak Yavuz] cleanup imports d5dca3f [Burak Yavuz] moved math functions to mathfunctions 25e6534 [Burak Yavuz] addressed comments v2.0 d3f7e0f [Burak Yavuz] addressed comments and added tests 7b7d7c4 [Burak Yavuz] remove tests for removed methods 33c2c15 [Burak Yavuz] fixed python style 3ee0c05 [Burak Yavuz] added python functions
1 parent 8dee274 commit fe917f5

File tree

8 files changed

+276
-423
lines changed

8 files changed

+276
-423
lines changed

python/pyspark/sql/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _(col):
5454
'upper': 'Converts a string expression to upper case.',
5555
'lower': 'Converts a string expression to upper case.',
5656
'sqrt': 'Computes the square root of the specified float value.',
57-
'abs': 'Computes the absolutle value.',
57+
'abs': 'Computes the absolute value.',
5858

5959
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
6060
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""
19+
A collection of builtin math functions
20+
"""
21+
22+
from pyspark import SparkContext
23+
from pyspark.sql.dataframe import Column
24+
25+
__all__ = []
26+
27+
28+
def _create_unary_mathfunction(name, doc=""):
29+
""" Create a unary mathfunction by name"""
30+
def _(col):
31+
sc = SparkContext._active_spark_context
32+
jc = getattr(sc._jvm.mathfunctions, name)(col._jc if isinstance(col, Column) else col)
33+
return Column(jc)
34+
_.__name__ = name
35+
_.__doc__ = doc
36+
return _
37+
38+
39+
def _create_binary_mathfunction(name, doc=""):
40+
""" Create a binary mathfunction by name"""
41+
def _(col1, col2):
42+
sc = SparkContext._active_spark_context
43+
# users might write ints for simplicity. This would throw an error on the JVM side.
44+
if type(col1) is int:
45+
col1 = col1 * 1.0
46+
if type(col2) is int:
47+
col2 = col2 * 1.0
48+
jc = getattr(sc._jvm.mathfunctions, name)(col1._jc if isinstance(col1, Column) else col1,
49+
col2._jc if isinstance(col2, Column) else col2)
50+
return Column(jc)
51+
_.__name__ = name
52+
_.__doc__ = doc
53+
return _
54+
55+
56+
# math functions are found under another object therefore, they need to be handled separately
57+
_mathfunctions = {
58+
'acos': 'Computes the cosine inverse of the given value; the returned angle is in the range' +
59+
'0.0 through pi.',
60+
'asin': 'Computes the sine inverse of the given value; the returned angle is in the range' +
61+
'-pi/2 through pi/2.',
62+
'atan': 'Computes the tangent inverse of the given value.',
63+
'cbrt': 'Computes the cube-root of the given value.',
64+
'ceil': 'Computes the ceiling of the given value.',
65+
'cos': 'Computes the cosine of the given value.',
66+
'cosh': 'Computes the hyperbolic cosine of the given value.',
67+
'exp': 'Computes the exponential of the given value.',
68+
'expm1': 'Computes the exponential of the given value minus one.',
69+
'floor': 'Computes the floor of the given value.',
70+
'log': 'Computes the natural logarithm of the given value.',
71+
'log10': 'Computes the logarithm of the given value in Base 10.',
72+
'log1p': 'Computes the natural logarithm of the given value plus one.',
73+
'rint': 'Returns the double value that is closest in value to the argument and' +
74+
' is equal to a mathematical integer.',
75+
'signum': 'Computes the signum of the given value.',
76+
'sin': 'Computes the sine of the given value.',
77+
'sinh': 'Computes the hyperbolic sine of the given value.',
78+
'tan': 'Computes the tangent of the given value.',
79+
'tanh': 'Computes the hyperbolic tangent of the given value.',
80+
'toDeg': 'Converts an angle measured in radians to an approximately equivalent angle ' +
81+
'measured in degrees.',
82+
'toRad': 'Converts an angle measured in degrees to an approximately equivalent angle ' +
83+
'measured in radians.'
84+
}
85+
86+
# math functions that take two arguments as input
87+
_binary_mathfunctions = {
88+
'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' +
89+
'polar coordinates (r, theta).',
90+
'hypot': 'Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow.',
91+
'pow': 'Returns the value of the first argument raised to the power of the second argument.'
92+
}
93+
94+
for _name, _doc in _mathfunctions.items():
95+
globals()[_name] = _create_unary_mathfunction(_name, _doc)
96+
for _name, _doc in _binary_mathfunctions.items():
97+
globals()[_name] = _create_binary_mathfunction(_name, _doc)
98+
del _name, _doc
99+
__all__ += _mathfunctions.keys()
100+
__all__ += _binary_mathfunctions.keys()
101+
__all__.sort()

python/pyspark/sql/tests.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,35 @@ def test_aggregator(self):
387387
self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
388388
self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
389389

390+
def test_math_functions(self):
391+
df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF()
392+
from pyspark.sql import mathfunctions as functions
393+
import math
394+
395+
def get_values(l):
396+
return [j[0] for j in l]
397+
398+
def assert_close(a, b):
399+
c = get_values(b)
400+
diff = [abs(v - c[k]) < 1e-6 for k, v in enumerate(a)]
401+
return sum(diff) == len(a)
402+
assert_close([math.cos(i) for i in range(10)],
403+
df.select(functions.cos(df.a)).collect())
404+
assert_close([math.cos(i) for i in range(10)],
405+
df.select(functions.cos("a")).collect())
406+
assert_close([math.sin(i) for i in range(10)],
407+
df.select(functions.sin(df.a)).collect())
408+
assert_close([math.sin(i) for i in range(10)],
409+
df.select(functions.sin(df['a'])).collect())
410+
assert_close([math.pow(i, 2 * i) for i in range(10)],
411+
df.select(functions.pow(df.a, df.b)).collect())
412+
assert_close([math.pow(i, 2) for i in range(10)],
413+
df.select(functions.pow(df.a, 2)).collect())
414+
assert_close([math.pow(i, 2) for i in range(10)],
415+
df.select(functions.pow(df.a, 2.0)).collect())
416+
assert_close([math.hypot(i, 2 * i) for i in range(10)],
417+
df.select(functions.hypot(df.a, df.b)).collect())
418+
390419
def test_save_and_load(self):
391420
df = self.df
392421
tmpPath = tempfile.mkdtemp()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/binary.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,6 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
6565
}
6666
}
6767

68-
case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")
69-
70-
case class Hypot(
71-
left: Expression,
72-
right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT")
73-
7468
case class Atan2(
7569
left: Expression,
7670
right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") {
@@ -91,3 +85,9 @@ case class Atan2(
9185
}
9286
}
9387
}
88+
89+
case class Hypot(
90+
left: Expression,
91+
right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT")
92+
93+
case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathfuncs/unary.scala

Lines changed: 23 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -25,27 +25,16 @@ import org.apache.spark.sql.types._
2525
* input format, therefore these functions extend `ExpectsInputTypes`.
2626
* @param name The short name of the function
2727
*/
28-
abstract class MathematicalExpression(name: String)
28+
abstract class MathematicalExpression(f: Double => Double, name: String)
2929
extends UnaryExpression with Serializable with ExpectsInputTypes {
3030
self: Product =>
3131
type EvaluatedType = Any
3232

33+
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
3334
override def dataType: DataType = DoubleType
3435
override def foldable: Boolean = child.foldable
3536
override def nullable: Boolean = true
3637
override def toString: String = s"$name($child)"
37-
}
38-
39-
/**
40-
* A unary expression specifically for math functions that take a `Double` as input and return
41-
* a `Double`.
42-
* @param f The math function.
43-
* @param name The short name of the function
44-
*/
45-
abstract class MathematicalExpressionForDouble(f: Double => Double, name: String)
46-
extends MathematicalExpression(name) { self: Product =>
47-
48-
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
4938

5039
override def eval(input: Row): Any = {
5140
val evalE = child.eval(input)
@@ -58,111 +47,46 @@ abstract class MathematicalExpressionForDouble(f: Double => Double, name: String
5847
}
5948
}
6049

61-
/**
62-
* A unary expression specifically for math functions that take an `Int` as input and return
63-
* an `Int`.
64-
* @param f The math function.
65-
* @param name The short name of the function
66-
*/
67-
abstract class MathematicalExpressionForInt(f: Int => Int, name: String)
68-
extends MathematicalExpression(name) { self: Product =>
50+
case class Acos(child: Expression) extends MathematicalExpression(math.acos, "ACOS")
6951

70-
override def dataType: DataType = IntegerType
71-
override def expectedChildTypes: Seq[DataType] = Seq(IntegerType)
52+
case class Asin(child: Expression) extends MathematicalExpression(math.asin, "ASIN")
7253

73-
override def eval(input: Row): Any = {
74-
val evalE = child.eval(input)
75-
if (evalE == null) null else f(evalE.asInstanceOf[Int])
76-
}
77-
}
54+
case class Atan(child: Expression) extends MathematicalExpression(math.atan, "ATAN")
7855

79-
/**
80-
* A unary expression specifically for math functions that take a `Float` as input and return
81-
* a `Float`.
82-
* @param f The math function.
83-
* @param name The short name of the function
84-
*/
85-
abstract class MathematicalExpressionForFloat(f: Float => Float, name: String)
86-
extends MathematicalExpression(name) { self: Product =>
56+
case class Cbrt(child: Expression) extends MathematicalExpression(math.cbrt, "CBRT")
8757

88-
override def dataType: DataType = FloatType
89-
override def expectedChildTypes: Seq[DataType] = Seq(FloatType)
58+
case class Ceil(child: Expression) extends MathematicalExpression(math.ceil, "CEIL")
9059

91-
override def eval(input: Row): Any = {
92-
val evalE = child.eval(input)
93-
if (evalE == null) {
94-
null
95-
} else {
96-
val result = f(evalE.asInstanceOf[Float])
97-
if (result.isNaN) null else result
98-
}
99-
}
100-
}
101-
102-
/**
103-
* A unary expression specifically for math functions that take a `Long` as input and return
104-
* a `Long`.
105-
* @param f The math function.
106-
* @param name The short name of the function
107-
*/
108-
abstract class MathematicalExpressionForLong(f: Long => Long, name: String)
109-
extends MathematicalExpression(name) { self: Product =>
110-
111-
override def dataType: DataType = LongType
112-
override def expectedChildTypes: Seq[DataType] = Seq(LongType)
113-
114-
override def eval(input: Row): Any = {
115-
val evalE = child.eval(input)
116-
if (evalE == null) null else f(evalE.asInstanceOf[Long])
117-
}
118-
}
119-
120-
case class Sin(child: Expression) extends MathematicalExpressionForDouble(math.sin, "SIN")
121-
122-
case class Asin(child: Expression) extends MathematicalExpressionForDouble(math.asin, "ASIN")
123-
124-
case class Sinh(child: Expression) extends MathematicalExpressionForDouble(math.sinh, "SINH")
125-
126-
case class Cos(child: Expression) extends MathematicalExpressionForDouble(math.cos, "COS")
60+
case class Cos(child: Expression) extends MathematicalExpression(math.cos, "COS")
12761

128-
case class Acos(child: Expression) extends MathematicalExpressionForDouble(math.acos, "ACOS")
62+
case class Cosh(child: Expression) extends MathematicalExpression(math.cosh, "COSH")
12963

130-
case class Cosh(child: Expression) extends MathematicalExpressionForDouble(math.cosh, "COSH")
64+
case class Exp(child: Expression) extends MathematicalExpression(math.exp, "EXP")
13165

132-
case class Tan(child: Expression) extends MathematicalExpressionForDouble(math.tan, "TAN")
66+
case class Expm1(child: Expression) extends MathematicalExpression(math.expm1, "EXPM1")
13367

134-
case class Atan(child: Expression) extends MathematicalExpressionForDouble(math.atan, "ATAN")
68+
case class Floor(child: Expression) extends MathematicalExpression(math.floor, "FLOOR")
13569

136-
case class Tanh(child: Expression) extends MathematicalExpressionForDouble(math.tanh, "TANH")
70+
case class Log(child: Expression) extends MathematicalExpression(math.log, "LOG")
13771

138-
case class Ceil(child: Expression) extends MathematicalExpressionForDouble(math.ceil, "CEIL")
72+
case class Log10(child: Expression) extends MathematicalExpression(math.log10, "LOG10")
13973

140-
case class Floor(child: Expression) extends MathematicalExpressionForDouble(math.floor, "FLOOR")
74+
case class Log1p(child: Expression) extends MathematicalExpression(math.log1p, "LOG1P")
14175

142-
case class Rint(child: Expression) extends MathematicalExpressionForDouble(math.rint, "ROUND")
76+
case class Rint(child: Expression) extends MathematicalExpression(math.rint, "ROUND")
14377

144-
case class Cbrt(child: Expression) extends MathematicalExpressionForDouble(math.cbrt, "CBRT")
78+
case class Signum(child: Expression) extends MathematicalExpression(math.signum, "SIGNUM")
14579

146-
case class Signum(child: Expression) extends MathematicalExpressionForDouble(math.signum, "SIGNUM")
80+
case class Sin(child: Expression) extends MathematicalExpression(math.sin, "SIN")
14781

148-
case class ISignum(child: Expression) extends MathematicalExpressionForInt(math.signum, "ISIGNUM")
82+
case class Sinh(child: Expression) extends MathematicalExpression(math.sinh, "SINH")
14983

150-
case class FSignum(child: Expression) extends MathematicalExpressionForFloat(math.signum, "FSIGNUM")
84+
case class Tan(child: Expression) extends MathematicalExpression(math.tan, "TAN")
15185

152-
case class LSignum(child: Expression) extends MathematicalExpressionForLong(math.signum, "LSIGNUM")
86+
case class Tanh(child: Expression) extends MathematicalExpression(math.tanh, "TANH")
15387

15488
case class ToDegrees(child: Expression)
155-
extends MathematicalExpressionForDouble(math.toDegrees, "DEGREES")
89+
extends MathematicalExpression(math.toDegrees, "DEGREES")
15690

15791
case class ToRadians(child: Expression)
158-
extends MathematicalExpressionForDouble(math.toRadians, "RADIANS")
159-
160-
case class Log(child: Expression) extends MathematicalExpressionForDouble(math.log, "LOG")
161-
162-
case class Log10(child: Expression) extends MathematicalExpressionForDouble(math.log10, "LOG10")
163-
164-
case class Log1p(child: Expression) extends MathematicalExpressionForDouble(math.log1p, "LOG1P")
165-
166-
case class Exp(child: Expression) extends MathematicalExpressionForDouble(math.exp, "EXP")
167-
168-
case class Expm1(child: Expression) extends MathematicalExpressionForDouble(math.expm1, "EXPM1")
92+
extends MathematicalExpression(math.toRadians, "RADIANS")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,18 +1253,6 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
12531253
unaryMathFunctionEvaluation[Double](Signum, math.signum)
12541254
}
12551255

1256-
test("isignum") {
1257-
unaryMathFunctionEvaluation[Int](ISignum, math.signum, (-5 to 5))
1258-
}
1259-
1260-
test("fsignum") {
1261-
unaryMathFunctionEvaluation[Float](FSignum, math.signum, (-5 to 5).map(_.toFloat))
1262-
}
1263-
1264-
test("lsignum") {
1265-
unaryMathFunctionEvaluation[Long](LSignum, math.signum, (5 to 5).map(_.toLong))
1266-
}
1267-
12681256
test("log") {
12691257
unaryMathFunctionEvaluation(Log, math.log, (0 to 20).map(_ * 0.1))
12701258
unaryMathFunctionEvaluation(Log, math.log, (-5 to -1).map(_ * 0.1), true)

0 commit comments

Comments
 (0)