From 213ada833645f43a23b04204f503627c4cf3a945 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 16:11:49 -0700 Subject: [PATCH 01/24] First draft of partially aggregated and code generated count distinct / max --- .../sql/catalyst/expressions/Projection.scala | 3 +- .../sql/catalyst/expressions/aggregates.scala | 103 ++++++++++++++-- .../sql/catalyst/expressions/arithmetic.scala | 14 +++ .../expressions/codegen/CodeGenerator.scala | 80 +++++++++++- .../codegen/GenerateProjection.scala | 9 +- .../spark/sql/catalyst/expressions/sets.scala | 115 ++++++++++++++++++ .../sql/execution/GeneratedAggregate.scala | 34 ++++++ .../sql/execution/SparkSqlSerializer.scala | 80 ++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 3 +- 9 files changed, 428 insertions(+), 13 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 8fc5896974438..52904ff0221d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -27,7 +27,8 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) - protected val exprArray = expressions.toArray + // null check is required for when Kryo invokes the no-arg constructor. + protected val exprArray = if (expressions != null) expressions.toArray else null def apply(input: Row): Row = { val outputArray = new Array[Any](exprArray.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 01947273b6ccc..ea8ce757055d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -22,6 +22,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.util.collection.OpenHashSet abstract class AggregateExpression extends Expression { self: Product => @@ -161,13 +162,96 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance() = new CountFunction(child, this) } -case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { +case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { + def this() = this(null) + override def children = expressions override def references = expressions.flatMap(_.references).toSet override def nullable = false override def dataType = LongType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")})" override def newInstance() = new CountDistinctFunction(expressions, this) + + override def asPartial = { + val partialSet = Alias(CollectHashSet(expressions), "partialSets")() + SplitEvaluation( + CombineSetsAndCount(partialSet.toAttribute), + partialSet :: Nil) + } +} + +case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { + def this() = this(null) + + override def children = expressions + override def references = expressions.flatMap(_.references).toSet + override def nullable = false + override def dataType = ArrayType(expressions.head.dataType) + override def toString = s"AddToHashSet(${expressions.mkString(",")})" + override def newInstance() = new CollectHashSetFunction(expressions, this) +} + +case class CollectHashSetFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends MergableAggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) + + override def merge(other: MergableAggregateFunction): MergableAggregateFunction = { + val otherSetIterator = other.asInstanceOf[CountDistinctFunction].seen.iterator + while(otherSetIterator.hasNext) { + seen.add(otherSetIterator.next()) + } + this + } + + override def update(input: Row): Unit = { + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: Row): Any = { + seen + } +} + +case class CombineSetsAndCount(inputSet: Expression) extends AggregateExpression { + def this() = this(null) + + override def children = inputSet :: Nil + override def references = inputSet.references + override def nullable = false + override def dataType = LongType + override def toString = s"CombineAndCount($inputSet)" + override def newInstance() = new CombineSetsAndCountFunction(inputSet, this) +} + +case class CombineSetsAndCountFunction( + @transient inputSet: Expression, + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + override def update(input: Row): Unit = { + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) + } + } + + override def eval(input: Row): Any = seen.size.toLong } case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) @@ -379,17 +463,22 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) } -case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpression) - extends AggregateFunction { +case class CountDistinctFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends MergableAggregateFunction { def this() = this(null, null) // Required for serialization. - val seen = new scala.collection.mutable.HashSet[Any]() + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) override def update(input: Row): Unit = { - val evaluatedExpr = expr.map(_.eval(input)) - if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) { - seen += evaluatedExpr + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c79c1847cedf5..3d2dbac24eae8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -85,3 +85,17 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } + +case class MaxOf(left: Expression, right: Expression) extends Expression { + type EvaluatedType = Any + + override def nullable = left.nullable && right.nullable + + override def children = left :: right :: Nil + + override def references = (left.flatMap(_.references) ++ right.flatMap(_.references)).toSet + + override def dataType = left.dataType + + override def eval(input: Row): Any = ??? +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5b398695bf560..6da33b5f6efba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ +class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] +class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] + /** * A base class for generators of byte code to perform expression evaluation. Includes a set of * helpers for referring to Catalyst types and building trees that perform evaluation of individual @@ -71,7 +74,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * From the Guava Docs: A Cache is similar to ConcurrentMap, but not quite the same. The most * fundamental difference is that a ConcurrentMap persists all elements that are added to it until * they are explicitly removed. A Cache on the other hand is generally configured to evict entries - * automatically, in order to constrain its memory footprint + * automatically, in order to constrain its memory footprint. Note that this cache does not use + * weak keys/values and thus does not respond to memory pressure. */ protected val cache = CacheBuilder.newBuilder() .maximumSize(1000) @@ -398,6 +402,75 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin $primitiveTerm = ${falseEval.primitiveTerm} } """.children + + case NewSet(elementType) => + q""" + val $nullTerm = false + val $primitiveTerm = new ${hashSetForType(elementType)}() + """.children + + case AddItemToSet(item, set) => + val itemEval = expressionEvaluator(item) + val setEval = expressionEvaluator(set) + + val ArrayType(elementType, _) = set.dataType + + itemEval.code ++ setEval.code ++ + q""" + if (!${itemEval.nullTerm}) { + ${setEval.primitiveTerm} + .asInstanceOf[${hashSetForType(elementType)}] + .add(${itemEval.primitiveTerm}) + } + + val $nullTerm = false + val $primitiveTerm = ${setEval.primitiveTerm} + """.children + + case CombineSets(left, right) => + val leftEval = expressionEvaluator(left) + val rightEval = expressionEvaluator(right) + + val ArrayType(elementType, _) = left.dataType + + leftEval.code ++ rightEval.code ++ + q""" + val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val iterator = rightSet.iterator + while (iterator.hasNext) { + leftSet.add(iterator.next()) + } + + val $nullTerm = false + val $primitiveTerm = leftSet + """.children + + case MaxOf(e1, e2) => + val eval1 = expressionEvaluator(e1) + val eval2 = expressionEvaluator(e2) + + eval1.code ++ eval2.code ++ + q""" + var $nullTerm = false + var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)} + + if (${eval1.nullTerm}) { + $nullTerm = ${eval2.nullTerm} + $primitiveTerm = ${eval2.primitiveTerm} + } else if (${eval2.nullTerm}) { + $nullTerm = ${eval1.nullTerm} + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $nullTerm = false + if (${eval1.primitiveTerm} > ${eval2.primitiveTerm}) { + $primitiveTerm = ${eval1.primitiveTerm} + } else { + $primitiveTerm = ${eval2.primitiveTerm} + } + } + """.children + } // If there was no match in the partial function above, we fall back on calling the interpreted @@ -437,6 +510,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def accessorForType(dt: DataType) = newTermName(s"get${primitiveForType(dt)}") protected def mutatorForType(dt: DataType) = newTermName(s"set${primitiveForType(dt)}") + protected def hashSetForType(dt: DataType) = dt match { + case IntegerType => typeOf[IntegerHashSet] + case LongType => typeOf[LongHashSet] + } + protected def primitiveForType(dt: DataType) = dt match { case IntegerType => "Int" case LongType => "Long" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 77fa02c13de30..7871a62620478 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -69,8 +69,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { ..${evaluatedExpression.code} if(${evaluatedExpression.nullTerm}) setNullAt($iLit) - else + else { + nullBits($iLit) = false $elementName = ${evaluatedExpression.primitiveTerm} + } } """.children : Seq[Tree] } @@ -106,9 +108,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { if(value == null) { setNullAt(i) } else { + nullBits(i) = false $elementName = value.asInstanceOf[${termForType(e.dataType)}] - return } + return }""" } q"final def update(i: Int, value: Any): Unit = { ..$cases; $accessorFailure }" @@ -137,7 +140,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? - q"if(i == $i) { $elementName = value; return }" :: Nil + q"if(i == $i) { nullBits($i) = false; $elementName = value; return }" :: Nil case _ => Nil } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala new file mode 100644 index 0000000000000..14a95328c0a3a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.util.collection.OpenHashSet + +case class NewSet(elementType: DataType) extends LeafExpression { + type EvaluatedType = Any + + def references = Set.empty + + def nullable = false + + // This is not completely accurate.. + def dataType = ArrayType(elementType) + + def eval(input: Row): Any = { + new OpenHashSet[Any]() + } + + override def toString = s"new Set($dataType)" +} + +// THIS MUTATES ITS ARUGMENTS +case class AddItemToSet(item: Expression, set: Expression) extends Expression { + type EvaluatedType = Any + + def children = item :: set :: Nil + + def nullable = set.nullable + + def dataType = set.dataType + + def references = (item.flatMap(_.references) ++ set.flatMap(_.references)).toSet + + def eval(input: Row): Any = { + val itemEval = item.eval(input) + if (itemEval != null) { + val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] + if (setEval != null) { + setEval.add(itemEval) + setEval + } else { + null + } + } else { + null + } + } + + override def toString = s"$set += $item" +} + +// THIS MUTATES ITS ARUGMENTS +case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { + type EvaluatedType = Any + + def nullable = left.nullable || right.nullable + + def dataType = left.dataType + + def symbol = "++=" + + def eval(input: Row): Any = { + val leftEval = left.eval(input).asInstanceOf[OpenHashSet[Any]] + if(leftEval != null) { + val rightEval = right.eval(input).asInstanceOf[OpenHashSet[Any]] + if (rightEval != null) { + val iterator = rightEval.iterator + while(iterator.hasNext) { + val rightValue = iterator.next() + leftEval.add(rightValue) + } + leftEval + } else { + null + } + } else { + null + } + } +} + +case class CountSet(child: Expression) extends UnaryExpression { + type EvaluatedType = Any + + def nullable = child.nullable + + def dataType = LongType + + def eval(input: Row): Any = { + val childEval = child.eval(input).asInstanceOf[OpenHashSet[Any]] + if (childEval != null) { + childEval.size.toLong + } + } + + override def toString = s"$child.count()" +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 4a26934c49c93..39a96b2dd218f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -103,6 +103,40 @@ case class GeneratedAggregate( updateCount :: updateSum :: Nil, result ) + + case m @ Max(expr) => + val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() + val initialValue = Literal(null, expr.dataType) + val updateMax = MaxOf(currentMax, expr) + //If(IsNull(currentMax), expr, If(GreaterThan(currentMax, expr), currentMax, expr)) + + AggregateEvaluation( + currentMax :: Nil, + initialValue :: Nil, + updateMax :: Nil, + currentMax) + + case CollectHashSet(Seq(expr)) => + val set = AttributeReference("hashSet", ArrayType(expr.dataType), nullable = false)() + val initialValue = NewSet(expr.dataType) + val addToSet = AddItemToSet(expr, set) + + AggregateEvaluation( + set :: Nil, + initialValue :: Nil, + addToSet :: Nil, + set) + + case CombineSetsAndCount(inputSet) => + val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() + val initialValue = NewSet(IntegerType) // NOT TRUE + val collectSets = CombineSets(set, inputSet) + + AggregateEvaluation( + set :: Nil, + initialValue :: Nil, + collectSets :: Nil, + CountSet(set)) } val computationSchema = computeFunctions.flatMap(_.schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 34654447a5f4b..540013518cf86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import org.apache.spark.util.collection.OpenHashSet + import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -31,6 +33,8 @@ import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} + private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { val kryo = new Kryo() @@ -41,6 +45,10 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], new HyperLogLogSerializer) kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) + // Specific hashset must come first + kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) + kryo.register(classOf[LongHashSet], new LongHashSetSerializer) + kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) new AllScalaRegistrar().apply(kryo) @@ -109,3 +117,75 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { HyperLogLog.Builder.build(bytes) } } + +private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { + def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) { + output.writeInt(hs.size) + val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]] + val iterator = hs.iterator + while(iterator.hasNext) { + val row = iterator.next() + rowSerializer.write(kryo, output, row) + } + } + + def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = { + val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]] + + val numItems = input.readInt() + val set = new OpenHashSet[Any](numItems) + var i = 0 + while (i < numItems) { + val row = rowSerializer.read(kryo, input, classOf[Any].asInstanceOf[Class[Any]]) + set.add(row) + i += 1 + } + set + } +} + +private[sql] class IntegerHashSetSerializer extends Serializer[IntegerHashSet] { + def write(kryo: Kryo, output: Output, hs: IntegerHashSet) { + output.writeInt(hs.size) + val iterator = hs.iterator + while(iterator.hasNext) { + val value: Int = iterator.next() + output.writeInt(value) + } + } + + def read(kryo: Kryo, input: Input, tpe: Class[IntegerHashSet]): IntegerHashSet = { + val numItems = input.readInt() + val set = new IntegerHashSet + var i = 0 + while (i < numItems) { + val value = input.readInt() + set.add(value) + i += 1 + } + set + } +} + +private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { + def write(kryo: Kryo, output: Output, hs: LongHashSet) { + output.writeInt(hs.size) + val iterator = hs.iterator + while(iterator.hasNext) { + val value = iterator.next() + output.writeLong(value) + } + } + + def read(kryo: Kryo, input: Input, tpe: Class[LongHashSet]): LongHashSet = { + val numItems = input.readInt() + val set = new LongHashSet + var i = 0 + while (i < numItems) { + val value = input.readLong() + set.add(value) + i += 1 + } + set + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f0c958fdb537f..5a8dbc3916e4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -148,7 +148,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { - case _: Sum | _: Count => false + case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false + case CollectHashSet(exprs) if exprs.size == 1 => false case _ => true } From bd0823901502a1cb52d55fb19039a9847694d907 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 16:28:35 -0700 Subject: [PATCH 02/24] WIP --- .../apache/spark/sql/catalyst/expressions/aggregates.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index ea8ce757055d5..5caf032fdf8a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -194,7 +194,7 @@ case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpress case class CollectHashSetFunction( @transient expr: Seq[Expression], @transient base: AggregateExpression) - extends MergableAggregateFunction { + extends AggregateFunction { def this() = this(null, null) // Required for serialization. @@ -203,13 +203,14 @@ case class CollectHashSetFunction( @transient val distinctValue = new InterpretedProjection(expr) +/* override def merge(other: MergableAggregateFunction): MergableAggregateFunction = { val otherSetIterator = other.asInstanceOf[CountDistinctFunction].seen.iterator while(otherSetIterator.hasNext) { seen.add(otherSetIterator.next()) } this - } + }*/ override def update(input: Row): Unit = { val evaluatedExpr = distinctValue(input) @@ -466,7 +467,7 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) case class CountDistinctFunction( @transient expr: Seq[Expression], @transient base: AggregateExpression) - extends MergableAggregateFunction { + extends AggregateFunction { def this() = this(null, null) // Required for serialization. From 050bb9725d72affaa2a11552d131618384b435be Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 16:37:34 -0700 Subject: [PATCH 03/24] Skip no-arg constructors for kryo, --- .../scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cd04bdf02cf84..96ce35939e2cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -280,7 +280,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { */ def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") { try { - val defaultCtor = getClass.getConstructors.head + // Skip no-arg constructors that are just there for kryo. + val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head if (otherCopyArgs.isEmpty) { defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type] } else { From 41fbd1db14f0fbc11846548ed0ee1c8c941b6faf Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 16:53:08 -0700 Subject: [PATCH 04/24] Never try and create an empty hash set. --- .../org/apache/spark/sql/execution/SparkSqlSerializer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 540013518cf86..1c5995f82d1e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -133,7 +133,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]] val numItems = input.readInt() - val set = new OpenHashSet[Any](numItems) + val set = new OpenHashSet[Any](numItems + 1) var i = 0 while (i < numItems) { val row = rowSerializer.read(kryo, input, classOf[Any].asInstanceOf[Class[Any]]) From d4945987a87b6b020b3845a47bea7e949fe82cdd Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 17:04:38 -0700 Subject: [PATCH 05/24] Fix tests now that the planner is better --- .../org/apache/spark/sql/execution/PlannerSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 76b1724471442..37d64f0de7bab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -45,16 +45,16 @@ class PlannerSuite extends FunSuite { assert(aggregations.size === 2) } - test("count distinct is not partially aggregated") { + test("count distinct is partially aggregated") { val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed val planned = HashAggregation(query) - assert(planned.isEmpty) + assert(planned.nonEmpty) } - test("mixed aggregates are not partially aggregated") { + test("mixed aggregates are partially aggregated") { val query = testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed val planned = HashAggregation(query) - assert(planned.isEmpty) + assert(planned.nonEmpty) } } From 915365256212dae5e4ae43b1ec24238dbb85914e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 17:28:15 -0700 Subject: [PATCH 06/24] better toString --- .../org/apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 3d2dbac24eae8..d9784882de805 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -98,4 +98,6 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def dataType = left.dataType override def eval(input: Row): Any = ??? + + override def toString = s"MaxOf($left, $right)" } From 38c7449a3eeb65a4246e19eaca69c9a8bc0d838d Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 18:01:45 -0700 Subject: [PATCH 07/24] comments and style --- .../sql/catalyst/expressions/aggregates.scala | 9 -------- .../sql/catalyst/expressions/arithmetic.scala | 4 +++- .../spark/sql/catalyst/expressions/sets.scala | 21 +++++++++++++++---- .../sql/execution/GeneratedAggregate.scala | 4 ++-- .../sql/execution/SparkSqlSerializer.scala | 9 +++++--- 5 files changed, 28 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 5caf032fdf8a8..613b87ca98d97 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -203,15 +203,6 @@ case class CollectHashSetFunction( @transient val distinctValue = new InterpretedProjection(expr) -/* - override def merge(other: MergableAggregateFunction): MergableAggregateFunction = { - val otherSetIterator = other.asInstanceOf[CountDistinctFunction].seen.iterator - while(otherSetIterator.hasNext) { - seen.add(otherSetIterator.next()) - } - this - }*/ - override def update(input: Row): Unit = { val evaluatedExpr = distinctValue(input) if (!evaluatedExpr.anyNull) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d9784882de805..923a9b1445d6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -97,7 +97,9 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def dataType = left.dataType - override def eval(input: Row): Any = ??? + override def eval(input: Row): Any = { + val leftEval = left.eval(input) + } override def toString = s"MaxOf($left, $right)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 14a95328c0a3a..b7433b48b82af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -20,6 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ import org.apache.spark.util.collection.OpenHashSet +/** + * Creates a new set of the specified type + */ case class NewSet(elementType: DataType) extends LeafExpression { type EvaluatedType = Any @@ -27,7 +30,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { def nullable = false - // This is not completely accurate.. + // We are currently only using these Expressions internally for aggregation. However, if we ever + // expose these to users we'll want to create a proper type instead of hijacking ArrayType. def dataType = ArrayType(elementType) def eval(input: Row): Any = { @@ -37,7 +41,10 @@ case class NewSet(elementType: DataType) extends LeafExpression { override def toString = s"new Set($dataType)" } -// THIS MUTATES ITS ARUGMENTS +/** + * Adds an item to a set. + * For performance, this expression mutates its input during evaluation. + */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { type EvaluatedType = Any @@ -67,7 +74,10 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def toString = s"$set += $item" } -// THIS MUTATES ITS ARUGMENTS +/** + * Combines the elements of two sets. + * For performance, this expression mutates its left input set during evaluation. + */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { type EvaluatedType = Any @@ -97,6 +107,9 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres } } +/** + * Returns the number of elements in the input set. + */ case class CountSet(child: Expression) extends UnaryExpression { type EvaluatedType = Any @@ -112,4 +125,4 @@ case class CountSet(child: Expression) extends UnaryExpression { } override def toString = s"$child.count()" -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 39a96b2dd218f..17418726aa625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -108,7 +108,6 @@ case class GeneratedAggregate( val currentMax = AttributeReference("currentMax", expr.dataType, nullable = true)() val initialValue = Literal(null, expr.dataType) val updateMax = MaxOf(currentMax, expr) - //If(IsNull(currentMax), expr, If(GreaterThan(currentMax, expr), currentMax, expr)) AggregateEvaluation( currentMax :: Nil, @@ -128,8 +127,9 @@ case class GeneratedAggregate( set) case CombineSetsAndCount(inputSet) => + val ArrayType(inputType) = inputSet.dataType val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() - val initialValue = NewSet(IntegerType) // NOT TRUE + val initialValue = NewSet(inputType) val collectSets = CombineSets(set, inputSet) AggregateEvaluation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 1c5995f82d1e1..4956948e83914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -45,10 +45,13 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], new HyperLogLogSerializer) kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) - // Specific hashset must come first + + // Specific hashsets must come first kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) kryo.register(classOf[LongHashSet], new LongHashSetSerializer) - kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], new OpenHashSetSerializer) + kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], + new OpenHashSetSerializer) + kryo.setReferences(false) kryo.setClassLoader(Utils.getSparkClassLoader) new AllScalaRegistrar().apply(kryo) @@ -188,4 +191,4 @@ private[sql] class LongHashSetSerializer extends Serializer[LongHashSet] { } set } -} \ No newline at end of file +} From f31b8add6755b3b1c74a43e7feadb3e96cdaeba1 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 18:41:39 -0700 Subject: [PATCH 08/24] more fixes --- .../org/apache/spark/sql/execution/GeneratedAggregate.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 17418726aa625..54c9d033cd2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -127,7 +127,7 @@ case class GeneratedAggregate( set) case CombineSetsAndCount(inputSet) => - val ArrayType(inputType) = inputSet.dataType + val ArrayType(inputType, _) = inputSet.dataType val set = AttributeReference("hashSet", inputSet.dataType, nullable = false)() val initialValue = NewSet(inputType) val collectSets = CombineSets(set, inputSet) From c1f7114387c4e97310de0d0616ad84111025b01a Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 22:31:01 -0700 Subject: [PATCH 09/24] Improve tests / fix serialization. --- .../spark/sql/catalyst/expressions/Row.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 19 +++--- .../spark/sql/catalyst/expressions/sets.scala | 5 +- .../sql/execution/SparkSqlSerializer.scala | 10 +-- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../sql/hive/execution/HiveQuerySuite.scala | 65 +++++++++++++++++++ 6 files changed, 89 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index c9a63e201ef60..d68a4fabeac77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -127,7 +127,7 @@ object EmptyRow extends Row { * the array is not copied, and thus could technically be mutated after creation, this is not * allowed. */ -class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { +class GenericRow(protected[sql] val values: Array[Any]) extends Row { /** No-arg constructor for serialization. */ def this() = this(null) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 6da33b5f6efba..bcabbddd381d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -435,15 +435,18 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin leftEval.code ++ rightEval.code ++ q""" - val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] - val iterator = rightSet.iterator - while (iterator.hasNext) { - leftSet.add(iterator.next()) - } - val $nullTerm = false - val $primitiveTerm = leftSet + var $primitiveTerm: ${hashSetForType(elementType)} = null + + { + val leftSet = ${leftEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val rightSet = ${rightEval.primitiveTerm}.asInstanceOf[${hashSetForType(elementType)}] + val iterator = rightSet.iterator + while (iterator.hasNext) { + leftSet.add(iterator.next()) + } + $primitiveTerm = leftSet + } """.children case MaxOf(e1, e2) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index b7433b48b82af..e6c570b47bee2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -58,8 +58,9 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { def eval(input: Row): Any = { val itemEval = item.eval(input) + val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] + if (itemEval != null) { - val setEval = set.eval(input).asInstanceOf[OpenHashSet[Any]] if (setEval != null) { setEval.add(itemEval) setEval @@ -67,7 +68,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { null } } else { - null + setEval } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 4956948e83914..b18f1c8d32f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer +import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.util.collection.OpenHashSet import scala.reflect.ClassTag @@ -123,23 +124,22 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] { private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { def write(kryo: Kryo, output: Output, hs: OpenHashSet[_]) { + val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] output.writeInt(hs.size) - val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]] val iterator = hs.iterator while(iterator.hasNext) { val row = iterator.next() - rowSerializer.write(kryo, output, row) + rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values) } } def read(kryo: Kryo, input: Input, tpe: Class[OpenHashSet[_]]): OpenHashSet[_] = { - val rowSerializer = kryo.getSerializer(classOf[Any]).asInstanceOf[Serializer[Any]] - + val rowSerializer = kryo.getDefaultSerializer(classOf[Array[Any]]).asInstanceOf[Serializer[Any]] val numItems = input.readInt() val set = new OpenHashSet[Any](numItems + 1) var i = 0 while (i < numItems) { - val row = rowSerializer.read(kryo, input, classOf[Any].asInstanceOf[Class[Any]]) + val row = new GenericRow(rowSerializer.read(kryo, input, classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) set.add(row) i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5a8dbc3916e4b..517b77804ae2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan} import org.apache.spark.sql.parquet._ @@ -149,7 +150,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def canBeCodeGened(aggs: Seq[AggregateExpression]) = !aggs.exists { case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false - case CollectHashSet(exprs) if exprs.size == 1 => false + // The generated set implementation is pretty limited ATM. + case CollectHashSet(exprs) if exprs.size == 1 && + Seq(IntegerType, LongType).contains(exprs.head.dataType) => false case _ => true } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index fdb2f41f5a5b6..26e4ec6e6dcce 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -32,6 +32,71 @@ case class TestData(a: Int, b: String) */ class HiveQuerySuite extends HiveComparisonTest { + createQueryTest("count distinct 0 values", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 'a' AS a FROM src LIMIT 0) table + """.stripMargin) + + createQueryTest("count distinct 1 value strings", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 'a' AS a FROM src LIMIT 1 UNION ALL + | SELECT 'b' AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 1 AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 2 values", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 2 AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 2 values including null", + """ + |SELECT COUNT(DISTINCT a, 1) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT null AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value + null", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT 1 AS a FROM src LIMIT 1 UNION ALL + | SELECT null AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value long", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT 1L AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 2 values long", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT 2L AS a FROM src LIMIT 1) table + """.stripMargin) + + createQueryTest("count distinct 1 value + null long", + """ + |SELECT COUNT(DISTINCT a) FROM ( + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT 1L AS a FROM src LIMIT 1 UNION ALL + | SELECT null AS a FROM src LIMIT 1) table + """.stripMargin) + createQueryTest("null case", "SELECT case when(true) then 1 else null end FROM src LIMIT 1") From b3d0f6489ecc8fe9de505b3250599ab741351697 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sun, 17 Aug 2014 22:56:02 -0700 Subject: [PATCH 10/24] Add golden files. --- .../count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc | 1 + ...stinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 | 1 + ...nt distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 | 1 + ...ount distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff | 1 + ...t distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d | 1 + .../count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 | 1 + ...ct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e | 1 + ...unt distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 | 1 + .../count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e | 1 + ...how_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce | 0 10 files changed, 9 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc create mode 100644 sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 create mode 100644 sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 create mode 100644 sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff create mode 100644 sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d create mode 100644 sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 create mode 100644 sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e create mode 100644 sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 create mode 100644 sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e create mode 100644 sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce diff --git a/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc b/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 0 values-0-1843b7947729b771fee3a4abd050bfdc @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 b/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value + null long-0-89b850197b326239d60a5e1d5db7c9c9 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 b/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value + null-0-a014038c00fb81e88041ed4a8368e6f7 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff b/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value long-0-77b9ed1d7ae65fa53830a3bc586856ff @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d b/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d new file mode 100644 index 0000000000000..0cfbf08886fca --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value strings-0-c68e75ec4c884b93765a466e992e391d @@ -0,0 +1 @@ +2 diff --git a/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 b/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 1 value-0-a4047b06a324fb5ea400c94350c9e038 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e b/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 2 values including null-0-75672236a30e10dab13b9b246c5a3a1e @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 b/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 new file mode 100644 index 0000000000000..0cfbf08886fca --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 2 values long-0-f4ec7d767ba8c49d41edf5d6f58cf6d1 @@ -0,0 +1 @@ +2 diff --git a/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e b/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e new file mode 100644 index 0000000000000..0cfbf08886fca --- /dev/null +++ b/sql/hive/src/test/resources/golden/count distinct 2 values-0-c61df65af167acaf7edb174e77898f3e @@ -0,0 +1 @@ +2 diff --git a/sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce b/sql/hive/src/test/resources/golden/show_create_table_delimited-0-52b0e534c7df544258a1c59df9f816ce new file mode 100644 index 0000000000000..e69de29bb2d1d From 57ae3b1f4ce0ae4c4522ada5c1275d55e4e88563 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 18 Aug 2014 13:21:37 -0700 Subject: [PATCH 11/24] Fix order dependent test --- .../apache/spark/sql/hive/StatisticsSuite.scala | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 7c82964b5ecdc..206f514e9de00 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.scalatest.BeforeAndAfterAll + import scala.reflect.ClassTag @@ -26,7 +28,16 @@ import org.apache.spark.sql.execution.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -class StatisticsSuite extends QueryTest { +class StatisticsSuite extends QueryTest with BeforeAndAfterAll { + + override def beforeAll() = { + // HACK: Cached tables do not currently preserve statistics... + TestHive.cacheTables = false + } + + override def afterAll() = { + TestHive.cacheTables = true + } test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { @@ -126,7 +137,7 @@ class StatisticsSuite extends QueryTest { val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation => mr.statistics.sizeInBytes } - assert(sizes.size === 1) + assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}") assert(sizes(0).equals(BigInt(5812)), s"expected exact size 5812 for test table 'src', got: ${sizes(0)}") } @@ -147,7 +158,7 @@ class StatisticsSuite extends QueryTest { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold, - s"query should contain two relations, each of which has size smaller than autoConvertSize") + s"query should contain two relations, each of which has size smaller than autoConvertSize instead ${rdd.queryExecution}") // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. From abee26dd0276d107531a41c09744b21f0c49aa95 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 18 Aug 2014 14:05:36 -0700 Subject: [PATCH 12/24] WIP --- .../scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index c79a9ac2dad81..27d18f245b125 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -376,6 +376,8 @@ private[parquet] object ParquetTypesConverter extends Logging { } ParquetRelation.enableLogForwarding() + println(fs.listStatus(path).toSeq) + val children = fs.listStatus(path).filterNot { status => val name = status.getPath.getName name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME @@ -393,7 +395,7 @@ private[parquet] object ParquetTypesConverter extends Logging { .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) .map(ParquetFileReader.readFooter(conf, _)) .getOrElse( - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path in ${children.map(_.getPath.getName).toSeq}")) } /** From 87d101d396e8ed9149b3e7e8503ae0f598d5f4b9 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 18 Aug 2014 18:48:56 -0700 Subject: [PATCH 13/24] Fix isNullAt bug --- .../org/apache/spark/sql/catalyst/expressions/Projection.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 52904ff0221d4..968c5f1fe8f4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -110,7 +110,8 @@ class JoinedRow extends Row { def apply(i: Int) = if (i < row1.size) row1(i) else row2(i - row1.size) - def isNullAt(i: Int) = apply(i) == null + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) def getInt(i: Int): Int = if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) From 58d15f1e260c175fef8b08447361e04435dc381f Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 18 Aug 2014 18:49:19 -0700 Subject: [PATCH 14/24] disable codegen logging --- .../spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2004af95d514d..7362d3d6b0e7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -496,7 +496,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Only inject debugging code if debugging is turned on. val debugCode = - if (log.isDebugEnabled) { + if (false) { val localLogger = log val localLoggerTree = reify { localLogger } q""" From 8ff6402ac17664e020eadaa414416c97d8a24a41 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 18 Aug 2014 18:57:24 -0700 Subject: [PATCH 15/24] Add specific row. --- .../catalyst/expressions/SpecificRow.scala | 297 ++++++++++++++++++ .../spark/sql/parquet/ParquetConverter.scala | 8 +- 2 files changed, 301 insertions(+), 4 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala new file mode 100644 index 0000000000000..2b0f351654774 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.types._ + +/** + * + * + * +{{{ +val types = "Int,Float,Boolean,Double,Short,Long,Byte,Any".split(",") +types.map {tpe => +s""" +final class Mutable$tpe extends MutableValue { + var value: $tpe = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[$tpe] + } + def copy() = { + val newCopy = new Mutable$tpe + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +}""" +}.foreach(println) + +types.map { tpe => +s""" + override def set$tpe(ordinal: Int, value: $tpe): Unit = { + val currentValue = values(ordinal).asInstanceOf[Mutable$tpe] + currentValue.isNull = false + currentValue.value = value + } + + override def get$tpe(i: Int): $tpe = { + values(i).asInstanceOf[Mutable$tpe].value + }""" +}.foreach(println) +}}} + */ +abstract class MutableValue { + var isNull: Boolean = true + def boxed: Any + def update(v: Any) + def copy(): this.type +} + +final class MutableInt extends MutableValue { + var value: Int = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Int] + } + def copy() = { + val newCopy = new MutableInt + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableFloat extends MutableValue { + var value: Float = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Float] + } + def copy() = { + val newCopy = new MutableFloat + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableBoolean extends MutableValue { + var value: Boolean = false + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Boolean] + } + def copy() = { + val newCopy = new MutableBoolean + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableDouble extends MutableValue { + var value: Double = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Double] + } + def copy() = { + val newCopy = new MutableDouble + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableShort extends MutableValue { + var value: Short = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Short] + } + def copy() = { + val newCopy = new MutableShort + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableLong extends MutableValue { + var value: Long = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Long] + } + def copy() = { + val newCopy = new MutableLong + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableByte extends MutableValue { + var value: Byte = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Byte] + } + def copy() = { + val newCopy = new MutableByte + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +final class MutableAny extends MutableValue { + var value: Any = 0 + def boxed = if (isNull) null else value + def update(v: Any) = value = { + isNull = false + v.asInstanceOf[Any] + } + def copy() = { + val newCopy = new MutableAny + newCopy.isNull = isNull + newCopy.value = value + newCopy.asInstanceOf[this.type] + } +} + +class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { + + def this(dataTypes: Seq[DataType]) = + this( + dataTypes.map { + case IntegerType => new MutableInt + case ByteType => new MutableByte + case FloatType => new MutableFloat + case ShortType => new MutableShort + case LongType => new MutableLong + case _ => new MutableAny + }.toArray) + + override def length: Int = values.length + + override def setNullAt(i: Int): Unit = { + values(i).isNull = true + } + + override def apply(i: Int): Any = values(i).boxed + + override def isNullAt(i: Int): Boolean = values(i).isNull + + override def copy(): Row = { + val newValues = new Array[MutableValue](values.length) + var i = 0 + while (i < values.length) { + newValues(i) = values(i).copy() + i += 1 + } + new SpecificMutableRow(newValues) + } + + override def update(ordinal: Int, value: Any): Unit = values(ordinal).update(value) + + override def iterator: Iterator[Any] = values.map(_.boxed).iterator + + def setString(ordinal: Int, value: String) = update(ordinal, value) + + def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + + override def setInt(ordinal: Int, value: Int): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableInt] + currentValue.isNull = false + currentValue.value = value + } + + override def getInt(i: Int): Int = { + values(i).asInstanceOf[MutableInt].value + } + + override def setFloat(ordinal: Int, value: Float): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableFloat] + currentValue.isNull = false + currentValue.value = value + } + + override def getFloat(i: Int): Float = { + values(i).asInstanceOf[MutableFloat].value + } + + override def setBoolean(ordinal: Int, value: Boolean): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableBoolean] + currentValue.isNull = false + currentValue.value = value + } + + override def getBoolean(i: Int): Boolean = { + values(i).asInstanceOf[MutableBoolean].value + } + + override def setDouble(ordinal: Int, value: Double): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableDouble] + currentValue.isNull = false + currentValue.value = value + } + + override def getDouble(i: Int): Double = { + values(i).asInstanceOf[MutableDouble].value + } + + override def setShort(ordinal: Int, value: Short): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableShort] + currentValue.isNull = false + currentValue.value = value + } + + override def getShort(i: Int): Short = { + values(i).asInstanceOf[MutableShort].value + } + + override def setLong(ordinal: Int, value: Long): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableLong] + currentValue.isNull = false + currentValue.value = value + } + + override def getLong(i: Int): Long = { + values(i).asInstanceOf[MutableLong].value + } + + override def setByte(ordinal: Int, value: Byte): Unit = { + val currentValue = values(ordinal).asInstanceOf[MutableByte] + currentValue.isNull = false + currentValue.value = value + } + + override def getByte(i: Int): Byte = { + values(i).asInstanceOf[MutableByte].value + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 0a3b59cbc233a..ef4526ec03439 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -23,7 +23,7 @@ import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.expressions.{GenericRow, Row, Attribute} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.parquet.CatalystConverter.FieldType /** @@ -278,14 +278,14 @@ private[parquet] class CatalystGroupConverter( */ private[parquet] class CatalystPrimitiveRowConverter( protected[parquet] val schema: Array[FieldType], - protected[parquet] var current: ParquetRelation.RowType) + protected[parquet] var current: MutableRow) extends CatalystConverter { // This constructor is used for the root converter only def this(attributes: Array[Attribute]) = this( attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), - new ParquetRelation.RowType(attributes.length)) + new SpecificMutableRow(attributes.map(_.dataType))) protected [parquet] val converters: Array[Converter] = schema.zipWithIndex.map { @@ -299,7 +299,7 @@ private[parquet] class CatalystPrimitiveRowConverter( override val parent = null // Should be only called in root group converter! - override def getCurrentRecord: ParquetRelation.RowType = current + override def getCurrentRecord: Row = current override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) From c9e67dec15947c133fbe3813b71640f420315289 Mon Sep 17 00:00:00 2001 From: Gregory Owen Date: Mon, 18 Aug 2014 20:31:24 -0700 Subject: [PATCH 16/24] Made SpecificRow and types serializable by Kryo --- .../apache/spark/sql/catalyst/expressions/SpecificRow.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 2b0f351654774..cfb292a41dda7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -57,7 +57,7 @@ s""" }.foreach(println) }}} */ -abstract class MutableValue { +abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any def update(v: Any) @@ -197,6 +197,8 @@ class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { case _ => new MutableAny }.toArray) + def this() = this(Seq.empty) + override def length: Int = values.length override def setNullAt(i: Int): Unit = { From db44a30088821725561a719a2a656b6664e7e447 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 19 Aug 2014 00:24:38 -0700 Subject: [PATCH 17/24] JIT hax. --- .../sql/catalyst/expressions/Projection.scala | 332 ++++++++++++++++++ .../spark/sql/execution/Aggregate.scala | 2 +- .../sql/execution/GeneratedAggregate.scala | 2 +- .../apache/spark/sql/execution/joins.scala | 2 +- .../sql/parquet/ParquetTableOperations.scala | 2 +- 5 files changed, 336 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 968c5f1fe8f4e..a7c8d1a8273a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -153,3 +153,335 @@ class JoinedRow extends Row { s"[${row.mkString(",")}]" } } + +/** + * JIT HACK: Replace with macros + */ +class JoinedRow2 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} + +/** + * JIT HACK: Replace with macros + */ +class JoinedRow3 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} + +/** + * JIT HACK: Replace with macros + */ +class JoinedRow4 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} + +/** + * JIT HACK: Replace with macros + */ +class JoinedRow5 extends Row { + private[this] var row1: Row = _ + private[this] var row2: Row = _ + + def this(left: Row, right: Row) = { + this() + row1 = left + row2 = right + } + + /** Updates this JoinedRow to used point at two new base rows. Returns itself. */ + def apply(r1: Row, r2: Row): Row = { + row1 = r1 + row2 = r2 + this + } + + /** Updates this JoinedRow by updating its left base row. Returns itself. */ + def withLeft(newLeft: Row): Row = { + row1 = newLeft + this + } + + /** Updates this JoinedRow by updating its right base row. Returns itself. */ + def withRight(newRight: Row): Row = { + row2 = newRight + this + } + + def iterator = row1.iterator ++ row2.iterator + + def length = row1.length + row2.length + + def apply(i: Int) = + if (i < row1.size) row1(i) else row2(i - row1.size) + + def isNullAt(i: Int) = + if (i < row1.size) row1.isNullAt(i) else row2.isNullAt(i - row1.size) + + def getInt(i: Int): Int = + if (i < row1.size) row1.getInt(i) else row2.getInt(i - row1.size) + + def getLong(i: Int): Long = + if (i < row1.size) row1.getLong(i) else row2.getLong(i - row1.size) + + def getDouble(i: Int): Double = + if (i < row1.size) row1.getDouble(i) else row2.getDouble(i - row1.size) + + def getBoolean(i: Int): Boolean = + if (i < row1.size) row1.getBoolean(i) else row2.getBoolean(i - row1.size) + + def getShort(i: Int): Short = + if (i < row1.size) row1.getShort(i) else row2.getShort(i - row1.size) + + def getByte(i: Int): Byte = + if (i < row1.size) row1.getByte(i) else row2.getByte(i - row1.size) + + def getFloat(i: Int): Float = + if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } + + override def toString() = { + val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) + s"[${row.mkString(",")}]" + } +} \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 463a1d32d7fd7..be9f155253d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -175,7 +175,7 @@ case class Aggregate( private[this] val resultProjection = new InterpretedMutableProjection( resultExpressions, computedSchema ++ namedGroups.map(_._2)) - private[this] val joinedRow = new JoinedRow + private[this] val joinedRow = new JoinedRow4 override final def hasNext: Boolean = hashTableIter.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 54c9d033cd2ff..31ad5e8aabb0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -185,7 +185,7 @@ case class GeneratedAggregate( (namedGroups.map(_._2.toAttribute) ++ computationSchema).toSeq) log.info(s"Result Projection: ${resultExpressions.mkString(",")}") - val joinedRow = new JoinedRow + val joinedRow = new JoinedRow3 if (groupingExpressions.isEmpty) { // TODO: Codegening anything other than the updateProjection is probably over kill. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index b08f9aacc1fcb..2890a563bed48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -92,7 +92,7 @@ trait HashJoin { private[this] var currentMatchPosition: Int = -1 // Mutable per row objects. - private[this] val joinRow = new JoinedRow + private[this] val joinRow = new JoinedRow2 private[this] val joinKeys = streamSideKeyGenerator() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index f6cfab736d98a..a5a5d139a65cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -139,7 +139,7 @@ case class ParquetTableScan( partOutput.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow)) new Iterator[Row] { - private[this] val joinedRow = new JoinedRow(Row(partitionRowValues:_*), null) + private[this] val joinedRow = new JoinedRow5(Row(partitionRowValues:_*), null) def hasNext = iter.hasNext From 93d0f642aa3c246f10c8b665b846721c08b61bd3 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 19 Aug 2014 01:07:17 -0700 Subject: [PATCH 18/24] metastore concurrency fix. --- .../org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 3b371211e14cd..6571c35499ef4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -265,9 +265,9 @@ private[hive] case class MetastoreRelation // org.apache.hadoop.hive.ql.metadata.Partition will cause a NotSerializableException // which indicates the SerDe we used is not Serializable. - @transient lazy val hiveQlTable = new Table(table) + @transient val hiveQlTable = new Table(table) - def hiveQlPartitions = partitions.map { p => + @transient val hiveQlPartitions = partitions.map { p => new Partition(hiveQlTable, p) } From fdca8967f828a728ad9438673ae2a39b50585db7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 19 Aug 2014 10:36:54 -0700 Subject: [PATCH 19/24] cleanup --- .../apache/spark/sql/catalyst/expressions/SpecificRow.scala | 2 ++ .../scala/org/apache/spark/sql/parquet/ParquetTypes.scala | 4 +--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index cfb292a41dda7..5470fc9dfee52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -193,6 +193,8 @@ class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { case ByteType => new MutableByte case FloatType => new MutableFloat case ShortType => new MutableShort + case DoubleType => new MutableDouble + case BooleanType => new MutableBoolean case LongType => new MutableLong case _ => new MutableAny }.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 27d18f245b125..c79a9ac2dad81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -376,8 +376,6 @@ private[parquet] object ParquetTypesConverter extends Logging { } ParquetRelation.enableLogForwarding() - println(fs.listStatus(path).toSeq) - val children = fs.listStatus(path).filterNot { status => val name = status.getPath.getName name(0) == '.' || name == FileOutputCommitter.SUCCEEDED_FILE_NAME @@ -395,7 +393,7 @@ private[parquet] object ParquetTypesConverter extends Logging { .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) .map(ParquetFileReader.readFooter(conf, _)) .getOrElse( - throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path in ${children.map(_.getPath.getName).toSeq}")) + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) } /** From fae38f487400b01a9e0455917bf272080798d131 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 19 Aug 2014 13:39:39 -0700 Subject: [PATCH 20/24] Fix style --- .../apache/spark/sql/catalyst/expressions/Projection.scala | 2 +- .../apache/spark/sql/catalyst/expressions/SpecificRow.scala | 2 +- .../org/apache/spark/sql/execution/SparkSqlSerializer.scala | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index a7c8d1a8273a4..7af42295373b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -484,4 +484,4 @@ class JoinedRow5 extends Row { val row = (if (row1 != null) row1 else Seq[Any]()) ++ (if (row2 != null) row2 else Seq[Any]()) s"[${row.mkString(",")}]" } -} \ No newline at end of file +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 5470fc9dfee52..4a7a563cce5e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -298,4 +298,4 @@ class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } -} \ No newline at end of file +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index b18f1c8d32f1e..1ed6c340fc7d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -139,7 +139,11 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { val set = new OpenHashSet[Any](numItems + 1) var i = 0 while (i < numItems) { - val row = new GenericRow(rowSerializer.read(kryo, input, classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) + val row = + new GenericRow(rowSerializer.read( + kryo, + input, + classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) set.add(row) i += 1 } From c122cca102f789e020d2708fecd484d6b40df616 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 20 Aug 2014 15:03:39 -0700 Subject: [PATCH 21/24] Address comments, add tests --- .../sql/catalyst/expressions/Projection.scala | 6 ++ .../catalyst/expressions/SpecificRow.scala | 74 ++++++++++--------- .../sql/catalyst/expressions/arithmetic.scala | 15 +++- .../expressions/codegen/CodeGenerator.scala | 8 +- .../ExpressionEvaluationSuite.scala | 10 +++ .../spark/sql/hive/StatisticsSuite.scala | 2 +- 6 files changed, 78 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 7af42295373b5..ef1d12531f109 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -156,6 +156,12 @@ class JoinedRow extends Row { /** * JIT HACK: Replace with macros + * The `JoinedRow` class is used in many performance critical situation. Unfortunately, since there + * are multiple different types of `Rows` that could be stored as `row1` and `row2` most of the + * calls in the critical path are polymorphic. By creating special versions of this class that are + * used in only a single location of the code, we increase the chance that only a single type of + * Row will be referenced, increasing the opportunity for the JIT to play tricks. This sounds + * crazy but in benchmarks it had noticeable effects. */ class JoinedRow2 extends Row { private[this] var row1: Row = _ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 4a7a563cce5e8..75ea0e8459df8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -20,42 +20,43 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ /** + * A parent class for mutable container objects that are reused when the values are changed, + * resulting in less garbage. These values are held by a [[SpecificMutableRow]]. * + * The following code was roughly used to generate these objects: + * {{{ + * val types = "Int,Float,Boolean,Double,Short,Long,Byte,Any".split(",") + * types.map {tpe => + * s""" + * final class Mutable$tpe extends MutableValue { + * var value: $tpe = 0 + * def boxed = if (isNull) null else value + * def update(v: Any) = value = { + * isNull = false + * v.asInstanceOf[$tpe] + * } + * def copy() = { + * val newCopy = new Mutable$tpe + * newCopy.isNull = isNull + * newCopy.value = value + * newCopy.asInstanceOf[this.type] + * } + * }""" + * }.foreach(println) * + * types.map { tpe => + * s""" + * override def set$tpe(ordinal: Int, value: $tpe): Unit = { + * val currentValue = values(ordinal).asInstanceOf[Mutable$tpe] + * currentValue.isNull = false + * currentValue.value = value + * } * -{{{ -val types = "Int,Float,Boolean,Double,Short,Long,Byte,Any".split(",") -types.map {tpe => -s""" -final class Mutable$tpe extends MutableValue { - var value: $tpe = 0 - def boxed = if (isNull) null else value - def update(v: Any) = value = { - isNull = false - v.asInstanceOf[$tpe] - } - def copy() = { - val newCopy = new Mutable$tpe - newCopy.isNull = isNull - newCopy.value = value - newCopy.asInstanceOf[this.type] - } -}""" -}.foreach(println) - -types.map { tpe => -s""" - override def set$tpe(ordinal: Int, value: $tpe): Unit = { - val currentValue = values(ordinal).asInstanceOf[Mutable$tpe] - currentValue.isNull = false - currentValue.value = value - } - - override def get$tpe(i: Int): $tpe = { - values(i).asInstanceOf[Mutable$tpe].value - }""" -}.foreach(println) -}}} + * override def get$tpe(i: Int): $tpe = { + * values(i).asInstanceOf[Mutable$tpe].value + * }""" + * }.foreach(println) + * }}} */ abstract class MutableValue extends Serializable { var isNull: Boolean = true @@ -184,7 +185,12 @@ final class MutableAny extends MutableValue { } } -class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { +/** + * A row type that holds an array specialized container objects, of type [[MutableValue]], chosen + * based on the dataTypes of each column. The intent is to decrease garbage when modifying the + * values of primitive columns. + */ +final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableRow { def this(dataTypes: Seq[DataType]) = this( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 923a9b1445d6b..8d90614e4501a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -93,12 +93,25 @@ case class MaxOf(left: Expression, right: Expression) extends Expression { override def children = left :: right :: Nil - override def references = (left.flatMap(_.references) ++ right.flatMap(_.references)).toSet + override def references = left.references ++ right.references override def dataType = left.dataType override def eval(input: Row): Any = { val leftEval = left.eval(input) + val rightEval = right.eval(input) + if (leftEval == null) { + rightEval + } else if (rightEval == null) { + leftEval + } else { + val numeric = left.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] + if (numeric.compare(leftEval, rightEval) < 0) { + rightEval + } else { + leftEval + } + } } override def toString = s"MaxOf($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 7362d3d6b0e7c..f06b8c78a1be9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ +// These classes are here to avoid issues with serialization and integration with quasiquotes. class IntegerHashSet extends org.apache.spark.util.collection.OpenHashSet[Int] class LongHashSet extends org.apache.spark.util.collection.OpenHashSet[Long] @@ -53,6 +54,11 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin private val curId = new java.util.concurrent.atomic.AtomicInteger() private val javaSeparator = "$" + /** + * Can be flipped on manually in the console to add (expensive) expression evaluation trace code. + */ + var debugLogging = false + /** * Generates a class for a given input expression. Called when there is not cached code * already available. @@ -496,7 +502,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin // Only inject debugging code if debugging is turned on. val debugCode = - if (false) { + if (debugLogging) { val localLogger = log val localLoggerTree = reify { localLogger } q""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 999c9fff38d60..f1df817c41362 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -136,6 +136,16 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))) && In(Literal(2), Seq(Literal(1), Literal(2))), true) } + test("MaxOf") { + checkEvaluation(MaxOf(1, 2), 2) + checkEvaluation(MaxOf(2, 1), 2) + checkEvaluation(MaxOf(1L, 2L), 2L) + checkEvaluation(MaxOf(2L, 1L), 2L) + + checkEvaluation(MaxOf(Literal(null, IntegerType), 2), 2) + checkEvaluation(MaxOf(2, Literal(null, IntegerType)), 2) + } + test("LIKE literal Regular Expression") { checkEvaluation(Literal(null, StringType).like("a"), null) checkEvaluation(Literal("a", StringType).like(Literal(null, StringType)), null) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 206f514e9de00..2653871c2a7eb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -158,7 +158,7 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes } assert(sizes.size === 2 && sizes(0) <= autoBroadcastJoinThreshold, - s"query should contain two relations, each of which has size smaller than autoConvertSize instead ${rdd.queryExecution}") + s"query should contain two relations, each of which has size smaller than autoConvertSize") // Using `sparkPlan` because for relevant patterns in HashJoin to be // matched, other strategies need to be applied. From 32d216f9dde2fa8deef857927ce007a4e429400c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 20 Aug 2014 16:03:24 -0700 Subject: [PATCH 22/24] reynolds comments --- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 ++ .../apache/spark/sql/execution/SparkSqlSerializer.scala | 7 +++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f06b8c78a1be9..5a3f013c34579 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -539,6 +539,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin protected def hashSetForType(dt: DataType) = dt match { case IntegerType => typeOf[IntegerHashSet] case LongType => typeOf[LongHashSet] + case unsupportedType => + sys.error(s"Code generation not support for hashset of type $unsupportedType") } protected def primitiveForType(dt: DataType) = dt match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 1ed6c340fc7d7..077e6ebc5f11e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer -import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.util.collection.OpenHashSet - import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -31,6 +28,8 @@ import com.twitter.chill.{AllScalaRegistrar, ResourcePool} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} +import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.util.MutablePair import org.apache.spark.util.Utils @@ -47,7 +46,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co new HyperLogLogSerializer) kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) - // Specific hashsets must come first + // Specific hashsets must come first TODO: Move to core. kryo.register(classOf[IntegerHashSet], new IntegerHashSetSerializer) kryo.register(classOf[LongHashSet], new LongHashSetSerializer) kryo.register(classOf[org.apache.spark.util.collection.OpenHashSet[_]], From 8074a80b54cd788ae1552afa1af8a6a76cf226d7 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 20 Aug 2014 16:45:48 -0700 Subject: [PATCH 23/24] fix tests --- .../spark/sql/hive/execution/HiveSerDeSuite.scala | 11 ++++++++++- .../spark/sql/hive/execution/PruningSuite.scala | 8 +++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala index df9bae96494d5..8bc72384a64ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSerDeSuite.scala @@ -17,10 +17,19 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.hive.test.TestHive + /** * A set of tests that validates support for Hive SerDe. */ -class HiveSerDeSuite extends HiveComparisonTest { +class HiveSerDeSuite extends HiveComparisonTest with BeforeAndAfterAll { + + override def beforeAll() = { + TestHive.cacheTables = false + } + createQueryTest( "Read and write with LazySimpleSerDe (tab separated)", "SELECT * from serdeins") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 1a6dbc0ce0c0d..3804e09b943b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfter + import org.apache.spark.sql.hive.test.TestHive /* Implicit conversions */ @@ -25,10 +27,14 @@ import scala.collection.JavaConversions._ /** * A set of test cases that validate partition and column pruning. */ -class PruningSuite extends HiveComparisonTest { +class PruningSuite extends HiveComparisonTest with BeforeAndAfter { // MINOR HACK: You must run a query before calling reset the first time. TestHive.sql("SHOW TABLES") + override def beforeAll() = { + TestHive.cacheTables = false + } + // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset // the environment to ensure all referenced tables in this suites are not cached in-memory. // Refer to https://issues.apache.org/jira/browse/SPARK-2283 for details. From 5c7848d52070b639e23088972eb2a8316cddc54f Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 23 Aug 2014 12:40:05 -0700 Subject: [PATCH 24/24] turn off caching in the constructor --- .../org/apache/spark/sql/hive/StatisticsSuite.scala | 11 ++--------- .../spark/sql/hive/execution/PruningSuite.scala | 5 +---- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 2653871c2a7eb..8d6ca9939a730 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -29,15 +29,8 @@ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ class StatisticsSuite extends QueryTest with BeforeAndAfterAll { - - override def beforeAll() = { - // HACK: Cached tables do not currently preserve statistics... - TestHive.cacheTables = false - } - - override def afterAll() = { - TestHive.cacheTables = true - } + TestHive.reset() + TestHive.cacheTables = false test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 3804e09b943b4..8275e2d3bcce3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -30,10 +30,7 @@ import scala.collection.JavaConversions._ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { // MINOR HACK: You must run a query before calling reset the first time. TestHive.sql("SHOW TABLES") - - override def beforeAll() = { - TestHive.cacheTables = false - } + TestHive.cacheTables = false // Column/partition pruning is not implemented for `InMemoryColumnarTableScan` yet, need to reset // the environment to ensure all referenced tables in this suites are not cached in-memory.