diff --git a/pom.xml b/pom.xml
index 989658216e5fd..c86c90639e174 100644
--- a/pom.xml
+++ b/pom.xml
@@ -699,6 +699,12 @@
scalap
${scala.version}
+
+ org.scalaz
+ scalaz-core_2.10
+ 7.1.3
+ test
+
org.scalatest
scalatest_${scala.binary.version}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 41b7e62d8ccea..daf99f5f0f82b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -245,6 +245,12 @@ trait CheckAnalysis extends PredicateHelper {
aggregateExprs.foreach(checkValidAggregateExpression)
groupingExprs.foreach(checkValidGroupingExprs)
+ case s @ SetOperation(left, right) if left.output.length != right.output.length =>
+ failAnalysis(
+ s"${s.nodeName} can only be performed on tables with the same number of columns, " +
+ s"but the left table has ${left.output.length} columns and the right has " +
+ s"${right.output.length}")
+
case Sort(orders, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 70fff51956255..8edd9bed44839 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -209,7 +209,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
}
private[this] def decimalToTimestamp(d: Decimal): Long = {
- (d.toBigDecimal * 1000000L).longValue()
+ d.toJavaBigDecimal.multiply(java.math.BigDecimal.valueOf(1000000L)).longValue()
}
private[this] def doubleToTimestamp(d: Double): Any = {
if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index abb5594bfa7f8..c0af7eaf432c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -98,13 +98,13 @@ trait ExtractValue extends Expression
/**
* Returns the value of fields in the Struct `child`.
*
- * No need to do type checking since it is handled by [[ExtractValue]].
- *
* Note that we can pass in the field name directly to keep case preserving in `toString`.
* For example, when get field `yEAr` from ``, we should pass in `yEAr`.
*/
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
- extends UnaryExpression with ExtractValue {
+ extends UnaryExpression with ExtractValue with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StructType)
lazy val childSchema = child.dataType.asInstanceOf[StructType]
@@ -144,16 +144,15 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String]
/**
* For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array
* elements, and returns them as a new array.
- *
- * No need to do type checking since it is handled by [[ExtractValue]].
*/
case class GetArrayStructFields(
child: Expression,
field: StructField,
ordinal: Int,
numFields: Int,
- containsNull: Boolean) extends UnaryExpression with ExtractValue {
+ containsNull: Boolean) extends UnaryExpression with ExtractValue with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def toString: String = s"$child.${field.name}"
override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}"
@@ -215,8 +214,7 @@ case class GetArrayStructFields(
case class GetArrayItem(child: Expression, ordinal: Expression)
extends BinaryExpression with ExpectsInputTypes with ExtractValue {
- // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegralType)
override def toString: String = s"$child[$ordinal]"
override def sql: String = s"${child.sql}[${ordinal.sql}]"
@@ -264,8 +262,7 @@ case class GetMapValue(child: Expression, key: Expression)
private def keyType = child.dataType.asInstanceOf[MapType].keyType
- // We have done type checking for child in `ExtractValue`, so only need to check the `key`.
- override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType, keyType)
override def toString: String = s"$child[$key]"
override def sql: String = s"${child.sql}[${key.sql}]"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
index fa5dea6841149..30dce130dc094 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala
@@ -23,11 +23,11 @@ import org.apache.spark.sql.types._
/**
* Return the unscaled Long value of a Decimal, assuming it fits in a Long.
- * Note: this expression is internal and created only by the optimizer,
- * we don't need to do type check for it.
+ * Note: this expression is internal and created only by the optimizer.
*/
-case class UnscaledValue(child: Expression) extends UnaryExpression {
+case class UnscaledValue(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType)
override def dataType: DataType = LongType
override def toString: String = s"UnscaledValue($child)"
@@ -41,11 +41,15 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
/**
* Create a Decimal from an unscaled Long value.
- * Note: this expression is internal and created only by the optimizer,
- * we don't need to do type check for it.
+ * Note: this expression is internal and created only by the optimizer.
*/
-case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
+case class MakeDecimal(
+ child: Expression,
+ precision: Int,
+ scale: Int)
+ extends UnaryExpression with ExpectsInputTypes {
+ override def inputTypes: Seq[AbstractDataType] = Seq(LongType)
override def dataType: DataType = DecimalType(precision, scale)
override def nullable: Boolean = true
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
@@ -80,7 +84,12 @@ case class PromotePrecision(child: Expression) extends UnaryExpression {
* Rounds the decimal to given scale and check whether the decimal can fit in provided precision
* or not, returns null if not.
*/
-case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression {
+case class CheckOverflow(
+ child: Expression,
+ dataType: DecimalType)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType)
override def nullable: Boolean = true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index 80674d9b4bc9c..3c35e6e2c7e10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -21,6 +21,15 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Attribute
object JoinType {
+
+ val supportedJoinTypes = Seq(
+ "inner",
+ "outer", "full", "fullouter",
+ "leftouter", "left",
+ "rightouter", "right",
+ "leftsemi",
+ "leftanti")
+
def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match {
case "inner" => Inner
case "outer" | "full" | "fullouter" => FullOuter
@@ -29,16 +38,8 @@ object JoinType {
case "leftsemi" => LeftSemi
case "leftanti" => LeftAnti
case _ =>
- val supported = Seq(
- "inner",
- "outer", "full", "fullouter",
- "leftouter", "left",
- "rightouter", "right",
- "leftsemi",
- "leftanti")
-
throw new IllegalArgumentException(s"Unsupported join type '$typ'. " +
- "Supported join types include: " + supported.mkString("'", "', '", "'") + ".")
+ "Supported join types include: " + supportedJoinTypes.mkString("'", "', '", "'") + ".")
}
}
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index b2752638bebd5..a759774a7b110 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -127,6 +127,18 @@
xbean-asm5-shaded
test
+
+ org.clapper
+ classutil_${scala.binary.version}
+ 1.0.6
+ test
+
+
+ org.scalaz
+ scalaz-core_${scala.binary.version}
+ 7.2.3
+ test
+
target/scala-${scala.binary.version}/classes
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 547d3c1abe858..165e273e9060c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -150,7 +150,8 @@ class UDFSuite extends QueryTest with SharedSQLContext {
assert(result.count() === 2)
}
- test("UDFs everywhere") {
+ // Temporarily ignored until we implement code generation for ScalaUDF.
+ ignore("UDFs everywhere") {
spark.udf.register("groupFunction", (n: Int) => { n > 10 })
spark.udf.register("havingFilter", (n: Long) => { n > 2000 })
spark.udf.register("whereFilter", (n: Int) => { n < 150 })
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala
new file mode 100644
index 0000000000000..e049b2a2c5e83
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.fuzzing
+
+import scala.util.Random
+import scala.util.control.NonFatal
+
+import org.apache.spark.{SharedSparkContext, SparkFunSuite}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+object DataFrameFuzzingUtils {
+
+ def randomChoice[T](values: Seq[T]): T = {
+ values(Random.nextInt(values.length))
+ }
+
+ /**
+ * Build a list of column names and types for the given StructType, taking nesting into account.
+ * For nested struct fields, this will emit both the column for the struct field itself as well as
+ * fields for the nested struct's fields. This process will be performed recursively in order to
+ * handle deeply-nested structs.
+ */
+ def getColumnsAndTypes(struct: StructType): Seq[(String, DataType)] = {
+ struct.flatMap { field =>
+ val nestedFieldInfos: Seq[(String, DataType)] = field.dataType match {
+ case nestedStruct: StructType =>
+ Seq((field.name, field.dataType)) ++ getColumnsAndTypes(nestedStruct).map {
+ case (nestedColName, dataType) => (field.name + "." + nestedColName, dataType)
+ }
+ case _ => Seq.empty
+ }
+ Seq((field.name, field.dataType)) ++ nestedFieldInfos
+ }
+ }
+
+ def getRandomColumnName(
+ df: DataFrame,
+ condition: DataType => Boolean = _ => true): Option[String] = {
+ val columnsWithTypes = getColumnsAndTypes(df.schema)
+ val candidateColumns = columnsWithTypes.filter(c => condition(c._2))
+ if (candidateColumns.isEmpty) {
+ None
+ } else {
+ Some(randomChoice(candidateColumns)._1)
+ }
+ }
+}
+
+
+/**
+ * This test suite generates random data frames, then applies random sequences of operations to
+ * them in order to construct random queries. We don't have a source of truth for these random
+ * queries but nevertheless they are still useful for testing that we don't crash in bad ways.
+ */
+class DataFrameFuzzingSuite extends QueryTest with SharedSparkContext {
+
+
+ override protected def spark: SparkSession = sqlContext.sparkSession
+
+ val tempDir = Utils.createTempDir()
+
+ private var sqlContext: SQLContext = _
+ private var dataGenerator: RandomDataFrameGenerator = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ dataGenerator = new RandomDataFrameGenerator(123, sqlContext)
+ sqlContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10)
+ }
+
+ def tryToExecute(df: DataFrame): DataFrame = {
+ try {
+ df.rdd.count()
+ df
+ } catch {
+ case NonFatal(e) =>
+ // scalastyle:off println
+ println(df.queryExecution)
+ // scalastyle:on println
+ throw e
+ }
+ }
+
+ // TODO: make these regexes.
+ val ignoredAnalysisExceptionMessages = Seq(
+ // TODO: filter only for binary type:
+ "cannot sort data type array<",
+ "cannot be used in grouping expression",
+ "cannot be used in join condition",
+ "can only be performed on tables with the same number of columns",
+ "number of columns doesn't match",
+ "unsupported join type",
+ "is neither present in the group by, nor is it an aggregate function",
+ "is ambiguous, could be:",
+ "unresolved operator 'Project", // TODO
+ "unresolved operator 'Union", // TODO: disabled to let me find new errors
+ "unresolved operator 'Except", // TODO: disabled to let me find new errors
+ "unresolved operator 'Intersect", // TODO: disabled to let me find new errors
+ "Cannot resolve column name" // TODO: only ignore for join?
+ )
+
+ def getRandomTransformation(df: DataFrame): DataFrameTransformation = {
+ (1 to 1000).iterator.map(_ => ReflectiveFuzzing.getTransformation(df)).flatten.next()
+ }
+
+ def applyRandomTransform(df: DataFrame): DataFrame = {
+ val tf = getRandomTransformation(df)
+ // scalastyle:off println
+ println(" " + tf)
+ // scalastyle:on println
+ tf.apply(df)
+ }
+
+ def resetConfs(): Unit = {
+ sqlContext.conf.getAllDefinedConfs.foreach { case (key, defaultValue, doc) =>
+ sqlContext.conf.setConfString(key, defaultValue)
+ }
+ sqlContext.conf.setConfString("spark.sql.crossJoin.enabled", "true")
+ sqlContext.conf.setConfString("spark.sql.autoBroadcastJoinThreshold", "-1")
+ }
+
+ private val configurations = Seq(
+ "default" -> Seq(),
+ "no optimization" -> Seq(SQLConf.OPTIMIZER_MAX_ITERATIONS.key -> "0"),
+ "disable-wholestage-codegen" -> Seq(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false"),
+ "disable-exchange-reuse" -> Seq(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false")
+ )
+
+ def replan(df: DataFrame): DataFrame = {
+ new Dataset[Row](sqlContext.sparkSession, df.logicalPlan, RowEncoder(df.schema))
+ }
+
+ test("fuzz test") {
+ for (i <- 1 to 1000) {
+ // scalastyle:off println
+ println(s"Iteration $i")
+ // scalastyle:on println
+ try {
+ resetConfs()
+ var df = dataGenerator.randomDataFrame(
+ numCols = Random.nextInt(2) + 1,
+ numRows = 20,
+ allowComplexTypes = false)
+ var depth = 3
+ while (depth > 0) {
+ df = tryToExecute(applyRandomTransform(df))
+ depth -= 1
+ }
+ val defaultResult = replan(df).collect()
+ configurations.foreach { case (confName, confsToSet) =>
+ resetConfs()
+ withClue(s"configuration = $confName") {
+ confsToSet.foreach { case (key, value) =>
+ sqlContext.conf.setConfString(key, value)
+ }
+ checkAnswer(replan(df), defaultResult)
+ }
+ }
+ println(s"Finished all tests successfully for plan:\n${df.logicalPlan}")
+ } catch {
+ case e: UnresolvedException[_] =>
+// println("skipped due to unresolved")
+ case e: Exception
+ if ignoredAnalysisExceptionMessages.exists {
+ m => Option(e.getMessage).getOrElse("").toLowerCase.contains(m.toLowerCase)
+ } =>
+// println("Skipped due to expected AnalysisException " + e)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala
new file mode 100644
index 0000000000000..99bf7bc5bc038
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.fuzzing
+
+import java.io.File
+import java.lang.reflect.Constructor
+
+import scala.util.{Random, Try}
+
+import org.clapper.classutil.ClassFinder
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.{BinaryType, DataType, DataTypeTestUtils, DecimalType}
+import org.apache.spark.util.Utils
+
+/**
+ * This test suite implements fuzz tests for expression code generation. It uses reflection to
+ * automatically discover all [[Expression]]s, then instantiates these expressions with random
+ * children/inputs. If the resulting expression passes the type checker after type coercion is
+ * performed then we attempt to compile the expression and compare its output to output generated
+ * by the interpreted expression.
+ */
+class ExpressionFuzzingSuite extends SparkFunSuite with Logging {
+
+ val NUM_TRIALS_PER_EXPRESSION: Int = 100
+
+ /**
+ * All evaluable subclasses of [[Expression]].
+ */
+ lazy val evaluableExpressionClasses: Seq[Class[Expression]] = {
+ val classpathEntries: Seq[File] = System.getProperty("java.class.path")
+ .split(File.pathSeparatorChar)
+ .filter(_.contains("spark"))
+ .map(new File(_))
+ .filter(_.exists()).toSeq
+ val allClasses = ClassFinder(classpathEntries).getClasses().toIterator
+ assert(allClasses.nonEmpty, "Could not find Spark classes on classpath.")
+ ClassFinder.concreteSubclasses(classOf[Expression].getName, allClasses)
+ .map(c => Utils.classForName(c.name).asInstanceOf[Class[Expression]]).toSeq
+ // We should only test evalulable expressions:
+ .filterNot(c => classOf[Unevaluable].isAssignableFrom(c))
+ // These expressions currently OOM because we try to pass in massive numeric literals:
+ .filterNot(_ == classOf[FormatNumber])
+ .filterNot(_ == classOf[StringSpace])
+ .filterNot(_ == classOf[StringLPad])
+ .filterNot(_ == classOf[StringRPad])
+ .filterNot(_ == classOf[BRound])
+ .filterNot(_ == classOf[Round])
+ }
+
+ def coerceTypes(expression: Expression): Expression = {
+ val dummyPlan: LogicalPlan = DummyPlan(expression)
+ DummyAnalyzer.execute(dummyPlan).asInstanceOf[DummyPlan].expression
+ }
+
+ /**
+ * Given an expression class, find the constructor which accepts only expressions. If there are
+ * multiple such constructors, pick the one with the most parameters.
+ *
+ * @return The matching constructor, or None if no appropriate constructor could be found.
+ */
+ def getBestConstructor(expressionClass: Class[Expression]): Option[Constructor[Expression]] = {
+ val allConstructors = expressionClass.getConstructors ++ expressionClass.getDeclaredConstructors
+ allConstructors
+ .map(_.asInstanceOf[Constructor[Expression]])
+ .filter(_.getParameterTypes.toSet == Set(classOf[Expression]))
+ .sortBy(_.getParameterTypes.length * -1)
+ .headOption
+ }
+
+ def getRandomLiteral: Literal = {
+ val allTypes = DataTypeTestUtils.atomicTypes
+ .filterNot(_.isInstanceOf[DecimalType]) // casts can lead to OOM
+ .filterNot(_.isInstanceOf[BinaryType]) // leads to spurious errors in string reverse
+ val dataTypesWithGenerators: Map[DataType, () => Any] = allTypes.map { dt =>
+ (dt, RandomDataGenerator.forType(dt, nullable = true))
+ }.filter(_._2.isDefined).toMap.mapValues(_.get)
+ val (dt, generator) =
+ dataTypesWithGenerators.toSeq(Random.nextInt(dataTypesWithGenerators.size))
+ Literal.create(generator(), dt)
+ }
+
+ def testExpression(expressionClass: Class[Expression]): Unit = {
+ // Eventually, we should add support for testing multiple constructors. For now, though, we
+ // only test the "best" one:
+ val constructor: Constructor[Expression] = {
+ val maybeBestConstructor = getBestConstructor(expressionClass)
+ assume(maybeBestConstructor.isDefined, "Could not find an Expression-only constructor")
+ maybeBestConstructor.get
+ }
+ val numChildren: Int = constructor.getParameterTypes.length
+ // Construct random literals for all child expressions and leave it up to the type coercion
+ // rules to cast them to the appropriate types. Skip
+ for (_ <- 1 to NUM_TRIALS_PER_EXPRESSION) {
+ val expression: Expression = {
+ val childExpressions: Seq[Expression] = Seq.fill(numChildren)(getRandomLiteral)
+ coerceTypes(constructor.newInstance(childExpressions: _*))
+ }
+ logInfo(s"After type coercion, expression is $expression")
+ // Make sure that the resulting expression passes type checks.
+ require(expression.childrenResolved)
+ val typecheckResult = expression.checkInputDataTypes()
+ if (typecheckResult.isFailure) {
+ logDebug(s"Type checks failed: $typecheckResult")
+ } else {
+ withClue(s"$expression") {
+ val inputRow = InternalRow.apply() // Can be empty since we're only using literals
+ val inputSchema = expression.children.map(c => AttributeReference("f", c.dataType)())
+
+ val interpretedProjection = new InterpretedProjection(Seq(expression), inputSchema)
+ val interpretedResult = interpretedProjection.apply(inputRow)
+
+ val maybeGenProjection =
+ Try(GenerateSafeProjection.generate(Seq(expression), inputSchema))
+ if (maybeGenProjection.isFailure) {
+ //scalastyle:off
+ println(
+ s"Code generation for expression $expression failed with inputSchema $inputSchema")
+ }
+ maybeGenProjection.foreach { generatedProjection =>
+ val generatedResult = generatedProjection.apply(inputRow)
+ assert(generatedResult === interpretedResult)
+ }
+ }
+ }
+ }
+ }
+
+ // Run the actual tests
+ evaluableExpressionClasses.sortBy(_.getName).foreach { expressionClass =>
+ test(s"${expressionClass.getName}") {
+ testExpression(expressionClass)
+ }
+ }
+}
+
+private case object DummyAnalyzer extends RuleExecutor[LogicalPlan] {
+ override protected val batches: Seq[Batch] = Seq(
+ Batch("analysis", FixedPoint(100), TypeCoercion.typeCoercionRules: _*)
+ )
+}
+
+private case class DummyPlan(expression: Expression) extends LogicalPlan {
+ override def output: Seq[Attribute] = Seq.empty
+ override def children: Seq[LogicalPlan] = Seq.empty
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala
new file mode 100644
index 0000000000000..ef5f5e6cdb286
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.fuzzing
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import scala.util.Random
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.types._
+
+class RandomDataFrameGenerator(
+ seed: Long,
+ @transient val sqlContext: SQLContext)
+ extends Serializable {
+
+ private val rand = new Random(seed)
+ private val nextId = new AtomicInteger()
+
+ private def hasRandomDataGenerator(dataType: DataType): Boolean = {
+ RandomDataGenerator.forType(dataType).isDefined
+ }
+
+ def randomChoice[T](values: Seq[T]): T = {
+ values(rand.nextInt(values.length))
+ }
+
+ private val simpleTypes: Set[DataType] = {
+ DataTypeTestUtils.atomicTypes
+ .filter(hasRandomDataGenerator)
+ // Ignore decimal type since it can lead to OOM (see SPARK-9303). TODO: It would be better to
+ // only generate limited precision decimals instead.
+ .filterNot(_.isInstanceOf[DecimalType])
+ }
+
+ private val arrayTypes: Set[DataType] = {
+ DataTypeTestUtils.atomicArrayTypes
+ .filter(hasRandomDataGenerator)
+ // Filter until SPARK-10038 is fixed.
+ .filterNot(_.elementType.isInstanceOf[BinaryType])
+ // See above comment about DecimalType
+ .filterNot(_.elementType.isInstanceOf[DecimalType]).toSet
+ }
+
+ private def randomStructField(
+ allowComplexTypes: Boolean = false,
+ allowSpacesInColumnName: Boolean = false): StructField = {
+ val name = "c" + nextId.getAndIncrement + (if (allowSpacesInColumnName) " space" else "")
+ val candidateTypes: Seq[DataType] = Seq(
+ simpleTypes,
+ arrayTypes.filter(_ => allowComplexTypes),
+ // This does not allow complex types, limiting the depth of recursion:
+ if (allowComplexTypes) {
+ Set[DataType](randomStructType(numCols = rand.nextInt(2) + 1))
+ } else {
+ Set[DataType]()
+ }
+ ).flatten
+ val dataType = randomChoice(candidateTypes)
+ val nullable = rand.nextBoolean()
+ StructField(name, dataType, nullable)
+ }
+
+ private def randomStructType(
+ numCols: Int,
+ allowComplexTypes: Boolean = false,
+ allowSpacesInColumnNames: Boolean = false): StructType = {
+ StructType(Array.fill(numCols)(randomStructField(allowComplexTypes, allowSpacesInColumnNames)))
+ }
+
+ def randomDataFrame(
+ numCols: Int,
+ numRows: Int,
+ allowComplexTypes: Boolean = false,
+ allowSpacesInColumnNames: Boolean = false): DataFrame = {
+ val schema = randomStructType(numCols, allowComplexTypes, allowSpacesInColumnNames)
+ val rows = sqlContext.sparkContext.parallelize(1 to numRows).mapPartitions { iter =>
+ val rowGenerator = RandomDataGenerator.forType(schema, nullable = false, rand = rand).get
+ iter.map(_ => rowGenerator().asInstanceOf[Row])
+ }
+ sqlContext.createDataFrame(rows, schema)
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala
new file mode 100644
index 0000000000000..d3d7c69a8d7a4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+/**
+ * TODO(josh): Document this package.
+ */
+package object fuzzing {
+
+}
+
+trait DataFrameTransformation extends Function[DataFrame, DataFrame]
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala
new file mode 100644
index 0000000000000..e22e9a824e7bd
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.fuzzing
+
+import java.lang.reflect.InvocationTargetException
+
+import scala.reflect.runtime.{universe => ru}
+import scala.util.{Random, Try}
+
+import scalaz._, Scalaz._
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.types._
+
+object ReflectiveFuzzing {
+
+ import DataFrameFuzzingUtils._
+
+ private implicit val m: ru.Mirror = ru.runtimeMirror(this.getClass.getClassLoader)
+
+ /**
+ * Method parameter types for which the fuzzer can supply random values. This list is used to
+ * filter out methods that we don't know how to call.
+ */
+ private val whitelistedParameterTypes = Set(
+ m.universe.typeOf[DataFrame],
+ m.universe.typeOf[Dataset[_]],
+ m.universe.typeOf[Seq[Column]],
+ m.universe.typeOf[Column],
+ m.universe.typeOf[String],
+ m.universe.typeOf[Seq[String]]
+ )
+
+ /**
+ * A list of candidate DataFrame methods that the fuzzer will try to call. Excludes private
+ * methods and methods with parameters that we don't know how to supply.
+ */
+ private val dataFrameTransformations: Seq[ru.MethodSymbol] = {
+ val dfType = m.universe.typeOf[Dataset[_]]
+ dfType.members
+ .filter(_.isPublic)
+ .filter(_.isMethod)
+ .map(_.asMethod)
+ .filter(_.returnType <:< dfType)
+ .filterNot(_.isConstructor)
+ .filter { m =>
+ m.paramss.flatten.forall { p =>
+ whitelistedParameterTypes.exists { t => p.typeSignature <:< t }
+ }
+ }
+ .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns
+ .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output
+ .filterNot(_.name.toString == "dropDuplicates")
+ .filterNot(_.name.toString == "toDF") // since this is effectively a no-op
+ .filterNot(_.name.toString == "toSchemaRDD") // since this is effectively a no-op
+ .toSeq
+ }
+
+ /**
+ * Given a Dataframe and a method, try to choose a set of arguments to call that method with.
+ *
+ * @param df the data frame to transform
+ * @param method the method to call
+ * @param typeConstraint an optional type constraint governing the types of the parameters.
+ * @return
+ */
+ def getParamValues(
+ df: DataFrame,
+ method: ru.MethodSymbol,
+ typeConstraint: DataType => Boolean = _ => true): Option[List[Any]] = {
+ val params = method.paramss.flatten // We don't use multiple parameter lists
+ val maybeValues: List[Option[Any]] = params.map { p =>
+ val t = p.typeSignature
+ if (t <:< ru.typeOf[Dataset[_]]) {
+ randomChoice(
+ df ::
+ // TODO(josh): restore ability to generate new random DataFrames for use in joins.
+ // dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) ::
+ Nil
+ ).some
+ } else if (t =:= ru.typeOf[Column]) {
+ getRandomColumnName(df, typeConstraint).map(df.col)
+ } else if (t =:= ru.typeOf[String]) {
+ if (p.name == "joinType") {
+ randomChoice(JoinType.supportedJoinTypes).some
+ } else {
+ getRandomColumnName(df, typeConstraint).map(df.col)
+ }
+ } else if (t <:< ru.typeOf[Seq[Column]]) {
+ Seq.fill(Random.nextInt(2) + 1)(
+ getRandomColumnName(df, typeConstraint).map(df.col)).flatten.some
+ } else if (t <:< ru.typeOf[Seq[String]]) {
+ Seq.fill(Random.nextInt(2) + 1)(
+ getRandomColumnName(df, typeConstraint).map(df.col)).flatten.some
+ } else {
+ None
+ }
+ }
+ maybeValues.sequence
+ }
+
+ def getTransformation(df: DataFrame): Option[DataFrameTransformation] = {
+ val method: ru.MethodSymbol = DataFrameFuzzingUtils.randomChoice(dataFrameTransformations)
+ val values: Option[Seq[Any]] = {
+ def validateValues(vs: Seq[Any]): Try[Seq[Any]] = {
+ Try(CallTransformReflectively(method, vs).apply(df)).map(_ => vs)
+ }
+ getParamValues(df, method).map { (vs: Seq[Any]) =>
+ validateValues(vs).recoverWith {
+ case e: AnalysisException if e.getMessage.contains("is not a boolean") =>
+ Try(getParamValues(df, method, _ == BooleanType).get).flatMap(validateValues)
+ case e: AnalysisException
+ if e.getMessage.contains("is not supported for columns of type") =>
+ Try(getParamValues(df, method, _.isInstanceOf[AtomicType]).get).flatMap(validateValues)
+ }
+ }.flatMap(_.toOption)
+ }
+ values.map(vs => CallTransformReflectively(method, vs))
+ }
+}
+
+case class CallTransformReflectively(
+ method: ru.MethodSymbol,
+ args: Seq[Any])(
+ implicit runtimeMirror: ru.Mirror) extends DataFrameTransformation {
+
+ override def apply(df: DataFrame): DataFrame = {
+ val reflectedMethod: ru.MethodMirror = runtimeMirror.reflect(df).reflectMethod(method)
+ try {
+ reflectedMethod.apply(args: _*).asInstanceOf[DataFrame]
+ } catch {
+ case e: InvocationTargetException => throw e.getCause
+ }
+ }
+
+ override def toString(): String = {
+ s"${method.name}(${args.map(_.toString).mkString(", ")})"
+ }
+}