Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.security.{MessageDigest, NoSuchAlgorithmException}
import java.util.zip.CRC32

import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.InternalRow

import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -160,3 +161,22 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp
})
}
}

/** An expression that returns the hashCode of the input row. */
case object RowHashCode extends LeafExpression {
override def dataType: DataType = IntegerType

/** hashCode will never be null. */
override def nullable: Boolean = false

override def eval(input: InternalRow): Any = {
input.hashCode
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
s"""
boolean ${ev.isNull} = false;
${ctx.javaType(dataType)} ${ev.primitive} = i.hashCode();
"""
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,58 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
}
}

/**
* A predicate that is evaluated to be true if there are at least `n` null values.
*/
case class AtLeastNNulls(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
override def toString: String = s"AtLeastNNulls($n, ${children.mkString(",")})"

private[this] val childrenArray = children.toArray

override def eval(input: InternalRow): Boolean = {
var numNulls = 0
var i = 0
while (i < childrenArray.length && numNulls < n) {
val evalC = childrenArray(i).eval(input)
if (evalC == null) {
numNulls += 1
}
i += 1
}
numNulls >= n
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val numNulls = ctx.freshName("numNulls")
val code = children.map { e =>
val eval = e.gen(ctx)
s"""
if ($numNulls < $n) {
${eval.code}
if (${eval.isNull}) {
$numNulls += 1;
}
}
"""
}.mkString("\n")
s"""
int $numNulls = 0;
$code
boolean ${ev.isNull} = false;
boolean ${ev.primitive} = $numNulls >= $n;
"""
}
}

/**
* A predicate that is evaluated to be true if there are at least `n` non-null and non-NaN values.
*/
case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate {
case class AtLeastNNonNullNans(n: Int, children: Seq[Expression]) extends Predicate {
override def nullable: Boolean = false
override def foldable: Boolean = children.forall(_.foldable)
override def toString: String = s"AtLeastNNulls(n, ${children.mkString(",")})"
override def toString: String = s"AtLeastNNonNullNans($n, ${children.mkString(",")})"

private[this] val childrenArray = children.toArray

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,14 @@ import org.apache.spark.sql.types._

abstract class Optimizer extends RuleExecutor[LogicalPlan]

object DefaultOptimizer extends Optimizer {
val batches =
class DefaultOptimizer extends Optimizer {

/**
* Override to provide additional rules for the "Operator Optimizations" batch.
*/
val extendedOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil

lazy val batches =
// SubQueries are only needed for analysis and can be removed before execution.
Batch("Remove SubQueries", FixedPoint(100),
EliminateSubQueries) ::
Expand All @@ -41,26 +47,27 @@ object DefaultOptimizer extends Optimizer {
RemoveLiteralFromGroupExpressions) ::
Batch("Operator Optimizations", FixedPoint(100),
// Operator push down
SetOperationPushDown,
SamplePushDown,
PushPredicateThroughJoin,
PushPredicateThroughProject,
PushPredicateThroughGenerate,
ColumnPruning,
SetOperationPushDown ::
SamplePushDown ::
PushPredicateThroughJoin ::
PushPredicateThroughProject ::
PushPredicateThroughGenerate ::
ColumnPruning ::
// Operator combine
ProjectCollapsing,
CombineFilters,
CombineLimits,
ProjectCollapsing ::
CombineFilters ::
CombineLimits ::
// Constant folding
NullPropagation,
OptimizeIn,
ConstantFolding,
LikeSimplification,
BooleanSimplification,
RemovePositive,
SimplifyFilters,
SimplifyCasts,
SimplifyCaseConversionExpressions) ::
NullPropagation ::
OptimizeIn ::
ConstantFolding ::
LikeSimplification ::
BooleanSimplification ::
RemovePositive ::
SimplifyFilters ::
SimplifyCasts ::
SimplifyCaseConversionExpressions ::
extendedOperatorOptimizationRules.toList : _*) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
Batch("LocalRelation", FixedPoint(100),
Expand Down Expand Up @@ -222,12 +229,18 @@ object ColumnPruning extends Rule[LogicalPlan] {
}

/** Applies a projection only when the child is producing unnecessary attributes */
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) = {
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
Project(allReferences.filter(c.outputSet.contains).toSeq, c)
// We need to preserve the nullability of c's output.
// So, we first create a outputMap and if a reference is from the output of
// c, we use that output attribute from c.
val outputMap = AttributeMap(c.output.map(attr => (attr, attr)))
val projectList = allReferences.filter(outputMap.contains).map(outputMap).toSeq
Project(projectList, c)
} else {
c
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,37 @@ case class Generate(
}

case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
/**
* Indicates if `atLeastNNulls` is used to check if atLeastNNulls.children
* have at least one null value and atLeastNNulls.children are all attributes.
*/
private def isAtLeastOneNullOutputAttributes(atLeastNNulls: AtLeastNNulls): Boolean = {
val expressions = atLeastNNulls.children
val n = atLeastNNulls.n
if (n != 1) {
// AtLeastNNulls is not used to check if atLeastNNulls.children have
// at least one null value.
false
} else {
// AtLeastNNulls is used to check if atLeastNNulls.children have
// at least one null value. We need to make sure all atLeastNNulls.children
// are attributes.
expressions.forall(_.isInstanceOf[Attribute])
}
}

override def output: Seq[Attribute] = condition match {
case Not(a: AtLeastNNulls) if isAtLeastOneNullOutputAttributes(a) =>
// The condition is used to make sure that there is no null value in
// a.children.
val nonNullableAttributes = AttributeSet(a.children.asInstanceOf[Seq[Attribute]])
child.output.map {
case attr if nonNullableAttributes.contains(attr) =>
attr.withNullability(false)
case attr => attr
}
case _ => child.output
}
}

case class Union(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
Expand Down
Loading