From 6f2b909e425aa8b00e386ca252f6830dc5a38e41 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 18:03:07 -0700 Subject: [PATCH 01/67] Fix SPARK-9292. --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 5 +++++ .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c203fcecf20fb..c23ab3c74338d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -83,6 +83,11 @@ trait CheckAnalysis { s"filter expression '${f.condition.prettyString}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case j @ Join(_, _, _, Some(condition)) if condition.dataType != BooleanType => + failAnalysis( + s"join condition '${condition.prettyString}' " + + s"of type ${condition.dataType.simpleString} is not a boolean.") + case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index dca8c881f21ab..7bf678ebf71ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -118,6 +118,11 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { testRelation.where(Literal(1)), "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + errorTest( + "non-boolean join conditions", + testRelation.join(testRelation, condition = Some(Literal(1))), + "condition" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil) + errorTest( "missing group by", testRelation2.groupBy('a)('b), From 03120d5b48e94e164ea4e8182c6acc0d08eb204e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 18:10:15 -0700 Subject: [PATCH 02/67] Check condition type in resolved() --- .../spark/sql/catalyst/plans/logical/basicOperators.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 6aefa9f67556a..57a12820fa4c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -128,7 +128,10 @@ case class Join( // Joins are only resolved if they don't introduce ambiguous expression ids. override lazy val resolved: Boolean = { - childrenResolved && expressions.forall(_.resolved) && selfJoinResolved + childrenResolved && + expressions.forall(_.resolved) && + selfJoinResolved && + condition.forall(_.dataType == BooleanType) } } From e1f462ef3abec729ad8a533e98a5465c5ccb57b4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 10 Jul 2015 19:14:43 -0700 Subject: [PATCH 03/67] Initial commit for SQL expression fuzzing harness --- sql/core/pom.xml | 6 ++ .../spark/sql/ExpressionFuzzingSuite.scala | 59 +++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala diff --git a/sql/core/pom.xml b/sql/core/pom.xml index be0966641b5c4..7c8155bafc88a 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -36,6 +36,12 @@ + + org.clapper + classutil_${scala.binary.version} + 1.0.4 + test + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala new file mode 100644 index 0000000000000..90f6aa3c85068 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.File + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.NullType +import org.clapper.classutil.ClassFinder + +class ExpressionFuzzingSuite extends SparkFunSuite { + lazy val expressionClasses: Seq[Class[_]] = { + val finder = ClassFinder( + System.getProperty("java.class.path").split(':').map(new File(_)) .filter(_.exists)) + val classes = finder.getClasses().toIterator + ClassFinder.concreteSubclasses(classOf[Expression].getName, classes) + .map(c => Class.forName(c.name)).toSeq + } + + expressionClasses.foreach(println) + for (c <- expressionClasses) { + val singleExprConstructor = c.getConstructors.filter { c => + c.getParameterTypes.toSeq == Seq(classOf[Expression]) + } + singleExprConstructor.foreach { cons => + try { + val expr: Expression = cons.newInstance(Literal(null)).asInstanceOf[Expression] + val row: InternalRow = new GenericInternalRow(Array[Any](null)) + val gened = GenerateMutableProjection.generate(Seq(expr), Seq(AttributeReference("f", NullType)())) + gened() +// expr.eval(row) + println(s"Passed $c") + } catch { + case e: Exception => + println(s"Got exception $e while testing $c") + e.printStackTrace(System.out) + } + + } + } +} From f8daec768cd08a05a8b6564ac74d2fe8ced4b498 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 10 Jul 2015 20:03:46 -0700 Subject: [PATCH 04/67] Apply implicit casts (in a hacky way for now) --- .../spark/sql/ExpressionFuzzingSuite.scala | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index 90f6aa3c85068..db8f3bc5651e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -21,11 +21,31 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.NullType import org.clapper.classutil.ClassFinder +case object DummyAnalyzer extends RuleExecutor[LogicalPlan] { + override protected val batches: Seq[Batch] = Seq( + Batch("analysis", FixedPoint(100), HiveTypeCoercion.typeCoercionRules: _*) + ) +} + + +case class DummyPlan(expr: Expression) extends LogicalPlan { + override def output: Seq[Attribute] = Seq.empty + + /** Returns all of the expressions present in this query plan operator. */ + override def expressions: Seq[Expression] = Seq(expr) + + override def children: Seq[LogicalPlan] = Seq.empty +} + + class ExpressionFuzzingSuite extends SparkFunSuite { lazy val expressionClasses: Seq[Class[_]] = { val finder = ClassFinder( @@ -36,24 +56,27 @@ class ExpressionFuzzingSuite extends SparkFunSuite { } expressionClasses.foreach(println) + for (c <- expressionClasses) { val singleExprConstructor = c.getConstructors.filter { c => c.getParameterTypes.toSeq == Seq(classOf[Expression]) } singleExprConstructor.foreach { cons => - try { + test(s"${c.getName}") { val expr: Expression = cons.newInstance(Literal(null)).asInstanceOf[Expression] - val row: InternalRow = new GenericInternalRow(Array[Any](null)) - val gened = GenerateMutableProjection.generate(Seq(expr), Seq(AttributeReference("f", NullType)())) - gened() -// expr.eval(row) - println(s"Passed $c") - } catch { - case e: Exception => - println(s"Got exception $e while testing $c") - e.printStackTrace(System.out) + val coercedExpr: Expression = { + val dummyPlan = DummyPlan(expr) + DummyAnalyzer.execute(dummyPlan).asInstanceOf[DummyPlan].expr + } + if (expr.checkInputDataTypes().isSuccess) { + val row: InternalRow = new GenericInternalRow(Array[Any](null)) + val gened = GenerateMutableProjection.generate(Seq(coercedExpr), Seq(AttributeReference("f", NullType)())) + gened() + // expr.eval(row) + } else { + println(s"Input types check failed for $c") + } } - } } } From df00e7a2474a1deb0bf7fb5a6a4719d1026ffbd5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 11 Jul 2015 15:40:21 -0700 Subject: [PATCH 05/67] More messy WIP prototyping on expression fuzzing --- .../spark/sql/ExpressionFuzzingSuite.scala | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index db8f3bc5651e6..d4c897c82559a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -22,7 +22,7 @@ import java.io.File import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -58,24 +58,28 @@ class ExpressionFuzzingSuite extends SparkFunSuite { expressionClasses.foreach(println) for (c <- expressionClasses) { - val singleExprConstructor = c.getConstructors.filter { c => - c.getParameterTypes.toSeq == Seq(classOf[Expression]) - } - singleExprConstructor.foreach { cons => + val exprOnlyConstructor = c.getConstructors.filter { c => + c.getParameterTypes.toSet == Set(classOf[Expression]) + }.sortBy(_.getParameterTypes.length * -1).headOption + exprOnlyConstructor.foreach { cons => + val numChildren = cons.getParameterTypes.length test(s"${c.getName}") { - val expr: Expression = cons.newInstance(Literal(null)).asInstanceOf[Expression] + val expr: Expression = + cons.newInstance(Seq.fill(numChildren)(Literal.create(null, NullType)): _*).asInstanceOf[Expression] val coercedExpr: Expression = { val dummyPlan = DummyPlan(expr) DummyAnalyzer.execute(dummyPlan).asInstanceOf[DummyPlan].expr + expr } - if (expr.checkInputDataTypes().isSuccess) { - val row: InternalRow = new GenericInternalRow(Array[Any](null)) - val gened = GenerateMutableProjection.generate(Seq(coercedExpr), Seq(AttributeReference("f", NullType)())) - gened() - // expr.eval(row) - } else { - println(s"Input types check failed for $c") - } + println(s"Before coercion: ${expr.children.map(_.dataType)}") + println(s"After coercion: ${coercedExpr.children.map(_.dataType)}") + assume(coercedExpr.checkInputDataTypes().isSuccess, coercedExpr.checkInputDataTypes().toString) + val row: InternalRow = new GenericInternalRow(Array.fill[Any](numChildren)(null)) + val inputSchema = coercedExpr.children.map(c => AttributeReference("f", c.dataType)()) + val gened = GenerateProjection.generate(Seq(coercedExpr), inputSchema) + // TODO: mutable projections + //gened().apply(row) + coercedExpr.eval(row) } } } From 2dcbc108e4da5671d03d7eefe9c91521618a94f9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 11 Jul 2015 16:02:00 -0700 Subject: [PATCH 06/67] Add some comments; speed up classpath search --- .../spark/sql/ExpressionFuzzingSuite.scala | 69 +++++++++++-------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index d4c897c82559a..632fe59a3d885 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -19,45 +19,42 @@ package org.apache.spark.sql import java.io.File +import org.clapper.classutil.ClassFinder + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.NullType -import org.clapper.classutil.ClassFinder - -case object DummyAnalyzer extends RuleExecutor[LogicalPlan] { - override protected val batches: Seq[Batch] = Seq( - Batch("analysis", FixedPoint(100), HiveTypeCoercion.typeCoercionRules: _*) - ) -} - - -case class DummyPlan(expr: Expression) extends LogicalPlan { - override def output: Seq[Attribute] = Seq.empty - - /** Returns all of the expressions present in this query plan operator. */ - override def expressions: Seq[Expression] = Seq(expr) - - override def children: Seq[LogicalPlan] = Seq.empty -} - +/** + * This test suite implements fuzz tests for expression code generation. It uses reflection to + * automatically discover all [[Expression]]s, then instantiates these expressions with random + * children/inputs. If the resulting expression passes the type checker after type coerceion is + * performed then we attempt to compile the expression and compare its output to output generated + * by the interpreted expression. + */ class ExpressionFuzzingSuite extends SparkFunSuite { - lazy val expressionClasses: Seq[Class[_]] = { - val finder = ClassFinder( - System.getProperty("java.class.path").split(':').map(new File(_)) .filter(_.exists)) - val classes = finder.getClasses().toIterator - ClassFinder.concreteSubclasses(classOf[Expression].getName, classes) - .map(c => Class.forName(c.name)).toSeq - } - expressionClasses.foreach(println) + /** + * All subclasses of [[Expression]]. + */ + lazy val expressionSubclasses: Seq[Class[Expression]] = { + val classpathEntries: Seq[File] = System.getProperty("java.class.path") + .split(File.pathSeparatorChar) + .filter(_.contains("spark")) + .map(new File(_)) + .filter(_.exists()).toSeq + val allClasses = ClassFinder(classpathEntries).getClasses() + assert(allClasses.nonEmpty, "Could not find Spark classes on classpath.") + ClassFinder.concreteSubclasses(classOf[Expression].getName, allClasses) + .map(c => Class.forName(c.name).asInstanceOf[Class[Expression]]).toSeq + } - for (c <- expressionClasses) { + for (c <- expressionSubclasses) { val exprOnlyConstructor = c.getConstructors.filter { c => c.getParameterTypes.toSet == Set(classOf[Expression]) }.sortBy(_.getParameterTypes.length * -1).headOption @@ -84,3 +81,19 @@ class ExpressionFuzzingSuite extends SparkFunSuite { } } } + +case object DummyAnalyzer extends RuleExecutor[LogicalPlan] { + override protected val batches: Seq[Batch] = Seq( + Batch("analysis", FixedPoint(100), HiveTypeCoercion.typeCoercionRules: _*) + ) +} + + +case class DummyPlan(expr: Expression) extends LogicalPlan { + override def output: Seq[Attribute] = Seq.empty + + /** Returns all of the expressions present in this query plan operator. */ + override def expressions: Seq[Expression] = Seq(expr) + + override def children: Seq[LogicalPlan] = Seq.empty +} From c20a67997a26f066b9ef9e627927977215878f5c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 11 Jul 2015 16:07:37 -0700 Subject: [PATCH 07/67] Move dummy type coercion to a helper method --- .../spark/sql/ExpressionFuzzingSuite.scala | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index 632fe59a3d885..26cf2fbab2aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.NullType /** * This test suite implements fuzz tests for expression code generation. It uses reflection to * automatically discover all [[Expression]]s, then instantiates these expressions with random - * children/inputs. If the resulting expression passes the type checker after type coerceion is + * children/inputs. If the resulting expression passes the type checker after type coercion is * performed then we attempt to compile the expression and compare its output to output generated * by the interpreted expression. */ @@ -54,6 +54,11 @@ class ExpressionFuzzingSuite extends SparkFunSuite { .map(c => Class.forName(c.name).asInstanceOf[Class[Expression]]).toSeq } + def coerceTypes(expression: Expression): Expression = { + val dummyPlan: LogicalPlan = DummyPlan(expression) + DummyAnalyzer.execute(dummyPlan).asInstanceOf[DummyPlan].expression + } + for (c <- expressionSubclasses) { val exprOnlyConstructor = c.getConstructors.filter { c => c.getParameterTypes.toSet == Set(classOf[Expression]) @@ -63,11 +68,7 @@ class ExpressionFuzzingSuite extends SparkFunSuite { test(s"${c.getName}") { val expr: Expression = cons.newInstance(Seq.fill(numChildren)(Literal.create(null, NullType)): _*).asInstanceOf[Expression] - val coercedExpr: Expression = { - val dummyPlan = DummyPlan(expr) - DummyAnalyzer.execute(dummyPlan).asInstanceOf[DummyPlan].expr - expr - } + val coercedExpr: Expression = coerceTypes(expr) println(s"Before coercion: ${expr.children.map(_.dataType)}") println(s"After coercion: ${coercedExpr.children.map(_.dataType)}") assume(coercedExpr.checkInputDataTypes().isSuccess, coercedExpr.checkInputDataTypes().toString) @@ -82,18 +83,14 @@ class ExpressionFuzzingSuite extends SparkFunSuite { } } -case object DummyAnalyzer extends RuleExecutor[LogicalPlan] { +private case object DummyAnalyzer extends RuleExecutor[LogicalPlan] { override protected val batches: Seq[Batch] = Seq( Batch("analysis", FixedPoint(100), HiveTypeCoercion.typeCoercionRules: _*) ) } - -case class DummyPlan(expr: Expression) extends LogicalPlan { +private case class DummyPlan(expression: Expression) extends LogicalPlan { + override def expressions: Seq[Expression] = Seq(expression) override def output: Seq[Attribute] = Seq.empty - - /** Returns all of the expressions present in this query plan operator. */ - override def expressions: Seq[Expression] = Seq(expr) - override def children: Seq[LogicalPlan] = Seq.empty } From 95860dee6c2784869ae06a7acad1d4dc52eb7aec Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 11 Jul 2015 16:30:59 -0700 Subject: [PATCH 08/67] More code cleanup and comments --- .../spark/sql/ExpressionFuzzingSuite.scala | 70 +++++++++++++------ 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index 26cf2fbab2aa0..ed29427c72c8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.File +import java.lang.reflect.Constructor import org.clapper.classutil.ClassFinder @@ -59,26 +60,55 @@ class ExpressionFuzzingSuite extends SparkFunSuite { DummyAnalyzer.execute(dummyPlan).asInstanceOf[DummyPlan].expression } - for (c <- expressionSubclasses) { - val exprOnlyConstructor = c.getConstructors.filter { c => - c.getParameterTypes.toSet == Set(classOf[Expression]) - }.sortBy(_.getParameterTypes.length * -1).headOption - exprOnlyConstructor.foreach { cons => - val numChildren = cons.getParameterTypes.length - test(s"${c.getName}") { - val expr: Expression = - cons.newInstance(Seq.fill(numChildren)(Literal.create(null, NullType)): _*).asInstanceOf[Expression] - val coercedExpr: Expression = coerceTypes(expr) - println(s"Before coercion: ${expr.children.map(_.dataType)}") - println(s"After coercion: ${coercedExpr.children.map(_.dataType)}") - assume(coercedExpr.checkInputDataTypes().isSuccess, coercedExpr.checkInputDataTypes().toString) - val row: InternalRow = new GenericInternalRow(Array.fill[Any](numChildren)(null)) - val inputSchema = coercedExpr.children.map(c => AttributeReference("f", c.dataType)()) - val gened = GenerateProjection.generate(Seq(coercedExpr), inputSchema) - // TODO: mutable projections - //gened().apply(row) - coercedExpr.eval(row) - } + /** + * Given an expression class, find the constructor which accepts only expressions. If there are + * multiple such constructors, pick the one with the most parameters. + * @return The matching constructor, or None if no appropriate constructor could be found. + */ + def getBestConstructor(expressionClass: Class[Expression]): Option[Constructor[Expression]] = { + val allConstructors = expressionClass.getConstructors ++ expressionClass.getDeclaredConstructors + allConstructors + .map(_.asInstanceOf[Constructor[Expression]]) + .filter(_.getParameterTypes.toSet == Set(classOf[Expression])) + .sortBy(_.getParameterTypes.length * -1) + .headOption + } + + def testExpression(expressionClass: Class[Expression]): Unit = { + // Eventually, we should add support for testing multiple constructors. For now, though, we + // only test the "best" one: + val constructor: Constructor[Expression] = { + val maybeBestConstructor = getBestConstructor(expressionClass) + assume(maybeBestConstructor.isDefined, "Could not find an Expression-only constructor") + maybeBestConstructor.get + } + val numChildren: Int = constructor.getParameterTypes.length + // Eventually, we should test with multiple types of child expressions. For now, though, we + // construct null literals for all child expressions and leave it up to the type coercion rules + // to cast them to the appropriate types. + val expression: Expression = { + val childExpressions: Seq[Expression] = Seq.fill(numChildren)(Literal.create(null, NullType)) + coerceTypes(constructor.newInstance(childExpressions: _*)) + } + // Make sure that the resulting expression passes type checks. + val typecheckResult = expression.checkInputDataTypes() + assume(typecheckResult.isSuccess, s"Type checks failed: $typecheckResult") + // Attempt to generate code for this expression by using it to generate a projection. + val inputSchema = expression.children.map(c => AttributeReference("f", c.dataType)()) + val generatedProjection = GenerateProjection.generate(Seq(expression), inputSchema) + val interpretedProjection = InterpretedMutableProjection(Seq(expression)) + // Check that the answers agree for an input row consisting entirely of nulls, since the + // implicit type casts should make this safe + val inputRow = InternalRow.apply(Seq.fill(numChildren)(null)) + val generatedResult = generatedProjection.apply(inputRow) + val interpretedResult = interpretedProjection.apply(inputRow) + assert(generatedResult === interpretedResult) + } + + // Run the actual tests + expressionSubclasses.foreach { expressionClass => + test(s"${expressionClass.getName}") { + testExpression(expressionClass) } } } From abaed51744c7183f53629b28be5ed49ecdb28fff Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 11 Jul 2015 16:39:53 -0700 Subject: [PATCH 09/67] Use non-mutable interpreted projection. --- .../scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index ed29427c72c8c..4879b0847cebd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -96,7 +96,7 @@ class ExpressionFuzzingSuite extends SparkFunSuite { // Attempt to generate code for this expression by using it to generate a projection. val inputSchema = expression.children.map(c => AttributeReference("f", c.dataType)()) val generatedProjection = GenerateProjection.generate(Seq(expression), inputSchema) - val interpretedProjection = InterpretedMutableProjection(Seq(expression)) + val interpretedProjection = new InterpretedProjection(Seq(expression), inputSchema) // Check that the answers agree for an input row consisting entirely of nulls, since the // implicit type casts should make this safe val inputRow = InternalRow.apply(Seq.fill(numChildren)(null)) From 129ad6c0d3c3bb6682a78682e76b0133a0e41eff Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 11 Jul 2015 16:44:26 -0700 Subject: [PATCH 10/67] Log expression after coercion --- .../scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index 4879b0847cebd..40f2a30247a77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -22,7 +22,7 @@ import java.lang.reflect.Constructor import org.clapper.classutil.ClassFinder -import org.apache.spark.SparkFunSuite +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.NullType * performed then we attempt to compile the expression and compare its output to output generated * by the interpreted expression. */ -class ExpressionFuzzingSuite extends SparkFunSuite { +class ExpressionFuzzingSuite extends SparkFunSuite with Logging { /** * All subclasses of [[Expression]]. @@ -90,6 +90,7 @@ class ExpressionFuzzingSuite extends SparkFunSuite { val childExpressions: Seq[Expression] = Seq.fill(numChildren)(Literal.create(null, NullType)) coerceTypes(constructor.newInstance(childExpressions: _*)) } + logInfo(s"After type coercion, expression is $expression") // Make sure that the resulting expression passes type checks. val typecheckResult = expression.checkInputDataTypes() assume(typecheckResult.isSuccess, s"Type checks failed: $typecheckResult") From e1f91df228423d95f5ad6904819ce3d8bc60f09c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 12 Jul 2015 17:50:07 -0700 Subject: [PATCH 11/67] Run tests in deterministic order --- .../scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index 40f2a30247a77..4153f88b0266a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -107,7 +107,7 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { } // Run the actual tests - expressionSubclasses.foreach { expressionClass => + expressionSubclasses.sortBy(_.getName).foreach { expressionClass => test(s"${expressionClass.getName}") { testExpression(expressionClass) } From adc3c7f34c866f469d1f2d92e44568943539a784 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 13:31:06 -0700 Subject: [PATCH 12/67] Test with random inputs of all types --- .../spark/sql/ExpressionFuzzingSuite.scala | 78 +++++++++++++------ 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index 4153f88b0266a..ad94e6cfdd0e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql import java.io.File import java.lang.reflect.Constructor +import scala.util.{Try, Random} + import org.clapper.classutil.ClassFinder import org.apache.spark.{Logging, SparkFunSuite} @@ -29,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.NullType +import org.apache.spark.sql.types.{DecimalType, DataType, DataTypeTestUtils} /** * This test suite implements fuzz tests for expression code generation. It uses reflection to @@ -40,10 +42,12 @@ import org.apache.spark.sql.types.NullType */ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { + val NUM_TRIALS_PER_EXPRESSION: Int = 100 + /** - * All subclasses of [[Expression]]. + * All evaluable subclasses of [[Expression]]. */ - lazy val expressionSubclasses: Seq[Class[Expression]] = { + lazy val evaluableExpressionClasses: Seq[Class[Expression]] = { val classpathEntries: Seq[File] = System.getProperty("java.class.path") .split(File.pathSeparatorChar) .filter(_.contains("spark")) @@ -53,6 +57,14 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { assert(allClasses.nonEmpty, "Could not find Spark classes on classpath.") ClassFinder.concreteSubclasses(classOf[Expression].getName, allClasses) .map(c => Class.forName(c.name).asInstanceOf[Class[Expression]]).toSeq + // We should only test evalulable expressions: + .filterNot(c => classOf[Unevaluable].isAssignableFrom(c)) + // These expressions currently OOM because we try to pass in massive numeric literals: + .filterNot(_ == classOf[FormatNumber]) + .filterNot(_ == classOf[StringSpace]) + .filterNot(_ == classOf[StringLPad]) + .filterNot(_ == classOf[StringRPad]) + .filterNot(_ == classOf[Round]) } def coerceTypes(expression: Expression): Expression = { @@ -74,6 +86,17 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { .headOption } + def getRandomLiteral: Literal = { + // For now, filter out DecimalType since casts can lead to OOMs: + val allTypes = DataTypeTestUtils.atomicTypes.filterNot(_.isInstanceOf[DecimalType]) + val dataTypesWithGenerators: Map[DataType, () => Any] = allTypes.map { dt => + (dt, RandomDataGenerator.forType(dt, nullable = true, seed=None)) + }.filter(_._2.isDefined).toMap.mapValues(_.get) + val (dt, generator) = + dataTypesWithGenerators.toSeq(Random.nextInt(dataTypesWithGenerators.size)) + Literal.create(generator(), dt) + } + def testExpression(expressionClass: Class[Expression]): Unit = { // Eventually, we should add support for testing multiple constructors. For now, though, we // only test the "best" one: @@ -83,31 +106,38 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { maybeBestConstructor.get } val numChildren: Int = constructor.getParameterTypes.length - // Eventually, we should test with multiple types of child expressions. For now, though, we - // construct null literals for all child expressions and leave it up to the type coercion rules - // to cast them to the appropriate types. - val expression: Expression = { - val childExpressions: Seq[Expression] = Seq.fill(numChildren)(Literal.create(null, NullType)) - coerceTypes(constructor.newInstance(childExpressions: _*)) + // Construct random literals for all child expressions and leave it up to the type coercion + // rules to cast them to the appropriate types. Skip + for (_ <- 1 to NUM_TRIALS_PER_EXPRESSION) { + val expression: Expression = { + val childExpressions: Seq[Expression] = Seq.fill(numChildren)(getRandomLiteral) + coerceTypes(constructor.newInstance(childExpressions: _*)) + } + logInfo(s"After type coercion, expression is $expression") + // Make sure that the resulting expression passes type checks. + val typecheckResult = expression.checkInputDataTypes() + if (typecheckResult.isFailure) { + logDebug(s"Type checks failed: $typecheckResult") + } else { + withClue(s"$expression") { + val inputRow = InternalRow.apply() // Can be empty since we're only using literals + val inputSchema = expression.children.map(c => AttributeReference("f", c.dataType)()) + + val interpretedProjection = new InterpretedProjection(Seq(expression), inputSchema) + val interpretedResult = interpretedProjection.apply(inputRow) + + val maybeGenProjection = Try(GenerateProjection.generate(Seq(expression), inputSchema)) + maybeGenProjection.foreach { generatedProjection => + val generatedResult = generatedProjection.apply(inputRow) + assert(generatedResult === interpretedResult) + } + } + } } - logInfo(s"After type coercion, expression is $expression") - // Make sure that the resulting expression passes type checks. - val typecheckResult = expression.checkInputDataTypes() - assume(typecheckResult.isSuccess, s"Type checks failed: $typecheckResult") - // Attempt to generate code for this expression by using it to generate a projection. - val inputSchema = expression.children.map(c => AttributeReference("f", c.dataType)()) - val generatedProjection = GenerateProjection.generate(Seq(expression), inputSchema) - val interpretedProjection = new InterpretedProjection(Seq(expression), inputSchema) - // Check that the answers agree for an input row consisting entirely of nulls, since the - // implicit type casts should make this safe - val inputRow = InternalRow.apply(Seq.fill(numChildren)(null)) - val generatedResult = generatedProjection.apply(inputRow) - val interpretedResult = interpretedProjection.apply(inputRow) - assert(generatedResult === interpretedResult) } // Run the actual tests - expressionSubclasses.sortBy(_.getName).foreach { expressionClass => + evaluableExpressionClasses.sortBy(_.getName).foreach { expressionClass => test(s"${expressionClass.getName}") { testExpression(expressionClass) } From ae5e1510e08a40e78c7aa1ca6669f4ce414d2b70 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 15:28:48 -0700 Subject: [PATCH 13/67] Ignore BinaryType for now, since it led to some spurious failures. --- .../org/apache/spark/sql/ExpressionFuzzingSuite.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala index ad94e6cfdd0e9..530795fef2a38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{DecimalType, DataType, DataTypeTestUtils} +import org.apache.spark.sql.types.{BinaryType, DecimalType, DataType, DataTypeTestUtils} /** * This test suite implements fuzz tests for expression code generation. It uses reflection to @@ -87,8 +87,9 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { } def getRandomLiteral: Literal = { - // For now, filter out DecimalType since casts can lead to OOMs: - val allTypes = DataTypeTestUtils.atomicTypes.filterNot(_.isInstanceOf[DecimalType]) + val allTypes = DataTypeTestUtils.atomicTypes + .filterNot(_.isInstanceOf[DecimalType]) // casts can lead to OOM + .filterNot(_.isInstanceOf[BinaryType]) // leads to spurious errors in string reverse val dataTypesWithGenerators: Map[DataType, () => Any] = allTypes.map { dt => (dt, RandomDataGenerator.forType(dt, nullable = true, seed=None)) }.filter(_._2.isDefined).toMap.mapValues(_.get) From a35420840bd4bbd13a678ea24feca5129fa3796d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 17:07:23 -0700 Subject: [PATCH 14/67] Begin to add a DataFrame API fuzzer. --- .../spark/sql/DataFrameFuzzingSuite.scala | 111 ++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala new file mode 100644 index 0000000000000..d985cedb67d87 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala @@ -0,0 +1,111 @@ +package org.apache.spark.sql + +import java.lang.reflect.InvocationTargetException + +import org.apache.spark.sql.test.TestSQLContext + +import scala.reflect.runtime.{universe => ru} + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + +import scala.util.control.NonFatal + +/** + * This test suite generates random data frames, then applies random sequences of operations to + * them in order to construct random queries. We don't have a source of truth for these random + * queries but nevertheless they are still useful for testing that we don't crash in bad ways. + */ +class DataFrameFuzzingSuite extends SparkFunSuite { + + def randomChoice[T](values: Seq[T]): T = { + values(Random.nextInt(values.length)) + } + + val randomValueGenerators: Map[Class[_], () => Any] = Map( + classOf[String] -> (() => Random.nextString(10)) + ) + + def generateRandomDataFrame(): DataFrame = { + val allTypes = DataTypeTestUtils.atomicTypes + .filterNot(_.isInstanceOf[DecimalType]) // casts can lead to OOM + .filterNot(_.isInstanceOf[BinaryType]) // leads to spurious errors in string reverse + val dataTypesWithGenerators = allTypes.filter { dt => + RandomDataGenerator.forType(dt, nullable = true, seed = None).isDefined + } + def randomType(): DataType = randomChoice(dataTypesWithGenerators.toSeq) + val numColumns = 1 + Random.nextInt(3) + val schema = + new StructType((1 to numColumns).map(i => new StructField(s"c$i", randomType())).toArray) + val rowGenerator = RandomDataGenerator.forType(schema).get + val rows: Seq[Row] = Seq.fill(10)(rowGenerator().asInstanceOf[Row]) + TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(rows), schema) + } + + val df = generateRandomDataFrame() + + val m = ru.runtimeMirror(this.getClass.getClassLoader) + + val whitelistedColumnTypes = Set( + m.universe.typeOf[DataFrame], + m.universe.typeOf[Column] + ) + + val dataFrameTransformations = { + val dfType = m.universe.typeOf[DataFrame] + dfType.members + .filter(_.isPublic) + .filter(_.isMethod) + .map(_.asMethod) + .filter(_.returnType =:= dfType) + .filterNot(_.isConstructor) + .filter { m => + m.paramss.flatten.forall { p => + whitelistedColumnTypes.exists { t => t =:= p.typeSignature.erasure } + } + } + .toSeq + } + + def applyRandomTransformationToDataFrame(df: DataFrame): DataFrame = { + val method = randomChoice(dataFrameTransformations) + val params = method.paramss.flatten // We don't use multiple parameter lists + val paramTypes = params.map(_.typeSignature) + val paramValues = paramTypes.map { t => + if (m.universe.typeOf[DataFrame] =:= t.erasure) { + df + } else if (m.universe.typeOf[Column] =:= t.erasure) { + df.col(randomChoice(df.columns)) + } else { + sys.error("ERROR!") + } + } + val reflectedMethod: ru.MethodMirror = m.reflect(df).reflectMethod(method) + try { + reflectedMethod.apply(paramValues: _*).asInstanceOf[DataFrame] + } catch { + case e: InvocationTargetException => + throw e.getCause + } + } + + for (_ <- 1 to 1000) { + try { + val df2 = applyRandomTransformationToDataFrame(df) + try { + df2.collectAsList() + } catch { + case NonFatal(e) => + println(df2.queryExecution) + println(df) + println(df.collectAsList()) + throw e + } + } catch { + case e: AnalysisException => null + } + } + +} From 13f8c560b7103a13d47a673eaaef152830bddeed Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 17:11:33 -0700 Subject: [PATCH 15/67] Don't puts nulls into the DataFrame --- .../test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala index d985cedb67d87..d309ca41b7c81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala @@ -39,7 +39,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { val numColumns = 1 + Random.nextInt(3) val schema = new StructType((1 to numColumns).map(i => new StructField(s"c$i", randomType())).toArray) - val rowGenerator = RandomDataGenerator.forType(schema).get + val rowGenerator = RandomDataGenerator.forType(schema, nullable = false).get val rows: Seq[Row] = Seq.fill(10)(rowGenerator().asInstanceOf[Row]) TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(rows), schema) } From dd16f4dd5f6ffe1247c03b62bfddc5e7130d9ae5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 18:01:54 -0700 Subject: [PATCH 16/67] Print logical plans. --- .../test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala index d309ca41b7c81..341c344409157 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala @@ -98,6 +98,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { df2.collectAsList() } catch { case NonFatal(e) => + println(df2.logicalPlan) println(df2.queryExecution) println(df) println(df.collectAsList()) From 7f2b771f7fe36bf27c8cc7e680fdbff252bad3d1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 19:33:14 -0700 Subject: [PATCH 17/67] Fuzzer improvements. --- .../org/apache/spark/sql/DataFrameFuzzingSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala index 341c344409157..74dc392bec6f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala @@ -66,6 +66,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { whitelistedColumnTypes.exists { t => t =:= p.typeSignature.erasure } } } + .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns .toSeq } @@ -75,7 +76,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { val paramTypes = params.map(_.typeSignature) val paramValues = paramTypes.map { t => if (m.universe.typeOf[DataFrame] =:= t.erasure) { - df + randomChoice(Seq(df, generateRandomDataFrame())) } else if (m.universe.typeOf[Column] =:= t.erasure) { df.col(randomChoice(df.columns)) } else { @@ -83,6 +84,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { } } val reflectedMethod: ru.MethodMirror = m.reflect(df).reflectMethod(method) + println("Applying method " + reflectedMethod) try { reflectedMethod.apply(paramValues: _*).asInstanceOf[DataFrame] } catch { @@ -91,9 +93,10 @@ class DataFrameFuzzingSuite extends SparkFunSuite { } } - for (_ <- 1 to 1000) { + for (_ <- 1 to 10000) { + println("-" * 80) try { - val df2 = applyRandomTransformationToDataFrame(df) + val df2 = applyRandomTransformationToDataFrame(applyRandomTransformationToDataFrame(df)) try { df2.collectAsList() } catch { From 326d759c0a407a2ca9ea8a946ff84764031929b3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 19:33:50 -0700 Subject: [PATCH 18/67] Fix SPARK-9293 --- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 ++++ .../catalyst/analysis/HiveTypeCoercion.scala | 14 +++----- .../plans/logical/basicOperators.scala | 36 +++++++++---------- .../analysis/AnalysisErrorSuite.scala | 17 +++++++++ 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c203fcecf20fb..4154777c8af1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -98,6 +98,12 @@ trait CheckAnalysis { aggregateExprs.foreach(checkValidAggregateExpression) + case s @ SetOperation(left, right) if left.output.length != right.output.length => + failAnalysis( + s"${s.nodeName} can only be performed on tables with the same number of columns, " + + s"but the left table has ${left.output.length} columns and the right has " + + s"${right.output.length}") + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d56ceeadc9e85..eef76615f1fe8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -181,6 +181,7 @@ object HiveTypeCoercion { planName: String, left: LogicalPlan, right: LogicalPlan): (LogicalPlan, LogicalPlan) = { + require(left.output.length == right.output.length) val castedTypes = left.output.zip(right.output).map { case (lhs, rhs) if lhs.dataType != rhs.dataType => @@ -218,15 +219,10 @@ object HiveTypeCoercion { } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case u @ Union(left, right) if u.childrenResolved && !u.resolved => - val (newLeft, newRight) = widenOutputTypes(u.nodeName, left, right) - Union(newLeft, newRight) - case e @ Except(left, right) if e.childrenResolved && !e.resolved => - val (newLeft, newRight) = widenOutputTypes(e.nodeName, left, right) - Except(newLeft, newRight) - case i @ Intersect(left, right) if i.childrenResolved && !i.resolved => - val (newLeft, newRight) = widenOutputTypes(i.nodeName, left, right) - Intersect(newLeft, newRight) + case s @ SetOperation(left, right) if s.childrenResolved + && left.output.length == right.output.length && !s.resolved => + val (newLeft, newRight) = widenOutputTypes(s.nodeName, left, right) + s.makeCopy(Array(newLeft, newRight)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 6aefa9f67556a..f2a49194de8d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -89,13 +89,21 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } -case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { +abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { // TODO: These aren't really the same attributes as nullability etc might change. - override def output: Seq[Attribute] = left.output + final override def output: Seq[Attribute] = left.output - override lazy val resolved: Boolean = + final override lazy val resolved: Boolean = childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } + left.output.length == right.output.length && + left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } +} + +private[sql] object SetOperation { + def unapply(p: SetOperation): Option[(LogicalPlan, LogicalPlan)] = Some((p.left, p.right)) +} + +case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) { override def statistics: Statistics = { val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes @@ -103,6 +111,10 @@ case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { } } +case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + +case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) + case class Join( left: LogicalPlan, right: LogicalPlan, @@ -139,15 +151,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } - -case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output - - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} - case class InsertIntoTable( table: LogicalPlan, partition: Map[String, Option[String]], @@ -440,10 +443,3 @@ case object OneRowRelation extends LeafNode { override def statistics: Statistics = Statistics(sizeInBytes = 1) } -case class Intersect(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def output: Seq[Attribute] = left.output - - override lazy val resolved: Boolean = - childrenResolved && - left.output.zip(right.output).forall { case (l, r) => l.dataType == r.dataType } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index dca8c881f21ab..bc2c696f33c14 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -147,6 +147,23 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { UnresolvedTestPlan(), "unresolved" :: Nil) + errorTest( + "union with unequal number of columns", + testRelation.unionAll(testRelation2), + "union" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "intersect with unequal number of columns", + testRelation.intersect(testRelation2), + "intersect" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) + + errorTest( + "except with unequal number of columns", + testRelation.except(testRelation2), + "except" :: "number of columns" :: testRelation2.output.length.toString :: + testRelation.output.length.toString :: Nil) test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) From 37e4ce82807930f207066cca7efbd3195bd3d17e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 20:31:06 -0700 Subject: [PATCH 19/67] Support methods that take varargs Column parameters. --- .../org/apache/spark/sql/DataFrameFuzzingSuite.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala index 74dc392bec6f1..97b90d100827c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala @@ -48,8 +48,9 @@ class DataFrameFuzzingSuite extends SparkFunSuite { val m = ru.runtimeMirror(this.getClass.getClassLoader) - val whitelistedColumnTypes = Set( + val whitelistedParameterTypes = Set( m.universe.typeOf[DataFrame], + m.universe.typeOf[Seq[Column]], m.universe.typeOf[Column] ) @@ -63,7 +64,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { .filterNot(_.isConstructor) .filter { m => m.paramss.flatten.forall { p => - whitelistedColumnTypes.exists { t => t =:= p.typeSignature.erasure } + whitelistedParameterTypes.exists { t => p.typeSignature <:< t } } } .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns @@ -75,10 +76,12 @@ class DataFrameFuzzingSuite extends SparkFunSuite { val params = method.paramss.flatten // We don't use multiple parameter lists val paramTypes = params.map(_.typeSignature) val paramValues = paramTypes.map { t => - if (m.universe.typeOf[DataFrame] =:= t.erasure) { + if (t =:= m.universe.typeOf[DataFrame]) { randomChoice(Seq(df, generateRandomDataFrame())) - } else if (m.universe.typeOf[Column] =:= t.erasure) { + } else if (t =:= m.universe.typeOf[Column]) { df.col(randomChoice(df.columns)) + } else if (t <:< m.universe.typeOf[Seq[Column]]) { + Seq.fill(Random.nextInt(2) + 1)(df.col(randomChoice(df.columns))) } else { sys.error("ERROR!") } From 2f1b802839b88e3850d3333892fa23127b04486f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 21:36:04 -0700 Subject: [PATCH 20/67] Add analysis rule to detect sorting on unsupported column types (SPARK-9295) --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 4 ++++ .../spark/sql/catalyst/analysis/AnalysisErrorSuite.scala | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c203fcecf20fb..bc57dc0d79f07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -98,6 +98,10 @@ trait CheckAnalysis { aggregateExprs.foreach(checkValidAggregateExpression) + case Sort(order, _, _) if !order.forall(_.dataType.isInstanceOf[AtomicType]) => + val c = order.filterNot(_.dataType.isInstanceOf[AtomicType]).head + failAnalysis(s"Sorting is not supported for columns of type ${c.dataType.simpleString}") + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index dca8c881f21ab..83ae6f86ff0d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -113,6 +113,11 @@ class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter { testRelation.select(Literal(1).cast(BinaryType).as('badCast)), "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + errorTest( + "sorting by unsupported column types", + listRelation.orderBy('list.asc), + "sorting" :: "type" :: "array" :: Nil) + errorTest( "non-boolean filters", testRelation.where(Literal(1)), From d7a35358e2068eca9bdead2b93f3b96dcaf890d8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 23:02:13 -0700 Subject: [PATCH 21/67] [SPARK-9303] Decimal should use java.math.Decimal directly instead of via Scala wrapper --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../org/apache/spark/sql/types/Decimal.scala | 50 ++++++++++--------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c66854d52c50b..d4e319845bf6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -192,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def decimalToTimestamp(d: Decimal): Long = { - (d.toBigDecimal * 1000000L).longValue() + d.toJavaBigDecimal.multiply(java.math.BigDecimal.valueOf(1000000L)).longValue() } private[this] def doubleToTimestamp(d: Double): Any = { if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc292..3e99d2999ca24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.{BigDecimal => JavaBigDecimal} + import org.apache.spark.annotation.DeveloperApi /** @@ -30,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE} - private var decimalVal: BigDecimal = null + private var decimalVal: JavaBigDecimal = null private var longVal: Long = 0L private var _precision: Int = 1 private var _scale: Int = 0 @@ -44,7 +46,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(longVal: Long): Decimal = { if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { // We can't represent this compactly as a long without risking overflow - this.decimalVal = BigDecimal(longVal) + this.decimalVal = new JavaBigDecimal(longVal) this.longVal = 0L } else { this.decimalVal = null @@ -86,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision < 19) { return null // Requested precision is too low to represent this value } - this.decimalVal = BigDecimal(unscaled) + this.decimalVal = new JavaBigDecimal(unscaled) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) @@ -105,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE).underlying() require(decimalVal.precision <= precision, "Overflowed precision") this.longVal = 0L this._precision = precision @@ -117,7 +119,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. */ def set(decimal: BigDecimal): Decimal = { - this.decimalVal = decimal + this.decimalVal = decimal.underlying() this.longVal = 0L this._precision = decimal.precision this._scale = decimal.scale @@ -135,19 +137,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } - def toBigDecimal: BigDecimal = { + def toBigDecimal: BigDecimal = BigDecimal(toJavaBigDecimal) + + def toJavaBigDecimal: JavaBigDecimal = { if (decimalVal.ne(null)) { decimalVal } else { - BigDecimal(longVal, _scale) + JavaBigDecimal.valueOf(longVal, _scale) } } - def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying() - def toUnscaledLong: Long = { if (decimalVal.ne(null)) { - decimalVal.underlying().unscaledValue().longValue() + decimalVal.unscaledValue().longValue() } else { longVal } @@ -164,9 +166,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toDouble: Double = toBigDecimal.doubleValue() + def toDouble: Double = toJavaBigDecimal.doubleValue() - def toFloat: Float = toBigDecimal.floatValue() + def toFloat: Float = toJavaBigDecimal.floatValue() def toLong: Long = { if (decimalVal.eq(null)) { @@ -208,7 +210,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { longVal *= POW_10(diff) } else { // Give up on using Longs; switch to BigDecimal, which we'll modify below - decimalVal = BigDecimal(longVal, _scale) + decimalVal = JavaBigDecimal.valueOf(longVal, _scale) } } // In both cases, we will check whether our precision is okay below @@ -217,7 +219,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + val newVal = decimalVal.setScale(scale, ROUNDING_MODE.id) if (newVal.precision > precision) { return false } @@ -242,7 +244,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 } else { - toBigDecimal.compare(other.toBigDecimal) + toJavaBigDecimal.compareTo(other.toJavaBigDecimal) } } @@ -253,27 +255,27 @@ final class Decimal extends Ordered[Decimal] with Serializable { false } - override def hashCode(): Int = toBigDecimal.hashCode() + override def hashCode(): Int = toJavaBigDecimal.hashCode() def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 - def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + def + (that: Decimal): Decimal = Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal)) - def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + def - (that: Decimal): Decimal = Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal)) - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal)) def remainder(that: Decimal): Decimal = this % that def unary_- : Decimal = { if (decimalVal.ne(null)) { - Decimal(-decimalVal) + Decimal(decimalVal.negate()) } else { Decimal(-longVal, precision, scale) } @@ -290,7 +292,7 @@ object Decimal { private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO = BigDecimal(0) + private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0) def apply(value: Double): Decimal = new Decimal().set(value) @@ -300,7 +302,7 @@ object Decimal { def apply(value: BigDecimal): Decimal = new Decimal().set(value) - def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) + def apply(value: JavaBigDecimal): Decimal = new Decimal().set(value) def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) From bfe1451ec40bcf20f85c5b6fc7bcc1f23bdc6c91 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 23:18:17 -0700 Subject: [PATCH 22/67] Update to allow sorting by null literals --- .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bc57dc0d79f07..3a76609867a4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -98,9 +98,14 @@ trait CheckAnalysis { aggregateExprs.foreach(checkValidAggregateExpression) - case Sort(order, _, _) if !order.forall(_.dataType.isInstanceOf[AtomicType]) => - val c = order.filterNot(_.dataType.isInstanceOf[AtomicType]).head - failAnalysis(s"Sorting is not supported for columns of type ${c.dataType.simpleString}") + case Sort(orders, _, _) => + def checkValidSortOrder(order: SortOrder): Unit = order.dataType match { + case t: AtomicType => // OK + case NullType => // OK + case t => + failAnalysis(s"Sorting is not supported for columns of type ${t.simpleString}") + } + orders.foreach(checkValidSortOrder) case _ => // Fallbacks to the following checks } From 55221fa51136920a11da22690ae53c59c865a7a7 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Jul 2015 23:17:43 +0800 Subject: [PATCH 23/67] Shouldn't use SortMergeJoin when joining on unsortable columns. --- .../sql/catalyst/planning/patterns.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 19 +++++++++++++++---- .../org/apache/spark/sql/JoinSuite.scala | 12 ++++++++++++ 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b8e3b0d53a505..1e7b2a536ac12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -184,7 +184,7 @@ object PartialAggregation { * A pattern that finds joins with equality conditions that can be evaluated using equi-join. */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { - /** (joinType, rightKeys, leftKeys, condition, leftChild, rightChild) */ + /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ type ReturnType = (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eb4be1900b153..5a0cb77061bc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -36,9 +36,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object LeftSemiJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + case ExtractEquiJoinKeys( + LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => joins.BroadcastLeftSemiJoinHash( leftKeys, rightKeys, planLater(left), planLater(right), condition) :: Nil // Find left semi joins where at least some predicates can be evaluated by matching join keys @@ -91,6 +90,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil } + private[this] def isValidSort( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression]): Boolean = { + !leftKeys.zip(rightKeys).exists { keys => + (keys._1.dataType, keys._2.dataType) match { + case (l: AtomicType, r: AtomicType) => false + case (NullType, NullType) => false + case _ => true + } + } + } + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) @@ -101,7 +112,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.sortMergeJoinEnabled => + if sqlContext.conf.sortMergeJoinEnabled && isValidSort(leftKeys, rightKeys) => val mergeJoin = joins.SortMergeJoin(leftKeys, rightKeys, planLater(left), planLater(right)) condition.map(Filter(_, mergeJoin)).getOrElse(mergeJoin) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 8953889d1fae9..dfb2a7e099748 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -108,6 +108,18 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } } + test("SortMergeJoin shouldn't work on unsortable columns") { + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled + try { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + Seq( + ("SELECT * FROM arrayData JOIN complexData ON data = a", classOf[ShuffledHashJoin]) + ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + } finally { + ctx.conf.setConf(SQLConf.SORTMERGE_JOIN, SORTMERGEJOIN_ENABLED) + } + } + test("broadcasted hash join operator selection") { ctx.cacheManager.clearCache() ctx.sql("CACHE TABLE testData") From a2407074dc34672f71c33d671d285809969bfc78 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 24 Jul 2015 23:58:26 +0800 Subject: [PATCH 24/67] Use forall instead of exists for readability. --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5a0cb77061bc0..8116abec46a98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -93,11 +93,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { private[this] def isValidSort( leftKeys: Seq[Expression], rightKeys: Seq[Expression]): Boolean = { - !leftKeys.zip(rightKeys).exists { keys => + leftKeys.zip(rightKeys).forall { keys => (keys._1.dataType, keys._2.dataType) match { - case (l: AtomicType, r: AtomicType) => false - case (NullType, NullType) => false - case _ => true + case (l: AtomicType, r: AtomicType) => true + case (NullType, NullType) => true + case _ => false } } } From 68c0e972e21779d70ff5bf76f691a0e64a974e39 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Jul 2015 16:55:52 -0700 Subject: [PATCH 25/67] Commit some outstanding changes. --- .../spark/sql/DataFrameFuzzingSuite.scala | 84 +++++++++++++++---- 1 file changed, 66 insertions(+), 18 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala index 97b90d100827c..66cee60481fa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala @@ -1,12 +1,15 @@ package org.apache.spark.sql +import java.io.File import java.lang.reflect.InvocationTargetException +import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils import scala.reflect.runtime.{universe => ru} -import scala.util.Random +import scala.util.{Try, Random} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ @@ -20,6 +23,8 @@ import scala.util.control.NonFatal */ class DataFrameFuzzingSuite extends SparkFunSuite { + val tempDir = Utils.createTempDir() + def randomChoice[T](values: Seq[T]): T = { values(Random.nextInt(values.length)) } @@ -28,20 +33,29 @@ class DataFrameFuzzingSuite extends SparkFunSuite { classOf[String] -> (() => Random.nextString(10)) ) - def generateRandomDataFrame(): DataFrame = { - val allTypes = DataTypeTestUtils.atomicTypes - .filterNot(_.isInstanceOf[DecimalType]) // casts can lead to OOM - .filterNot(_.isInstanceOf[BinaryType]) // leads to spurious errors in string reverse - val dataTypesWithGenerators = allTypes.filter { dt => + val allTypes = DataTypeTestUtils.atomicTypes + //.filterNot(_.isInstanceOf[DecimalType]) // casts can lead to OOM + .filterNot(_.isInstanceOf[BinaryType]) // leads to spurious errors in string reverse + val dataTypesWithGenerators = allTypes.filter { dt => RandomDataGenerator.forType(dt, nullable = true, seed = None).isDefined } - def randomType(): DataType = randomChoice(dataTypesWithGenerators.toSeq) + def randomType(): DataType = randomChoice(dataTypesWithGenerators.toSeq) + + def generateRandomSchema(): StructType = { val numColumns = 1 + Random.nextInt(3) - val schema = - new StructType((1 to numColumns).map(i => new StructField(s"c$i", randomType())).toArray) + val r = Random.nextString(1) + new StructType((1 to numColumns).map(i => new StructField(s"c$i$r", randomType())).toArray) + } + + def generateRandomDataFrame(): DataFrame = { + val schema = generateRandomSchema() val rowGenerator = RandomDataGenerator.forType(schema, nullable = false).get val rows: Seq[Row] = Seq.fill(10)(rowGenerator().asInstanceOf[Row]) - TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(rows), schema) + val df = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(rows), schema) + val path = new File(tempDir, Random.nextInt(1000000).toString).getAbsolutePath + df.write.json(path) + TestSQLContext.read.json(path) + df } val df = generateRandomDataFrame() @@ -51,7 +65,9 @@ class DataFrameFuzzingSuite extends SparkFunSuite { val whitelistedParameterTypes = Set( m.universe.typeOf[DataFrame], m.universe.typeOf[Seq[Column]], - m.universe.typeOf[Column] + m.universe.typeOf[Column], + m.universe.typeOf[String], + m.universe.typeOf[Seq[String]] ) val dataFrameTransformations = { @@ -68,26 +84,46 @@ class DataFrameFuzzingSuite extends SparkFunSuite { } } .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns + .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output + .filterNot(_.name.toString == "dropDuplicates") + .filter(_.name.toString == "join") .toSeq } + def getRandomColumnName(df: DataFrame): String = { + randomChoice(df.columns.zip(df.schema).map { case (colName, field) => + field.dataType match { + case StructType(fields) => + colName + "." + randomChoice(fields.map(_.name)) + case _ => colName + } + }) + } + def applyRandomTransformationToDataFrame(df: DataFrame): DataFrame = { val method = randomChoice(dataFrameTransformations) val params = method.paramss.flatten // We don't use multiple parameter lists val paramTypes = params.map(_.typeSignature) val paramValues = paramTypes.map { t => if (t =:= m.universe.typeOf[DataFrame]) { - randomChoice(Seq(df, generateRandomDataFrame())) + randomChoice(Seq( + df, + generateRandomDataFrame() + )) // ++ Try(applyRandomTransformationToDataFrame(df)).toOption.toSeq) } else if (t =:= m.universe.typeOf[Column]) { - df.col(randomChoice(df.columns)) + df.col(getRandomColumnName(df)) + } else if (t =:= m.universe.typeOf[String]) { + getRandomColumnName(df) } else if (t <:< m.universe.typeOf[Seq[Column]]) { - Seq.fill(Random.nextInt(2) + 1)(df.col(randomChoice(df.columns))) + Seq.fill(Random.nextInt(2) + 1)(df.col(getRandomColumnName(df))) + } else if (t <:< m.universe.typeOf[Seq[String]]) { + Seq.fill(Random.nextInt(2) + 1)(getRandomColumnName(df)) } else { sys.error("ERROR!") } } val reflectedMethod: ru.MethodMirror = m.reflect(df).reflectMethod(method) - println("Applying method " + reflectedMethod) + println("Applying method " + method + " with values " + paramValues) try { reflectedMethod.apply(paramValues: _*).asInstanceOf[DataFrame] } catch { @@ -96,6 +132,14 @@ class DataFrameFuzzingSuite extends SparkFunSuite { } } + //TestSQLContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) +// TestSQLContext.conf.setConf(SQLConf.UNSAFE_ENABLED, true) + TestSQLContext.conf.setConf(SQLConf.SORTMERGE_JOIN, true) + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + + TestSQLContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) + + for (_ <- 1 to 10000) { println("-" * 80) try { @@ -104,14 +148,18 @@ class DataFrameFuzzingSuite extends SparkFunSuite { df2.collectAsList() } catch { case NonFatal(e) => - println(df2.logicalPlan) println(df2.queryExecution) println(df) println(df.collectAsList()) - throw e + throw new Exception(e) } } catch { - case e: AnalysisException => null + case e: UnresolvedException[_] => + println("skipped due to unresolved") + case e: AnalysisException => + println("Skipped") + case e: IllegalArgumentException => + println("Skipped") } } From 2d4ed76faa34e81ba32059a0b5040e4b9a8c8517 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Jul 2015 16:56:55 -0700 Subject: [PATCH 26/67] Move to fuzzing package. --- .../{ => fuzzing}/DataFrameFuzzingSuite.scala | 29 ++++++++++++++----- .../ExpressionFuzzingSuite.scala | 16 +++++----- 2 files changed, 30 insertions(+), 15 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{ => fuzzing}/DataFrameFuzzingSuite.scala (86%) rename sql/core/src/test/scala/org/apache/spark/sql/{ => fuzzing}/ExpressionFuzzingSuite.scala (97%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala similarity index 86% rename from sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 66cee60481fa3..3edda0853b0d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -1,19 +1,34 @@ -package org.apache.spark.sql +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.fuzzing import java.io.File import java.lang.reflect.InvocationTargetException +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import scala.reflect.runtime.{universe => ru} - -import scala.util.{Try, Random} - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types._ - +import scala.util.Random import scala.util.control.NonFatal /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala index 530795fef2a38..4cf2ebbe268f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala @@ -15,23 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql +package org.apache.spark.sql.fuzzing import java.io.File import java.lang.reflect.Constructor -import scala.util.{Try, Random} - -import org.clapper.classutil.ClassFinder - -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.{BinaryType, DecimalType, DataType, DataTypeTestUtils} +import org.apache.spark.sql.types.{BinaryType, DataType, DataTypeTestUtils, DecimalType} +import org.apache.spark.{Logging, SparkFunSuite} +import org.clapper.classutil.ClassFinder + +import scala.util.{Random, Try} /** * This test suite implements fuzz tests for expression code generation. It uses reflection to From ac8dd74e635fb088f02cb5aee2a2856ef589af5e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 24 Jul 2015 17:48:30 -0700 Subject: [PATCH 27/67] Begin to clean up random DF generator --- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 123 ++++++++++++------ 1 file changed, 85 insertions(+), 38 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 3edda0853b0d9..33f8e33a4d960 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.fuzzing -import java.io.File import java.lang.reflect.InvocationTargetException +import java.util.concurrent.atomic.AtomicInteger + +import scala.reflect.runtime.{universe => ru} +import scala.util.Random +import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -27,9 +31,75 @@ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -import scala.reflect.runtime.{universe => ru} -import scala.util.Random -import scala.util.control.NonFatal +class RandomDataFrameGenerator(seed: Long, sqlContext: SQLContext) { + + private val rand = new Random(seed) + private val nextId = new AtomicInteger() + + private def hasRandomDataGenerator(dataType: DataType): Boolean = { + RandomDataGenerator.forType(dataType).isDefined + } + + def randomChoice[T](values: Seq[T]): T = { + values(rand.nextInt(values.length)) + } + + private val simpleTypes: Set[DataType] = { + DataTypeTestUtils.atomicTypes + .filter(hasRandomDataGenerator) + // Ignore decimal type since it can lead to OOM (see SPARK-9303). TODO: It would be better to + // only generate limited precision decimals instead. + .filterNot(_.isInstanceOf[DecimalType]) + } + + private val arrayTypes: Set[DataType] = { + DataTypeTestUtils.atomicArrayTypes + .filter(hasRandomDataGenerator) + // See above comment about DecimalType + .filterNot(_.elementType.isInstanceOf[DecimalType]).toSet + } + + private def randomStructField( + allowComplexTypes: Boolean = false, + allowSpacesInColumnName: Boolean = false): StructField = { + val name = "c" + nextId.getAndIncrement + (if (allowSpacesInColumnName) " space" else "") + val candidateTypes: Seq[DataType] = Seq( + simpleTypes, + arrayTypes.filter(_ => allowComplexTypes), + // This does not allow complex types, limiting the depth of recursion: + if (allowComplexTypes) { + Set[DataType](randomStructType(numCols = rand.nextInt(2) + 1)) + } else { + Set[DataType]() + } + ).flatten + val dataType = randomChoice(candidateTypes) + val nullable = rand.nextBoolean() + StructField(name, dataType, nullable) + } + + private def randomStructType( + numCols: Int, + allowComplexTypes: Boolean = false, + allowSpacesInColumnNames: Boolean = false): StructType = { + StructType(Array.fill(numCols)(randomStructField(allowComplexTypes, allowSpacesInColumnNames))) + } + + def randomDataFrame( + numCols: Int, + numRows: Int, + allowComplexTypes: Boolean = false, + allowSpacesInColumnNames: Boolean = false): DataFrame = { + val schema = randomStructType(numCols, allowComplexTypes, allowSpacesInColumnNames) + val rows = sqlContext.sparkContext.parallelize(1 to numRows).mapPartitions { iter => + val rowGenerator = RandomDataGenerator.forType(schema, nullable = false, seed = Some(42)).get + iter.map(_ => rowGenerator().asInstanceOf[Row]) + } + sqlContext.createDataFrame(rows, schema) + } + +} + /** * This test suite generates random data frames, then applies random sequences of operations to @@ -40,40 +110,12 @@ class DataFrameFuzzingSuite extends SparkFunSuite { val tempDir = Utils.createTempDir() + private val dataGenerator = new RandomDataFrameGenerator(123, TestSQLContext) + def randomChoice[T](values: Seq[T]): T = { values(Random.nextInt(values.length)) } - val randomValueGenerators: Map[Class[_], () => Any] = Map( - classOf[String] -> (() => Random.nextString(10)) - ) - - val allTypes = DataTypeTestUtils.atomicTypes - //.filterNot(_.isInstanceOf[DecimalType]) // casts can lead to OOM - .filterNot(_.isInstanceOf[BinaryType]) // leads to spurious errors in string reverse - val dataTypesWithGenerators = allTypes.filter { dt => - RandomDataGenerator.forType(dt, nullable = true, seed = None).isDefined - } - def randomType(): DataType = randomChoice(dataTypesWithGenerators.toSeq) - - def generateRandomSchema(): StructType = { - val numColumns = 1 + Random.nextInt(3) - val r = Random.nextString(1) - new StructType((1 to numColumns).map(i => new StructField(s"c$i$r", randomType())).toArray) - } - - def generateRandomDataFrame(): DataFrame = { - val schema = generateRandomSchema() - val rowGenerator = RandomDataGenerator.forType(schema, nullable = false).get - val rows: Seq[Row] = Seq.fill(10)(rowGenerator().asInstanceOf[Row]) - val df = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(rows), schema) - val path = new File(tempDir, Random.nextInt(1000000).toString).getAbsolutePath - df.write.json(path) - TestSQLContext.read.json(path) - df - } - - val df = generateRandomDataFrame() val m = ru.runtimeMirror(this.getClass.getClassLoader) @@ -101,7 +143,6 @@ class DataFrameFuzzingSuite extends SparkFunSuite { .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output .filterNot(_.name.toString == "dropDuplicates") - .filter(_.name.toString == "join") .toSeq } @@ -123,7 +164,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { if (t =:= m.universe.typeOf[DataFrame]) { randomChoice(Seq( df, - generateRandomDataFrame() + dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) )) // ++ Try(applyRandomTransformationToDataFrame(df)).toOption.toSeq) } else if (t =:= m.universe.typeOf[Column]) { df.col(getRandomColumnName(df)) @@ -158,6 +199,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { for (_ <- 1 to 10000) { println("-" * 80) try { + val df = dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 20) val df2 = applyRandomTransformationToDataFrame(applyRandomTransformationToDataFrame(df)) try { df2.collectAsList() @@ -173,8 +215,13 @@ class DataFrameFuzzingSuite extends SparkFunSuite { println("skipped due to unresolved") case e: AnalysisException => println("Skipped") - case e: IllegalArgumentException => - println("Skipped") + case e: IllegalArgumentException if e.getMessage.contains("number of columns doesn't match") => + case e: IllegalArgumentException if e.getMessage.contains("Unsupported join type") => + + + // case e: IllegalArgumentException => +// println(e) +// println("Skipped due to IOE") } } From 0b3938b1662dd889390429fadb80a59eaf58fbc0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 26 Jul 2015 13:02:38 -0700 Subject: [PATCH 28/67] Add basic backtracking to improve chance of generating executable plan. --- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 137 +++++++++++++----- 1 file changed, 100 insertions(+), 37 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 33f8e33a4d960..464c77b3196d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.fuzzing import java.lang.reflect.InvocationTargetException import java.util.concurrent.atomic.AtomicInteger +import scala.reflect.runtime import scala.reflect.runtime.{universe => ru} import scala.util.Random import scala.util.control.NonFatal @@ -127,7 +128,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { m.universe.typeOf[Seq[String]] ) - val dataFrameTransformations = { + val dataFrameTransformations: Seq[ru.MethodSymbol] = { val dfType = m.universe.typeOf[DataFrame] dfType.members .filter(_.isPublic) @@ -146,60 +147,125 @@ class DataFrameFuzzingSuite extends SparkFunSuite { .toSeq } - def getRandomColumnName(df: DataFrame): String = { - randomChoice(df.columns.zip(df.schema).map { case (colName, field) => - field.dataType match { - case StructType(fields) => - colName + "." + randomChoice(fields.map(_.name)) - case _ => colName + /** + * Build a list of column names and types for the given StructType, taking nesting into account. + * For nested struct fields, this will emit both the column for the struct field itself as well as + * fields for the nested struct's fields. This process will be performed recursively in order to + * handle deeply-nested structs. + */ + def getColumnsAndTypes(struct: StructType): Seq[(String, DataType)] = { + struct.flatMap { field => + val nestedFieldInfos: Seq[(String, DataType)] = field.dataType match { + case nestedStruct: StructType => + Seq((field.name, field.dataType)) ++ getColumnsAndTypes(nestedStruct).map { + case (nestedColName, dataType) => (field.name + "." + nestedColName, dataType) + } + case _ => Seq.empty } - }) + Seq((field.name, field.dataType)) ++ nestedFieldInfos + } } - def applyRandomTransformationToDataFrame(df: DataFrame): DataFrame = { - val method = randomChoice(dataFrameTransformations) + def getRandomColumnName( + df: DataFrame, + condition: DataType => Boolean = _ => true): Option[String] = { + val columnsWithTypes = getColumnsAndTypes(df.schema) + val candidateColumns = columnsWithTypes.filter(c => condition(c._2)) + if (candidateColumns.isEmpty) { + None + } else { + Some(randomChoice(candidateColumns)._1) + } + } + + class NoDataGeneratorException extends Exception + + def getParamValues( + df: DataFrame, + method: ru.MethodSymbol, + typeConstraint: DataType => Boolean = _ => true): Seq[Any] = { val params = method.paramss.flatten // We don't use multiple parameter lists val paramTypes = params.map(_.typeSignature) - val paramValues = paramTypes.map { t => - if (t =:= m.universe.typeOf[DataFrame]) { + def randColName(): String = + getRandomColumnName(df, typeConstraint).getOrElse(throw new NoDataGeneratorException) + paramTypes.map { t => + if (t =:= ru.typeOf[DataFrame]) { randomChoice(Seq( df, + applyRandomTransformationToDataFrame(df), dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) )) // ++ Try(applyRandomTransformationToDataFrame(df)).toOption.toSeq) - } else if (t =:= m.universe.typeOf[Column]) { - df.col(getRandomColumnName(df)) - } else if (t =:= m.universe.typeOf[String]) { - getRandomColumnName(df) - } else if (t <:< m.universe.typeOf[Seq[Column]]) { - Seq.fill(Random.nextInt(2) + 1)(df.col(getRandomColumnName(df))) - } else if (t <:< m.universe.typeOf[Seq[String]]) { - Seq.fill(Random.nextInt(2) + 1)(getRandomColumnName(df)) + } else if (t =:= ru.typeOf[Column]) { + df.col(randColName()) + } else if (t =:= ru.typeOf[String]) { + randColName() + } else if (t <:< ru.typeOf[Seq[Column]]) { + Seq.fill(Random.nextInt(2) + 1)(df.col(randColName())) + } else if (t <:< ru.typeOf[Seq[String]]) { + Seq.fill(Random.nextInt(2) + 1)(randColName()) } else { sys.error("ERROR!") } } + } + + def applyRandomTransformationToDataFrame(df: DataFrame): DataFrame = { + val method = randomChoice(dataFrameTransformations) val reflectedMethod: ru.MethodMirror = m.reflect(df).reflectMethod(method) - println("Applying method " + method + " with values " + paramValues) + def callMethod(paramValues: Seq[Any]): DataFrame = { + try { + val df2 = reflectedMethod.apply(paramValues: _*).asInstanceOf[DataFrame] + println("Applied method " + method + " with values " + paramValues) + df2 + } catch { + case e: InvocationTargetException => + throw e.getCause + } + } try { - reflectedMethod.apply(paramValues: _*).asInstanceOf[DataFrame] + val paramValues = getParamValues(df, method) + try { + callMethod(paramValues) + } catch { + case NonFatal(e) => + println(s"Encountered error when calling $method with values $paramValues") + throw e + } } catch { - case e: InvocationTargetException => - throw e.getCause + case e: AnalysisException if e.getMessage.contains("is not a boolean") => + callMethod(getParamValues(df, method, _ == BooleanType)) + case e: AnalysisException if e.getMessage.contains("is not supported for columns of type") => + callMethod(getParamValues(df, method, _.isInstanceOf[AtomicType])) } } //TestSQLContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) -// TestSQLContext.conf.setConf(SQLConf.UNSAFE_ENABLED, true) + TestSQLContext.conf.setConf(SQLConf.UNSAFE_ENABLED, false) TestSQLContext.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, false) TestSQLContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) + val ignoredAnalysisExceptionMessages = Seq( + "can only be performed on tables with the same number of columns", + "number of columns doesn't match", + "unsupported join type", + "is neither present in the group by, nor is it an aggregate function", + "is ambiguous, could be:", + "unresolved operator 'Project", //TODO + "unresolved operator 'Union", // TODO: disabled to let me find new errors + "unresolved operator 'Except", // TODO: disabled to let me find new errors + "unresolved operator 'Intersect", // TODO: disabled to let me find new errors + "Cannot resolve column name" // TODO: only ignore for join? + ) - for (_ <- 1 to 10000) { + for (_ <- 1 to 1000) { println("-" * 80) try { - val df = dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 20) + val df = dataGenerator.randomDataFrame( + numCols = Random.nextInt(4) + 1, + numRows = 20, + allowComplexTypes = true) val df2 = applyRandomTransformationToDataFrame(applyRandomTransformationToDataFrame(df)) try { df2.collectAsList() @@ -211,17 +277,14 @@ class DataFrameFuzzingSuite extends SparkFunSuite { throw new Exception(e) } } catch { + case e: NoDataGeneratorException => + println("skipped due to lack of data generator") case e: UnresolvedException[_] => println("skipped due to unresolved") - case e: AnalysisException => - println("Skipped") - case e: IllegalArgumentException if e.getMessage.contains("number of columns doesn't match") => - case e: IllegalArgumentException if e.getMessage.contains("Unsupported join type") => - - - // case e: IllegalArgumentException => -// println(e) -// println("Skipped due to IOE") + case e: Exception + if ignoredAnalysisExceptionMessages.exists { + m => e.getMessage.toLowerCase.contains(m.toLowerCase) + } => println("Skipped due to expected AnalysisException") } } From c836884953241901f177931a101ca423b57f51ee Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 26 Jul 2015 13:09:11 -0700 Subject: [PATCH 29/67] Hacky approach to try to execute child plans first. --- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 464c77b3196d6..4e68522cb33cd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -192,7 +192,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { if (t =:= ru.typeOf[DataFrame]) { randomChoice(Seq( df, - applyRandomTransformationToDataFrame(df), + tryToExecute(applyRandomTransformationToDataFrame(df)), dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) )) // ++ Try(applyRandomTransformationToDataFrame(df)).toOption.toSeq) } else if (t =:= ru.typeOf[Column]) { @@ -215,7 +215,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { def callMethod(paramValues: Seq[Any]): DataFrame = { try { val df2 = reflectedMethod.apply(paramValues: _*).asInstanceOf[DataFrame] - println("Applied method " + method + " with values " + paramValues) + println(s"Applied method $method with values $paramValues") df2 } catch { case e: InvocationTargetException => @@ -239,10 +239,20 @@ class DataFrameFuzzingSuite extends SparkFunSuite { } } + def tryToExecute(df: DataFrame): DataFrame = { + try { + df.collectAsList() + df + } catch { + case NonFatal(e) => + println(df.queryExecution) + throw new Exception(e) + } + } //TestSQLContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) - TestSQLContext.conf.setConf(SQLConf.UNSAFE_ENABLED, false) + TestSQLContext.conf.setConf(SQLConf.UNSAFE_ENABLED, true) TestSQLContext.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, false) + TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) TestSQLContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) @@ -266,16 +276,8 @@ class DataFrameFuzzingSuite extends SparkFunSuite { numCols = Random.nextInt(4) + 1, numRows = 20, allowComplexTypes = true) - val df2 = applyRandomTransformationToDataFrame(applyRandomTransformationToDataFrame(df)) - try { - df2.collectAsList() - } catch { - case NonFatal(e) => - println(df2.queryExecution) - println(df) - println(df.collectAsList()) - throw new Exception(e) - } + val df1 = tryToExecute(applyRandomTransformationToDataFrame(df)) + val df2 = tryToExecute(applyRandomTransformationToDataFrame(df1)) } catch { case e: NoDataGeneratorException => println("skipped due to lack of data generator") From 396c2351c64cb1a5fccfe5b61be7774a4c123e6e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 20 Jul 2015 18:55:45 -0700 Subject: [PATCH 30/67] Enable Unsafe by default --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 2a641b9d64a95..8a44777c4dc53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -229,7 +229,7 @@ private[spark] object SQLConf { " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( From 71fe0bb6269ffabe8ad3a28b657e02b183796588 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Jul 2015 00:45:18 -0700 Subject: [PATCH 31/67] Ignore failing ScalaUDFSuite test. --- sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index c1516b450cbd4..3c171f34cc593 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -130,7 +130,8 @@ class UDFSuite extends QueryTest { assert(result.count() === 2) } - test("UDFs everywhere") { + // Temporarily ignored until we implement code generation for ScalaUDF. + ignore("UDFs everywhere") { ctx.udf.register("groupFunction", (n: Int) => { n > 10 }) ctx.udf.register("havingFilter", (n: Long) => { n > 2000 }) ctx.udf.register("whereFilter", (n: Int) => { n < 150 }) From 04fbe3e2185f53dc71bed992b297e46d37300c00 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Jul 2015 12:00:43 -0700 Subject: [PATCH 32/67] Do not use UnsafeExternalSort operator if codegen is disabled --- .../org/apache/spark/sql/execution/SparkStrategies.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eb4be1900b153..f5558bac81691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -380,7 +380,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { * if necessary. */ def getSortOperator(sortExprs: Seq[SortOrder], global: Boolean, child: SparkPlan): SparkPlan = { - if (sqlContext.conf.unsafeEnabled && UnsafeExternalSort.supportsSchema(child.schema)) { + if (sqlContext.conf.unsafeEnabled + && sqlContext.conf.codegenEnabled + && UnsafeExternalSort.supportsSchema(child.schema)) { execution.UnsafeExternalSort(sortExprs, global, child) } else if (sqlContext.conf.externalSortEnabled) { execution.ExternalSort(sortExprs, global, child) From a7979dcd5f60b40a3be2e55449e86353c6ff18a3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Jul 2015 15:32:42 -0700 Subject: [PATCH 33/67] Disable unsafe Exchange path when RangePartitioning is used --- .../scala/org/apache/spark/sql/execution/Exchange.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 41a0c519ba527..70e5031fb63c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = { + // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to + // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to + // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. + !newPartitioning.isInstanceOf[RangePartitioning] + } /** * Determines whether records must be defensively copied before being sent to the shuffle. From 4fcae4a816b40155702ad11088709d89e2a2e41a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Jul 2015 16:10:15 -0700 Subject: [PATCH 34/67] Reduce page size to make HiveCompatibilitySuite pass. --- .../spark/util/collection/unsafe/sort/UnsafeExternalSorter.java | 2 +- .../main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 4d6731ee60af3..3b6db2080aec9 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -41,7 +41,7 @@ public final class UnsafeExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); - private static final int PAGE_SIZE = 1 << 27; // 128 megabytes + private static final int PAGE_SIZE = 1 << 22; // 4 megabytes @VisibleForTesting static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d0bde69cc1068..3c56b321a0caf 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -78,7 +78,7 @@ public final class BytesToBytesMap { * The size of the data pages that hold key and value data. Map entries cannot span multiple * pages, so this limits the maximum entry size. */ - private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + private static final long PAGE_SIZE_BYTES = 1L << 22; // 4 megabytes /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be From 7e49d0e56bc46c109b0ba15a04d6894e4e696dea Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 26 Jul 2015 15:23:32 -0700 Subject: [PATCH 35/67] Fix use-after-free bug in UnsafeExternalSorter. --- .../spark/sql/execution/UnsafeExternalRowSorter.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index be4ff400c4754..811d595a64ecd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -141,10 +141,13 @@ public InternalRow next() { numFields, sortedIterator.getRecordLength()); if (!hasNext()) { - row.copy(); // so that we don't have dangling pointers to freed page + UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page + row.pointTo(null, 0, 0, 0); // so that we don't keep references to the base object cleanupResources(); + return copy; + } else { + return row; } - return row; } catch (IOException e) { cleanupResources(); // Scala iterators don't declare any checked exceptions, so we need to use this hack From 454c921575bccb235dc0ee63a647b5b70aa65ce6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 26 Jul 2015 16:30:52 -0700 Subject: [PATCH 36/67] Hack to enable join types to be tested --- .../spark/sql/catalyst/plans/joinTypes.scala | 17 +++++++------- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 22 +++++++++++++------ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 77dec7ca6e2b5..149e2410d3023 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -18,6 +18,14 @@ package org.apache.spark.sql.catalyst.plans object JoinType { + + val supportedJoinTypes = Seq( + "inner", + "outer", "full", "fullouter", + "leftouter", "left", + "rightouter", "right", + "leftsemi") + def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { case "inner" => Inner case "outer" | "full" | "fullouter" => FullOuter @@ -25,15 +33,8 @@ object JoinType { case "rightouter" | "right" => RightOuter case "leftsemi" => LeftSemi case _ => - val supported = Seq( - "inner", - "outer", "full", "fullouter", - "leftouter", "left", - "rightouter", "right", - "leftsemi") - throw new IllegalArgumentException(s"Unsupported join type '$typ'. " + - "Supported join types include: " + supported.mkString("'", "', '", "'") + ".") + "Supported join types include: " + supportedJoinTypes.mkString("'", "', '", "'") + ".") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 4e68522cb33cd..b6b06c5f0b9b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.fuzzing import java.lang.reflect.InvocationTargetException import java.util.concurrent.atomic.AtomicInteger +import org.apache.spark.sql.catalyst.plans.JoinType + import scala.reflect.runtime import scala.reflect.runtime.{universe => ru} import scala.util.Random @@ -144,6 +146,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output .filterNot(_.name.toString == "dropDuplicates") + .filter(_.name.toString == "join") .toSeq } @@ -185,20 +188,24 @@ class DataFrameFuzzingSuite extends SparkFunSuite { method: ru.MethodSymbol, typeConstraint: DataType => Boolean = _ => true): Seq[Any] = { val params = method.paramss.flatten // We don't use multiple parameter lists - val paramTypes = params.map(_.typeSignature) def randColName(): String = getRandomColumnName(df, typeConstraint).getOrElse(throw new NoDataGeneratorException) - paramTypes.map { t => + params.map { p => + val t = p.typeSignature if (t =:= ru.typeOf[DataFrame]) { randomChoice(Seq( df, - tryToExecute(applyRandomTransformationToDataFrame(df)), + //tryToExecute(applyRandomTransformationToDataFrame(df)), dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) )) // ++ Try(applyRandomTransformationToDataFrame(df)).toOption.toSeq) } else if (t =:= ru.typeOf[Column]) { df.col(randColName()) } else if (t =:= ru.typeOf[String]) { - randColName() + if (p.name == "joinType") { + randomChoice(JoinType.supportedJoinTypes) + } else { + randColName() + } } else if (t <:< ru.typeOf[Seq[Column]]) { Seq.fill(Random.nextInt(2) + 1)(df.col(randColName())) } else if (t <:< ru.typeOf[Seq[String]]) { @@ -229,6 +236,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { } catch { case NonFatal(e) => println(s"Encountered error when calling $method with values $paramValues") + println(df.queryExecution) throw e } } catch { @@ -241,7 +249,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite { def tryToExecute(df: DataFrame): DataFrame = { try { - df.collectAsList() + df.rdd.count() df } catch { case NonFatal(e) => @@ -273,9 +281,9 @@ class DataFrameFuzzingSuite extends SparkFunSuite { println("-" * 80) try { val df = dataGenerator.randomDataFrame( - numCols = Random.nextInt(4) + 1, + numCols = Random.nextInt(2) + 1, numRows = 20, - allowComplexTypes = true) + allowComplexTypes = false) val df1 = tryToExecute(applyRandomTransformationToDataFrame(df)) val df2 = tryToExecute(applyRandomTransformationToDataFrame(df1)) } catch { From 11f80a348f02f6592f944e313005bcae9e803f05 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 20:23:19 -0700 Subject: [PATCH 37/67] [SPARK-9368][SQL] Support get(ordinal, dataType) generic getter in UnsafeRow. --- .../sql/catalyst/expressions/UnsafeRow.java | 46 ++++++++++++++++++- .../spark/sql/catalyst/InternalRow.scala | 4 +- .../expressions/SpecificMutableRow.scala | 2 +- .../codegen/GenerateProjection.scala | 2 +- .../spark/sql/catalyst/expressions/rows.scala | 4 +- 5 files changed, 51 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87e5a89c19658..7812e8dc9c32a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -24,7 +24,7 @@ import java.util.HashSet; import java.util.Set; -import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -235,6 +235,35 @@ public Object get(int ordinal) { throw new UnsupportedOperationException(); } + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof NullType) { + return null; + } else if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof DecimalType) { + return getDecimal(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType) dataType).size()); + } else { + throw new UnsupportedOperationException("Unsupported data type " + dataType.simpleString()); + } + } + @Override public boolean isNullAt(int ordinal) { assertIndexIsValid(ordinal); @@ -436,4 +465,19 @@ public String toString() { public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes / 8); } + + /** + * Writes the content of this row into a memory address, identified by an object and an offset. + * The target memory address must already been allocated, and have enough space to hold all the + * bytes in this string. + */ + public void writeToMemory(Object target, long targetOffset) { + PlatformDependent.copyMemory( + baseObject, + baseOffset, + target, + targetOffset, + sizeInBytes + ); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 385d9671386dc..ad3977281d1a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -30,11 +30,11 @@ abstract class InternalRow extends Serializable { def numFields: Int - def get(ordinal: Int): Any + def get(ordinal: Int): Any = get(ordinal, null) def genericGet(ordinal: Int): Any = get(ordinal, null) - def get(ordinal: Int, dataType: DataType): Any = get(ordinal) + def get(ordinal: Int, dataType: DataType): Any def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 5953a093dc684..b877ce47c083f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -219,7 +219,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).isNull = true } - override def get(i: Int): Any = values(i).boxed + override def get(i: Int, dataType: DataType): Any = values(i).boxed override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).boxed.asInstanceOf[InternalRow] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index a361b216eb472..35920147105ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -183,7 +183,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i) { + public Object get(int i, ${classOf[DataType].getName} dataType) { if (isNullAt(i)) return null; switch (i) { $getCases diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index daeabe8e90f1d..b7c4ece4a16fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -99,7 +99,7 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) extends Internal override def numFields: Int = values.length - override def get(i: Int): Any = values(i) + override def get(i: Int, dataType: DataType): Any = values(i) override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).asInstanceOf[InternalRow] @@ -130,7 +130,7 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow { override def numFields: Int = values.length - override def get(i: Int): Any = values(i) + override def get(i: Int, dataType: DataType): Any = values(i) override def getStruct(ordinal: Int, numFields: Int): InternalRow = { values(ordinal).asInstanceOf[InternalRow] From 9989064b9afc2007e769d520798a6b1da27d3e21 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 20:40:24 -0700 Subject: [PATCH 38/67] JoinedRow. --- .../org/apache/spark/sql/catalyst/expressions/Projection.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index cc89d74146b34..27d6ff587ab71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -198,7 +198,7 @@ class JoinedRow extends InternalRow { if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) } - override def get(i: Int): Any = + override def get(i: Int, dataType: DataType): Any = if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) override def isNullAt(i: Int): Boolean = From 24a3e4604acb88c1fba83b577de547662285403a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 21:13:00 -0700 Subject: [PATCH 39/67] Added support for DateType/TimestampType. Updated ExpressionEvalHelper to avoid conversion. --- .../sql/catalyst/expressions/UnsafeRow.java | 4 ++++ .../expressions/ExpressionEvalHelper.scala | 24 ++++++++++--------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 7812e8dc9c32a..f98a53feb4167 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -255,6 +255,10 @@ public Object get(int ordinal, DataType dataType) { return getDouble(ordinal); } else if (dataType instanceof DecimalType) { return getDecimal(ordinal); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); } else if (dataType instanceof StringType) { return getUTF8String(ordinal); } else if (dataType instanceof StructType) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 8b0f90cf3a623..eedac664c6f53 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -160,17 +160,20 @@ trait ExpressionEvalHelper { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { - val plan = generateProject( + val project = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - - val unsafeRow = plan(inputRow) - // UnsafeRow cannot be compared with GenericInternalRow directly - val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) - val expectedRow = InternalRow(expected) - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") + val out = project(inputRow) + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + + if (expected == null) { + if (!out.isNullAt(0)) { + val actual = out.get(0, expression.dataType) + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + } + } else if (out.get(0, expression.dataType) != expected) { + val actual = out.get(0, expression.dataType) + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") } } @@ -200,8 +203,7 @@ trait ExpressionEvalHelper { plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - actual = FromUnsafeProjection(expression.dataType :: Nil)( - plan(inputRow)).get(0, expression.dataType) + actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) } } From fb6ca303bf7bafe587c24e1022b9d2ab2a4115fc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 21:14:22 -0700 Subject: [PATCH 40/67] Support BinaryType. --- .../org/apache/spark/sql/catalyst/expressions/UnsafeRow.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index f98a53feb4167..0fb33dd5a15a0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -259,6 +259,8 @@ public Object get(int ordinal, DataType dataType) { return getInt(ordinal); } else if (dataType instanceof TimestampType) { return getLong(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); } else if (dataType instanceof StringType) { return getUTF8String(ordinal); } else if (dataType instanceof StructType) { From 0f57c556d15f1c1294e07b44fdd9e41c461d2dfd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 21:16:16 -0700 Subject: [PATCH 41/67] Reset the changes in ExpressionEvalHelper. --- .../expressions/ExpressionEvalHelper.scala | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index eedac664c6f53..7baa0291477b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -156,24 +156,21 @@ trait ExpressionEvalHelper { } protected def checkEvalutionWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { - val project = generateProject( + val plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - val out = project(inputRow) - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - - if (expected == null) { - if (!out.isNullAt(0)) { - val actual = out.get(0, expression.dataType) - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") - } - } else if (out.get(0, expression.dataType) != expected) { - val actual = out.get(0, expression.dataType) - fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expected$input") + + val unsafeRow = plan(inputRow) + // UnsafeRow cannot be compared with GenericInternalRow directly + val actual = FromUnsafeProjection(expression.dataType :: Nil)(unsafeRow) + val expectedRow = InternalRow(expected) + if (actual != expectedRow) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail(s"Incorrect Evaluation: $expression, actual: $actual, expected: $expectedRow$input") } } @@ -187,9 +184,9 @@ trait ExpressionEvalHelper { } protected def checkDoubleEvaluation( - expression: => Expression, - expected: Spread[Double], - inputRow: InternalRow = EmptyRow): Unit = { + expression: => Expression, + expected: Spread[Double], + inputRow: InternalRow = EmptyRow): Unit = { checkEvaluationWithoutCodegen(expression, expected) checkEvaluationWithGeneratedMutableProjection(expression, expected) checkEvaluationWithOptimization(expression, expected) @@ -203,7 +200,8 @@ trait ExpressionEvalHelper { plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) - actual = plan(inputRow).get(0, expression.dataType) + actual = FromUnsafeProjection(expression.dataType :: Nil)( + plan(inputRow)).get(0, expression.dataType) assert(checkResult(actual, expected)) } } From 3063788608a75c23c7f0c5691934ba3566804849 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 26 Jul 2015 21:18:07 -0700 Subject: [PATCH 42/67] Reset the change for real this time. --- .../catalyst/expressions/ExpressionEvalHelper.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 7baa0291477b5..8b0f90cf3a623 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -156,9 +156,9 @@ trait ExpressionEvalHelper { } protected def checkEvalutionWithUnsafeProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { + expression: Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { val plan = generateProject( GenerateUnsafeProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), @@ -184,9 +184,9 @@ trait ExpressionEvalHelper { } protected def checkDoubleEvaluation( - expression: => Expression, - expected: Spread[Double], - inputRow: InternalRow = EmptyRow): Unit = { + expression: => Expression, + expected: Spread[Double], + inputRow: InternalRow = EmptyRow): Unit = { checkEvaluationWithoutCodegen(expression, expected) checkEvaluationWithGeneratedMutableProjection(expression, expected) checkEvaluationWithOptimization(expression, expected) From 6214682542f7ac6410461ea647ad1016133d4e63 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 16:01:15 -0700 Subject: [PATCH 43/67] Fixes to null handling in UnsafeRow --- .../sql/catalyst/expressions/UnsafeRow.java | 14 +++----------- .../org/apache/spark/sql/UnsafeRowSuite.scala | 17 ++++++++++++++++- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 955fb4226fc0e..64a8edc34d681 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -239,7 +239,7 @@ public Object get(int ordinal) { @Override public Object get(int ordinal, DataType dataType) { - if (dataType instanceof NullType) { + if (isNullAt(ordinal) || dataType instanceof NullType) { return null; } else if (dataType instanceof BooleanType) { return getBoolean(ordinal); @@ -313,21 +313,13 @@ public long getLong(int ordinal) { @Override public float getFloat(int ordinal) { assertIndexIsValid(ordinal); - if (isNullAt(ordinal)) { - return Float.NaN; - } else { - return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); - } + return PlatformDependent.UNSAFE.getFloat(baseObject, getFieldOffset(ordinal)); } @Override public double getDouble(int ordinal) { assertIndexIsValid(ordinal); - if (isNullAt(ordinal)) { - return Float.NaN; - } else { - return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); - } + return PlatformDependent.UNSAFE.getDouble(baseObject, getFieldOffset(ordinal)); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index ad3bb1744cb3c..e72a1bc6c4e20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -22,7 +22,7 @@ import java.io.ByteArrayOutputStream import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} -import org.apache.spark.sql.types.{DataType, IntegerType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.memory.MemoryAllocator import org.apache.spark.unsafe.types.UTF8String @@ -67,4 +67,19 @@ class UnsafeRowSuite extends SparkFunSuite { assert(bytesFromArrayBackedRow === bytesFromOffheapRow) } + + test("calling getDouble() and getFloat() on null columns") { + val row = InternalRow.apply(null, null) + val unsafeRow = UnsafeProjection.create(Array[DataType](FloatType, DoubleType)).apply(row) + assert(unsafeRow.getFloat(0) === row.getFloat(0)) + assert(unsafeRow.getDouble(1) === row.getDouble(1)) + } + + test("calling get(ordinal, datatype) on null columns") { + val row = InternalRow.apply(null) + val unsafeRow = UnsafeProjection.create(Array[DataType](NullType)).apply(row) + for (dataType <- DataTypeTestUtils.atomicTypes) { + assert(unsafeRow.get(0, dataType) === null) + } + } } From 4c09a78d6d95cd6bb79371907962863862cf5c1b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 20 Jul 2015 18:55:45 -0700 Subject: [PATCH 44/67] Enable Unsafe by default --- sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 40eba33f595ca..a8022e173990d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -229,7 +229,7 @@ private[spark] object SQLConf { " a specific query.") val UNSAFE_ENABLED = booleanConf("spark.sql.unsafe.enabled", - defaultValue = Some(false), + defaultValue = Some(true), doc = "When true, use the new optimized Tungsten physical execution backend.") val DIALECT = stringConf( From 54579b13451f70adc21c36692fca2380f94268c7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Jul 2015 15:32:42 -0700 Subject: [PATCH 45/67] Disable unsafe Exchange path when RangePartitioning is used --- .../scala/org/apache/spark/sql/execution/Exchange.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 41a0c519ba527..70e5031fb63c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -47,7 +47,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def canProcessSafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true + override def canProcessUnsafeRows: Boolean = { + // Do not use the Unsafe path if we are using a RangePartitioning, since this may lead to + // an interpreted RowOrdering being applied to an UnsafeRow, which will lead to + // ClassCastExceptions at runtime. This check can be removed after SPARK-9054 is fixed. + !newPartitioning.isInstanceOf[RangePartitioning] + } /** * Determines whether records must be defensively copied before being sent to the shuffle. From 601fcbd8415d47cece40193c94f7e83316969f72 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 21 Jul 2015 16:10:15 -0700 Subject: [PATCH 46/67] Reduce page size to make HiveCompatibilitySuite pass. --- .../spark/util/collection/unsafe/sort/UnsafeExternalSorter.java | 2 +- .../main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 80b03d7e99e2b..787555aa45c45 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -41,7 +41,7 @@ public final class UnsafeExternalSorter { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); - private static final int PAGE_SIZE = 1 << 27; // 128 megabytes + private static final int PAGE_SIZE = 1 << 22; // 4 megabytes @VisibleForTesting static final int MAX_RECORD_SIZE = PAGE_SIZE - 4; diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index d0bde69cc1068..3c56b321a0caf 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -78,7 +78,7 @@ public final class BytesToBytesMap { * The size of the data pages that hold key and value data. Map entries cannot span multiple * pages, so this limits the maximum entry size. */ - private static final long PAGE_SIZE_BYTES = 1L << 26; // 64 megabytes + private static final long PAGE_SIZE_BYTES = 1L << 22; // 4 megabytes /** * The maximum number of keys that BytesToBytesMap supports. The hash table has to be From e5f7464623ac6d04e426941c49df0770191610c4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 26 Jul 2015 20:01:30 -0700 Subject: [PATCH 47/67] Add task completion callback to avoid leak in limit after sort --- .../unsafe/sort/UnsafeExternalSorter.java | 14 +++++++++++++ .../execution/UnsafeExternalSortSuite.scala | 20 +------------------ 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 787555aa45c45..a113fd8cae466 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -20,6 +20,9 @@ import java.io.IOException; import java.util.LinkedList; +import scala.runtime.AbstractFunction0; +import scala.runtime.BoxedUnit; + import com.google.common.annotations.VisibleForTesting; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -92,6 +95,17 @@ public UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; initializeForWriting(); + + // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at + // the end of the task. This is necessary to avoid memory leaks in when the downstream operator + // does not fully consume the sorter's output (e.g. sort followed by limit). + taskContext.addOnCompleteCallback(new AbstractFunction0() { + @Override + public BoxedUnit apply() { + freeMemory(); + return null; + } + }); } // TODO: metrics tracking + integration with shuffle write metrics diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 7a4baa9e4a49d..138636b0c65b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -36,10 +36,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get) } - ignore("sort followed by limit should not leak memory") { - // TODO: this test is going to fail until we implement a proper iterator interface - // with a close() method. - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") + test("sort followed by limit") { checkThatPlansAgree( (1 to 100).map(v => Tuple1(v)).toDF("a"), (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), @@ -48,21 +45,6 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { ) } - test("sort followed by limit") { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - try { - checkThatPlansAgree( - (1 to 100).map(v => Tuple1(v)).toDF("a"), - (child: SparkPlan) => Limit(10, UnsafeExternalSort('a.asc :: Nil, true, child)), - (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)), - sortAnswers = false - ) - } finally { - TestSQLContext.sparkContext.conf.set("spark.unsafe.exceptionOnMemoryLeak", "false") - - } - } - test("sorting does not crash for large inputs") { val sortOrder = 'a.asc :: Nil val stringLength = 1024 * 1024 * 2 From c8eb2ee505cef0f65b1a2865cc217270d2fcd45f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 16:12:29 -0700 Subject: [PATCH 48/67] Fix test in UnsafeRowConverterSuite --- .../sql/catalyst/expressions/UnsafeRowConverterSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 2834b54e8fb2e..b7bc17f89e82f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -146,8 +146,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(createdFromNull.getShort(3) === 0) assert(createdFromNull.getInt(4) === 0) assert(createdFromNull.getLong(5) === 0) - assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getFloat(6) === 0.0f) + assert(createdFromNull.getDouble(7) === 0.0d) assert(createdFromNull.getUTF8String(8) === null) assert(createdFromNull.getBinary(9) === null) // assert(createdFromNull.get(10) === null) From ef1c62d0b9e536e8e82e99b75abab23666283e58 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 27 Jul 2015 23:47:05 -0700 Subject: [PATCH 49/67] Also match TungstenProject in checkNumProjects --- .../scala/org/apache/spark/sql/ColumnExpressionSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1f9f7118c3f04..5d8d232fd617e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.scalatest.Matchers._ -import org.apache.spark.sql.execution.Project +import org.apache.spark.sql.execution.{Project, TungstenProject} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -523,6 +523,7 @@ class ColumnExpressionSuite extends QueryTest { def checkNumProjects(df: DataFrame, expectedNumProjects: Int): Unit = { val projects = df.queryExecution.executedPlan.collect { case project: Project => project + case tungstenProject: TungstenProject => tungstenProject } assert(projects.size === expectedNumProjects) } From 203f1d85eb4e08e27eb5b5da1e51919c9b0ee4c4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 10:58:34 -0700 Subject: [PATCH 50/67] Use TaskAttemptIds to track unroll memory --- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../org/apache/spark/executor/Executor.scala | 2 +- .../apache/spark/storage/MemoryStore.scala | 74 ++++++++-------- .../spark/storage/BlockManagerSuite.scala | 84 +++++++++---------- 4 files changed, 84 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 598953ac3bcc8..1d021a08f2da9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -267,7 +267,7 @@ private[spark] class PythonRDD( // Release memory used by this thread for shuffles env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() + env.blockManager.memoryStore.releaseUnrollMemoryForThisTask() } } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index e76664f1bd7b0..7885560fbd5f0 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -316,7 +316,7 @@ private[spark] class Executor( // Release memory used by this thread for shuffles env.shuffleMemoryManager.releaseMemoryForThisThread() // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() + env.blockManager.memoryStore.releaseUnrollMemoryForThisTask() runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index ed609772e6979..5c126b01779ec 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -23,6 +23,7 @@ import java.util.LinkedHashMap import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.TaskContext import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -43,7 +44,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Ensure only one thread is putting, and if necessary, dropping blocks at any given time private val accountingLock = new Object - // A mapping from thread ID to amount of memory used for unrolling a block (in bytes) + // A mapping from taskAttemptId ID to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. @@ -259,12 +260,12 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 // Previous unroll memory held by this thread, for releasing later (only at the very end) - val previousMemoryReserved = currentUnrollMemoryForThisThread + val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(initialMemoryThreshold) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -283,7 +284,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Hold the accounting lock, in case another thread concurrently puts a block that // takes up the unrolling space we just ensured here accountingLock.synchronized { - if (!reserveUnrollMemoryForThisThread(amountToRequest)) { + if (!reserveUnrollMemoryForThisTask(amountToRequest)) { // If the first request is not granted, try again after ensuring free space // If there is still not enough space, give up and drop the partition val spaceToEnsure = maxUnrollMemory - currentUnrollMemory @@ -291,7 +292,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) val result = ensureFreeSpace(blockId, spaceToEnsure) droppedBlocks ++= result.droppedBlocks } - keepUnrolling = reserveUnrollMemoryForThisThread(amountToRequest) + keepUnrolling = reserveUnrollMemoryForThisTask(amountToRequest) } } // New threshold is currentSize * memoryGrowthFactor @@ -317,9 +318,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // later when the task finishes. if (keepUnrolling) { accountingLock.synchronized { - val amountToRelease = currentUnrollMemoryForThisThread - previousMemoryReserved - releaseUnrollMemoryForThisThread(amountToRelease) - reservePendingUnrollMemoryForThisThread(amountToRelease) + val amountToRelease = currentUnrollMemoryForThisTask - previousMemoryReserved + releaseUnrollMemoryForThisTask(amountToRelease) + reservePendingUnrollMemoryForThisTask(amountToRelease) } } } @@ -397,7 +398,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) } } // Release the unroll memory used because we no longer need the underlying Array - releasePendingUnrollMemoryForThisThread() + releasePendingUnrollMemoryForThisTask() } ResultWithDroppedBlocks(putSuccess, droppedBlocks) } @@ -482,16 +483,20 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) entries.synchronized { entries.containsKey(blockId) } } + private def currentTaskAttemptId(): Long = { + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1) + } + /** - * Reserve additional memory for unrolling blocks used by this thread. + * Reserve additional memory for unrolling blocks used by this task. * Return whether the request is granted. */ - def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { - val threadId = Thread.currentThread().getId - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, 0L) + memory + val taskAttemptId = currentTaskAttemptId() + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } granted } @@ -499,62 +504,63 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) /** * Release memory used by this thread for unrolling blocks. - * If the amount is not specified, remove the current thread's allocation altogether. + * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { - val threadId = Thread.currentThread().getId + def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { if (memory < 0) { - unrollMemoryMap.remove(threadId) + unrollMemoryMap.remove(taskAttemptId) } else { - unrollMemoryMap(threadId) = unrollMemoryMap.getOrElse(threadId, memory) - memory + unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory // If this thread claims no more unroll memory, release it completely - if (unrollMemoryMap(threadId) <= 0) { - unrollMemoryMap.remove(threadId) + if (unrollMemoryMap(taskAttemptId) <= 0) { + unrollMemoryMap.remove(taskAttemptId) } } } } /** - * Reserve the unroll memory of current unroll successful block used by this thread + * Reserve the unroll memory of current unroll successful block used by this task * until actually put the block into memory entry. */ - def reservePendingUnrollMemoryForThisThread(memory: Long): Unit = { - val threadId = Thread.currentThread().getId + def reservePendingUnrollMemoryForThisTask(memory: Long): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap(threadId) = pendingUnrollMemoryMap.getOrElse(threadId, 0L) + memory + pendingUnrollMemoryMap(taskAttemptId) = + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } } /** - * Release pending unroll memory of current unroll successful block used by this thread + * Release pending unroll memory of current unroll successful block used by this task */ - def releasePendingUnrollMemoryForThisThread(): Unit = { - val threadId = Thread.currentThread().getId + def releasePendingUnrollMemoryForThisTask(): Unit = { + val taskAttemptId = currentTaskAttemptId() accountingLock.synchronized { - pendingUnrollMemoryMap.remove(threadId) + pendingUnrollMemoryMap.remove(taskAttemptId) } } /** - * Return the amount of memory currently occupied for unrolling blocks across all threads. + * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum + pendingUnrollMemoryMap.values.sum } /** - * Return the amount of memory currently occupied for unrolling blocks by this thread. + * Return the amount of memory currently occupied for unrolling blocks by this task. */ - def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { + def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) } /** - * Return the number of threads currently unrolling blocks. + * Return the number of tasks currently unrolling blocks. */ - def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + def numTasksUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } /** * Log information about current memory usage. @@ -566,7 +572,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo( s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + - s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"$numTasksUnrolling tasks(s)) = ${Utils.bytesToString(totalMemory)}. " + s"Storage limit = ${Utils.bytesToString(maxMemory)}." ) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index bcee901f5dd5f..f480fd107a0c2 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1004,32 +1004,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store = makeBlockManager(12000) val memoryStore = store.memoryStore assert(memoryStore.currentUnrollMemory === 0) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Reserve - memoryStore.reserveUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 100) - memoryStore.reserveUnrollMemoryForThisThread(200) - assert(memoryStore.currentUnrollMemoryForThisThread === 300) - memoryStore.reserveUnrollMemoryForThisThread(500) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) - memoryStore.reserveUnrollMemoryForThisThread(1000000) - assert(memoryStore.currentUnrollMemoryForThisThread === 800) // not granted + memoryStore.reserveUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 100) + memoryStore.reserveUnrollMemoryForThisTask(200) + assert(memoryStore.currentUnrollMemoryForThisTask === 300) + memoryStore.reserveUnrollMemoryForThisTask(500) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) + memoryStore.reserveUnrollMemoryForThisTask(1000000) + assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 700) - memoryStore.releaseUnrollMemoryForThisThread(100) - assert(memoryStore.currentUnrollMemoryForThisThread === 600) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 700) + memoryStore.releaseUnrollMemoryForThisTask(100) + assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again - memoryStore.reserveUnrollMemoryForThisThread(4400) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) - memoryStore.reserveUnrollMemoryForThisThread(20000) - assert(memoryStore.currentUnrollMemoryForThisThread === 5000) // not granted + memoryStore.reserveUnrollMemoryForThisTask(4400) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) + memoryStore.reserveUnrollMemoryForThisTask(20000) + assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again - memoryStore.releaseUnrollMemoryForThisThread(1000) - assert(memoryStore.currentUnrollMemoryForThisThread === 4000) - memoryStore.releaseUnrollMemoryForThisThread() // release all - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + memoryStore.releaseUnrollMemoryForThisTask(1000) + assert(memoryStore.currentUnrollMemoryForThisTask === 4000) + memoryStore.releaseUnrollMemoryForThisTask() // release all + assert(memoryStore.currentUnrollMemoryForThisTask === 0) } /** @@ -1060,24 +1060,24 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) val memoryStore = store.memoryStore val droppedBlocks = new ArrayBuffer[(BlockId, BlockStatus)] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with all the space in the world. This should succeed and return an array. var unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) - memoryStore.releasePendingUnrollMemoryForThisThread() + assert(memoryStore.currentUnrollMemoryForThisTask === 0) + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll with not enough space. This should succeed after kicking out someBlock1. store.putIterator("someBlock1", smallList.iterator, StorageLevel.MEMORY_ONLY) store.putIterator("someBlock2", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", smallList.iterator, droppedBlocks) verifyUnroll(smallList.iterator, unrollResult, shouldBeArray = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock1")) droppedBlocks.clear() - memoryStore.releasePendingUnrollMemoryForThisThread() + memoryStore.releasePendingUnrollMemoryForThisTask() // Unroll huge block with not enough space. Even after ensuring free space of 12000 * 0.4 = // 4800 bytes, there is still not enough room to unroll this block. This returns an iterator. @@ -1085,7 +1085,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE store.putIterator("someBlock3", smallList.iterator, StorageLevel.MEMORY_ONLY) unrollResult = memoryStore.unrollSafely("unroll", bigList.iterator, droppedBlocks) verifyUnroll(bigList.iterator, unrollResult, shouldBeArray = false) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator assert(droppedBlocks.size === 1) assert(droppedBlocks.head._1 === TestBlockId("someBlock2")) droppedBlocks.clear() @@ -1099,7 +1099,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll with plenty of space. This should succeed and cache both blocks. val result1 = memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) @@ -1110,7 +1110,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(result2.size > 0) assert(result1.data.isLeft) // unroll did not drop this block to disk assert(result2.data.isLeft) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Re-put these two blocks so block manager knows about them too. Otherwise, block manager // would not know how to drop them from memory later. @@ -1126,7 +1126,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b1")) assert(memoryStore.contains("b2")) assert(memoryStore.contains("b3")) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.remove("b3") store.putIterator("b3", smallIterator, memOnly) @@ -1138,7 +1138,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!memoryStore.contains("b2")) assert(memoryStore.contains("b3")) assert(!memoryStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } /** @@ -1153,7 +1153,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val bigList = List.fill(40)(new Array[Byte](1000)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] def bigIterator: Iterator[Any] = bigList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) store.putIterator("b1", smallIterator, memAndDisk) store.putIterator("b2", smallIterator, memAndDisk) @@ -1170,7 +1170,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!diskStore.contains("b3")) memoryStore.remove("b3") store.putIterator("b3", smallIterator, StorageLevel.MEMORY_ONLY) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll huge block with not enough space. This should fail and drop the new block to disk // directly in addition to kicking out b2 in the process. Memory store should contain only @@ -1186,7 +1186,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(diskStore.contains("b2")) assert(!diskStore.contains("b3")) assert(diskStore.contains("b4")) - assert(memoryStore.currentUnrollMemoryForThisThread > 0) // we returned an iterator + assert(memoryStore.currentUnrollMemoryForThisTask > 0) // we returned an iterator } test("multiple unrolls by the same thread") { @@ -1195,32 +1195,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryStore = store.memoryStore val smallList = List.fill(40)(new Array[Byte](100)) def smallIterator: Iterator[Any] = smallList.iterator.asInstanceOf[Iterator[Any]] - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // All unroll memory used is released because unrollSafely returned an array memoryStore.putIterator("b1", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) memoryStore.putIterator("b2", smallIterator, memOnly, returnValues = true) - assert(memoryStore.currentUnrollMemoryForThisThread === 0) + assert(memoryStore.currentUnrollMemoryForThisTask === 0) // Unroll memory is not released because unrollSafely returned an iterator // that still depends on the underlying vector used in the process memoryStore.putIterator("b3", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB3 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB3 > 0) // The unroll memory owned by this thread builds on top of its value after the previous unrolls memoryStore.putIterator("b4", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB4 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB4 > unrollMemoryAfterB3) // ... but only to a certain extent (until we run out of free space to grant new unroll memory) memoryStore.putIterator("b5", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB5 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b6", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB6 = memoryStore.currentUnrollMemoryForThisTask memoryStore.putIterator("b7", smallIterator, memOnly, returnValues = true) - val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisThread + val unrollMemoryAfterB7 = memoryStore.currentUnrollMemoryForThisTask assert(unrollMemoryAfterB5 === unrollMemoryAfterB4) assert(unrollMemoryAfterB6 === unrollMemoryAfterB4) assert(unrollMemoryAfterB7 === unrollMemoryAfterB4) From d7a2788147315e7ace0543eafc1e650b1b3220a0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 11:09:11 -0700 Subject: [PATCH 51/67] Use TaskAttemptIds to track shuffle memory --- .../apache/spark/api/python/PythonRDD.scala | 2 +- .../org/apache/spark/executor/Executor.scala | 2 +- .../spark/shuffle/ShuffleMemoryManager.scala | 86 ++++++++++--------- .../apache/spark/storage/MemoryStore.scala | 2 +- .../shuffle/ShuffleMemoryManagerSuite.scala | 4 +- 5 files changed, 50 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 1d021a08f2da9..8d21a8e618d6e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -265,7 +265,7 @@ private[spark] class PythonRDD( } } finally { // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() + env.shuffleMemoryManager.releaseMemoryForThisTask() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisTask() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 7885560fbd5f0..19379918dc6cf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -314,7 +314,7 @@ private[spark] class Executor( } finally { // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisThread() + env.shuffleMemoryManager.releaseMemoryForThisTask() // Release memory used by this thread for unrolling blocks env.blockManager.memoryStore.releaseUnrollMemoryForThisTask() runningTasks.remove(taskId) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index 3bcc7178a3d8b..ed164fbcb8ec2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -19,95 +19,99 @@ package org.apache.spark.shuffle import scala.collection.mutable -import org.apache.spark.{Logging, SparkException, SparkConf} +import org.apache.spark.{Logging, SparkException, SparkConf, TaskContext} /** - * Allocates a pool of memory to task threads for use in shuffle operations. Each disk-spilling + * Allocates a pool of memory to tasks for use in shuffle operations. Each disk-spilling * collection (ExternalAppendOnlyMap or ExternalSorter) used by these tasks can acquire memory * from this pool and release it as it spills data out. When a task ends, all its memory will be * released by the Executor. * - * This class tries to ensure that each thread gets a reasonable share of memory, instead of some - * thread ramping up to a large amount first and then causing others to spill to disk repeatedly. - * If there are N threads, it ensures that each thread can acquire at least 1 / 2N of the memory + * This class tries to ensure that each task gets a reasonable share of memory, instead of some + * task ramping up to a large amount first and then causing others to spill to disk repeatedly. + * If there are N tasks, it ensures that each tasks can acquire at least 1 / 2N of the memory * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the - * set of active threads and redo the calculations of 1 / 2N and 1 / N in waiting threads whenever + * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever * this set changes. This is all done by synchronizing access on "this" to mutate state and using * wait() and notifyAll() to signal changes. */ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { - private val threadMemory = new mutable.HashMap[Long, Long]() // threadId -> memory bytes + private val taskMemory = new mutable.HashMap[Long, Long]() // taskAttemptId -> memory bytes def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) + private def currenttaskAttemptId(): Long = { + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) + } + /** - * Try to acquire up to numBytes memory for the current thread, and return the number of bytes + * Try to acquire up to numBytes memory for the current task, and return the number of bytes * obtained, or 0 if none can be allocated. This call may block until there is enough free memory - * in some situations, to make sure each thread has a chance to ramp up to at least 1 / 2N of the - * total memory pool (where N is the # of active threads) before it is forced to spill. This can - * happen if the number of threads increases but an older thread had a lot of memory already. + * in some situations, to make sure each task has a chance to ramp up to at least 1 / 2N of the + * total memory pool (where N is the # of active tasks) before it is forced to spill. This can + * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val threadId = Thread.currentThread().getId + val taskAttemptId = currenttaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) - // Add this thread to the threadMemory map just so we can keep an accurate count of the number - // of active threads, to let other threads ramp down their memory in calls to tryToAcquire - if (!threadMemory.contains(threadId)) { - threadMemory(threadId) = 0L - notifyAll() // Will later cause waiting threads to wake up and check numThreads again + // Add this task to the taskMemory map just so we can keep an accurate count of the number + // of active tasks, to let other tasks ramp down their memory in calls to tryToAcquire + if (!taskMemory.contains(taskAttemptId)) { + taskMemory(taskAttemptId) = 0L + notifyAll() // Will later cause waiting tasks to wake up and check numThreads again } // Keep looping until we're either sure that we don't want to grant this request (because this - // thread would have more than 1 / numActiveThreads of the memory) or we have enough free - // memory to give it (we always let each thread get at least 1 / (2 * numActiveThreads)). + // task would have more than 1 / numActiveTasks of the memory) or we have enough free + // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)). while (true) { - val numActiveThreads = threadMemory.keys.size - val curMem = threadMemory(threadId) - val freeMemory = maxMemory - threadMemory.values.sum + val numActiveTasks = taskMemory.keys.size + val curMem = taskMemory(taskAttemptId) + val freeMemory = maxMemory - taskMemory.values.sum - // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads; + // How much we can grant this task; don't let it grow to more than 1 / numActiveTasks; // don't let it be negative - val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem)) + val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveTasks) - curMem)) - if (curMem < maxMemory / (2 * numActiveThreads)) { - // We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking; - // if we can't give it this much now, wait for other threads to free up memory - // (this happens if older threads allocated lots of memory before N grew) - if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveThreads) - curMem)) { + if (curMem < maxMemory / (2 * numActiveTasks)) { + // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking; + // if we can't give it this much now, wait for other tasks to free up memory + // (this happens if older tasks allocated lots of memory before N grew) + if (freeMemory >= math.min(maxToGrant, maxMemory / (2 * numActiveTasks) - curMem)) { val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $threadId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo(s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { // Only give it as much memory as is free, which might be none if it reached 1 / numThreads val toGrant = math.min(maxToGrant, freeMemory) - threadMemory(threadId) += toGrant + taskMemory(taskAttemptId) += toGrant return toGrant } } 0L // Never reached } - /** Release numBytes bytes for the current thread. */ + /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val threadId = Thread.currentThread().getId - val curMem = threadMemory.getOrElse(threadId, 0L) + val taskAttemptId = currenttaskAttemptId() + val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( - s"Internal error: release called on ${numBytes} bytes but thread only has ${curMem}") + s"Internal error: release called on ${numBytes} bytes but task only has ${curMem}") } - threadMemory(threadId) -= numBytes + taskMemory(taskAttemptId) -= numBytes notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } - /** Release all memory for the current thread and mark it as inactive (e.g. when a task ends). */ - def releaseMemoryForThisThread(): Unit = synchronized { - val threadId = Thread.currentThread().getId - threadMemory.remove(threadId) + /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ + def releaseMemoryForThisTask(): Unit = synchronized { + val taskAttemptId = currenttaskAttemptId() + taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } } diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 5c126b01779ec..18f728c0484c0 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -484,7 +484,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } private def currentTaskAttemptId(): Long = { - Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1) + Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) } /** diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 96778c9ebafb1..8124a5e515010 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -50,7 +50,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(manager.tryToAcquire(300L) === 300L) assert(manager.tryToAcquire(300L) === 200L) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() assert(manager.tryToAcquire(1000L) === 1000L) assert(manager.tryToAcquire(100L) === 0L) } @@ -253,7 +253,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make // sure the other thread blocks for some time otherwise Thread.sleep(300) - manager.releaseMemoryForThisThread() + manager.releaseMemoryForThisTask() } val t2 = startThread("t2") { From b38e70fca7584a5cf099aac12c8cbf85d6d2e4eb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 11:09:45 -0700 Subject: [PATCH 52/67] Roll back fix in PySpark, which is no longer necessary --- .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8d21a8e618d6e..ae72ad1565c8e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -263,11 +263,6 @@ private[spark] class PythonRDD( if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } - } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisTask() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisTask() } } } From d8bd892722ec5165fecd4f59b0722e4da13edcc9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 11:13:06 -0700 Subject: [PATCH 53/67] Fix capitalization --- .../org/apache/spark/shuffle/ShuffleMemoryManager.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index ed164fbcb8ec2..fab2ac2f81c75 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -40,7 +40,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { def this(conf: SparkConf) = this(ShuffleMemoryManager.getMaxMemory(conf)) - private def currenttaskAttemptId(): Long = { + private def currentTaskAttemptId(): Long = { Option(TaskContext.get()).map(_.taskAttemptId()).getOrElse(-1L) } @@ -52,7 +52,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { * happen if the number of tasks increases but an older task had a lot of memory already. */ def tryToAcquire(numBytes: Long): Long = synchronized { - val taskAttemptId = currenttaskAttemptId() + val taskAttemptId = currentTaskAttemptId() assert(numBytes > 0, "invalid number of bytes requested: " + numBytes) // Add this task to the taskMemory map just so we can keep an accurate count of the number @@ -98,7 +98,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { /** Release numBytes bytes for the current task. */ def release(numBytes: Long): Unit = synchronized { - val taskAttemptId = currenttaskAttemptId() + val taskAttemptId = currentTaskAttemptId() val curMem = taskMemory.getOrElse(taskAttemptId, 0L) if (curMem < numBytes) { throw new SparkException( @@ -110,7 +110,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { /** Release all memory for the current task and mark it as inactive (e.g. when a task ends). */ def releaseMemoryForThisTask(): Unit = synchronized { - val taskAttemptId = currenttaskAttemptId() + val taskAttemptId = currentTaskAttemptId() taskMemory.remove(taskAttemptId) notifyAll() // Notify waiters who locked "this" in tryToAcquire that memory has been freed } From 56edb41f34d1dd1d736c346d77db9ded726954e1 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 11:17:57 -0700 Subject: [PATCH 54/67] Move Executor's cleanup into Task so that TaskContext is defined when cleanup is performed --- .../scala/org/apache/spark/executor/Executor.scala | 4 ---- .../main/scala/org/apache/spark/scheduler/Task.scala | 11 +++++++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 19379918dc6cf..7bc7fce7ae8dd 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -313,10 +313,6 @@ private[spark] class Executor( } } finally { - // Release memory used by this thread for shuffles - env.shuffleMemoryManager.releaseMemoryForThisTask() - // Release memory used by this thread for unrolling blocks - env.blockManager.memoryStore.releaseUnrollMemoryForThisTask() runningTasks.remove(taskId) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d11a00956a9a9..9ca15a6e638d6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{TaskContextImpl, TaskContext} +import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.unsafe.memory.TaskMemoryManager @@ -86,7 +86,14 @@ private[spark] abstract class Task[T]( (runTask(context), context.collectAccumulators()) } finally { context.markTaskCompleted() - TaskContext.unset() + try { + // Release memory used by this thread for shuffles + SparkEnv.get.shuffleMemoryManager.releaseMemoryForThisTask() + // Release memory used by this thread for unrolling blocks + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + } finally { + TaskContext.unset() + } } } From f4f5859ae6062e45e68477c9ad23a8c926b5d2e9 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 11:22:56 -0700 Subject: [PATCH 55/67] More thread -> task changes --- .../apache/spark/storage/MemoryStore.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 18f728c0484c0..a008428da0ff1 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -48,7 +48,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // All accesses of this map are assumed to have manually synchronized on `accountingLock` private val unrollMemoryMap = mutable.HashMap[Long, Long]() // Same as `unrollMemoryMap`, but for pending unroll memory as defined below. - // Pending unroll memory refers to the intermediate memory occupied by a thread + // Pending unroll memory refers to the intermediate memory occupied by a task // after the unroll but before the actual putting of the block in the cache. // This chunk of memory is expected to be released *as soon as* we finish // caching the corresponding block as opposed to until after the task finishes. @@ -251,15 +251,15 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) var elementsUnrolled = 0 // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true - // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. + // Initial per-task memory to request for unrolling blocks (bytes). Exposed for testing. val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 - // Memory currently reserved by this thread for this particular unrolling operation + // Memory currently reserved by this task for this particular unrolling operation var memoryThreshold = initialMemoryThreshold // Memory to request as a multiple of current vector size val memoryGrowthFactor = 1.5 - // Previous unroll memory held by this thread, for releasing later (only at the very end) + // Previous unroll memory held by this task, for releasing later (only at the very end) val previousMemoryReserved = currentUnrollMemoryForThisTask // Underlying vector for unrolling the block var vector = new SizeTrackingVector[Any] @@ -428,9 +428,9 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Take into account the amount of memory currently occupied by unrolling blocks // and minus the pending unroll memory for that block on current thread. - val threadId = Thread.currentThread().getId + val taskAttemptId = currentTaskAttemptId() val actualFreeMemory = freeMemory - currentUnrollMemory + - pendingUnrollMemoryMap.getOrElse(threadId, 0L) + pendingUnrollMemoryMap.getOrElse(taskAttemptId, 0L) if (actualFreeMemory < space) { val rddToAdd = getRddId(blockIdToAdd) @@ -456,7 +456,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) logInfo(s"${selectedBlocks.size} blocks selected for dropping") for (blockId <- selectedBlocks) { val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one thread should be dropping + // This should never be null as only one task should be dropping // blocks and removing entries. However the check is still here for // future safety. if (entry != null) { @@ -503,7 +503,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) } /** - * Release memory used by this thread for unrolling blocks. + * Release memory used by this task for unrolling blocks. * If the amount is not specified, remove the current task's allocation altogether. */ def releaseUnrollMemoryForThisTask(memory: Long = -1L): Unit = { @@ -513,7 +513,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) unrollMemoryMap.remove(taskAttemptId) } else { unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, memory) - memory - // If this thread claims no more unroll memory, release it completely + // If this task claims no more unroll memory, release it completely if (unrollMemoryMap(taskAttemptId) <= 0) { unrollMemoryMap.remove(taskAttemptId) } @@ -554,7 +554,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Return the amount of memory currently occupied for unrolling blocks by this task. */ def currentUnrollMemoryForThisTask: Long = accountingLock.synchronized { - unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) + unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** From e2b69c933d0bc74ad7bbad06ba4974858010b14e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 11:40:10 -0700 Subject: [PATCH 56/67] Fix ShuffleMemoryManagerSuite --- .../shuffle/ShuffleMemoryManagerSuite.scala | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala index 8124a5e515010..f495b6a037958 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleMemoryManagerSuite.scala @@ -17,26 +17,39 @@ package org.apache.spark.shuffle +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import org.mockito.Mockito._ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.CountDownLatch -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, TaskContext} class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { + + val nextTaskAttemptId = new AtomicInteger() + /** Launch a thread with the given body block and return it. */ private def startThread(name: String)(body: => Unit): Thread = { val thread = new Thread("ShuffleMemorySuite " + name) { override def run() { - body + try { + val taskAttemptId = nextTaskAttemptId.getAndIncrement + val mockTaskContext = mock(classOf[TaskContext], RETURNS_SMART_NULLS) + when(mockTaskContext.taskAttemptId()).thenReturn(taskAttemptId) + TaskContext.setTaskContext(mockTaskContext) + body + } finally { + TaskContext.unset() + } } } thread.start() thread } - test("single thread requesting memory") { + test("single task requesting memory") { val manager = new ShuffleMemoryManager(1000L) assert(manager.tryToAcquire(100L) === 100L) @@ -107,8 +120,8 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } - test("threads cannot grow past 1 / N") { - // Two threads request 250 bytes first, wait for each other to get it, and then request + test("tasks cannot grow past 1 / N") { + // Two tasks request 250 bytes first, wait for each other to get it, and then request // 500 more; we should only grant 250 bytes to each of them on this second request val manager = new ShuffleMemoryManager(1000L) @@ -158,7 +171,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { assert(state.t2Result2 === 250L) } - test("threads can block to get at least 1 / 2N memory") { + test("tasks can block to get at least 1 / 2N memory") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases 250 bytes, which should then be granted to t2. Further requests // by t2 will return false right away because it now has 1 / 2N of the memory. @@ -224,7 +237,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("releaseMemoryForThisThread") { + test("releaseMemoryForThisTask") { // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. It sleeps // for a bit and releases all its memory. t2 should now be able to grab all the memory. @@ -251,7 +264,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } // Sleep a bit before releasing our memory; this is hacky but it would be difficult to make - // sure the other thread blocks for some time otherwise + // sure the other task blocks for some time otherwise Thread.sleep(300) manager.releaseMemoryForThisTask() } @@ -282,7 +295,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { t2.join() } - // Both threads should've been able to acquire their memory; the second one will have waited + // Both tasks should've been able to acquire their memory; the second one will have waited // until the first one acquired 1000 bytes and then released all of it state.synchronized { assert(state.t1Result === 1000L, "t1 could not allocate memory") @@ -293,7 +306,7 @@ class ShuffleMemoryManagerSuite extends SparkFunSuite with Timeouts { } } - test("threads should not be granted a negative size") { + test("tasks should not be granted a negative size") { val manager = new ShuffleMemoryManager(1000L) manager.tryToAcquire(700L) From 63492c4411922e3465d6a2ca1132cee731019d54 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 28 Jul 2015 13:52:53 -0700 Subject: [PATCH 57/67] Fix long line. --- .../scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala index fab2ac2f81c75..f3b0fcefdcd54 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala @@ -83,7 +83,8 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging { taskMemory(taskAttemptId) += toGrant return toGrant } else { - logInfo(s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") + logInfo( + s"Thread $taskAttemptId waiting for at least 1/2N of shuffle memory pool to be free") wait() } } else { From ca8168af2d4488df2838027d092c9732e523e8ba Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 16 Aug 2015 14:52:55 -0700 Subject: [PATCH 58/67] Fix compilation with latest master. --- .../org/apache/spark/sql/types/Decimal.scala | 37 +++++----- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 68 +++++++++---------- 2 files changed, 52 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 084faefdcb5dd..d95805c24521c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.types -import java.math.{BigDecimal => JavaBigDecimal} import java.math.{RoundingMode, MathContext} import org.apache.spark.annotation.DeveloperApi @@ -33,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal._ - private var decimalVal: JavaBigDecimal = null + private var decimalVal: BigDecimal = null private var longVal: Long = 0L private var _precision: Int = 1 private var _scale: Int = 0 @@ -47,7 +46,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(longVal: Long): Decimal = { if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { // We can't represent this compactly as a long without risking overflow - this.decimalVal = new JavaBigDecimal(longVal) + this.decimalVal = BigDecimal(longVal) this.longVal = 0L } else { this.decimalVal = null @@ -89,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision < 19) { return null // Requested precision is too low to represent this value } - this.decimalVal = new JavaBigDecimal(unscaled) + this.decimalVal = BigDecimal(unscaled) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) @@ -108,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE).underlying() + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) require(decimalVal.precision <= precision, "Overflowed precision") this.longVal = 0L this._precision = precision @@ -120,7 +119,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. */ def set(decimal: BigDecimal): Decimal = { - this.decimalVal = decimal.underlying() + this.decimalVal = decimal this.longVal = 0L this._precision = decimal.precision this._scale = decimal.scale @@ -138,19 +137,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } - def toBigDecimal: BigDecimal = BigDecimal(toJavaBigDecimal) - - def toJavaBigDecimal: JavaBigDecimal = { + def toBigDecimal: BigDecimal = { if (decimalVal.ne(null)) { decimalVal } else { - JavaBigDecimal.valueOf(longVal, _scale) + BigDecimal(longVal, _scale) } } + def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying() + def toUnscaledLong: Long = { if (decimalVal.ne(null)) { - decimalVal.unscaledValue().longValue() + decimalVal.underlying().unscaledValue().longValue() } else { longVal } @@ -167,9 +166,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toDouble: Double = toJavaBigDecimal.doubleValue() + def toDouble: Double = toBigDecimal.doubleValue() - def toFloat: Float = toJavaBigDecimal.floatValue() + def toFloat: Float = toBigDecimal.floatValue() def toLong: Long = { if (decimalVal.eq(null)) { @@ -215,7 +214,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { longVal *= POW_10(diff) } else { // Give up on using Longs; switch to BigDecimal, which we'll modify below - decimalVal = JavaBigDecimal.valueOf(longVal, _scale) + decimalVal = BigDecimal(longVal, _scale) } } // In both cases, we will check whether our precision is okay below @@ -224,7 +223,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE.id) + val newVal = decimalVal.setScale(scale, ROUNDING_MODE) if (newVal.precision > precision) { return false } @@ -249,7 +248,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 } else { - toJavaBigDecimal.compareTo(other.toJavaBigDecimal) + toBigDecimal.compare(other.toBigDecimal) } } @@ -260,7 +259,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { false } - override def hashCode(): Int = toJavaBigDecimal.hashCode() + override def hashCode(): Int = toBigDecimal.hashCode() def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 @@ -312,7 +311,7 @@ object Decimal { private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0) + private val BIG_DEC_ZERO = BigDecimal(0) private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) @@ -327,7 +326,7 @@ object Decimal { def apply(value: BigDecimal): Decimal = new Decimal().set(value) - def apply(value: JavaBigDecimal): Decimal = new Decimal().set(value) + def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index b6b06c5f0b9b9..4c22a25cbd544 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -20,17 +20,14 @@ package org.apache.spark.sql.fuzzing import java.lang.reflect.InvocationTargetException import java.util.concurrent.atomic.AtomicInteger -import org.apache.spark.sql.catalyst.plans.JoinType - -import scala.reflect.runtime import scala.reflect.runtime.{universe => ru} import scala.util.Random import scala.util.control.NonFatal -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -109,17 +106,24 @@ class RandomDataFrameGenerator(seed: Long, sqlContext: SQLContext) { * them in order to construct random queries. We don't have a source of truth for these random * queries but nevertheless they are still useful for testing that we don't crash in bad ways. */ -class DataFrameFuzzingSuite extends SparkFunSuite { +class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { val tempDir = Utils.createTempDir() - private val dataGenerator = new RandomDataFrameGenerator(123, TestSQLContext) + private var sqlContext: SQLContext = _ + private var dataGenerator: RandomDataFrameGenerator = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataGenerator = new RandomDataFrameGenerator(123, sqlContext) + sqlContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) + } def randomChoice[T](values: Seq[T]): T = { values(Random.nextInt(values.length)) } - val m = ru.runtimeMirror(this.getClass.getClassLoader) val whitelistedParameterTypes = Set( @@ -257,12 +261,6 @@ class DataFrameFuzzingSuite extends SparkFunSuite { throw new Exception(e) } } - //TestSQLContext.conf.setConf(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS, false) - TestSQLContext.conf.setConf(SQLConf.UNSAFE_ENABLED, true) - TestSQLContext.conf.setConf(SQLConf.SORTMERGE_JOIN, true) - TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true) - - TestSQLContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) val ignoredAnalysisExceptionMessages = Seq( "can only be performed on tables with the same number of columns", @@ -277,25 +275,27 @@ class DataFrameFuzzingSuite extends SparkFunSuite { "Cannot resolve column name" // TODO: only ignore for join? ) - for (_ <- 1 to 1000) { - println("-" * 80) - try { - val df = dataGenerator.randomDataFrame( - numCols = Random.nextInt(2) + 1, - numRows = 20, - allowComplexTypes = false) - val df1 = tryToExecute(applyRandomTransformationToDataFrame(df)) - val df2 = tryToExecute(applyRandomTransformationToDataFrame(df1)) - } catch { - case e: NoDataGeneratorException => - println("skipped due to lack of data generator") - case e: UnresolvedException[_] => - println("skipped due to unresolved") - case e: Exception - if ignoredAnalysisExceptionMessages.exists { - m => e.getMessage.toLowerCase.contains(m.toLowerCase) - } => println("Skipped due to expected AnalysisException") - } - } + test("fuzz test") { + for (_ <- 1 to 1000) { + println("-" * 80) + try { + val df = dataGenerator.randomDataFrame( + numCols = Random.nextInt(2) + 1, + numRows = 20, + allowComplexTypes = false) + val df1 = tryToExecute(applyRandomTransformationToDataFrame(df)) + val df2 = tryToExecute(applyRandomTransformationToDataFrame(df1)) + } catch { + case e: NoDataGeneratorException => + println("skipped due to lack of data generator") + case e: UnresolvedException[_] => + println("skipped due to unresolved") + case e: Exception + if ignoredAnalysisExceptionMessages.exists { + m => Option(e.getMessage).getOrElse("").toLowerCase.contains(m.toLowerCase) + } => println("Skipped due to expected AnalysisException") + } + } + } } From 0c7e9d08497a2cc91aaca932c4a09b947508179a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 16 Aug 2015 14:55:31 -0700 Subject: [PATCH 59/67] Update to ignore some new analysis exceptions. --- .../org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 4c22a25cbd544..5e39d0f3eaaeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -150,7 +150,6 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output .filterNot(_.name.toString == "dropDuplicates") - .filter(_.name.toString == "join") .toSeq } @@ -262,7 +261,11 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { } } + // TODO: make these regexes. val ignoredAnalysisExceptionMessages = Seq( + // TODO: filter only for binary type: + "cannot be used in grouping expression", + "cannot be used in join condition", "can only be performed on tables with the same number of columns", "number of columns doesn't match", "unsupported join type", From fb0671f1b6e16107ce7a970ea02ece48123bf48a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 16 Aug 2015 15:01:21 -0700 Subject: [PATCH 60/67] Move RandomDataFrameGenerator to own file. --- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 71 -------------- .../fuzzing/RandomDataFrameGenerator.scala | 94 +++++++++++++++++++ 2 files changed, 94 insertions(+), 71 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 5e39d0f3eaaeb..b287bcaec1241 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.fuzzing import java.lang.reflect.InvocationTargetException -import java.util.concurrent.atomic.AtomicInteger import scala.reflect.runtime.{universe => ru} import scala.util.Random @@ -31,76 +30,6 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class RandomDataFrameGenerator(seed: Long, sqlContext: SQLContext) { - - private val rand = new Random(seed) - private val nextId = new AtomicInteger() - - private def hasRandomDataGenerator(dataType: DataType): Boolean = { - RandomDataGenerator.forType(dataType).isDefined - } - - def randomChoice[T](values: Seq[T]): T = { - values(rand.nextInt(values.length)) - } - - private val simpleTypes: Set[DataType] = { - DataTypeTestUtils.atomicTypes - .filter(hasRandomDataGenerator) - // Ignore decimal type since it can lead to OOM (see SPARK-9303). TODO: It would be better to - // only generate limited precision decimals instead. - .filterNot(_.isInstanceOf[DecimalType]) - } - - private val arrayTypes: Set[DataType] = { - DataTypeTestUtils.atomicArrayTypes - .filter(hasRandomDataGenerator) - // See above comment about DecimalType - .filterNot(_.elementType.isInstanceOf[DecimalType]).toSet - } - - private def randomStructField( - allowComplexTypes: Boolean = false, - allowSpacesInColumnName: Boolean = false): StructField = { - val name = "c" + nextId.getAndIncrement + (if (allowSpacesInColumnName) " space" else "") - val candidateTypes: Seq[DataType] = Seq( - simpleTypes, - arrayTypes.filter(_ => allowComplexTypes), - // This does not allow complex types, limiting the depth of recursion: - if (allowComplexTypes) { - Set[DataType](randomStructType(numCols = rand.nextInt(2) + 1)) - } else { - Set[DataType]() - } - ).flatten - val dataType = randomChoice(candidateTypes) - val nullable = rand.nextBoolean() - StructField(name, dataType, nullable) - } - - private def randomStructType( - numCols: Int, - allowComplexTypes: Boolean = false, - allowSpacesInColumnNames: Boolean = false): StructType = { - StructType(Array.fill(numCols)(randomStructField(allowComplexTypes, allowSpacesInColumnNames))) - } - - def randomDataFrame( - numCols: Int, - numRows: Int, - allowComplexTypes: Boolean = false, - allowSpacesInColumnNames: Boolean = false): DataFrame = { - val schema = randomStructType(numCols, allowComplexTypes, allowSpacesInColumnNames) - val rows = sqlContext.sparkContext.parallelize(1 to numRows).mapPartitions { iter => - val rowGenerator = RandomDataGenerator.forType(schema, nullable = false, seed = Some(42)).get - iter.map(_ => rowGenerator().asInstanceOf[Row]) - } - sqlContext.createDataFrame(rows, schema) - } - -} - - /** * This test suite generates random data frames, then applies random sequences of operations to * them in order to construct random queries. We don't have a source of truth for these random diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala new file mode 100644 index 0000000000000..b0c3894336c5c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.fuzzing + +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Random + +import org.apache.spark.sql.{Row, DataFrame, RandomDataGenerator, SQLContext} +import org.apache.spark.sql.types._ + +class RandomDataFrameGenerator(seed: Long, sqlContext: SQLContext) { + + private val rand = new Random(seed) + private val nextId = new AtomicInteger() + + private def hasRandomDataGenerator(dataType: DataType): Boolean = { + RandomDataGenerator.forType(dataType).isDefined + } + + def randomChoice[T](values: Seq[T]): T = { + values(rand.nextInt(values.length)) + } + + private val simpleTypes: Set[DataType] = { + DataTypeTestUtils.atomicTypes + .filter(hasRandomDataGenerator) + // Ignore decimal type since it can lead to OOM (see SPARK-9303). TODO: It would be better to + // only generate limited precision decimals instead. + .filterNot(_.isInstanceOf[DecimalType]) + } + + private val arrayTypes: Set[DataType] = { + DataTypeTestUtils.atomicArrayTypes + .filter(hasRandomDataGenerator) + // See above comment about DecimalType + .filterNot(_.elementType.isInstanceOf[DecimalType]).toSet + } + + private def randomStructField( + allowComplexTypes: Boolean = false, + allowSpacesInColumnName: Boolean = false): StructField = { + val name = "c" + nextId.getAndIncrement + (if (allowSpacesInColumnName) " space" else "") + val candidateTypes: Seq[DataType] = Seq( + simpleTypes, + arrayTypes.filter(_ => allowComplexTypes), + // This does not allow complex types, limiting the depth of recursion: + if (allowComplexTypes) { + Set[DataType](randomStructType(numCols = rand.nextInt(2) + 1)) + } else { + Set[DataType]() + } + ).flatten + val dataType = randomChoice(candidateTypes) + val nullable = rand.nextBoolean() + StructField(name, dataType, nullable) + } + + private def randomStructType( + numCols: Int, + allowComplexTypes: Boolean = false, + allowSpacesInColumnNames: Boolean = false): StructType = { + StructType(Array.fill(numCols)(randomStructField(allowComplexTypes, allowSpacesInColumnNames))) + } + + def randomDataFrame( + numCols: Int, + numRows: Int, + allowComplexTypes: Boolean = false, + allowSpacesInColumnNames: Boolean = false): DataFrame = { + val schema = randomStructType(numCols, allowComplexTypes, allowSpacesInColumnNames) + val rows = sqlContext.sparkContext.parallelize(1 to numRows).mapPartitions { iter => + val rowGenerator = RandomDataGenerator.forType(schema, nullable = false, seed = Some(42)).get + iter.map(_ => rowGenerator().asInstanceOf[Row]) + } + sqlContext.createDataFrame(rows, schema) + } + +} \ No newline at end of file From 78a71afc023a2ccd05026637bde7eda3fef70dcb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 16 Aug 2015 17:23:43 -0700 Subject: [PATCH 61/67] WIP --- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 51 ++++++++++++------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index b287bcaec1241..c484a77810d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -30,6 +30,29 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.types._ import org.apache.spark.util.Utils + +trait DataFrameTransformation extends Function[DataFrame, DataFrame] { + +} + +case class CallTransform( + method: ru.MethodSymbol, + args: Seq[Any])( + implicit runtimeMirror: ru.Mirror) extends DataFrameTransformation { + override def apply(df: DataFrame): DataFrame = { + val reflectedMethod: ru.MethodMirror = runtimeMirror.reflect(df).reflectMethod(method) + try { + println(s" Applying method $reflectedMethod with args $args") + val x = reflectedMethod.apply(args: _*).asInstanceOf[DataFrame] + println(s" Applied method $reflectedMethod with args $args") + x + } catch { + case e: InvocationTargetException => throw e.getCause + } + } +} + + /** * This test suite generates random data frames, then applies random sequences of operations to * them in order to construct random queries. We don't have a source of truth for these random @@ -53,7 +76,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { values(Random.nextInt(values.length)) } - val m = ru.runtimeMirror(this.getClass.getClassLoader) + implicit val m: ru.Mirror = ru.runtimeMirror(this.getClass.getClassLoader) val whitelistedParameterTypes = Set( m.universe.typeOf[DataFrame], @@ -149,38 +172,27 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { } def applyRandomTransformationToDataFrame(df: DataFrame): DataFrame = { - val method = randomChoice(dataFrameTransformations) - val reflectedMethod: ru.MethodMirror = m.reflect(df).reflectMethod(method) - def callMethod(paramValues: Seq[Any]): DataFrame = { - try { - val df2 = reflectedMethod.apply(paramValues: _*).asInstanceOf[DataFrame] - println(s"Applied method $method with values $paramValues") - df2 - } catch { - case e: InvocationTargetException => - throw e.getCause - } - } + val method: ru.MethodSymbol = randomChoice(dataFrameTransformations) try { - val paramValues = getParamValues(df, method) try { - callMethod(paramValues) + CallTransform(method, getParamValues(df, method)).apply(df) } catch { case NonFatal(e) => - println(s"Encountered error when calling $method with values $paramValues") println(df.queryExecution) throw e } } catch { case e: AnalysisException if e.getMessage.contains("is not a boolean") => - callMethod(getParamValues(df, method, _ == BooleanType)) + CallTransform(method, getParamValues(df, method, _ == BooleanType)).apply(df) case e: AnalysisException if e.getMessage.contains("is not supported for columns of type") => - callMethod(getParamValues(df, method, _.isInstanceOf[AtomicType])) + CallTransform(method, getParamValues(df, method, _.isInstanceOf[AtomicType])).apply(df) } } def tryToExecute(df: DataFrame): DataFrame = { try { + println("Before executing:") + df.explain(true) df.rdd.count() df } catch { @@ -193,6 +205,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { // TODO: make these regexes. val ignoredAnalysisExceptionMessages = Seq( // TODO: filter only for binary type: + "cannot sort data type array<", "cannot be used in grouping expression", "cannot be used in join condition", "can only be performed on tables with the same number of columns", @@ -215,7 +228,7 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { val df = dataGenerator.randomDataFrame( numCols = Random.nextInt(2) + 1, numRows = 20, - allowComplexTypes = false) + allowComplexTypes = true) val df1 = tryToExecute(applyRandomTransformationToDataFrame(df)) val df2 = tryToExecute(applyRandomTransformationToDataFrame(df1)) } catch { From 3b0684995d7cc783dadd4b95f04a0a1502768f72 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 16 Aug 2015 17:52:56 -0700 Subject: [PATCH 62/67] Filter failing BinaryType array test. --- .../org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala index b0c3894336c5c..9df1df9fe259e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala @@ -48,6 +48,8 @@ class RandomDataFrameGenerator(seed: Long, sqlContext: SQLContext) { private val arrayTypes: Set[DataType] = { DataTypeTestUtils.atomicArrayTypes .filter(hasRandomDataGenerator) + // Filter until SPARK-10038 is fixed. + .filterNot(_.elementType.isInstanceOf[BinaryType]) // See above comment about DecimalType .filterNot(_.elementType.isInstanceOf[DecimalType]).toSet } From a4c9b3398bacfc1a76e66d62e3e277df61d1f228 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 24 Aug 2015 13:31:12 -0700 Subject: [PATCH 63/67] WIP --- pom.xml | 6 + sql/core/pom.xml | 5 + .../sql/fuzzing/DataFrameFuzzingSuite.scala | 174 ++++-------------- .../sql/fuzzing/ExpressionFuzzingSuite.scala | 6 +- .../apache/spark/sql/fuzzing/package.scala | 27 +++ .../spark/sql/fuzzing/reflectiveFuzzing.scala | 152 +++++++++++++++ 6 files changed, 233 insertions(+), 137 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala diff --git a/pom.xml b/pom.xml index d5945f2546d38..a3fb2770bfa2e 100644 --- a/pom.xml +++ b/pom.xml @@ -721,6 +721,12 @@ scalap ${scala.version} + + org.scalaz + scalaz-core_2.10 + 7.1.3 + test + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 1d59c14cfbdf5..2fca4d13dd4f8 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -79,6 +79,11 @@ jackson-databind ${fasterxml.jackson.version} + + org.scalaz + scalaz-core_${scala.binary.version} + test + junit junit diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index c484a77810d63..954801c732b35 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -17,94 +17,21 @@ package org.apache.spark.sql.fuzzing -import java.lang.reflect.InvocationTargetException - -import scala.reflect.runtime.{universe => ru} import scala.util.Random import scala.util.control.NonFatal import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.types._ import org.apache.spark.util.Utils - -trait DataFrameTransformation extends Function[DataFrame, DataFrame] { - -} - -case class CallTransform( - method: ru.MethodSymbol, - args: Seq[Any])( - implicit runtimeMirror: ru.Mirror) extends DataFrameTransformation { - override def apply(df: DataFrame): DataFrame = { - val reflectedMethod: ru.MethodMirror = runtimeMirror.reflect(df).reflectMethod(method) - try { - println(s" Applying method $reflectedMethod with args $args") - val x = reflectedMethod.apply(args: _*).asInstanceOf[DataFrame] - println(s" Applied method $reflectedMethod with args $args") - x - } catch { - case e: InvocationTargetException => throw e.getCause - } - } -} - - -/** - * This test suite generates random data frames, then applies random sequences of operations to - * them in order to construct random queries. We don't have a source of truth for these random - * queries but nevertheless they are still useful for testing that we don't crash in bad ways. - */ -class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { - - val tempDir = Utils.createTempDir() - - private var sqlContext: SQLContext = _ - private var dataGenerator: RandomDataFrameGenerator = _ - - override def beforeAll(): Unit = { - super.beforeAll() - sqlContext = new SQLContext(sc) - dataGenerator = new RandomDataFrameGenerator(123, sqlContext) - sqlContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) - } +object DataFrameFuzzingUtils { def randomChoice[T](values: Seq[T]): T = { values(Random.nextInt(values.length)) } - implicit val m: ru.Mirror = ru.runtimeMirror(this.getClass.getClassLoader) - - val whitelistedParameterTypes = Set( - m.universe.typeOf[DataFrame], - m.universe.typeOf[Seq[Column]], - m.universe.typeOf[Column], - m.universe.typeOf[String], - m.universe.typeOf[Seq[String]] - ) - - val dataFrameTransformations: Seq[ru.MethodSymbol] = { - val dfType = m.universe.typeOf[DataFrame] - dfType.members - .filter(_.isPublic) - .filter(_.isMethod) - .map(_.asMethod) - .filter(_.returnType =:= dfType) - .filterNot(_.isConstructor) - .filter { m => - m.paramss.flatten.forall { p => - whitelistedParameterTypes.exists { t => p.typeSignature <:< t } - } - } - .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns - .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output - .filterNot(_.name.toString == "dropDuplicates") - .toSeq - } - /** * Build a list of column names and types for the given StructType, taking nesting into account. * For nested struct fields, this will emit both the column for the struct field itself as well as @@ -135,70 +62,36 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { Some(randomChoice(candidateColumns)._1) } } +} - class NoDataGeneratorException extends Exception - def getParamValues( - df: DataFrame, - method: ru.MethodSymbol, - typeConstraint: DataType => Boolean = _ => true): Seq[Any] = { - val params = method.paramss.flatten // We don't use multiple parameter lists - def randColName(): String = - getRandomColumnName(df, typeConstraint).getOrElse(throw new NoDataGeneratorException) - params.map { p => - val t = p.typeSignature - if (t =:= ru.typeOf[DataFrame]) { - randomChoice(Seq( - df, - //tryToExecute(applyRandomTransformationToDataFrame(df)), - dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) - )) // ++ Try(applyRandomTransformationToDataFrame(df)).toOption.toSeq) - } else if (t =:= ru.typeOf[Column]) { - df.col(randColName()) - } else if (t =:= ru.typeOf[String]) { - if (p.name == "joinType") { - randomChoice(JoinType.supportedJoinTypes) - } else { - randColName() - } - } else if (t <:< ru.typeOf[Seq[Column]]) { - Seq.fill(Random.nextInt(2) + 1)(df.col(randColName())) - } else if (t <:< ru.typeOf[Seq[String]]) { - Seq.fill(Random.nextInt(2) + 1)(randColName()) - } else { - sys.error("ERROR!") - } - } - } +/** + * This test suite generates random data frames, then applies random sequences of operations to + * them in order to construct random queries. We don't have a source of truth for these random + * queries but nevertheless they are still useful for testing that we don't crash in bad ways. + */ +class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { - def applyRandomTransformationToDataFrame(df: DataFrame): DataFrame = { - val method: ru.MethodSymbol = randomChoice(dataFrameTransformations) - try { - try { - CallTransform(method, getParamValues(df, method)).apply(df) - } catch { - case NonFatal(e) => - println(df.queryExecution) - throw e - } - } catch { - case e: AnalysisException if e.getMessage.contains("is not a boolean") => - CallTransform(method, getParamValues(df, method, _ == BooleanType)).apply(df) - case e: AnalysisException if e.getMessage.contains("is not supported for columns of type") => - CallTransform(method, getParamValues(df, method, _.isInstanceOf[AtomicType])).apply(df) - } + val tempDir = Utils.createTempDir() + + private var sqlContext: SQLContext = _ + private var dataGenerator: RandomDataFrameGenerator = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataGenerator = new RandomDataFrameGenerator(123, sqlContext) + sqlContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) } def tryToExecute(df: DataFrame): DataFrame = { try { - println("Before executing:") - df.explain(true) df.rdd.count() df } catch { case NonFatal(e) => println(df.queryExecution) - throw new Exception(e) + throw e } } @@ -220,26 +113,37 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { "Cannot resolve column name" // TODO: only ignore for join? ) + def getRandomTransformation(df: DataFrame): DataFrameTransformation = { + (1 to 1000).iterator.map(_ => ReflectiveFuzzing.getTransformation(df)).flatten.next() + } + + def applyRandomTransform(df: DataFrame): DataFrame = { + val tf = getRandomTransformation(df) + println(" " + tf) + tf.apply(df) + } test("fuzz test") { - for (_ <- 1 to 1000) { - println("-" * 80) + for (i <- 1 to 1000) { + println(s"Iteration $i") try { - val df = dataGenerator.randomDataFrame( + var df = dataGenerator.randomDataFrame( numCols = Random.nextInt(2) + 1, numRows = 20, allowComplexTypes = true) - val df1 = tryToExecute(applyRandomTransformationToDataFrame(df)) - val df2 = tryToExecute(applyRandomTransformationToDataFrame(df1)) + var depth = 3 + while (depth > 0) { + df = tryToExecute(applyRandomTransform(df)) + depth -= 1 + } } catch { - case e: NoDataGeneratorException => - println("skipped due to lack of data generator") case e: UnresolvedException[_] => - println("skipped due to unresolved") +// println("skipped due to unresolved") case e: Exception if ignoredAnalysisExceptionMessages.exists { m => Option(e.getMessage).getOrElse("").toLowerCase.contains(m.toLowerCase) - } => println("Skipped due to expected AnalysisException") + } => +// println("Skipped due to expected AnalysisException " + e) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala index 4cf2ebbe268f3..30fae07fd633f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala @@ -20,6 +20,10 @@ package org.apache.spark.sql.fuzzing import java.io.File import java.lang.reflect.Constructor +import scala.util.{Random, Try} + +import org.clapper.classutil.ClassFinder + import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion @@ -29,9 +33,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.{BinaryType, DataType, DataTypeTestUtils, DecimalType} import org.apache.spark.{Logging, SparkFunSuite} -import org.clapper.classutil.ClassFinder -import scala.util.{Random, Try} /** * This test suite implements fuzz tests for expression code generation. It uses reflection to diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala new file mode 100644 index 0000000000000..187b41e05950c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * TODO(josh): Document this package. + */ +package object fuzzing { + +} + +trait DataFrameTransformation extends Function[DataFrame, DataFrame] \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala new file mode 100644 index 0000000000000..9436b119f80e0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.fuzzing + +import java.lang.reflect.InvocationTargetException + +import scala.reflect.runtime.{universe => ru} +import scala.util.{Try, Random} +import scala.util.control.NonFatal + +import scalaz._, Scalaz._ + +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.types.{BooleanType, AtomicType, DataType} +import org.apache.spark.sql.{DataFrameTransformation, AnalysisException, Column, DataFrame} + +object ReflectiveFuzzing { + + import DataFrameFuzzingUtils._ + + private implicit val m: ru.Mirror = ru.runtimeMirror(this.getClass.getClassLoader) + + /** + * Method parameter types for which the fuzzer can supply random values. This list is used to + * filter out methods that we don't know how to call. + */ + private val whitelistedParameterTypes = Set( + m.universe.typeOf[DataFrame], + m.universe.typeOf[Seq[Column]], + m.universe.typeOf[Column], + m.universe.typeOf[String], + m.universe.typeOf[Seq[String]] + ) + + /** + * A list of candidate DataFrame methods that the fuzzer will try to call. Excludes private + * methods and methods with parameters that we don't know how to supply. + */ + private val dataFrameTransformations: Seq[ru.MethodSymbol] = { + val dfType = m.universe.typeOf[DataFrame] + dfType.members + .filter(_.isPublic) + .filter(_.isMethod) + .map(_.asMethod) + .filter(_.returnType =:= dfType) + .filterNot(_.isConstructor) + .filter { m => + m.paramss.flatten.forall { p => + whitelistedParameterTypes.exists { t => p.typeSignature <:< t } + } + } + .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns + .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output + .filterNot(_.name.toString == "dropDuplicates") + .filterNot(_.name.toString == "toDF") // since this is effectively a no-op + .filterNot(_.name.toString == "toSchemaRDD") // since this is effectively a no-op + .toSeq + } + + /** + * Given a Dataframe and a method, try to choose a set of arguments to call that method with. + * + * @param df the data frame to transform + * @param method the method to call + * @param typeConstraint an optional type constraint governing the types of the parameters. + * @return + */ + def getParamValues( + df: DataFrame, + method: ru.MethodSymbol, + typeConstraint: DataType => Boolean = _ => true): Option[List[Any]] = { + val params = method.paramss.flatten // We don't use multiple parameter lists + val maybeValues: List[Option[Any]] = params.map { p => + val t = p.typeSignature + if (t =:= ru.typeOf[DataFrame]) { + randomChoice( + df :: + // TODO(josh): restore ability to generate new random DataFrames for use in joins. + // dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) :: + Nil + ).some + } else if (t =:= ru.typeOf[Column]) { + getRandomColumnName(df, typeConstraint).map(df.col) + } else if (t =:= ru.typeOf[String]) { + if (p.name == "joinType") { + randomChoice(JoinType.supportedJoinTypes).some + } else { + getRandomColumnName(df, typeConstraint).map(df.col) + } + } else if (t <:< ru.typeOf[Seq[Column]]) { + Seq.fill(Random.nextInt(2) + 1)(getRandomColumnName(df, typeConstraint).map(df.col)).flatten.some + } else if (t <:< ru.typeOf[Seq[String]]) { + Seq.fill(Random.nextInt(2) + 1)(getRandomColumnName(df, typeConstraint).map(df.col)).flatten.some + } else { + None + } + } + maybeValues.sequence + } + + def getTransformation(df: DataFrame): Option[DataFrameTransformation] = { + val method: ru.MethodSymbol = DataFrameFuzzingUtils.randomChoice(dataFrameTransformations) + val values: Option[Seq[Any]] = { + def validateValues(vs: Seq[Any]): Try[Seq[Any]] = { + Try(CallTransformReflectively(method, vs).apply(df)).map(_ => vs) + } + getParamValues(df, method).map { (vs: Seq[Any]) => + validateValues(vs).recoverWith { + case e: AnalysisException if e.getMessage.contains("is not a boolean") => + Try(getParamValues(df, method, _ == BooleanType).get).flatMap(validateValues) + case e: AnalysisException if e.getMessage.contains("is not supported for columns of type") => + Try(getParamValues(df, method, _.isInstanceOf[AtomicType]).get).flatMap(validateValues) + } + }.flatMap(_.toOption) + } + values.map(vs => CallTransformReflectively(method, vs)) + } +} + +case class CallTransformReflectively( + method: ru.MethodSymbol, + args: Seq[Any])( + implicit runtimeMirror: ru.Mirror) extends DataFrameTransformation { + + override def apply(df: DataFrame): DataFrame = { + val reflectedMethod: ru.MethodMirror = runtimeMirror.reflect(df).reflectMethod(method) + try { + reflectedMethod.apply(args: _*).asInstanceOf[DataFrame] + } catch { + case e: InvocationTargetException => throw e.getCause + } + } + + override def toString(): String = { + s"${method.name}(${args.map(_.toString).mkString(", ")})" + } +} \ No newline at end of file From ae5055ac2349ec880215ace9a8c9e8f349177d90 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 27 May 2016 12:20:38 -0700 Subject: [PATCH 64/67] Also ignore BRound expression. --- .../org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala index 3d87054914c0d..d07747fbafe0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala @@ -67,6 +67,7 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { .filterNot(_ == classOf[StringSpace]) .filterNot(_ == classOf[StringLPad]) .filterNot(_ == classOf[StringRPad]) + .filterNot(_ == classOf[BRound]) .filterNot(_ == classOf[Round]) } From dfdab5e51af3b509d140d60afe82ce55edcddd44 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 27 May 2016 12:25:06 -0700 Subject: [PATCH 65/67] Fix serializability. --- .../apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala index 4ff96db141e44..ef5f5e6cdb286 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala @@ -24,7 +24,10 @@ import scala.util.Random import org.apache.spark.sql._ import org.apache.spark.sql.types._ -class RandomDataFrameGenerator(seed: Long, sqlContext: SQLContext) { +class RandomDataFrameGenerator( + seed: Long, + @transient val sqlContext: SQLContext) + extends Serializable { private val rand = new Random(seed) private val nextId = new AtomicInteger() From e60f23125ae40be17655a8af9cabdaec4dbd04a3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 29 Jul 2016 14:30:07 -0700 Subject: [PATCH 66/67] More input type validation. --- .../expressions/complexTypeExtractors.scala | 17 ++++++-------- .../expressions/decimalExpressions.scala | 23 +++++++++++++------ .../sql/fuzzing/ExpressionFuzzingSuite.scala | 6 +++++ 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 3b4468f55ca73..0562d38d51d0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -98,13 +98,13 @@ trait ExtractValue extends Expression /** * Returns the value of fields in the Struct `child`. * - * No need to do type checking since it is handled by [[ExtractValue]]. - * * Note that we can pass in the field name directly to keep case preserving in `toString`. * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with ExtractValue { + extends UnaryExpression with ExtractValue with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType) private[sql] lazy val childSchema = child.dataType.asInstanceOf[StructType] @@ -144,16 +144,15 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] /** * For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array * elements, and returns them as a new array. - * - * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression with ExtractValue { + containsNull: Boolean) extends UnaryExpression with ExtractValue with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" @@ -215,8 +214,7 @@ case class GetArrayStructFields( case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryExpression with ExpectsInputTypes with ExtractValue { - // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegralType) override def toString: String = s"$child[$ordinal]" override def sql: String = s"${child.sql}[${ordinal.sql}]" @@ -264,8 +262,7 @@ case class GetMapValue(child: Expression, key: Expression) private def keyType = child.dataType.asInstanceOf[MapType].keyType - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, keyType) override def toString: String = s"$child[$key]" override def sql: String = s"${child.sql}[${key.sql}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index fa5dea6841149..30dce130dc094 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.types._ /** * Return the unscaled Long value of a Decimal, assuming it fits in a Long. - * Note: this expression is internal and created only by the optimizer, - * we don't need to do type check for it. + * Note: this expression is internal and created only by the optimizer. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType) override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -41,11 +41,15 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { /** * Create a Decimal from an unscaled Long value. - * Note: this expression is internal and created only by the optimizer, - * we don't need to do type check for it. + * Note: this expression is internal and created only by the optimizer. */ -case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { +case class MakeDecimal( + child: Expression, + precision: Int, + scale: Int) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(LongType) override def dataType: DataType = DecimalType(precision, scale) override def nullable: Boolean = true override def toString: String = s"MakeDecimal($child,$precision,$scale)" @@ -80,7 +84,12 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { * Rounds the decimal to given scale and check whether the decimal can fit in provided precision * or not, returns null if not. */ -case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { +case class CheckOverflow( + child: Expression, + dataType: DecimalType) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType) override def nullable: Boolean = true diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala index d07747fbafe0d..99bf7bc5bc038 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala @@ -121,6 +121,7 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { } logInfo(s"After type coercion, expression is $expression") // Make sure that the resulting expression passes type checks. + require(expression.childrenResolved) val typecheckResult = expression.checkInputDataTypes() if (typecheckResult.isFailure) { logDebug(s"Type checks failed: $typecheckResult") @@ -134,6 +135,11 @@ class ExpressionFuzzingSuite extends SparkFunSuite with Logging { val maybeGenProjection = Try(GenerateSafeProjection.generate(Seq(expression), inputSchema)) + if (maybeGenProjection.isFailure) { + //scalastyle:off + println( + s"Code generation for expression $expression failed with inputSchema $inputSchema") + } maybeGenProjection.foreach { generatedProjection => val generatedResult = generatedProjection.apply(inputRow) assert(generatedResult === interpretedResult) From 94087cb170df89578bc2d6fb254e43b86bdcc426 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 29 Jul 2016 15:49:54 -0700 Subject: [PATCH 67/67] Updates for DataSet API. --- .../sql/fuzzing/DataFrameFuzzingSuite.scala | 77 ++++++++++++++----- .../spark/sql/fuzzing/reflectiveFuzzing.scala | 7 +- 2 files changed, 60 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala index 0529484b2af17..e049b2a2c5e83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -23,6 +23,7 @@ import scala.util.control.NonFatal import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -71,7 +72,10 @@ object DataFrameFuzzingUtils { * them in order to construct random queries. We don't have a source of truth for these random * queries but nevertheless they are still useful for testing that we don't crash in bad ways. */ -class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { +class DataFrameFuzzingSuite extends QueryTest with SharedSparkContext { + + + override protected def spark: SparkSession = sqlContext.sparkSession val tempDir = Utils.createTempDir() @@ -128,30 +132,61 @@ class DataFrameFuzzingSuite extends SparkFunSuite with SharedSparkContext { tf.apply(df) } + def resetConfs(): Unit = { + sqlContext.conf.getAllDefinedConfs.foreach { case (key, defaultValue, doc) => + sqlContext.conf.setConfString(key, defaultValue) + } + sqlContext.conf.setConfString("spark.sql.crossJoin.enabled", "true") + sqlContext.conf.setConfString("spark.sql.autoBroadcastJoinThreshold", "-1") + } + + private val configurations = Seq( + "default" -> Seq(), + "no optimization" -> Seq(SQLConf.OPTIMIZER_MAX_ITERATIONS.key -> "0"), + "disable-wholestage-codegen" -> Seq(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false"), + "disable-exchange-reuse" -> Seq(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") + ) + + def replan(df: DataFrame): DataFrame = { + new Dataset[Row](sqlContext.sparkSession, df.logicalPlan, RowEncoder(df.schema)) + } + test("fuzz test") { - for (i <- 1 to 1000) { - // scalastyle:off println - println(s"Iteration $i") - // scalastyle:on println - try { - var df = dataGenerator.randomDataFrame( - numCols = Random.nextInt(2) + 1, - numRows = 20, - allowComplexTypes = true) - var depth = 3 - while (depth > 0) { - df = tryToExecute(applyRandomTransform(df)) - depth -= 1 + for (i <- 1 to 1000) { + // scalastyle:off println + println(s"Iteration $i") + // scalastyle:on println + try { + resetConfs() + var df = dataGenerator.randomDataFrame( + numCols = Random.nextInt(2) + 1, + numRows = 20, + allowComplexTypes = false) + var depth = 3 + while (depth > 0) { + df = tryToExecute(applyRandomTransform(df)) + depth -= 1 + } + val defaultResult = replan(df).collect() + configurations.foreach { case (confName, confsToSet) => + resetConfs() + withClue(s"configuration = $confName") { + confsToSet.foreach { case (key, value) => + sqlContext.conf.setConfString(key, value) + } + checkAnswer(replan(df), defaultResult) } - } catch { - case e: UnresolvedException[_] => + } + println(s"Finished all tests successfully for plan:\n${df.logicalPlan}") + } catch { + case e: UnresolvedException[_] => // println("skipped due to unresolved") - case e: Exception - if ignoredAnalysisExceptionMessages.exists { - m => Option(e.getMessage).getOrElse("").toLowerCase.contains(m.toLowerCase) - } => + case e: Exception + if ignoredAnalysisExceptionMessages.exists { + m => Option(e.getMessage).getOrElse("").toLowerCase.contains(m.toLowerCase) + } => // println("Skipped due to expected AnalysisException " + e) - } } } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala index 959291b32eb6b..e22e9a824e7bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala @@ -40,6 +40,7 @@ object ReflectiveFuzzing { */ private val whitelistedParameterTypes = Set( m.universe.typeOf[DataFrame], + m.universe.typeOf[Dataset[_]], m.universe.typeOf[Seq[Column]], m.universe.typeOf[Column], m.universe.typeOf[String], @@ -51,12 +52,12 @@ object ReflectiveFuzzing { * methods and methods with parameters that we don't know how to supply. */ private val dataFrameTransformations: Seq[ru.MethodSymbol] = { - val dfType = m.universe.typeOf[DataFrame] + val dfType = m.universe.typeOf[Dataset[_]] dfType.members .filter(_.isPublic) .filter(_.isMethod) .map(_.asMethod) - .filter(_.returnType =:= dfType) + .filter(_.returnType <:< dfType) .filterNot(_.isConstructor) .filter { m => m.paramss.flatten.forall { p => @@ -86,7 +87,7 @@ object ReflectiveFuzzing { val params = method.paramss.flatten // We don't use multiple parameter lists val maybeValues: List[Option[Any]] = params.map { p => val t = p.typeSignature - if (t =:= ru.typeOf[DataFrame]) { + if (t <:< ru.typeOf[Dataset[_]]) { randomChoice( df :: // TODO(josh): restore ability to generate new random DataFrames for use in joins.