Skip to content

Commit b5347a4

Browse files
brkyvzrxin
authored andcommitted
[SPARK-7248] implemented random number generators for DataFrames
Adds the functions `rand` (Uniform Dist) and `randn` (Normal Dist.) as expressions to DataFrames. cc mengxr rxin Author: Burak Yavuz <[email protected]> Closes apache#5819 from brkyvz/df-rng and squashes the following commits: 50d69d4 [Burak Yavuz] add seed for test that failed 4234c3a [Burak Yavuz] fix Rand expression 13cad5c [Burak Yavuz] couple fixes 7d53953 [Burak Yavuz] waiting for hive tests b453716 [Burak Yavuz] move radn with seed down 03637f0 [Burak Yavuz] fix broken hive func c5909eb [Burak Yavuz] deleted old implementation of Rand 6d43895 [Burak Yavuz] implemented random generators
1 parent 69a739c commit b5347a4

File tree

10 files changed

+149
-46
lines changed

10 files changed

+149
-46
lines changed

python/pyspark/sql/functions.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,37 @@ def _(col):
6767
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
6868
}
6969

70-
7170
for _name, _doc in _functions.items():
7271
globals()[_name] = _create_function(_name, _doc)
7372
del _name, _doc
7473
__all__ += _functions.keys()
7574
__all__.sort()
7675

7776

77+
def rand(seed=None):
78+
"""
79+
Generate a random column with i.i.d. samples from U[0.0, 1.0].
80+
"""
81+
sc = SparkContext._active_spark_context
82+
if seed:
83+
jc = sc._jvm.functions.rand(seed)
84+
else:
85+
jc = sc._jvm.functions.rand()
86+
return Column(jc)
87+
88+
89+
def randn(seed=None):
90+
"""
91+
Generate a column with i.i.d. samples from the standard normal distribution.
92+
"""
93+
sc = SparkContext._active_spark_context
94+
if seed:
95+
jc = sc._jvm.functions.randn(seed)
96+
else:
97+
jc = sc._jvm.functions.randn()
98+
return Column(jc)
99+
100+
78101
def approxCountDistinct(col, rsd=None):
79102
"""Returns a new :class:`Column` for approximate distinct count of ``col``.
80103

python/pyspark/sql/tests.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,16 @@ def assert_close(a, b):
416416
assert_close([math.hypot(i, 2 * i) for i in range(10)],
417417
df.select(functions.hypot(df.a, df.b)).collect())
418418

419+
def test_rand_functions(self):
420+
df = self.df
421+
from pyspark.sql import functions
422+
rnd = df.select('key', functions.rand()).collect()
423+
for row in rnd:
424+
assert row[1] >= 0.0 and row[1] <= 1.0, "got: %s" % row[1]
425+
rndn = df.select('key', functions.randn(5)).collect()
426+
for row in rndn:
427+
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
428+
419429
def test_save_and_load(self):
420430
df = self.df
421431
tmpPath = tempfile.mkdtemp()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala

Lines changed: 0 additions & 36 deletions
This file was deleted.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.expressions
19+
20+
import org.apache.spark.TaskContext
21+
import org.apache.spark.sql.types.{DataType, DoubleType}
22+
import org.apache.spark.util.Utils
23+
import org.apache.spark.util.random.XORShiftRandom
24+
25+
/**
26+
* A Random distribution generating expression.
27+
* TODO: This can be made generic to generate any type of random distribution, or any type of
28+
* StructType.
29+
*
30+
* Since this expression is stateful, it cannot be a case object.
31+
*/
32+
abstract class RDG(seed: Long) extends LeafExpression with Serializable {
33+
self: Product =>
34+
35+
/**
36+
* Record ID within each partition. By being transient, the Random Number Generator is
37+
* reset every time we serialize and deserialize it.
38+
*/
39+
@transient protected lazy val rng = new XORShiftRandom(seed + TaskContext.get().partitionId())
40+
41+
override type EvaluatedType = Double
42+
43+
override def nullable: Boolean = false
44+
45+
override def dataType: DataType = DoubleType
46+
}
47+
48+
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
49+
case class Rand(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
50+
override def eval(input: Row): Double = rng.nextDouble()
51+
}
52+
53+
/** Generate a random column with i.i.d. gaussian random distribution. */
54+
case class Randn(seed: Long = Utils.random.nextLong()) extends RDG(seed) {
55+
override def eval(input: Row): Double = rng.nextGaussian()
56+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,15 +160,15 @@ class ConstantFoldingSuite extends PlanTest {
160160
val originalQuery =
161161
testRelation
162162
.select(
163-
Rand + Literal(1) as Symbol("c1"),
163+
Rand(5L) + Literal(1) as Symbol("c1"),
164164
Sum('a) as Symbol("c2"))
165165

166166
val optimized = Optimize.execute(originalQuery.analyze)
167167

168168
val correctAnswer =
169169
testRelation
170170
.select(
171-
Rand + Literal(1.0) as Symbol("c1"),
171+
Rand(5L) + Literal(1.0) as Symbol("c1"),
172172
Sum('a) as Symbol("c2"))
173173
.analyze
174174

sql/core/src/main/scala/org/apache/spark/sql/functions.scala

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
2525
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star}
2626
import org.apache.spark.sql.catalyst.expressions._
2727
import org.apache.spark.sql.types._
28-
28+
import org.apache.spark.util.Utils
2929

3030
/**
3131
* :: Experimental ::
@@ -346,6 +346,34 @@ object functions {
346346
*/
347347
def not(e: Column): Column = !e
348348

349+
/**
350+
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
351+
*
352+
* @group normal_funcs
353+
*/
354+
def rand(seed: Long): Column = Rand(seed)
355+
356+
/**
357+
* Generate a random column with i.i.d. samples from U[0.0, 1.0].
358+
*
359+
* @group normal_funcs
360+
*/
361+
def rand(): Column = rand(Utils.random.nextLong)
362+
363+
/**
364+
* Generate a column with i.i.d. samples from the standard normal distribution.
365+
*
366+
* @group normal_funcs
367+
*/
368+
def randn(seed: Long): Column = Randn(seed)
369+
370+
/**
371+
* Generate a column with i.i.d. samples from the standard normal distribution.
372+
*
373+
* @group normal_funcs
374+
*/
375+
def randn(): Column = randn(Utils.random.nextLong)
376+
349377
/**
350378
* Partition ID of the Spark task.
351379
*

sql/core/src/main/scala/org/apache/spark/sql/mathfunctions.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ import org.apache.spark.sql.functions.lit
2727
/**
2828
* :: Experimental ::
2929
* Mathematical Functions available for [[DataFrame]].
30-
*
31-
* @groupname double_funcs Functions that require DoubleType as an input
3230
*/
3331
@Experimental
3432
// scalastyle:off

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ public void testVarargMethods() {
104104
df2.select(pow("a", "a"), pow("b", 2.0));
105105
df2.select(pow(col("a"), col("b")), exp("b"));
106106
df2.select(sin("a"), acos("b"));
107+
108+
df2.select(rand(), acos("b"));
109+
df2.select(col("*"), randn(5L));
107110
}
108111

109112
@Ignore

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql
1919

20+
import org.scalatest.Matchers._
21+
2022
import org.apache.spark.sql.functions._
2123
import org.apache.spark.sql.test.TestSQLContext
2224
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -349,4 +351,24 @@ class ColumnExpressionSuite extends QueryTest {
349351
assert(schema("value").metadata === Metadata.empty)
350352
assert(schema("abc").metadata === metadata)
351353
}
354+
355+
test("rand") {
356+
val randCol = testData.select('key, rand(5L).as("rand"))
357+
randCol.columns.length should be (2)
358+
val rows = randCol.collect()
359+
rows.foreach { row =>
360+
assert(row.getDouble(1) <= 1.0)
361+
assert(row.getDouble(1) >= 0.0)
362+
}
363+
}
364+
365+
test("randn") {
366+
val randCol = testData.select('key, randn(5L).as("rand"))
367+
randCol.columns.length should be (2)
368+
val rows = randCol.collect()
369+
rows.foreach { row =>
370+
assert(row.getDouble(1) <= 4.0)
371+
assert(row.getDouble(1) >= -4.0)
372+
}
373+
}
352374
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@ package org.apache.spark.sql.hive
1919

2020
import java.sql.Date
2121

22-
23-
import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
24-
2522
import scala.collection.mutable.ArrayBuffer
2623

2724
import org.apache.hadoop.hive.conf.HiveConf
2825
import org.apache.hadoop.hive.ql.Context
26+
import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
2927
import org.apache.hadoop.hive.ql.lib.Node
3028
import org.apache.hadoop.hive.ql.metadata.Table
3129
import org.apache.hadoop.hive.ql.parse._
@@ -1244,7 +1242,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
12441242
/* Other functions */
12451243
case Token("TOK_FUNCTION", Token(ARRAY(), Nil) :: children) =>
12461244
CreateArray(children.map(nodeToExpr))
1247-
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand
1245+
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand()
1246+
case Token("TOK_FUNCTION", Token(RAND(), Nil) :: seed :: Nil) => Rand(seed.toString.toLong)
12481247
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) =>
12491248
Substring(nodeToExpr(string), nodeToExpr(pos), Literal.create(Integer.MAX_VALUE, IntegerType))
12501249
case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) =>

0 commit comments

Comments
 (0)