diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 2c8c8e2d80f09..eb8d39a77c864 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -169,6 +169,12 @@ def _(): 'measured in radians.', } +_functions_2_2 = { + 'to_date': 'Converts a string date into a DateType using the (optionally) specified format.', + 'to_timestamp': 'Converts a string timestamp into a timestamp type using the ' + + '(optionally) specified format.', +} + # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': 'Returns the angle theta from the conversion of rectangular coordinates (x, y) to' + @@ -350,6 +356,28 @@ def countDistinct(col, *cols): return Column(jc) +def every(col): + """Aggregate function: returns true if all values in a group are true. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.every(_to_java_column(col)) + return Column(jc) + + +def any(col): + """Aggregate function: returns true if at least one value in the group is true. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.any(_to_java_column(col)) + return Column(jc) + + +def some(col): + """Aggregate function: returns true if at least one value in the group is true. + """ + return any(col) + + @since(1.3) def first(col, ignorenulls=False): """Aggregate function: returns the first value in a group. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index be5495ca019a2..2cfa8349af230 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1158,6 +1158,24 @@ def test_cov(self): cov = df.stat.cov("a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_every_any(self): + from pyspark.sql import functions + data = [ + Row(key="a", value=False), + Row(key="a", value=True), + Row(key="a", value=False), + Row(key="b", value=True), + Row(key="b", value=True), + Row(key="c", value=False), + Row(key="d", value=True), + Row(key="d", value=None) + ] + df = self.sc.parallelize(data).toDF() + df2 = df.select(functions.every(df.value).alias('a'), + functions.any(df.value).alias('b'), + functions.some(df.value).alias('c')) + self.assertEqual([Row(a=False, b=True, c=True)], df2.collect()) + def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() ct = df.stat.crosstab("a", "b").collect() @@ -2631,6 +2649,75 @@ def test_window_functions_cumulative_sum(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_every_any(self): + df = self.spark.createDataFrame([ + ("a", False), + ("a", True), + ("a", False), + ("b", True), + ("b", True), + ("c", False), + ("d", True), + ("d", None) + ], ["key", "value"]) + w = Window \ + .partitionBy("key").orderBy("value") \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + from pyspark.sql import functions as F + sel = df.select(df.key, + df.value, + F.every("value").over(w), + F.any("value").over(w), + F.some("value").over(w)) + rs = sel.collect() + expected = [ + ("a", False, False, True, True), + ("a", False, False, True, True), + ("a", True, False, True, True), + ("b", True, True, True, True), + ("b", True, True, True, True), + ("c", False, False, False, False), + ("d", None, False, True, True), + ("d", True, False, True, True) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_window_functions_every_any_without_partitionBy(self): + df = self.spark.createDataFrame([ + (False,), + (True,), + (False,), + (True,), + (True,), + (False,), + (True,), + (None,) + ], ["value"]) + w1 = Window.orderBy("value").rowsBetween(Window.unboundedPreceding, 0) + w2 = Window.orderBy("value").rowsBetween(-1, 0) + from pyspark.sql import functions as F + sel = df.select(df.value, + F.every("value").over(w1), + F.any("value").over(w1), + F.some("value").over(w1), + F.every("value").over(w2), + F.any("value").over(w2), + F.some("value").over(w2)) + rs = sel.collect() + expected = [ + (None, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (True, False, True, True, False, True, True), + (True, False, True, True, True, True, True), + (True, False, True, True, True, True, True), + (True, False, True, True, True, True, True) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 10b22ae562bcf..03070fc16bfa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -299,6 +299,9 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[Every]("every"), + expression[AnyAgg]("any"), + expression[AnyAgg]("some"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala new file mode 100644 index 0000000000000..bf599fb23f4a2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") +case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBooleanExpr(child.dataType, "function any") + + private lazy val some = AttributeReference("some", BooleanType)() + + private lazy val emptySet = AttributeReference("emptySet", BooleanType)() + + override lazy val aggBufferAttributes = some :: emptySet :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + Literal(false), + Literal(true) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + Or(some, Coalesce(Seq(child, Literal(false)))), + Literal(false) + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + Or(some.left, some.right), + And(emptySet.left, emptySet.right) + ) + + override lazy val evaluateExpression: Expression = And(!emptySet, some) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala new file mode 100644 index 0000000000000..508fc3dae4939 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.") +case class Every(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBooleanExpr(child.dataType, "function every") + + private lazy val every = AttributeReference("every", BooleanType)() + + private lazy val emptySet = AttributeReference("emptySet", BooleanType)() + + override lazy val aggBufferAttributes = every :: emptySet :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + Literal(true), + Literal(true) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + And(every, Coalesce(Seq(child, Literal(false)))), + Literal(false) + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + And(every.left, every.right), + And(emptySet.left, emptySet.right) + ) + + override lazy val evaluateExpression: Expression = And(!emptySet, every) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 45225779bffcb..ecad25406b968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -33,6 +33,14 @@ object TypeUtils { } } + def checkForBooleanExpr(dt: DataType, caller: String): TypeCheckResult = { + if (dt.isInstanceOf[BooleanType] || dt == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller requires boolean types, not $dt") + } + } + def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 30725773a37b1..0cd781855b6be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -143,6 +143,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(Every('booleanField)) + assertSuccess(AnyAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala new file mode 100644 index 0000000000000..b0f5962ac27b2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.BooleanType + +class AnyTestSuite extends SparkFunSuite { + val input = AttributeReference("input", BooleanType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(AnyAgg(input), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(false, true)) + } + + test("update") { + val result = evaluator.update( + InternalRow(true), + InternalRow(false), + InternalRow(true)) + assert(result === InternalRow(true, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(false, true)) + + // Single merge + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.merge(p1) === InternalRow(true, false)) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(false), InternalRow(null)) + assert(evaluator.merge(p1, p2) === InternalRow(true, false)) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p0, p2) === InternalRow(false, false)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(true, false)) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(false)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(true), InternalRow(null)) + assert(evaluator.eval(p1) === InternalRow(true)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(false), InternalRow(false)) + val m1 = evaluator.merge(p0, p2) + assert(evaluator.eval(m1) === InternalRow(false)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(true)) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala new file mode 100644 index 0000000000000..c01fcb11fee78 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.BooleanType + +class EveryTestSuite extends SparkFunSuite { + val input = AttributeReference("input", BooleanType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(Every(input), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(true, true)) + } + + test("update") { + val result = evaluator.update( + InternalRow(true), + InternalRow(false), + InternalRow(true)) + assert(result === InternalRow(false, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(true, true)) + + // Single merge + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.merge(p1) === InternalRow(true, false)) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(true), InternalRow(null)) + assert(evaluator.merge(p1, p2) === InternalRow(false, false)) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p1, p0, p2) === InternalRow(false, false)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(false, false)) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(false)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.eval(p1) === InternalRow(true)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(false), InternalRow(true)) + val m1 = evaluator.merge(p1, p0, p2) + assert(evaluator.eval(m1) === InternalRow(false)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(false)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b825b6cd6160f..1c95adc40dd87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -226,6 +226,21 @@ class Dataset[T] private[sql]( } } + private[sql] def booleanColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[BooleanType]).map { n => + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get + } + } + + private def aggregatableColumns: Seq[Expression] = { + schema.fields + .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) + .map { n => + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) + .get + } + } + /** * Compose the string representing rows for output * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 147b549964913..6ad10b22ab136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.NumericType -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{BooleanType, NumericType, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -89,7 +88,7 @@ class RelationalGroupedDataset protected[sql]( } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { + : DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. @@ -101,7 +100,28 @@ class RelationalGroupedDataset protected[sql]( if (!namedExpr.dataType.isInstanceOf[NumericType]) { throw new AnalysisException( s""""$colName" is not a numeric column. """ + - "Aggregation function can only be applied on a numeric column.") + "Aggregation function can only be applied on a numeric column.") + } + namedExpr + } + } + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) + } + + private[this] def aggregateBooleanColumns(colNames: String*)(f: Expression => AggregateFunction) + : DataFrame = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.booleanColumns + } else { + // Make sure all specified columns are numeric. + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[BooleanType]) { + throw new AnalysisException( + s""""$colName" is not a boolean column. """ + + "Aggregation function can only be applied on a boolean column.") } namedExpr } @@ -298,9 +318,45 @@ class RelationalGroupedDataset protected[sql]( } /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * Compute the logical and of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ + @scala.annotation.varargs + def every(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(Every) + } + + /** + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ + @scala.annotation.varargs + def any(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(AnyAgg) + } + + /** + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. * - * There are two versions of `pivot` function: one that requires the caller to specify the list + * @since 2.2.0 + */ + @scala.annotation.varargs + def some(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(AnyAgg) + } + + /** + * Pivots a column of the current `DataFrame` and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list +>>>>>>> Add new aggregates EVERY and ANY (SOME). * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ebdeb42b0bfb1..a654e1250fd23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -774,6 +774,54 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.3.0 + */ + def every(e: Column): Column = withAggregateFunction { Every(e.expr) } + + /** + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def every(columnName: String): Column = every(Column(columnName)) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def any(e: Column): Column = withAggregateFunction { AnyAgg(e.expr) } + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def any(columnName: String): Column = any(Column(columnName)) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def some(e: Column): Column = any(e) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def some(columnName: String): Column = any(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4568b67024acb..2aaf04dfba4f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -557,4 +557,43 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) } + + test("every") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(every('b)), + Seq(Row(1, false), Row(2, true), Row(3, false))) + } + + test("every empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(every('b)), + Seq(Row(false))) + } + + test("any") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(any('b)), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + + test("any empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(any('b)), + Seq(Row(false))) + } + + test("some") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(some('b)), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 204858fa29787..818310a762678 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -468,4 +468,37 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } } } + + test("every/any/some") { + val df = Seq[(String, java.lang.Boolean)]( + ("a", false), + ("a", true), + ("a", false), + ("b", true), + ("b", true), + ("c", false), + ("d", true), + ("d", null) + ).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy(s"value") + .rowsBetween(Long.MinValue, Long.MaxValue) + checkAnswer( + df.select( + $"key", + $"value", + every($"value").over(window), + any($"value").over(window), + some($"value").over(window)), + Seq( + Row("a", false, false, true, true), + Row("a", false, false, true, true), + Row("a", true, false, true, true), + Row("b", true, true, true, true), + Row("b", true, true, true, true), + Row("c", false, false, false, false), + Row("d", true, false, true, true), + Row("d", null, false, true, true) + ) + ) + } }