From 3f1e6a1c029709032b06236c2a8e151fe61b53b8 Mon Sep 17 00:00:00 2001 From: ptkool Date: Thu, 20 Apr 2017 11:27:41 -0400 Subject: [PATCH 1/2] Add new query hint NO_COLLAPSE. --- python/pyspark/sql/functions.py | 8 ++++ .../plans/logical/basicLogicalOperators.scala | 7 +++ .../optimizer/CollapseProjectSuite.scala | 12 +++++- .../sql/catalyst/parser/PlanParserSuite.scala | 5 +++ .../spark/sql/execution/SparkStrategies.scala | 1 + .../org/apache/spark/sql/functions.scala | 43 +++++++++++++------ 6 files changed, 61 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 843ae3816f061..a5d38abb4be79 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -466,6 +466,14 @@ def nanvl(col1, col2): return Column(sc._jvm.functions.nanvl(_to_java_column(col1), _to_java_column(col2))) +@since(2.2) +def no_collapse(df): + """Marks a DataFrame as small enough for use in broadcast joins.""" + + sc = SparkContext._active_spark_context + return DataFrame(sc._jvm.functions.no_collapse(df._jdf), df.sql_ctx) + + @since(1.4) def rand(seed=None): """Generates a random column with independent and identically distributed (i.i.d.) samples diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba851..656b0bba1865a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -386,6 +386,13 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode { child.stats(conf).copy(isBroadcastable = true) } +/** + * A hint for the optimizer that we should not merge two projections. + */ +case class NoCollapseHint(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + /** * A general hint for the child. This node will be eliminated post analysis. * A pair of (name, parameters). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 587437e9aa81d..d639f34a6eae8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Rand import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoCollapseHint} import org.apache.spark.sql.catalyst.rules.RuleExecutor class CollapseProjectSuite extends PlanTest { @@ -119,4 +119,14 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("do not collapse projects with onceOnly expressions") { + val query = NoCollapseHint(testRelation.select(('a * 10).as('a_times_10))) + .select(('a_times_10 + 1).as('a_times_10_plus_1), ('a_times_10 + 2).as('a_times_10_plus_2)) + + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411777d6e85a2..51443f1b082cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -537,5 +537,10 @@ class PlanParserSuite extends PlanTest { comparePlans( parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) + + comparePlans( + parsePlan("SELECT a FROM (SELECT /*+ NO_COLLAPSE */ * FROM t) t1"), + SubqueryAlias("t1", Hint("NO_COLLAPSE", Seq.empty, table("t").select(star()))) + .select('a)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ca2f6dd7a84b2..8a7b4fc88ade4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -433,6 +433,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil case BroadcastHint(child) => planLater(child) :: Nil + case NoCollapseHint(child) => planLater(child) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f07e04368389f..c1eb6ef6a3f3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -19,17 +19,16 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.{typeTag, TypeTag} +import scala.reflect.runtime.universe.{TypeTag, typeTag} import scala.util.Try import scala.util.control.NonFatal - import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, NoCollapseHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.internal.SQLConf @@ -1007,21 +1006,37 @@ object functions { def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } /** - * Marks a DataFrame as small enough for use in broadcast joins. - * - * The following example marks the right DataFrame for broadcast hash join using `joinKey`. - * {{{ - * // left and right are DataFrames - * left.join(broadcast(right), "joinKey") - * }}} - * - * @group normal_funcs - * @since 1.5.0 - */ + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. + * {{{ + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") + * }}} + * + * @group normal_funcs + * @since 1.5.0 + */ def broadcast[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) } + /** + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. + * {{{ + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") + * }}} + * + * @group normal_funcs + * @since 1.5.0 + */ + def no_collapse[T](df: Dataset[T]): Dataset[T] = { + Dataset[T](df.sparkSession, NoCollapseHint(df.logicalPlan))(df.exprEnc) + } + /** * Returns the first column that is not null, or null if all inputs are null. * From 3986247b2721e058257ba92c0e418dcc42318c13 Mon Sep 17 00:00:00 2001 From: ptkool Date: Thu, 20 Apr 2017 14:16:33 -0400 Subject: [PATCH 2/2] Resolve scalastyle errors. --- python/pyspark/sql/functions.py | 2 +- .../sql/catalyst/parser/PlanParserSuite.scala | 5 -- .../org/apache/spark/sql/functions.scala | 47 ++++++++++--------- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a5d38abb4be79..a83d06ad0821e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -468,7 +468,7 @@ def nanvl(col1, col2): @since(2.2) def no_collapse(df): - """Marks a DataFrame as small enough for use in broadcast joins.""" + """Marks a DataFrame as non-collapsible.""" sc = SparkContext._active_spark_context return DataFrame(sc._jvm.functions.no_collapse(df._jdf), df.sql_ctx) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 51443f1b082cf..411777d6e85a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -537,10 +537,5 @@ class PlanParserSuite extends PlanTest { comparePlans( parsePlan("SELECT /*+ MAPJOIN(t) */ a from t where true group by a order by a"), Hint("MAPJOIN", Seq("t"), table("t").where(Literal(true)).groupBy('a)('a)).orderBy('a.asc)) - - comparePlans( - parsePlan("SELECT a FROM (SELECT /*+ NO_COLLAPSE */ * FROM t) t1"), - SubqueryAlias("t1", Hint("NO_COLLAPSE", Seq.empty, table("t").select(star()))) - .select('a)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c1eb6ef6a3f3f..589be14a500f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import scala.util.control.NonFatal + import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} @@ -1006,33 +1007,33 @@ object functions { def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) } /** - * Marks a DataFrame as small enough for use in broadcast joins. - * - * The following example marks the right DataFrame for broadcast hash join using `joinKey`. - * {{{ - * // left and right are DataFrames - * left.join(broadcast(right), "joinKey") - * }}} - * - * @group normal_funcs - * @since 1.5.0 - */ + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. + * {{{ + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") + * }}} + * + * @group normal_funcs + * @since 1.5.0 + */ def broadcast[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) } /** - * Marks a DataFrame as small enough for use in broadcast joins. - * - * The following example marks the right DataFrame for broadcast hash join using `joinKey`. - * {{{ - * // left and right are DataFrames - * left.join(broadcast(right), "joinKey") - * }}} - * - * @group normal_funcs - * @since 1.5.0 - */ + * Marks a DataFrame as non-collapsible. + * + * For example: + * {{{ + * df1 = no_collapse(df.select((df.col("qty") * lit(10).alias("c1"))) + * df2 = df1.select(col("c1") + lit(1)), col("c1") + lit(2))) + * }}} + * + * @group normal_funcs + * @since 2.2.0 + */ def no_collapse[T](df: Dataset[T]): Dataset[T] = { Dataset[T](df.sparkSession, NoCollapseHint(df.logicalPlan))(df.exprEnc) }