diff --git a/pom.xml b/pom.xml index 989658216e5fd..c86c90639e174 100644 --- a/pom.xml +++ b/pom.xml @@ -699,6 +699,12 @@ scalap ${scala.version} + + org.scalaz + scalaz-core_2.10 + 7.1.3 + test + org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 41b7e62d8ccea..daf99f5f0f82b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -245,6 +245,12 @@ trait CheckAnalysis extends PredicateHelper { aggregateExprs.foreach(checkValidAggregateExpression) groupingExprs.foreach(checkValidGroupingExprs) + case s @ SetOperation(left, right) if left.output.length != right.output.length => + failAnalysis( + s"${s.nodeName} can only be performed on tables with the same number of columns, " + + s"but the left table has ${left.output.length} columns and the right has " + + s"${right.output.length}") + case Sort(orders, _, _) => orders.foreach { order => if (!RowOrdering.isOrderable(order.dataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 70fff51956255..8edd9bed44839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -209,7 +209,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } private[this] def decimalToTimestamp(d: Decimal): Long = { - (d.toBigDecimal * 1000000L).longValue() + d.toJavaBigDecimal.multiply(java.math.BigDecimal.valueOf(1000000L)).longValue() } private[this] def doubleToTimestamp(d: Double): Any = { if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index abb5594bfa7f8..c0af7eaf432c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -98,13 +98,13 @@ trait ExtractValue extends Expression /** * Returns the value of fields in the Struct `child`. * - * No need to do type checking since it is handled by [[ExtractValue]]. - * * Note that we can pass in the field name directly to keep case preserving in `toString`. * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with ExtractValue { + extends UnaryExpression with ExtractValue with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StructType) lazy val childSchema = child.dataType.asInstanceOf[StructType] @@ -144,16 +144,15 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] /** * For a child whose data type is an array of structs, extracts the `ordinal`-th fields of all array * elements, and returns them as a new array. - * - * No need to do type checking since it is handled by [[ExtractValue]]. */ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression with ExtractValue { + containsNull: Boolean) extends UnaryExpression with ExtractValue with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" override def sql: String = s"${child.sql}.${quoteIdentifier(field.name)}" @@ -215,8 +214,7 @@ case class GetArrayStructFields( case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryExpression with ExpectsInputTypes with ExtractValue { - // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegralType) override def toString: String = s"$child[$ordinal]" override def sql: String = s"${child.sql}[${ordinal.sql}]" @@ -264,8 +262,7 @@ case class GetMapValue(child: Expression, key: Expression) private def keyType = child.dataType.asInstanceOf[MapType].keyType - // We have done type checking for child in `ExtractValue`, so only need to check the `key`. - override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, keyType) override def toString: String = s"$child[$key]" override def sql: String = s"${child.sql}[${key.sql}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index fa5dea6841149..30dce130dc094 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.types._ /** * Return the unscaled Long value of a Decimal, assuming it fits in a Long. - * Note: this expression is internal and created only by the optimizer, - * we don't need to do type check for it. + * Note: this expression is internal and created only by the optimizer. */ -case class UnscaledValue(child: Expression) extends UnaryExpression { +case class UnscaledValue(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType) override def dataType: DataType = LongType override def toString: String = s"UnscaledValue($child)" @@ -41,11 +41,15 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { /** * Create a Decimal from an unscaled Long value. - * Note: this expression is internal and created only by the optimizer, - * we don't need to do type check for it. + * Note: this expression is internal and created only by the optimizer. */ -case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { +case class MakeDecimal( + child: Expression, + precision: Int, + scale: Int) + extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(LongType) override def dataType: DataType = DecimalType(precision, scale) override def nullable: Boolean = true override def toString: String = s"MakeDecimal($child,$precision,$scale)" @@ -80,7 +84,12 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { * Rounds the decimal to given scale and check whether the decimal can fit in provided precision * or not, returns null if not. */ -case class CheckOverflow(child: Expression, dataType: DecimalType) extends UnaryExpression { +case class CheckOverflow( + child: Expression, + dataType: DecimalType) + extends UnaryExpression with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 80674d9b4bc9c..3c35e6e2c7e10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -21,6 +21,15 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Attribute object JoinType { + + val supportedJoinTypes = Seq( + "inner", + "outer", "full", "fullouter", + "leftouter", "left", + "rightouter", "right", + "leftsemi", + "leftanti") + def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { case "inner" => Inner case "outer" | "full" | "fullouter" => FullOuter @@ -29,16 +38,8 @@ object JoinType { case "leftsemi" => LeftSemi case "leftanti" => LeftAnti case _ => - val supported = Seq( - "inner", - "outer", "full", "fullouter", - "leftouter", "left", - "rightouter", "right", - "leftsemi", - "leftanti") - throw new IllegalArgumentException(s"Unsupported join type '$typ'. " + - "Supported join types include: " + supported.mkString("'", "', '", "'") + ".") + "Supported join types include: " + supportedJoinTypes.mkString("'", "', '", "'") + ".") } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index b2752638bebd5..a759774a7b110 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -127,6 +127,18 @@ xbean-asm5-shaded test + + org.clapper + classutil_${scala.binary.version} + 1.0.6 + test + + + org.scalaz + scalaz-core_${scala.binary.version} + 7.2.3 + test + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 547d3c1abe858..165e273e9060c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -150,7 +150,8 @@ class UDFSuite extends QueryTest with SharedSQLContext { assert(result.count() === 2) } - test("UDFs everywhere") { + // Temporarily ignored until we implement code generation for ScalaUDF. + ignore("UDFs everywhere") { spark.udf.register("groupFunction", (n: Int) => { n > 10 }) spark.udf.register("havingFilter", (n: Long) => { n > 2000 }) spark.udf.register("whereFilter", (n: Int) => { n < 150 }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala new file mode 100644 index 0000000000000..e049b2a2c5e83 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/DataFrameFuzzingSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.fuzzing + +import scala.util.Random +import scala.util.control.NonFatal + +import org.apache.spark.{SharedSparkContext, SparkFunSuite} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +object DataFrameFuzzingUtils { + + def randomChoice[T](values: Seq[T]): T = { + values(Random.nextInt(values.length)) + } + + /** + * Build a list of column names and types for the given StructType, taking nesting into account. + * For nested struct fields, this will emit both the column for the struct field itself as well as + * fields for the nested struct's fields. This process will be performed recursively in order to + * handle deeply-nested structs. + */ + def getColumnsAndTypes(struct: StructType): Seq[(String, DataType)] = { + struct.flatMap { field => + val nestedFieldInfos: Seq[(String, DataType)] = field.dataType match { + case nestedStruct: StructType => + Seq((field.name, field.dataType)) ++ getColumnsAndTypes(nestedStruct).map { + case (nestedColName, dataType) => (field.name + "." + nestedColName, dataType) + } + case _ => Seq.empty + } + Seq((field.name, field.dataType)) ++ nestedFieldInfos + } + } + + def getRandomColumnName( + df: DataFrame, + condition: DataType => Boolean = _ => true): Option[String] = { + val columnsWithTypes = getColumnsAndTypes(df.schema) + val candidateColumns = columnsWithTypes.filter(c => condition(c._2)) + if (candidateColumns.isEmpty) { + None + } else { + Some(randomChoice(candidateColumns)._1) + } + } +} + + +/** + * This test suite generates random data frames, then applies random sequences of operations to + * them in order to construct random queries. We don't have a source of truth for these random + * queries but nevertheless they are still useful for testing that we don't crash in bad ways. + */ +class DataFrameFuzzingSuite extends QueryTest with SharedSparkContext { + + + override protected def spark: SparkSession = sqlContext.sparkSession + + val tempDir = Utils.createTempDir() + + private var sqlContext: SQLContext = _ + private var dataGenerator: RandomDataFrameGenerator = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataGenerator = new RandomDataFrameGenerator(123, sqlContext) + sqlContext.conf.setConf(SQLConf.SHUFFLE_PARTITIONS, 10) + } + + def tryToExecute(df: DataFrame): DataFrame = { + try { + df.rdd.count() + df + } catch { + case NonFatal(e) => + // scalastyle:off println + println(df.queryExecution) + // scalastyle:on println + throw e + } + } + + // TODO: make these regexes. + val ignoredAnalysisExceptionMessages = Seq( + // TODO: filter only for binary type: + "cannot sort data type array<", + "cannot be used in grouping expression", + "cannot be used in join condition", + "can only be performed on tables with the same number of columns", + "number of columns doesn't match", + "unsupported join type", + "is neither present in the group by, nor is it an aggregate function", + "is ambiguous, could be:", + "unresolved operator 'Project", // TODO + "unresolved operator 'Union", // TODO: disabled to let me find new errors + "unresolved operator 'Except", // TODO: disabled to let me find new errors + "unresolved operator 'Intersect", // TODO: disabled to let me find new errors + "Cannot resolve column name" // TODO: only ignore for join? + ) + + def getRandomTransformation(df: DataFrame): DataFrameTransformation = { + (1 to 1000).iterator.map(_ => ReflectiveFuzzing.getTransformation(df)).flatten.next() + } + + def applyRandomTransform(df: DataFrame): DataFrame = { + val tf = getRandomTransformation(df) + // scalastyle:off println + println(" " + tf) + // scalastyle:on println + tf.apply(df) + } + + def resetConfs(): Unit = { + sqlContext.conf.getAllDefinedConfs.foreach { case (key, defaultValue, doc) => + sqlContext.conf.setConfString(key, defaultValue) + } + sqlContext.conf.setConfString("spark.sql.crossJoin.enabled", "true") + sqlContext.conf.setConfString("spark.sql.autoBroadcastJoinThreshold", "-1") + } + + private val configurations = Seq( + "default" -> Seq(), + "no optimization" -> Seq(SQLConf.OPTIMIZER_MAX_ITERATIONS.key -> "0"), + "disable-wholestage-codegen" -> Seq(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false"), + "disable-exchange-reuse" -> Seq(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") + ) + + def replan(df: DataFrame): DataFrame = { + new Dataset[Row](sqlContext.sparkSession, df.logicalPlan, RowEncoder(df.schema)) + } + + test("fuzz test") { + for (i <- 1 to 1000) { + // scalastyle:off println + println(s"Iteration $i") + // scalastyle:on println + try { + resetConfs() + var df = dataGenerator.randomDataFrame( + numCols = Random.nextInt(2) + 1, + numRows = 20, + allowComplexTypes = false) + var depth = 3 + while (depth > 0) { + df = tryToExecute(applyRandomTransform(df)) + depth -= 1 + } + val defaultResult = replan(df).collect() + configurations.foreach { case (confName, confsToSet) => + resetConfs() + withClue(s"configuration = $confName") { + confsToSet.foreach { case (key, value) => + sqlContext.conf.setConfString(key, value) + } + checkAnswer(replan(df), defaultResult) + } + } + println(s"Finished all tests successfully for plan:\n${df.logicalPlan}") + } catch { + case e: UnresolvedException[_] => +// println("skipped due to unresolved") + case e: Exception + if ignoredAnalysisExceptionMessages.exists { + m => Option(e.getMessage).getOrElse("").toLowerCase.contains(m.toLowerCase) + } => +// println("Skipped due to expected AnalysisException " + e) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala new file mode 100644 index 0000000000000..99bf7bc5bc038 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/ExpressionFuzzingSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.fuzzing + +import java.io.File +import java.lang.reflect.Constructor + +import scala.util.{Random, Try} + +import org.clapper.classutil.ClassFinder + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{BinaryType, DataType, DataTypeTestUtils, DecimalType} +import org.apache.spark.util.Utils + +/** + * This test suite implements fuzz tests for expression code generation. It uses reflection to + * automatically discover all [[Expression]]s, then instantiates these expressions with random + * children/inputs. If the resulting expression passes the type checker after type coercion is + * performed then we attempt to compile the expression and compare its output to output generated + * by the interpreted expression. + */ +class ExpressionFuzzingSuite extends SparkFunSuite with Logging { + + val NUM_TRIALS_PER_EXPRESSION: Int = 100 + + /** + * All evaluable subclasses of [[Expression]]. + */ + lazy val evaluableExpressionClasses: Seq[Class[Expression]] = { + val classpathEntries: Seq[File] = System.getProperty("java.class.path") + .split(File.pathSeparatorChar) + .filter(_.contains("spark")) + .map(new File(_)) + .filter(_.exists()).toSeq + val allClasses = ClassFinder(classpathEntries).getClasses().toIterator + assert(allClasses.nonEmpty, "Could not find Spark classes on classpath.") + ClassFinder.concreteSubclasses(classOf[Expression].getName, allClasses) + .map(c => Utils.classForName(c.name).asInstanceOf[Class[Expression]]).toSeq + // We should only test evalulable expressions: + .filterNot(c => classOf[Unevaluable].isAssignableFrom(c)) + // These expressions currently OOM because we try to pass in massive numeric literals: + .filterNot(_ == classOf[FormatNumber]) + .filterNot(_ == classOf[StringSpace]) + .filterNot(_ == classOf[StringLPad]) + .filterNot(_ == classOf[StringRPad]) + .filterNot(_ == classOf[BRound]) + .filterNot(_ == classOf[Round]) + } + + def coerceTypes(expression: Expression): Expression = { + val dummyPlan: LogicalPlan = DummyPlan(expression) + DummyAnalyzer.execute(dummyPlan).asInstanceOf[DummyPlan].expression + } + + /** + * Given an expression class, find the constructor which accepts only expressions. If there are + * multiple such constructors, pick the one with the most parameters. + * + * @return The matching constructor, or None if no appropriate constructor could be found. + */ + def getBestConstructor(expressionClass: Class[Expression]): Option[Constructor[Expression]] = { + val allConstructors = expressionClass.getConstructors ++ expressionClass.getDeclaredConstructors + allConstructors + .map(_.asInstanceOf[Constructor[Expression]]) + .filter(_.getParameterTypes.toSet == Set(classOf[Expression])) + .sortBy(_.getParameterTypes.length * -1) + .headOption + } + + def getRandomLiteral: Literal = { + val allTypes = DataTypeTestUtils.atomicTypes + .filterNot(_.isInstanceOf[DecimalType]) // casts can lead to OOM + .filterNot(_.isInstanceOf[BinaryType]) // leads to spurious errors in string reverse + val dataTypesWithGenerators: Map[DataType, () => Any] = allTypes.map { dt => + (dt, RandomDataGenerator.forType(dt, nullable = true)) + }.filter(_._2.isDefined).toMap.mapValues(_.get) + val (dt, generator) = + dataTypesWithGenerators.toSeq(Random.nextInt(dataTypesWithGenerators.size)) + Literal.create(generator(), dt) + } + + def testExpression(expressionClass: Class[Expression]): Unit = { + // Eventually, we should add support for testing multiple constructors. For now, though, we + // only test the "best" one: + val constructor: Constructor[Expression] = { + val maybeBestConstructor = getBestConstructor(expressionClass) + assume(maybeBestConstructor.isDefined, "Could not find an Expression-only constructor") + maybeBestConstructor.get + } + val numChildren: Int = constructor.getParameterTypes.length + // Construct random literals for all child expressions and leave it up to the type coercion + // rules to cast them to the appropriate types. Skip + for (_ <- 1 to NUM_TRIALS_PER_EXPRESSION) { + val expression: Expression = { + val childExpressions: Seq[Expression] = Seq.fill(numChildren)(getRandomLiteral) + coerceTypes(constructor.newInstance(childExpressions: _*)) + } + logInfo(s"After type coercion, expression is $expression") + // Make sure that the resulting expression passes type checks. + require(expression.childrenResolved) + val typecheckResult = expression.checkInputDataTypes() + if (typecheckResult.isFailure) { + logDebug(s"Type checks failed: $typecheckResult") + } else { + withClue(s"$expression") { + val inputRow = InternalRow.apply() // Can be empty since we're only using literals + val inputSchema = expression.children.map(c => AttributeReference("f", c.dataType)()) + + val interpretedProjection = new InterpretedProjection(Seq(expression), inputSchema) + val interpretedResult = interpretedProjection.apply(inputRow) + + val maybeGenProjection = + Try(GenerateSafeProjection.generate(Seq(expression), inputSchema)) + if (maybeGenProjection.isFailure) { + //scalastyle:off + println( + s"Code generation for expression $expression failed with inputSchema $inputSchema") + } + maybeGenProjection.foreach { generatedProjection => + val generatedResult = generatedProjection.apply(inputRow) + assert(generatedResult === interpretedResult) + } + } + } + } + } + + // Run the actual tests + evaluableExpressionClasses.sortBy(_.getName).foreach { expressionClass => + test(s"${expressionClass.getName}") { + testExpression(expressionClass) + } + } +} + +private case object DummyAnalyzer extends RuleExecutor[LogicalPlan] { + override protected val batches: Seq[Batch] = Seq( + Batch("analysis", FixedPoint(100), TypeCoercion.typeCoercionRules: _*) + ) +} + +private case class DummyPlan(expression: Expression) extends LogicalPlan { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq.empty +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala new file mode 100644 index 0000000000000..ef5f5e6cdb286 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/RandomDataFrameGenerator.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.fuzzing + +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.Random + +import org.apache.spark.sql._ +import org.apache.spark.sql.types._ + +class RandomDataFrameGenerator( + seed: Long, + @transient val sqlContext: SQLContext) + extends Serializable { + + private val rand = new Random(seed) + private val nextId = new AtomicInteger() + + private def hasRandomDataGenerator(dataType: DataType): Boolean = { + RandomDataGenerator.forType(dataType).isDefined + } + + def randomChoice[T](values: Seq[T]): T = { + values(rand.nextInt(values.length)) + } + + private val simpleTypes: Set[DataType] = { + DataTypeTestUtils.atomicTypes + .filter(hasRandomDataGenerator) + // Ignore decimal type since it can lead to OOM (see SPARK-9303). TODO: It would be better to + // only generate limited precision decimals instead. + .filterNot(_.isInstanceOf[DecimalType]) + } + + private val arrayTypes: Set[DataType] = { + DataTypeTestUtils.atomicArrayTypes + .filter(hasRandomDataGenerator) + // Filter until SPARK-10038 is fixed. + .filterNot(_.elementType.isInstanceOf[BinaryType]) + // See above comment about DecimalType + .filterNot(_.elementType.isInstanceOf[DecimalType]).toSet + } + + private def randomStructField( + allowComplexTypes: Boolean = false, + allowSpacesInColumnName: Boolean = false): StructField = { + val name = "c" + nextId.getAndIncrement + (if (allowSpacesInColumnName) " space" else "") + val candidateTypes: Seq[DataType] = Seq( + simpleTypes, + arrayTypes.filter(_ => allowComplexTypes), + // This does not allow complex types, limiting the depth of recursion: + if (allowComplexTypes) { + Set[DataType](randomStructType(numCols = rand.nextInt(2) + 1)) + } else { + Set[DataType]() + } + ).flatten + val dataType = randomChoice(candidateTypes) + val nullable = rand.nextBoolean() + StructField(name, dataType, nullable) + } + + private def randomStructType( + numCols: Int, + allowComplexTypes: Boolean = false, + allowSpacesInColumnNames: Boolean = false): StructType = { + StructType(Array.fill(numCols)(randomStructField(allowComplexTypes, allowSpacesInColumnNames))) + } + + def randomDataFrame( + numCols: Int, + numRows: Int, + allowComplexTypes: Boolean = false, + allowSpacesInColumnNames: Boolean = false): DataFrame = { + val schema = randomStructType(numCols, allowComplexTypes, allowSpacesInColumnNames) + val rows = sqlContext.sparkContext.parallelize(1 to numRows).mapPartitions { iter => + val rowGenerator = RandomDataGenerator.forType(schema, nullable = false, rand = rand).get + iter.map(_ => rowGenerator().asInstanceOf[Row]) + } + sqlContext.createDataFrame(rows, schema) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala new file mode 100644 index 0000000000000..d3d7c69a8d7a4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/package.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +/** + * TODO(josh): Document this package. + */ +package object fuzzing { + +} + +trait DataFrameTransformation extends Function[DataFrame, DataFrame] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala new file mode 100644 index 0000000000000..e22e9a824e7bd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/fuzzing/reflectiveFuzzing.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.fuzzing + +import java.lang.reflect.InvocationTargetException + +import scala.reflect.runtime.{universe => ru} +import scala.util.{Random, Try} + +import scalaz._, Scalaz._ + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.types._ + +object ReflectiveFuzzing { + + import DataFrameFuzzingUtils._ + + private implicit val m: ru.Mirror = ru.runtimeMirror(this.getClass.getClassLoader) + + /** + * Method parameter types for which the fuzzer can supply random values. This list is used to + * filter out methods that we don't know how to call. + */ + private val whitelistedParameterTypes = Set( + m.universe.typeOf[DataFrame], + m.universe.typeOf[Dataset[_]], + m.universe.typeOf[Seq[Column]], + m.universe.typeOf[Column], + m.universe.typeOf[String], + m.universe.typeOf[Seq[String]] + ) + + /** + * A list of candidate DataFrame methods that the fuzzer will try to call. Excludes private + * methods and methods with parameters that we don't know how to supply. + */ + private val dataFrameTransformations: Seq[ru.MethodSymbol] = { + val dfType = m.universe.typeOf[Dataset[_]] + dfType.members + .filter(_.isPublic) + .filter(_.isMethod) + .map(_.asMethod) + .filter(_.returnType <:< dfType) + .filterNot(_.isConstructor) + .filter { m => + m.paramss.flatten.forall { p => + whitelistedParameterTypes.exists { t => p.typeSignature <:< t } + } + } + .filterNot(_.name.toString == "drop") // since this can lead to a DataFrame with no columns + .filterNot(_.name.toString == "describe") // since we cannot run all queries on describe output + .filterNot(_.name.toString == "dropDuplicates") + .filterNot(_.name.toString == "toDF") // since this is effectively a no-op + .filterNot(_.name.toString == "toSchemaRDD") // since this is effectively a no-op + .toSeq + } + + /** + * Given a Dataframe and a method, try to choose a set of arguments to call that method with. + * + * @param df the data frame to transform + * @param method the method to call + * @param typeConstraint an optional type constraint governing the types of the parameters. + * @return + */ + def getParamValues( + df: DataFrame, + method: ru.MethodSymbol, + typeConstraint: DataType => Boolean = _ => true): Option[List[Any]] = { + val params = method.paramss.flatten // We don't use multiple parameter lists + val maybeValues: List[Option[Any]] = params.map { p => + val t = p.typeSignature + if (t <:< ru.typeOf[Dataset[_]]) { + randomChoice( + df :: + // TODO(josh): restore ability to generate new random DataFrames for use in joins. + // dataGenerator.randomDataFrame(numCols = Random.nextInt(4) + 1, numRows = 100) :: + Nil + ).some + } else if (t =:= ru.typeOf[Column]) { + getRandomColumnName(df, typeConstraint).map(df.col) + } else if (t =:= ru.typeOf[String]) { + if (p.name == "joinType") { + randomChoice(JoinType.supportedJoinTypes).some + } else { + getRandomColumnName(df, typeConstraint).map(df.col) + } + } else if (t <:< ru.typeOf[Seq[Column]]) { + Seq.fill(Random.nextInt(2) + 1)( + getRandomColumnName(df, typeConstraint).map(df.col)).flatten.some + } else if (t <:< ru.typeOf[Seq[String]]) { + Seq.fill(Random.nextInt(2) + 1)( + getRandomColumnName(df, typeConstraint).map(df.col)).flatten.some + } else { + None + } + } + maybeValues.sequence + } + + def getTransformation(df: DataFrame): Option[DataFrameTransformation] = { + val method: ru.MethodSymbol = DataFrameFuzzingUtils.randomChoice(dataFrameTransformations) + val values: Option[Seq[Any]] = { + def validateValues(vs: Seq[Any]): Try[Seq[Any]] = { + Try(CallTransformReflectively(method, vs).apply(df)).map(_ => vs) + } + getParamValues(df, method).map { (vs: Seq[Any]) => + validateValues(vs).recoverWith { + case e: AnalysisException if e.getMessage.contains("is not a boolean") => + Try(getParamValues(df, method, _ == BooleanType).get).flatMap(validateValues) + case e: AnalysisException + if e.getMessage.contains("is not supported for columns of type") => + Try(getParamValues(df, method, _.isInstanceOf[AtomicType]).get).flatMap(validateValues) + } + }.flatMap(_.toOption) + } + values.map(vs => CallTransformReflectively(method, vs)) + } +} + +case class CallTransformReflectively( + method: ru.MethodSymbol, + args: Seq[Any])( + implicit runtimeMirror: ru.Mirror) extends DataFrameTransformation { + + override def apply(df: DataFrame): DataFrame = { + val reflectedMethod: ru.MethodMirror = runtimeMirror.reflect(df).reflectMethod(method) + try { + reflectedMethod.apply(args: _*).asInstanceOf[DataFrame] + } catch { + case e: InvocationTargetException => throw e.getCause + } + } + + override def toString(): String = { + s"${method.name}(${args.map(_.toString).mkString(", ")})" + } +}