Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -927,12 +927,17 @@ class Analyzer(
// from LogicalPlan, currently we only do it for UnaryNode which has same output
// schema with its child.
case p: UnaryNode if p.output == p.child.output && p.expressions.exists(!_.deterministic) =>
val nondeterministicExprs = p.expressions.filterNot(_.deterministic).map { e =>
val ne = e match {
case n: NamedExpression => n
case _ => Alias(e, "_nondeterministic")()
val nondeterministicExprs = p.expressions.filterNot(_.deterministic).flatMap { expr =>
val leafNondeterministic = expr.collect {
case n: Nondeterministic => n
}
leafNondeterministic.map { e =>
val ne = e match {
case n: NamedExpression => n
case _ => Alias(e, "_nondeterministic")()
}
new TreeNodeRef(e) -> ne
}
new TreeNodeRef(e) -> ne
}.toMap
val newPlan = p.transformExpressions { case e =>
nondeterministicExprs.get(new TreeNodeRef(e)).map(_.toAttribute).getOrElse(e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,9 @@ trait Nondeterministic extends Expression {

private[this] var initialized = false

final def initialize(): Unit = {
if (!initialized) {
initInternal()
initialized = true
}
final def setInitialValues(): Unit = {
initInternal()
initialized = true
}

protected def initInternal(): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection {
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize()
case n: Nondeterministic => n.setInitialValues()
case _ =>
})

Expand Down Expand Up @@ -63,7 +63,7 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

expressions.foreach(_.foreach {
case n: Nondeterministic => n.initialize()
case n: Nondeterministic => n.setInitialValues()
case _ =>
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ object InterpretedPredicate {

def create(expression: Expression): (InternalRow => Boolean) = {
expression.foreach {
case n: Nondeterministic => n.initialize()
case n: Nondeterministic => n.setInitialValues()
case _ =>
}
(r: InternalRow) => expression.eval(r).asInstanceOf[Boolean]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class AnalysisSuite extends AnalysisTest {
assert(pl(4).dataType == DoubleType)
}

test("pull out nondeterministic expressions from unary LogicalPlan") {
test("pull out nondeterministic expressions from RepartitionByExpression") {
val plan = RepartitionByExpression(Seq(Rand(33)), testRelation)
val projected = Alias(Rand(33), "_nondeterministic")()
val expected =
Expand All @@ -162,4 +162,14 @@ class AnalysisSuite extends AnalysisTest {
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}

test("pull out nondeterministic expressions from Sort") {
val plan = Sort(Seq(SortOrder(Rand(33), Ascending)), false, testRelation)
val projected = Alias(Rand(33), "_nondeterministic")()
val expected =
Project(testRelation.output,
Sort(Seq(SortOrder(projected.toAttribute, Ascending)), false,
Project(testRelation.output :+ projected, testRelation)))
checkAnalysis(plan, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ trait ExpressionEvalHelper {

protected def evaluate(expression: Expression, inputRow: InternalRow = EmptyRow): Any = {
expression.foreach {
case n: Nondeterministic => n.initialize()
case n: Nondeterministic => n.setInitialValues()
case _ =>
}
expression.eval(inputRow)
Expand Down
153 changes: 92 additions & 61 deletions sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,28 @@ import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils}
class DataFrameSuite extends QueryTest with SQLTestUtils {
import org.apache.spark.sql.TestData._

lazy val ctx = org.apache.spark.sql.test.TestSQLContext
import ctx.implicits._

def sqlContext: SQLContext = ctx
lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
import sqlContext.implicits._

test("analysis error should be eagerly reported") {
val oldSetting = ctx.conf.dataFrameEagerAnalysis
// Eager analysis.
ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true)

intercept[Exception] { testData.select('nonExistentName) }
intercept[Exception] {
testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
}
intercept[Exception] {
testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
}
intercept[Exception] {
testData.groupBy($"abcd").agg(Map("key" -> "sum"))
withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") {
intercept[Exception] { testData.select('nonExistentName) }
intercept[Exception] {
testData.groupBy('key).agg(Map("nonExistentName" -> "sum"))
}
intercept[Exception] {
testData.groupBy("nonExistentName").agg(Map("key" -> "sum"))
}
intercept[Exception] {
testData.groupBy($"abcd").agg(Map("key" -> "sum"))
}
}

// No more eager analysis once the flag is turned off
ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, false)
testData.select('nonExistentName)

// Set the flag back to original value before this test.
ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting)
withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "false") {
testData.select('nonExistentName)
}
}

test("dataframe toString") {
Expand All @@ -74,21 +69,18 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}

test("invalid plan toString, debug mode") {
val oldSetting = ctx.conf.dataFrameEagerAnalysis
ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, true)

// Turn on debug mode so we can see invalid query plans.
import org.apache.spark.sql.execution.debug._
ctx.debug()

val badPlan = testData.select('badColumn)
withSQLConf(SQLConf.DATAFRAME_EAGER_ANALYSIS.key -> "true") {
sqlContext.debug()

assert(badPlan.toString contains badPlan.queryExecution.toString,
"toString on bad query plans should include the query execution but was:\n" +
badPlan.toString)
val badPlan = testData.select('badColumn)

// Set the flag back to original value before this test.
ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting)
assert(badPlan.toString contains badPlan.queryExecution.toString,
"toString on bad query plans should include the query execution but was:\n" +
badPlan.toString)
}
}

test("access complex data") {
Expand All @@ -104,8 +96,8 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}

test("empty data frame") {
assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(ctx.emptyDataFrame.count() === 0)
assert(sqlContext.emptyDataFrame.columns.toSeq === Seq.empty[String])
assert(sqlContext.emptyDataFrame.count() === 0)
}

test("head and take") {
Expand Down Expand Up @@ -341,7 +333,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}

test("replace column using withColumn") {
val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df2 = sqlContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x")
val df3 = df2.withColumn("x", df2("x") + 1)
checkAnswer(
df3.select("x"),
Expand Down Expand Up @@ -422,7 +414,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {

test("randomSplit") {
val n = 600
val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
val data = sqlContext.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
Expand Down Expand Up @@ -499,7 +491,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {

test("showString: truncate = [true, false]") {
val longString = Array.fill(21)("1").mkString
val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF()
val df = sqlContext.sparkContext.parallelize(Seq("1", longString)).toDF()
val expectedAnswerForFalse = """+---------------------+
||_1 |
|+---------------------+
Expand Down Expand Up @@ -589,21 +581,17 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}

test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") {
val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val rowRDD = sqlContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0))))
val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false)))
val df = ctx.createDataFrame(rowRDD, schema)
val df = sqlContext.createDataFrame(rowRDD, schema)
df.rdd.collect()
}

test("SPARK-6899") {
val originalValue = ctx.conf.codegenEnabled
ctx.setConf(SQLConf.CODEGEN_ENABLED, true)
try{
test("SPARK-6899: type should match when using codegen") {
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
checkAnswer(
decimalData.agg(avg('a)),
Row(new java.math.BigDecimal(2.0)))
} finally {
ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue)
}
}

Expand All @@ -615,14 +603,14 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}

test("SPARK-7551: support backticks for DataFrame attribute resolution") {
val df = ctx.read.json(ctx.sparkContext.makeRDD(
val df = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df.select(df("`a.b`.c.`d..e`.`f`")),
Row(1)
)

val df2 = ctx.read.json(ctx.sparkContext.makeRDD(
val df2 = sqlContext.read.json(sqlContext.sparkContext.makeRDD(
"""{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil))
checkAnswer(
df2.select(df2("`a b`.c.d e.f")),
Expand All @@ -642,7 +630,7 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
}

test("SPARK-7324 dropDuplicates") {
val testData = ctx.sparkContext.parallelize(
val testData = sqlContext.sparkContext.parallelize(
(2, 1, 2) :: (1, 1, 1) ::
(1, 2, 1) :: (2, 1, 2) ::
(2, 2, 2) :: (2, 2, 1) ::
Expand Down Expand Up @@ -690,49 +678,49 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {

test("SPARK-7150 range api") {
// numSlice is greater than length
val res1 = ctx.range(0, 10, 1, 15).select("id")
val res1 = sqlContext.range(0, 10, 1, 15).select("id")
assert(res1.count == 10)
assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))

val res2 = ctx.range(3, 15, 3, 2).select("id")
val res2 = sqlContext.range(3, 15, 3, 2).select("id")
assert(res2.count == 4)
assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30)))

val res3 = ctx.range(1, -2).select("id")
val res3 = sqlContext.range(1, -2).select("id")
assert(res3.count == 0)

// start is positive, end is negative, step is negative
val res4 = ctx.range(1, -2, -2, 6).select("id")
val res4 = sqlContext.range(1, -2, -2, 6).select("id")
assert(res4.count == 2)
assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0)))

// start, end, step are negative
val res5 = ctx.range(-3, -8, -2, 1).select("id")
val res5 = sqlContext.range(-3, -8, -2, 1).select("id")
assert(res5.count == 3)
assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15)))

// start, end are negative, step is positive
val res6 = ctx.range(-8, -4, 2, 1).select("id")
val res6 = sqlContext.range(-8, -4, 2, 1).select("id")
assert(res6.count == 2)
assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14)))

val res7 = ctx.range(-10, -9, -20, 1).select("id")
val res7 = sqlContext.range(-10, -9, -20, 1).select("id")
assert(res7.count == 0)

val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
val res8 = sqlContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id")
assert(res8.count == 3)
assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3)))

val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
val res9 = sqlContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id")
assert(res9.count == 2)
assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1)))

// only end provided as argument
val res10 = ctx.range(10).select("id")
val res10 = sqlContext.range(10).select("id")
assert(res10.count == 10)
assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45)))

val res11 = ctx.range(-1).select("id")
val res11 = sqlContext.range(-1).select("id")
assert(res11.count == 0)
}

Expand Down Expand Up @@ -799,13 +787,13 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {

// pass case: parquet table (HadoopFsRelation)
df.write.mode(SaveMode.Overwrite).parquet(tempParquetFile.getCanonicalPath)
val pdf = ctx.read.parquet(tempParquetFile.getCanonicalPath)
val pdf = sqlContext.read.parquet(tempParquetFile.getCanonicalPath)
pdf.registerTempTable("parquet_base")
insertion.write.insertInto("parquet_base")

// pass case: json table (InsertableRelation)
df.write.mode(SaveMode.Overwrite).json(tempJsonFile.getCanonicalPath)
val jdf = ctx.read.json(tempJsonFile.getCanonicalPath)
val jdf = sqlContext.read.json(tempJsonFile.getCanonicalPath)
jdf.registerTempTable("json_base")
insertion.write.mode(SaveMode.Overwrite).insertInto("json_base")

Expand All @@ -825,11 +813,54 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
assert(e2.getMessage.contains("Inserting into an RDD-based table is not allowed."))

// error case: insert into an OneRowRelation
new DataFrame(ctx, OneRowRelation).registerTempTable("one_row")
new DataFrame(sqlContext, OneRowRelation).registerTempTable("one_row")
val e3 = intercept[AnalysisException] {
insertion.write.insertInto("one_row")
}
assert(e3.getMessage.contains("Inserting into an RDD-based table is not allowed."))
}
}

test("SPARK-8608: call `show` on local DataFrame with random columns should return same value") {
// Make sure we can pass this test for both codegen mode and interpreted mode.
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
val df = testData.select(rand(33))
assert(df.showString(5) == df.showString(5))
}

withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
val df = testData.select(rand(33))
assert(df.showString(5) == df.showString(5))
}

// We will reuse the same Expression object for LocalRelation.
val df = (1 to 10).map(Tuple1.apply).toDF().select(rand(33))
assert(df.showString(5) == df.showString(5))
}

test("SPARK-8609: local DataFrame with random columns should return same value after sort") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case should be fixed by #7593. However, PullOutNondeterministic rule can also handle it, so we can add this test before #7593 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to make it more clear, now this PR is aimed to bug fix, i.e. fix nondeterministic expression handling for local projection and sort. Leave #7593 to optimization, i.e. remove still-need-evaluate expressions from Sort.

// Make sure we can pass this test for both codegen mode and interpreted mode.
withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "true") {
checkAnswer(testData.sort(rand(33)), testData.sort(rand(33)))
}

withSQLConf(SQLConf.CODEGEN_ENABLED.key -> "false") {
checkAnswer(testData.sort(rand(33)), testData.sort(rand(33)))
}

// We will reuse the same Expression object for LocalRelation.
val df = (1 to 10).map(Tuple1.apply).toDF()
checkAnswer(df.sort(rand(33)), df.sort(rand(33)))
}

test("SPARK-9083: sort with non-deterministic expressions") {
import org.apache.spark.util.random.XORShiftRandom

val seed = 33
val df = (1 to 100).map(Tuple1.apply).toDF("i")
val random = new XORShiftRandom(seed)
val expected = (1 to 100).map(_ -> random.nextDouble()).sortBy(_._2).map(_._1)
val actual = df.sort(rand(seed)).collect().map(_.getInt(0))
assert(expected === actual)
}
}