From 8956204211d90e4c47d4f85d7ddc49d0be553d42 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 30 Nov 2015 14:44:20 -0800 Subject: [PATCH 01/14] consider nullable in codegen --- .../catalyst/expressions/BoundAttribute.scala | 15 ++- .../sql/catalyst/expressions/Expression.scala | 91 ++++++++++++------- .../expressions/aggregate/Count.scala | 17 +++- .../catalyst/expressions/aggregate/Sum.scala | 24 +++-- .../expressions/codegen/CodeGenerator.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 65 ++++++++----- 6 files changed, 141 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ff1f28ddbbf3..cbfe57032486 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -69,10 +69,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + if (nullable) { + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + ev.isNull = "false" + s""" + $javaType ${ev.value} = $value; + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 169435a10ea2..c25cfa55997c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -336,13 +336,21 @@ abstract class UnaryExpression extends Expression { f: String => String): String = { val eval = child.gen(ctx) val resultCode = f(eval.value) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { + if (nullable) { + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${eval.isNull}) { + $resultCode + } + """ + } else { + ev.isNull = "false" + eval.code + s""" + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode - } - """ + """ + } } } @@ -419,19 +427,30 @@ abstract class BinaryExpression extends Expression { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val resultCode = f(eval1.value, eval2.value) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; + if (nullable) { + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } } - } - """ + """ + + } else { + ev.isNull = "false" + s""" + ${eval1.code} + ${eval2.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } @@ -543,20 +562,30 @@ abstract class TernaryExpression extends Expression { f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) val resultCode = f(evals(0).value, evals(1).value, evals(2).value) - s""" - ${evals(0).code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode + if (nullable) { + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } } } - } - """ + """ + } else { + s""" + ${evals(0).code} + ${evals(1).code} + ${evals(2).code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 441f52ab5ca5..6433f347211e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -39,15 +39,24 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) - ) + override lazy val updateExpressions = { + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + ) + } + } override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override lazy val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = count override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index cfb042e0aa78..08a67ea3df51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -40,8 +40,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - // TODO: Remove this line once we remove the NullType from inputTypes. - case NullType => IntegerType case _ => child.dataType } @@ -57,18 +55,26 @@ case class Sum(child: Expression) extends DeclarativeAggregate { /* sum = */ Literal.create(null, sumDataType) ) - override lazy val updateExpressions: Seq[Expression] = Seq( - /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) - ) + override lazy val updateExpressions: Seq[Expression] = { + if (child.nullable) { + Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + } else { + Seq( + /* sum = */ + Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + ) + } + } override lazy val mergeExpressions: Seq[Expression] = { - val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ - Coalesce(Seq(add, sum.left)) + Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) ) } - override lazy val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 2f3d6aeb86c5..2dac1de67f0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -556,7 +556,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin def formatted = CodeFormatter.format(code) - logDebug({ + logError({ // Only add extra debugging info to byte code when we are going to print the source code. evaluator.setDebuggingInformation(true, true, false) formatted diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 40189f087776..a6ec242589fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,38 +44,55 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; - """ + if (e.nullable) { + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } else { + val value = s"value_$i" + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$value = ${evaluationCode.value}; + """ + } } val updates = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ + if (e.nullable) { + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + s""" + if (this.isNull_$i) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } else { + s""" + if (this.isNull_$i) { + mutableRow.setNullAt($i); + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } } else { s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; """ } + } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) From ddca7e23cb98db72e6a89e02feb449622e68b77e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 15 Dec 2015 14:20:01 -0800 Subject: [PATCH 02/14] remove GenerateProjection --- .../codegen/GenerateProjection.scala | 238 ------------------ .../expressions/CodeGenerationSuite.scala | 5 +- .../expressions/ExpressionEvalHelper.scala | 38 +-- .../expressions/MathFunctionsSuite.scala | 4 +- .../CodegenExpressionCachingSuite.scala | 18 -- .../sql/execution/local/ExpandNode.scala | 4 +- .../spark/sql/execution/local/LocalNode.scala | 18 -- 7 files changed, 7 insertions(+), 318 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala deleted file mode 100644 index f229f2000d8e..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.expressions.codegen - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ - -/** - * Java can not access Projection (in package object) - */ -abstract class BaseProjection extends Projection {} - -abstract class CodeGenMutableRow extends MutableRow with BaseGenericInternalRow - -/** - * Generates bytecode that produces a new [[InternalRow]] object based on a fixed set of input - * [[Expression Expressions]] and a given input [[InternalRow]]. The returned [[InternalRow]] - * object is custom generated based on the output types of the [[Expression]] to avoid boxing of - * primitive values. - */ -object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - - protected def canonicalize(in: Seq[Expression]): Seq[Expression] = - in.map(ExpressionCanonicalizer.execute) - - protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = - in.map(BindReferences.bindReference(_, inputSchema)) - - // Make Mutablility optional... - protected def create(expressions: Seq[Expression]): Projection = { - val ctx = newCodeGenContext() - val columns = expressions.zipWithIndex.map { - case (e, i) => - s"private ${ctx.javaType(e.dataType)} c$i = ${ctx.defaultValue(e.dataType)};\n" - }.mkString("\n") - - val initColumns = expressions.zipWithIndex.map { - case (e, i) => - val eval = e.gen(ctx) - s""" - { - // column$i - ${eval.code} - nullBits[$i] = ${eval.isNull}; - if (!${eval.isNull}) { - c$i = ${eval.value}; - } - } - """ - }.mkString("\n") - - val getCases = (0 until expressions.size).map { i => - s"case $i: return c$i;" - }.mkString("\n") - - val updateCases = expressions.zipWithIndex.map { case (e, i) => - s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" - }.mkString("\n") - - val specificAccessorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: return c$i;") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val getter = "get" + ctx.primitiveTypeName(jt) - s""" - public $jt $getter(int i) { - if (isNullAt(i)) { - return ${ctx.defaultValue(jt)}; - } - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i - + " in $getter"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val specificMutatorFunctions = ctx.primitiveTypes.map { jt => - val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == jt => - Some(s"case $i: { c$i = value; return; }") - case _ => None - }.mkString("\n") - if (cases.length > 0) { - val setter = "set" + ctx.primitiveTypeName(jt) - s""" - public void $setter(int i, $jt value) { - nullBits[i] = false; - switch (i) { - $cases - } - throw new IllegalArgumentException("Invalid index: " + i + - " in $setter}"); - }""" - } else { - "" - } - }.filter(_.length > 0).mkString("\n") - - val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = s"c$i" - val nonNull = e.dataType match { - case BooleanType => s"$col ? 0 : 1" - case ByteType | ShortType | IntegerType | DateType => s"$col" - case LongType | TimestampType => s"$col ^ ($col >>> 32)" - case FloatType => s"Float.floatToIntBits($col)" - case DoubleType => - s"(int)(Double.doubleToLongBits($col) ^ (Double.doubleToLongBits($col) >>> 32))" - case BinaryType => s"java.util.Arrays.hashCode($col)" - case _ => s"$col.hashCode()" - } - s"isNullAt($i) ? 0 : ($nonNull)" - } - - val hashUpdates: String = hashValues.map( v => - s""" - result *= 37; result += $v;""" - ).mkString("\n") - - val columnChecks = expressions.zipWithIndex.map { case (e, i) => - s""" - if (nullBits[$i] != row.nullBits[$i] || - (!nullBits[$i] && !(${ctx.genEqual(e.dataType, s"c$i", s"row.c$i")}))) { - return false; - } - """ - }.mkString("\n") - - val copyColumns = expressions.zipWithIndex.map { case (e, i) => - s"""if (!nullBits[$i]) arr[$i] = c$i;""" - }.mkString("\n") - - val code = s""" - public SpecificProjection generate($exprType[] expr) { - return new SpecificProjection(expr); - } - - class SpecificProjection extends ${classOf[BaseProjection].getName} { - private $exprType[] expressions; - ${declareMutableStates(ctx)} - ${declareAddedFunctions(ctx)} - - public SpecificProjection($exprType[] expr) { - expressions = expr; - ${initMutableStates(ctx)} - } - - public java.lang.Object apply(java.lang.Object r) { - // GenerateProjection does not work with UnsafeRows. - assert(!(r instanceof ${classOf[UnsafeRow].getName})); - return new SpecificRow((InternalRow) r); - } - - final class SpecificRow extends ${classOf[CodeGenMutableRow].getName} { - - $columns - - public SpecificRow(InternalRow ${ctx.INPUT_ROW}) { - $initColumns - } - - public int numFields() { return ${expressions.length};} - protected boolean[] nullBits = new boolean[${expressions.length}]; - public void setNullAt(int i) { nullBits[i] = true; } - public boolean isNullAt(int i) { return nullBits[i]; } - - public java.lang.Object genericGet(int i) { - if (isNullAt(i)) return null; - switch (i) { - $getCases - } - return null; - } - public void update(int i, java.lang.Object value) { - if (value == null) { - setNullAt(i); - return; - } - nullBits[i] = false; - switch (i) { - $updateCases - } - } - $specificAccessorFunctions - $specificMutatorFunctions - - public int hashCode() { - int result = 37; - $hashUpdates - return result; - } - - public boolean equals(java.lang.Object other) { - if (other instanceof SpecificRow) { - SpecificRow row = (SpecificRow) other; - $columnChecks - return true; - } - return super.equals(other); - } - - public InternalRow copy() { - java.lang.Object[] arr = new java.lang.Object[${expressions.length}]; - ${copyColumns} - return new ${classOf[GenericInternalRow].getName}(arr); - } - } - } - """ - - logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" + - CodeFormatter.format(code)) - - compile(code).generate(ctx.references.toArray).asInstanceOf[Projection] - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index cd2ef7dcd0cd..0c42e2fc7c5e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ @@ -38,7 +38,6 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val futures = (1 to 20).map { _ => future { GeneratePredicate.generate(EqualTo(Literal(1), Literal(1))) - GenerateProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateMutableProjection.generate(EqualTo(Literal(1), Literal(1)) :: Nil) GenerateOrdering.generate(Add(Literal(1), Literal(1)).asc :: Nil) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 465f7d08aa14..eeb5567679bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -120,42 +120,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } - protected def checkEvaluationWithGeneratedProjection( - expression: Expression, - expected: Any, - inputRow: InternalRow = EmptyRow): Unit = { - - val plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), - expression) - - val actual = plan(inputRow) - val expectedRow = InternalRow(expected) - - // We reimplement hashCode in generated `SpecificRow`, make sure it's consistent with our - // interpreted version. - if (actual.hashCode() != expectedRow.hashCode()) { - val ctx = new CodeGenContext - val evaluated = expression.gen(ctx) - fail( - s""" - |Mismatched hashCodes for values: $actual, $expectedRow - |Hash Codes: ${actual.hashCode()} != ${expectedRow.hashCode()} - |Expressions: $expression - |Code: $evaluated - """.stripMargin) - } - - if (actual != expectedRow) { - val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" - fail("Incorrect Evaluation in codegen mode: " + - s"$expression, actual: $actual, expected: $expectedRow$input") - } - if (actual.copy() != expectedRow) { - fail(s"Copy of generated Row is wrong: actual: ${actual.copy()}, expected: $expectedRow") - } - } - protected def checkEvalutionWithUnsafeProjection( expression: Expression, expected: Any, @@ -202,7 +166,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithOptimization(expression, expected) var plan = generateProject( - GenerateProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), expression) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 88ed9fdd6465..4ad65db0977c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import com.google.common.math.LongMath import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateProjection, GenerateMutableProjection} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 2d3f98dbbd3d..c9616cdb26c2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -34,12 +34,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance.apply(null).getBoolean(0) === false) } - test("GenerateProjection should initialize expressions") { - val expr = And(NondeterministicExpression(), NondeterministicExpression()) - val instance = GenerateProjection.generate(Seq(expr)) - assert(instance.apply(null).getBoolean(0) === false) - } - test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) val instance = GenerateMutableProjection.generate(Seq(expr))() @@ -64,18 +58,6 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection should not share expression instances") { - val expr1 = MutableExpression() - val instance1 = GenerateProjection.generate(Seq(expr1)) - assert(instance1.apply(null).getBoolean(0) === false) - - val expr2 = MutableExpression() - expr2.mutableState = true - val instance2 = GenerateProjection.generate(Seq(expr2)) - assert(instance1.apply(null).getBoolean(0) === false) - assert(instance2.apply(null).getBoolean(0) === true) - } - test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala index 2aff156d18b5..85111bd6d1c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ExpandNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Projection} +import org.apache.spark.sql.catalyst.expressions._ case class ExpandNode( conf: SQLConf, @@ -36,7 +36,7 @@ case class ExpandNode( override def open(): Unit = { child.open() - groups = projections.map(ee => newProjection(ee, child.output)).toArray + groups = projections.map(ee => newMutableProjection(ee, child.output)()).toArray idx = groups.length } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index d3381eac91d4..6a882c9234df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -103,24 +103,6 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin result } - protected def newProjection( - expressions: Seq[Expression], - inputSchema: Seq[Attribute]): Projection = { - log.debug( - s"Creating Projection: $expressions, inputSchema: $inputSchema") - try { - GenerateProjection.generate(expressions, inputSchema) - } catch { - case NonFatal(e) => - if (isTesting) { - throw e - } else { - log.error("Failed to generate projection, fallback to interpret", e) - new InterpretedProjection(expressions, inputSchema) - } - } - } - protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { From 830425111b12414284fc139808c66f2ffd8ded87 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 15 Dec 2015 14:57:34 -0800 Subject: [PATCH 03/14] fix build --- .../catalyst/expressions/codegen/GenerateSafeProjection.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index b7926bda3de1..13634b69457a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,10 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} import org.apache.spark.sql.types._ +/** + * Java can not access Projection (in package object) + */ +abstract class BaseProjection extends Projection {} /** * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update From e3cccda936a27bb6bfae1c28d2c0184ae75b7f6f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 15 Dec 2015 15:52:44 -0800 Subject: [PATCH 04/14] fix bug --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 1 + .../spark/sql/catalyst/expressions/stringExpressions.scala | 1 + 2 files changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index cb60d5958d53..df156a5a1156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -89,6 +89,7 @@ object Cast { private def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true + case (StringType, BooleanType) => true case (DoubleType, TimestampType) => true case (FloatType, TimestampType) => true case (StringType, DateType) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8770c4b76c2e..50c8b9d59847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -924,6 +924,7 @@ case class FormatNumber(x: Expression, d: Expression) override def left: Expression = x override def right: Expression = d override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) // Associated with the pattern, for the last d value, and we will update the From 2cca3a19245b84d8b3bb442c27ac51b9072214cc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 15 Dec 2015 15:53:31 -0800 Subject: [PATCH 05/14] fix test --- .../spark/sql/catalyst/expressions/ExpressionEvalHelper.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index eeb5567679bf..f869a96edb1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -43,7 +43,6 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expression, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expression, catalystValue, inputRow) - checkEvaluationWithGeneratedProjection(expression, catalystValue, inputRow) if (GenerateUnsafeProjection.canSupport(expression.dataType)) { checkEvalutionWithUnsafeProjection(expression, catalystValue, inputRow) } From 8aa8bb59c0ad0293f04f936902d7f04b4132492f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 15 Dec 2015 17:47:15 -0800 Subject: [PATCH 06/14] fix nullable --- .../spark/sql/catalyst/expressions/Cast.scala | 18 +++---------- .../sql/catalyst/expressions/Expression.scala | 6 ++--- .../expressions/codegen/CodegenFallback.scala | 27 ++++++++++++------- .../expressions/datetimeExpressions.scala | 4 +++ .../expressions/decimalExpressions.scala | 1 + .../expressions/jsonExpressions.scala | 15 ++++++----- .../expressions/mathExpressions.scala | 5 ++++ .../spark/sql/catalyst/expressions/misc.scala | 1 + .../sql/catalyst/expressions/CastSuite.scala | 10 +++---- .../apache/spark/sql/execution/commands.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 2 +- 11 files changed, 52 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index df156a5a1156..054f02d7cebe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -87,20 +87,10 @@ object Cast { private def resolvableNullability(from: Boolean, to: Boolean) = !from || to private def forceNullable(from: DataType, to: DataType) = (from, to) match { - case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (StringType, BooleanType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true - case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true - case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null - case _ => false + // TODO: add more case + case (_, StringType | BinaryType) => false + case (_: NumericType, FloatType | DoubleType) => false + case _ => true } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 3efe02206cb6..6a9c12127d36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -340,20 +340,19 @@ abstract class UnaryExpression extends Expression { ev: GeneratedExpressionCode, f: String => String): String = { val eval = child.gen(ctx) - val resultCode = f(eval.value) if (nullable) { eval.code + s""" boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${eval.isNull}) { - $resultCode + ${f(eval.value)} } """ } else { ev.isNull = "false" eval.code + s""" ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode + ${f(eval.value)} """ } } @@ -584,6 +583,7 @@ abstract class TernaryExpression extends Expression { } """ } else { + ev.isNull = "false" s""" ${evals(0).code} ${evals(1).code} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 26fb143d1e45..80c5e41baa92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -32,14 +32,23 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this.toCommentSafeString} */ - java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ + if (nullable) { + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } else { + ev.isNull = "false" + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03c39f8404e7..311540e33576 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -340,6 +340,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { Seq(TypeCollection(StringType, DateType, TimestampType), StringType) override def dataType: DataType = LongType + override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] @@ -455,6 +456,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) @@ -561,6 +563,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def nullSafeEval(start: Any, dayOfW: Any): Any = { val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) @@ -832,6 +835,7 @@ case class TruncDate(date: Expression, format: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def prettyName: String = "trunc" private lazy val truncLevel: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 78f6631e4647..c54bcdd77402 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -47,6 +47,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = true override def toString: String = s"MakeDecimal($child,$precision,$scale)" protected override def nullSafeEval(input: Any): Any = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 4991b9cb54e5..72b323587c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{StringWriter, ByteArrayOutputStream} +import java.io.{ByteArrayOutputStream, StringWriter} + +import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import scala.util.parsing.combinator.RegexParsers - private[this] sealed trait PathInstruction private[this] object PathInstruction { private[expressions] case object Subscript extends PathInstruction @@ -108,15 +109,17 @@ private[this] object SharedFactory { case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - import SharedFactory._ + import com.fasterxml.jackson.core.JsonToken._ + import PathInstruction._ + import SharedFactory._ import WriteStyle._ - import com.fasterxml.jackson.core.JsonToken._ override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def prettyName: String = "get_json_object" @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 28f616fbb9ca..9c1a3294def2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -75,6 +75,8 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) abstract class UnaryLogExpression(f: Double => Double, name: String) extends UnaryMathExpression(f, name) { + override def nullable: Boolean = true + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity protected val yAsymptote: Double = 0.0 @@ -194,6 +196,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { NumberConverter.convert( @@ -621,6 +624,8 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + override def nullable: Boolean = true + protected override def nullSafeEval(input1: Any, input2: Any): Any = { val dLeft = input1.asInstanceOf[Double] val dRight = input2.asInstanceOf[Double] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0f6d02f2e00c..5baab4f7e8c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -57,6 +57,7 @@ case class Sha2(left: Expression, right: Expression) extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a98e16c25321..c99a4ac9645a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -297,7 +297,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from string") { assert(cast("abcdef", StringType).nullable === false) assert(cast("abcdef", BinaryType).nullable === false) - assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === true) assert(cast("abcdef", TimestampType).nullable === true) assert(cast("abcdef", LongType).nullable === true) assert(cast("abcdef", IntegerType).nullable === true) @@ -547,7 +547,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Seq(null, true, false)) } @@ -606,7 +606,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { @@ -713,7 +713,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("a", BooleanType, nullable = true), StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, InternalRow(null, true, false)) } @@ -754,7 +754,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructType(Seq( StructField("l", LongType, nullable = true))))))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Row( Seq(123, null, null), Map("a" -> null, "b" -> true, "c" -> false), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 24a79f289aa8..e2dc13d66c61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -232,7 +232,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = - Seq(AttributeReference("plan", StringType, nullable = false)()), + Seq(AttributeReference("plan", StringType, nullable = true)()), extended: Boolean = false) extends RunnableCommand { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e7deeff13dc4..e759c011e75d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -42,7 +42,7 @@ case class DescribeCommand( new MetadataBuilder().putString("comment", "name of the column").build())(), AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = false, + AttributeReference("comment", StringType, nullable = true, new MetadataBuilder().putString("comment", "comment of the column").build())() ) } From f16006942e1d3b6e33a3befabc486ea5890a5f0f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 16 Dec 2015 10:21:11 -0800 Subject: [PATCH 07/14] fix nullability --- .../aggregate/CentralMomentAgg.scala | 2 +- .../catalyst/expressions/aggregate/Corr.scala | 3 +- .../expressions/aggregate/Count.scala | 2 +- .../expressions/codegen/CodeGenerator.scala | 2 +- .../codegen/GenerateUnsafeProjection.scala | 21 +++-- .../expressions/complexTypeExtractors.scala | 4 +- .../expressions/namedExpressions.scala | 2 +- .../org/apache/spark/sql/DataFrame.scala | 12 ++- .../sql/execution/joins/HashOuterJoin.scala | 80 +------------------ .../execution/joins/SortMergeOuterJoin.scala | 2 +- .../apache/spark/sql/DataFrameJoinSuite.scala | 6 +- 11 files changed, 36 insertions(+), 100 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index d07d4c338cdf..30f602227b17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 00d7436b710d..d25f3335ffd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -42,7 +41,7 @@ case class Corr( override def children: Seq[Expression] = Seq(left, right) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 6433f347211e..663c69e799fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -31,7 +31,7 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) - private lazy val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b57732b24428..440c7d2fc115 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -540,7 +540,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin def formatted = CodeFormatter.format(code) - logError({ + logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. evaluator.setDebuggingInformation(true, true, false) formatted diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 68005afb21d2..c1defe12b0b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -135,14 +135,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$rowWriter.write($index, ${input.value});" } - s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { + if (input.isNull == "false") { + s""" + ${input.code} ${writeField.trim} - } - """ + """ + } else { + s""" + ${input.code} + if (${input.isNull}) { + ${setNull.trim} + } else { + ${writeField.trim} + } + """ + } } s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 10ce10aaf6da..451d4a295b93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -107,7 +107,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) override def dataType: DataType = field.dataType - override def nullable: Boolean = child.nullable || field.nullable + override def nullable: Boolean = true override def toString: String = s"$child.${name.getOrElse(field.name)}" protected override def nullSafeEval(input: Any): Any = @@ -139,7 +139,7 @@ case class GetArrayStructFields( containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable || containsNull || field.nullable + override def nullable: Boolean = child.nullable || field.nullable override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 26b6aca79971..013578ea08c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -262,7 +262,7 @@ case class AttributeReference( } } - override def toString: String = s"$name#${exprId.id}$typeSuffix" + override def toString: String = s"$name#${exprId.id}$typeSuffix ${nullable}" // Since the expression id is not in the first constructor it is missing from the default // tree string. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 497bd4826677..a8dce2bc7393 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -499,10 +499,16 @@ class DataFrame private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, joinType = Inner, None)).analyzed.asInstanceOf[Join] + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + .analyzed.asInstanceOf[Join] - // Project only one of the join columns. - val joinedCols = usingColumns.map(col => withPlan(joined.right).resolve(col)) + // Project only one of the join columns for Inner join, outer join should have all the columns + // from both sides. + val joinedCols = if (JoinType(joinType) == Inner) { + usingColumns.map(col => withPlan(joined.right).resolve(col)) + } else { + Nil + } val condition = usingColumns.map { col => catalyst.expressions.EqualTo( withPlan(joined.left).resolve(col), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index ed626fef56af..c6e586818751 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -75,7 +75,7 @@ trait HashOuterJoin { UnsafeProjection.create(streamedKeys, streamedPlan.output) protected[this] def resultProjection: InternalRow => InternalRow = - UnsafeProjection.create(self.schema) + UnsafeProjection.create(output, output) @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -151,82 +151,4 @@ trait HashOuterJoin { } ret.iterator } - - protected[this] def fullOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - rightIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[InternalRow] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if boundCondition(joinedRow.withRight(r)) => - numOutputRows += 1 - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - resultProjection(joinedRow) - - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if !rightMatchedSet.contains(idx) => - numOutputRows += 1 - resultProjection(joinedRow(leftNullRow, r)) - } - } else { - leftIter.iterator.map[InternalRow] { l => - numOutputRows += 1 - resultProjection(joinedRow(l, rightNullRow)) - } ++ rightIter.iterator.map[InternalRow] { r => - numOutputRows += 1 - resultProjection(joinedRow(leftNullRow, r)) - } - } - } - - // This is only used by FullOuter - protected[this] def buildHashTable( - iter: Iterator[InternalRow], - numIterRows: LongSQLMetric, - keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = { - val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]() - while (iter.hasNext) { - val currentRow = iter.next() - numIterRows += 1 - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[InternalRow]() - hashTable.put(rowKey.copy(), existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index efaa69c1d322..7ce38ebdb341 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -114,7 +114,7 @@ case class SortMergeOuterJoin( (r: InternalRow) => true } } - val resultProj: InternalRow => InternalRow = UnsafeProjection.create(schema) + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { case LeftOuter => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index c70397f9853a..cf70056ef2fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -48,11 +48,13 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { checkAnswer( df.join(df2, Seq("int", "str"), "left"), - Row(1, 2, "1", null) :: Row(2, 3, "2", null) :: Row(3, 4, "3", null) :: Nil) + Row(1, 2, "1", null, null, null) :: Row(2, 3, "2", null, null, null) :: + Row(3, 4, "3", null, null, null) :: Nil) checkAnswer( df.join(df2, Seq("int", "str"), "right"), - Row(null, null, null, 2) :: Row(null, null, null, 3) :: Row(null, null, null, 4) :: Nil) + Row(null, null, null, 1, 2, "2") :: Row(null, null, null, 2, 3, "3") :: + Row(null, null, null, 3, 4, "4") :: Nil) } test("join - join using self join") { From 88b21072b7e645075a31546a06edd8d5ea4d5176 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 16 Dec 2015 11:11:36 -0800 Subject: [PATCH 08/14] fix bug --- .../spark/sql/catalyst/expressions/Cast.scala | 20 +++++++++++++++---- .../expressions/complexTypeExtractors.scala | 3 +-- .../expressions/namedExpressions.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 1 - 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 054f02d7cebe..d138e832f9e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -87,10 +87,22 @@ object Cast { private def resolvableNullability(from: Boolean, to: Boolean) = !from || to private def forceNullable(from: DataType, to: DataType) = (from, to) match { - // TODO: add more case - case (_, StringType | BinaryType) => false - case (_: NumericType, FloatType | DoubleType) => false - case _ => true + case (NullType, _) => true + case (_, _) if from == to => false + + case (StringType, BinaryType) => false + case (StringType, _) => true + + case (FloatType | DoubleType, TimestampType) => true + case (TimestampType, DateType) => false + case (_, DateType) => true + case (DateType, TimestampType) => false + case (DateType, _) => true + case (_, CalendarIntervalType) => true + + case (_, _: DecimalType) => true // overflow + case (_: FractionalType, _: IntegralType) => true // NaN, infinity + case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 451d4a295b93..454103064742 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -107,7 +107,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] private lazy val field = child.dataType.asInstanceOf[StructType](ordinal) override def dataType: DataType = field.dataType - override def nullable: Boolean = true + override def nullable: Boolean = child.nullable || field.nullable override def toString: String = s"$child.${name.getOrElse(field.name)}" protected override def nullSafeEval(input: Any): Any = @@ -139,7 +139,6 @@ case class GetArrayStructFields( containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable || field.nullable override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 013578ea08c4..26b6aca79971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -262,7 +262,7 @@ case class AttributeReference( } } - override def toString: String = s"$name#${exprId.id}$typeSuffix ${nullable}" + override def toString: String = s"$name#${exprId.id}$typeSuffix" // Since the expression id is not in the first constructor it is missing from the default // tree string. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 62fd47234b33..8ccf4b823489 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ From e41835804b818724f8c28c12e9606b4d4052fe37 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 16 Dec 2015 11:48:41 -0800 Subject: [PATCH 09/14] fix GetStructField --- .../expressions/complexTypeExtractors.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 454103064742..358c33748355 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -115,13 +115,19 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { + if (nullable) { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + } else { + s""" ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ + """ + } }) } } From 9adad172f287d9b54187f6699d59adae4078c56a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 08:41:34 -0800 Subject: [PATCH 10/14] fix style --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index b4c89721e0ff..fc718f6510ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -455,7 +455,7 @@ class DataFrame private[sql]( val joined = sqlContext.executePlan( Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) .analyzed.asInstanceOf[Join] - + val condition = usingColumns.map { col => catalyst.expressions.EqualTo( withPlan(joined.left).resolve(col), From 55bb6a9bb435a51fd86b2369a8e28257287dbb2f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 22:06:03 -0800 Subject: [PATCH 11/14] fix nullable of lead/lag --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 4 ++-- .../sql/catalyst/expressions/namedExpressions.scala | 5 +++-- .../sql/catalyst/expressions/windowExpressions.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/Window.scala | 9 +++++---- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index cbfe57032486..b6277eca6b4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal, $dataType]" + override def toString: String = s"input[$ordinal, $dataType]${if (nullable) "?" else ""}" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { @@ -99,7 +99,7 @@ object BindReferences extends Logging { sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } } else { - BoundReference(ordinal, a.dataType, a.nullable) + BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 26b6aca79971..646273c0329f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -81,9 +81,10 @@ trait NamedExpression extends Expression { protected def typeSuffix = if (resolved) { + val suffix = if (nullable) "?" else "" dataType match { - case LongType => "L" - case _ => "" + case LongType => "L" + suffix + case _ => suffix } } else { "" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 06252ac4fc61..91f169e7eac4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -329,7 +329,7 @@ abstract class OffsetWindowFunction */ override def foldable: Boolean = input.foldable && (default == null || default.foldable) - override def nullable: Boolean = input.nullable && (default == null || default.nullable) + override def nullable: Boolean = default == null || default.nullable override lazy val frame = { // This will be triggered by the Analyzer. @@ -381,7 +381,7 @@ abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowF self: Product => override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) override def dataType: DataType = IntegerType - override def nullable: Boolean = false + override def nullable: Boolean = true override def supportsPartial: Boolean = false override lazy val mergeExpressions = throw new UnsupportedOperationException("Window Functions do not support merging.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index fc718f6510ac..965eaa9efec4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -479,8 +479,8 @@ class DataFrame private[sql]( } // The nullability of output of joined could be different than original column, // so we can only compare them by exprId - val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil) - val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId)) + val joinRefs = AttributeSet(condition.toSeq.flatMap(_.references)) + val resultCols = joinedCols ++ joined.output.filterNot(joinRefs.contains(_)) withPlan { Project( resultCols, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 9852b6e7beeb..c941d673c724 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -440,16 +440,17 @@ private[execution] final class OffsetWindowFunctionFrame( /** Create the projection. */ private[this] val projection = { // Collect the expressions and bind them. - val numInputAttributes = inputSchema.size + val inputAttrs = inputSchema.map(_.withNullability(true)) + val numInputAttributes = inputAttrs.size val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { case e: OffsetWindowFunction => - val input = BindReferences.bindReference(e.input, inputSchema) + val input = BindReferences.bindReference(e.input, inputAttrs) if (e.default == null || e.default.foldable && e.default.eval() == null) { // Without default value. input } else { // With default value. - val default = BindReferences.bindReference(e.default, inputSchema).transform { + val default = BindReferences.bindReference(e.default, inputAttrs).transform { // Shift the input reference to its default version. case BoundReference(o, dataType, nullable) => BoundReference(o + numInputAttributes, dataType, nullable) @@ -457,7 +458,7 @@ private[execution] final class OffsetWindowFunctionFrame( org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) } case e => - BindReferences.bindReference(e, inputSchema) + BindReferences.bindReference(e, inputAttrs) } // Create the projection. From 9f7d7634b03f457d14d06363f855e3de0ae55980 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 22:11:08 -0800 Subject: [PATCH 12/14] cast date to string --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d138e832f9e0..b18f49f3203f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -92,6 +92,7 @@ object Cast { case (StringType, BinaryType) => false case (StringType, _) => true + case (_, StringType) => false case (FloatType | DoubleType, TimestampType) => true case (TimestampType, DateType) => false From 3b1e42ffb3808c7d830cd998ba0132404641e65c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 17 Dec 2015 23:04:15 -0800 Subject: [PATCH 13/14] remove suffix --- .../spark/sql/catalyst/expressions/BoundAttribute.scala | 2 +- .../spark/sql/catalyst/expressions/namedExpressions.scala | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index b6277eca6b4e..7293d5d4472a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) extends LeafExpression with NamedExpression { - override def toString: String = s"input[$ordinal, $dataType]${if (nullable) "?" else ""}" + override def toString: String = s"input[$ordinal, $dataType]" // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 646273c0329f..26b6aca79971 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -81,10 +81,9 @@ trait NamedExpression extends Expression { protected def typeSuffix = if (resolved) { - val suffix = if (nullable) "?" else "" dataType match { - case LongType => "L" + suffix - case _ => suffix + case LongType => "L" + case _ => "" } } else { "" From 3cc4cdf3f32e84e9c3c78377c37b5e6b7eeadba1 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 18 Dec 2015 00:50:09 -0800 Subject: [PATCH 14/14] fix nullability of Expand --- .../sql/catalyst/plans/logical/basicOperators.scala | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5665fd7e5f41..ec42b763f18e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -293,7 +293,14 @@ private[sql] object Expand { Literal.create(bitmask, IntegerType) }) } - Expand(projections, child.output :+ gid, child) + val output = child.output.map { attr => + if (groupByExprs.exists(_.semanticEquals(attr))) { + attr.withNullability(true) + } else { + attr + } + } + Expand(projections, output :+ gid, child) } }