From e6c5c84ea189b1fdbac8408594c880950b6b7398 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 15 Oct 2018 23:23:41 -0700 Subject: [PATCH 1/7] tests --- .../ExpressionTypeCheckingSuite.scala | 2 + .../resources/sql-tests/inputs/group-by.sql | 66 ++++++ .../sql-tests/results/group-by.sql.out | 214 +++++++++++++++++- .../spark/sql/DataFrameAggregateSuite.scala | 64 ++++++ 4 files changed, 345 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..b8643d0243f59 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,6 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(new EveryAgg('booleanField)) + assertSuccess(new AnyAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 433db71527437..92d73561d27ec 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -80,3 +80,69 @@ SELECT 1 FROM range(10) HAVING true; SELECT 1 FROM range(10) HAVING MAX(id) > 0; SELECT id FROM range(10) HAVING id > 0; + +-- Test data +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v); + +-- empty table +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0; + +-- all null values +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4; + +-- aggregates are null Filtering +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5; + +-- group by +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + +-- having +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false; +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- basic subquery path to make sure rewrite happens in both parent and child plans. +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k; + +-- input type checking Int +SELECT every(1); + +-- input type checking Short +SELECT some(1S); + +-- input type checking Long +SELECT any(1L); + +-- input type checking String +SELECT every("true"); + +-- every/some/any aggregates are not supported as windows expression. +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; + +-- simple explain of queries having every/some/any agregates. Optimized +-- plan should show the rewritten aggregate expression. +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k; + diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index f9d1ee8a6bcdb..9a8d025331b67 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 30 +-- Number of queries: 47 -- !query 0 @@ -275,3 +275,215 @@ struct<> -- !query 29 output org.apache.spark.sql.AnalysisException grouping expressions sequence is empty, and '`id`' is not an aggregate function. Wrap '()' in windowing function(s) or wrap '`id`' in first() (or first_value) if you don't care which value you get.; + + +-- !query 30 +CREATE OR REPLACE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null), + (5, null), (5, true), (5, false) AS test_agg(k, v) +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +SELECT every(v), some(v), any(v) FROM test_agg WHERE 1 = 0 +-- !query 31 schema +struct +-- !query 31 output +NULL NULL NULL + + +-- !query 32 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 4 +-- !query 32 schema +struct +-- !query 32 output +NULL NULL NULL + + +-- !query 33 +SELECT every(v), some(v), any(v) FROM test_agg WHERE k = 5 +-- !query 33 schema +struct +-- !query 33 output +false true true + + +-- !query 34 +SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 34 schema +struct +-- !query 34 output +1 false true true +2 true true true +3 false false false +4 NULL NULL NULL +5 false true true + + +-- !query 35 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) = false +-- !query 35 schema +struct +-- !query 35 output +1 false +3 false +5 false + + +-- !query 36 +SELECT k, every(v) FROM test_agg GROUP BY k HAVING every(v) IS NULL +-- !query 36 schema +struct +-- !query 36 output +4 NULL + + +-- !query 37 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Any(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 37 schema +struct +-- !query 37 output +2 true + + +-- !query 38 +SELECT k, + Every(v) AS every +FROM test_agg +WHERE k = 2 + AND v IN (SELECT Every(v) + FROM test_agg + WHERE k = 1) +GROUP BY k +-- !query 38 schema +struct +-- !query 38 output + + + +-- !query 39 +SELECT every(1) +-- !query 39 schema +struct<> +-- !query 39 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every(1)' due to data type mismatch: Input to function 'every' should have been boolean, but it's [int].; line 1 pos 7 + + +-- !query 40 +SELECT some(1S) +-- !query 40 schema +struct<> +-- !query 40 output +org.apache.spark.sql.AnalysisException +cannot resolve 'some(1S)' due to data type mismatch: Input to function 'some' should have been boolean, but it's [smallint].; line 1 pos 7 + + +-- !query 41 +SELECT any(1L) +-- !query 41 schema +struct<> +-- !query 41 output +org.apache.spark.sql.AnalysisException +cannot resolve 'any(1L)' due to data type mismatch: Input to function 'any' should have been boolean, but it's [bigint].; line 1 pos 7 + + +-- !query 42 +SELECT every("true") +-- !query 42 schema +struct<> +-- !query 42 output +org.apache.spark.sql.AnalysisException +cannot resolve 'every('true')' due to data type mismatch: Input to function 'every' should have been boolean, but it's [string].; line 1 pos 7 + + +-- !query 43 +SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 43 schema +struct +-- !query 43 output +1 false false +1 true false +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true false + + +-- !query 44 +SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 44 schema +struct +-- !query 44 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 45 +SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg +-- !query 45 schema +struct +-- !query 45 output +1 false false +1 true true +2 true true +3 NULL NULL +3 false false +4 NULL NULL +4 NULL NULL +5 NULL NULL +5 false false +5 true true + + +-- !query 46 +EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k +-- !query 46 schema +struct +-- !query 46 output +== Parsed Logical Plan == +'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)] ++- 'UnresolvedRelation `test_agg` + +== Analyzed Logical Plan == +k: int, every(v): boolean, some(v): boolean, any(v): boolean +Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x] ++- SubqueryAlias `test_agg` + +- Project [k#x, v#x] + +- SubqueryAlias `test_agg` + +- LocalRelation [k#x, v#x] + +== Optimized Logical Plan == +Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, max(v#x) AS any(v)#x] ++- LocalRelation [k#x, v#x] + +== Physical Plan == +*HashAggregate(keys=[k#x], functions=[min(v#x), max(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x]) ++- Exchange hashpartitioning(k#x, 200) + +- *HashAggregate(keys=[k#x], functions=[partial_min(v#x), partial_max(v#x)], output=[k#x, min#x, max#x]) + +- LocalTableScan [k#x, v#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db2..b5ab5cf737e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.Matchers.the +import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyAgg, EveryAgg} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -727,4 +728,67 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + def getEveryAggColumn(columnName: String): Column = { + Column(new EveryAgg(Column(columnName).expr).toAggregateExpression(false)) + } + + def getAnyAggColumn(columnName: String): Column = { + Column(new AnyAgg(Column(columnName).expr).toAggregateExpression(false)) + } + + test("every") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + + checkAnswer( + df.groupBy("a").agg(getEveryAggColumn("b")), + Seq(Row(1, false), Row(2, true), Row(3, false))) + } + + test("every null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)]( + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(getEveryAggColumn("b")), + Seq(Row(1, false), Row(2, true), Row(3, false), Row(4, null))) + } + + test("every empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(getEveryAggColumn("b")), + Seq(Row(null))) + } + + test("any") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(getAnyAggColumn("b")), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + + test("any empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(getAnyAggColumn("b")), + Seq(Row(null))) + } + + test("any null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)]( + (1, true), (1, false), + (2, true), + (3, true), (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(getAnyAggColumn("b")), + Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null))) + } } From b793d06cb937db300c78e4eb4cd143c385419e57 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 23 Oct 2018 09:43:04 -0700 Subject: [PATCH 2/7] Code changes --- .../catalyst/analysis/FunctionRegistry.scala | 4 +++ .../sql/catalyst/expressions/Expression.scala | 26 ++++++++++++++++ .../catalyst/expressions/aggregate/Max.scala | 31 +++++++++++++++++++ .../catalyst/expressions/aggregate/Min.scala | 24 ++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 4 ++- .../catalyst/optimizer/finishAnalysis.scala | 13 ++++++++ 6 files changed, 101 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 38f5c02910f79..3dfca3f0561cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,6 +300,10 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[EveryAgg]("every"), + expression[AnyAgg]("any"), + expression[SomeAgg]("some"), + // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c215735ab1c98..da4225465dd94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode @@ -282,6 +283,31 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { override lazy val canonicalized: Expression = child.canonicalized } +/** + * An aggregate expression that gets rewritten (currently by the optimizer) into a + * different aggregate expression for evaluation. This is mainly used to provide compatibility + * with other databases. For example, we use this to support every, any/some aggregates by rewriting + * them with Min and Max respectively. + */ +trait UnevaluableAggrgate extends DeclarativeAggregate { + + override def nullable: Boolean = true + + override lazy val aggBufferAttributes = + throw new UnsupportedOperationException(s"Cannot evaluate aggBufferAttributes: $this") + + override lazy val initialValues: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate initialValues: $this") + + override lazy val updateExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate updateExpressions: $this") + + override lazy val mergeExpressions: Seq[Expression] = + throw new UnsupportedOperationException(s"Cannot evaluate mergeExpressions: $this") + + override lazy val evaluateExpression: Expression = + throw new UnsupportedOperationException(s"Cannot evaluate evaluateExpression: $this") +} /** * Expressions that don't have SQL representation should extend this trait. Examples are diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 71099eba0fc75..6815e2f75a877 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -57,3 +57,34 @@ case class Max(child: Expression) extends DeclarativeAggregate { override lazy val evaluateExpression: AttributeReference = max } + +abstract class AnyAggBase(arg: Expression) + extends UnevaluableAggrgate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = arg :: Nil + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = { + arg.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") +case class AnyAgg(arg: Expression) extends AnyAggBase(arg) { + override def nodeName: String = "Any" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") +case class SomeAgg(arg: Expression) extends AnyAggBase(arg) { + override def nodeName: String = "Some" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 8c4ba93231cbe..34ca861b21d9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -57,3 +57,27 @@ case class Min(child: Expression) extends DeclarativeAggregate { override lazy val evaluateExpression: AttributeReference = min } + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.") +case class EveryAgg(arg: Expression) + extends UnevaluableAggrgate with ImplicitCastInputTypes { + + override def nodeName: String = "Every" + + override def children: Seq[Expression] = arg :: Nil + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = { + arg.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index da8009d50b5ec..1d51cac786b90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,6 +118,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReplaceExpressions, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), + RewriteUnevaluableAggregates, RewriteDistinctAggregates, ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -206,7 +207,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - PullOutPythonUDFInJoinCondition.ruleName :: Nil + PullOutPythonUDFInJoinCondition.ruleName :: + RewriteUnevaluableAggregates.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index af0837e36e8ad..25069b132d47c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -38,6 +39,18 @@ object ReplaceExpressions extends Rule[LogicalPlan] { } } +/** + * Rewrites the aggregates expressions by replacing them with another. This is mainly used to + * provide compatibiity with other databases. For example, we use this to support + * Every, Any/Some by rewriting them to Min, Max respectively. + */ +object RewriteUnevaluableAggregates extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case SomeAgg(arg) => Max(arg) + case AnyAgg(arg) => Max(arg) + case EveryAgg(arg) => Min(arg) + } +} /** * Computes the current date and time to make sure we return the same result in a single query. From 9e194b57cb30f022392c3fb01959984b9fc17f27 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 24 Oct 2018 12:57:37 -0700 Subject: [PATCH 3/7] fix --- .../sql/catalyst/optimizer/Optimizer.scala | 4 +--- .../catalyst/optimizer/finishAnalysis.scala | 21 +++++++------------ .../resources/sql-tests/inputs/group-by.sql | 2 +- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 1d51cac786b90..da8009d50b5ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,7 +118,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReplaceExpressions, ComputeCurrentTime, GetCurrentDatabase(sessionCatalog), - RewriteUnevaluableAggregates, RewriteDistinctAggregates, ReplaceDeduplicateWithAggregate) :: ////////////////////////////////////////////////////////////////////////////////////////// @@ -207,8 +206,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) PullupCorrelatedPredicates.ruleName :: RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: - PullOutPythonUDFInJoinCondition.ruleName :: - RewriteUnevaluableAggregates.ruleName :: Nil + PullOutPythonUDFInJoinCondition.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 25069b132d47c..5e458b4076bf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -29,23 +29,18 @@ import org.apache.spark.sql.types._ /** - * Finds all [[RuntimeReplaceable]] expressions and replace them with the expressions that can - * be evaluated. This is mainly used to provide compatibility with other databases. - * For example, we use this to support "nvl" by replacing it with "coalesce". + * Finds all the expressions that are unevaluable and replace/rewrite them with semantically + * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions : + * 1) [[RuntimeReplaceable]] expressions + * 2) [[UnevaluableAggrgate]] expressions such as Every, Some, Any + * This is mainly used to provide compatibility with other databases. + * Few examples are : + * we use this to support "nvl" by replacing it with "coalesce". + * we use this to replace Every and Any with Min and Max respectively. */ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case e: RuntimeReplaceable => e.child - } -} - -/** - * Rewrites the aggregates expressions by replacing them with another. This is mainly used to - * provide compatibiity with other databases. For example, we use this to support - * Every, Any/Some by rewriting them to Min, Max respectively. - */ -object RewriteUnevaluableAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case SomeAgg(arg) => Max(arg) case AnyAgg(arg) => Max(arg) case EveryAgg(arg) => Min(arg) diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 92d73561d27ec..ec263ea70bd4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -137,7 +137,7 @@ SELECT any(1L); -- input type checking String SELECT every("true"); --- every/some/any aggregates are not supported as windows expression. +-- every/some/any aggregates are supported as windows expression. SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg; From 6abf844f728471b737a047a81c3f287060506473 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 24 Oct 2018 20:23:08 -0700 Subject: [PATCH 4/7] Code review --- .../spark/sql/DataFrameAggregateSuite.scala | 64 ------------------- 1 file changed, 64 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index b5ab5cf737e25..d0106c44b7db2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,7 +21,6 @@ import scala.util.Random import org.scalatest.Matchers.the -import org.apache.spark.sql.catalyst.expressions.aggregate.{AnyAgg, EveryAgg} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -728,67 +727,4 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } - - def getEveryAggColumn(columnName: String): Column = { - Column(new EveryAgg(Column(columnName).expr).toAggregateExpression(false)) - } - - def getAnyAggColumn(columnName: String): Column = { - Column(new AnyAgg(Column(columnName).expr).toAggregateExpression(false)) - } - - test("every") { - val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) - .toDF("a", "b") - - checkAnswer( - df.groupBy("a").agg(getEveryAggColumn("b")), - Seq(Row(1, false), Row(2, true), Row(3, false))) - } - - test("every null values") { - val df = Seq[(java.lang.Integer, java.lang.Boolean)]( - (1, true), (1, false), - (2, true), - (3, false), (3, null), - (4, null), (4, null)) - .toDF("a", "b") - checkAnswer( - df.groupBy("a").agg(getEveryAggColumn("b")), - Seq(Row(1, false), Row(2, true), Row(3, false), Row(4, null))) - } - - test("every empty table") { - val df = Seq.empty[(Int, Boolean)].toDF("a", "b") - checkAnswer( - df.agg(getEveryAggColumn("b")), - Seq(Row(null))) - } - - test("any") { - val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) - .toDF("a", "b") - checkAnswer( - df.groupBy("a").agg(getAnyAggColumn("b")), - Seq(Row(1, true), Row(2, true), Row(3, false))) - } - - test("any empty table") { - val df = Seq.empty[(Int, Boolean)].toDF("a", "b") - checkAnswer( - df.agg(getAnyAggColumn("b")), - Seq(Row(null))) - } - - test("any null values") { - val df = Seq[(java.lang.Integer, java.lang.Boolean)]( - (1, true), (1, false), - (2, true), - (3, true), (3, false), (3, null), - (4, null), (4, null)) - .toDF("a", "b") - checkAnswer( - df.groupBy("a").agg(getAnyAggColumn("b")), - Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null))) - } } From 08999f98a3af6c7a30c545cec1c3657498fb39c0 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 25 Oct 2018 10:22:06 -0700 Subject: [PATCH 5/7] Code review --- .../catalyst/analysis/FunctionRegistry.scala | 1 - .../sql/catalyst/expressions/Expression.scala | 2 +- .../catalyst/expressions/aggregate/Max.scala | 31 ---------- .../catalyst/expressions/aggregate/Min.scala | 24 ------- .../aggregate/UnevaluableAggs.scala | 62 +++++++++++++++++++ .../catalyst/optimizer/finishAnalysis.scala | 5 +- .../ExpressionTypeCheckingSuite.scala | 1 + 7 files changed, 67 insertions(+), 59 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 3dfca3f0561cd..af6166bcb8692 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -304,7 +304,6 @@ object FunctionRegistry { expression[AnyAgg]("any"), expression[SomeAgg]("some"), - // string functions expression[Ascii]("ascii"), expression[Chr]("char"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index da4225465dd94..ccc5b9043a0aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -289,7 +289,7 @@ trait RuntimeReplaceable extends UnaryExpression with Unevaluable { * with other databases. For example, we use this to support every, any/some aggregates by rewriting * them with Min and Max respectively. */ -trait UnevaluableAggrgate extends DeclarativeAggregate { +trait UnevaluableAggregate extends DeclarativeAggregate { override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 6815e2f75a877..71099eba0fc75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -57,34 +57,3 @@ case class Max(child: Expression) extends DeclarativeAggregate { override lazy val evaluateExpression: AttributeReference = max } - -abstract class AnyAggBase(arg: Expression) - extends UnevaluableAggrgate with ImplicitCastInputTypes { - - override def children: Seq[Expression] = arg :: Nil - - override def dataType: DataType = BooleanType - - override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) - - override def checkInputDataTypes(): TypeCheckResult = { - arg.dataType match { - case dt if dt != BooleanType => - TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + - s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") - case _ => TypeCheckResult.TypeCheckSuccess - } - } -} - -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") -case class AnyAgg(arg: Expression) extends AnyAggBase(arg) { - override def nodeName: String = "Any" -} - -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") -case class SomeAgg(arg: Expression) extends AnyAggBase(arg) { - override def nodeName: String = "Some" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 34ca861b21d9e..8c4ba93231cbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -57,27 +57,3 @@ case class Min(child: Expression) extends DeclarativeAggregate { override lazy val evaluateExpression: AttributeReference = min } - -@ExpressionDescription( - usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.") -case class EveryAgg(arg: Expression) - extends UnevaluableAggrgate with ImplicitCastInputTypes { - - override def nodeName: String = "Every" - - override def children: Seq[Expression] = arg :: Nil - - override def dataType: DataType = BooleanType - - override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) - - override def checkInputDataTypes(): TypeCheckResult = { - arg.dataType match { - case dt if dt != BooleanType => - TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + - s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") - case _ => TypeCheckResult.TypeCheckSuccess - } - } -} - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala new file mode 100644 index 0000000000000..fc33ef919498b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/UnevaluableAggs.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +abstract class UnevaluableBooleanAggBase(arg: Expression) + extends UnevaluableAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = arg :: Nil + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = { + arg.dataType match { + case dt if dt != BooleanType => + TypeCheckResult.TypeCheckFailure(s"Input to function '$prettyName' should have been " + + s"${BooleanType.simpleString}, but it's [${arg.dataType.catalogString}].") + case _ => TypeCheckResult.TypeCheckSuccess + } + } +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.", + since = "3.0.0") +case class EveryAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Every" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class AnyAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Any" +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.", + since = "3.0.0") +case class SomeAgg(arg: Expression) extends UnevaluableBooleanAggBase(arg) { + override def nodeName: String = "Some" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 5e458b4076bf0..630513ede123a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.types._ * Finds all the expressions that are unevaluable and replace/rewrite them with semantically * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions : * 1) [[RuntimeReplaceable]] expressions - * 2) [[UnevaluableAggrgate]] expressions such as Every, Some, Any + * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any * This is mainly used to provide compatibility with other databases. - * Few examples are : + * Few examples are: * we use this to support "nvl" by replacing it with "coalesce". * we use this to replace Every and Any with Min and Max respectively. */ @@ -47,6 +47,7 @@ object ReplaceExpressions extends Rule[LogicalPlan] { } } + /** * Computes the current date and time to make sure we return the same result in a single query. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index b8643d0243f59..3eb3fe66cebc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -146,6 +146,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Min('arrayField)) assertSuccess(new EveryAgg('booleanField)) assertSuccess(new AnyAgg('booleanField)) + assertSuccess(new SomeAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") From 2bc996515ec1947b8a1b82f942bd6ebecc473277 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 25 Oct 2018 10:34:07 -0700 Subject: [PATCH 6/7] fix --- .../apache/spark/sql/catalyst/optimizer/finishAnalysis.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 630513ede123a..13081d4dbc48a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types._ /** * Finds all the expressions that are unevaluable and replace/rewrite them with semantically - * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions : + * equivalent expressions that can be evaluated. Currently we replace two kinds of expressions: * 1) [[RuntimeReplaceable]] expressions * 2) [[UnevaluableAggregate]] expressions such as Every, Some, Any * This is mainly used to provide compatibility with other databases. From 07205dea343539cb812622205fd0534b77f183d0 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 27 Oct 2018 13:31:20 -0700 Subject: [PATCH 7/7] Added todo --- .../apache/spark/sql/catalyst/optimizer/finishAnalysis.scala | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala index 13081d4dbc48a..fe196ec7c9d54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala @@ -37,6 +37,9 @@ import org.apache.spark.sql.types._ * Few examples are: * we use this to support "nvl" by replacing it with "coalesce". * we use this to replace Every and Any with Min and Max respectively. + * + * TODO: In future, explore an option to replace aggregate functions similar to + * how RruntimeReplaceable does. */ object ReplaceExpressions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {