diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3128d5792eead..2e5a9b79e89c2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -403,6 +403,28 @@ def countDistinct(col, *cols): return Column(jc) +def every(col): + """Aggregate function: returns true if all values in a group are true. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.every(_to_java_column(col)) + return Column(jc) + + +def any(col): + """Aggregate function: returns true if at least one value in the group is true. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.any(_to_java_column(col)) + return Column(jc) + + +def some(col): + """Aggregate function: returns true if at least one value in the group is true. + """ + return any(col) + + @since(1.3) def first(col, ignorenulls=False): """Aggregate function: returns the first value in a group. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 815772d23ceea..6c07013596ffa 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1532,6 +1532,24 @@ def test_cov(self): cov = df.stat.cov(u"a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_every_any(self): + from pyspark.sql import functions + data = [ + Row(key="a", value=False), + Row(key="a", value=True), + Row(key="a", value=False), + Row(key="b", value=True), + Row(key="b", value=True), + Row(key="c", value=False), + Row(key="d", value=True), + Row(key="d", value=None) + ] + df = self.sc.parallelize(data).toDF() + df2 = df.select(functions.every(df.value).alias('a'), + functions.any(df.value).alias('b'), + functions.some(df.value).alias('c')) + self.assertEqual([Row(a=False, b=True, c=True)], df2.collect()) + def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() ct = df.stat.crosstab(u"a", "b").collect() @@ -3938,6 +3956,75 @@ def test_window_functions_cumulative_sum(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_every_any(self): + df = self.spark.createDataFrame([ + ("a", False), + ("a", True), + ("a", False), + ("b", True), + ("b", True), + ("c", False), + ("d", True), + ("d", None) + ], ["key", "value"]) + w = Window \ + .partitionBy("key").orderBy("value") \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + from pyspark.sql import functions as F + sel = df.select(df.key, + df.value, + F.every("value").over(w), + F.any("value").over(w), + F.some("value").over(w)) + rs = sel.collect() + expected = [ + ("a", False, False, True, True), + ("a", False, False, True, True), + ("a", True, False, True, True), + ("b", True, True, True, True), + ("b", True, True, True, True), + ("c", False, False, False, False), + ("d", None, True, True, True), + ("d", True, True, True, True) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + + def test_window_functions_every_any_without_partitionBy(self): + df = self.spark.createDataFrame([ + (False,), + (True,), + (False,), + (True,), + (True,), + (False,), + (True,), + (None,) + ], ["value"]) + w1 = Window.orderBy("value").rowsBetween(Window.unboundedPreceding, 0) + w2 = Window.orderBy("value").rowsBetween(-1, 0) + from pyspark.sql import functions as F + sel = df.select(df.value, + F.every("value").over(w1), + F.any("value").over(w1), + F.some("value").over(w1), + F.every("value").over(w2), + F.any("value").over(w2), + F.some("value").over(w2)) + rs = sel.collect() + expected = [ + (None, None, None, None, None, None, None), + (False, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (True, False, True, True, False, True, True), + (True, False, True, True, True, True, True), + (True, False, True, True, True, True, True), + (True, False, True, True, True, True, True) + ] + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7dafebff79874..ac9ac7bbc2f66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,6 +300,9 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[Every]("every"), + expression[AnyAgg]("any"), + expression[AnyAgg]("some"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala new file mode 100644 index 0000000000000..1eac1e4ca3ffc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") +case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBooleanExpr(child.dataType, "function any") + + private lazy val some = AttributeReference("some", BooleanType)() + + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() + + override lazy val aggBufferAttributes = some :: valueSet :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + /* some = */ Literal.create(false, BooleanType), + /* valueSet = */ Literal.create(false, BooleanType) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* some = */ Or(some, If (child.isNull, some, child)), + /* valueSet = */ valueSet || child.isNotNull + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + /* some = */ Or(some.left, some.right), + /* valueSet */ valueSet.right || valueSet.left + ) + + override lazy val evaluateExpression: Expression = + If (valueSet, some, Literal.create(null, BooleanType)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala new file mode 100644 index 0000000000000..ad0301a3200e8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.") +case class Every(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBooleanExpr(child.dataType, "function every") + + private lazy val every = AttributeReference("every", BooleanType)() + + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() + + override lazy val aggBufferAttributes = every :: valueSet :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + /* every = */ Literal.create(true, BooleanType), + /* valueSet = */ Literal.create(false, BooleanType) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + /* every = */ And(every, If (child.isNull, every, child)), + /* valueSet = */ valueSet || child.isNotNull + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + /* every = */ And(every.left, every.right), + /* valueSet */ valueSet.right || valueSet.left + ) + + override lazy val evaluateExpression: Expression = + If (valueSet, every, Literal.create(null, BooleanType)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 76218b459ef0d..9b70162b6e8b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -33,6 +33,14 @@ object TypeUtils { } } + def checkForBooleanExpr(dt: DataType, caller: String): TypeCheckResult = { + if (dt.isInstanceOf[BooleanType] || dt == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller requires boolean types, not $dt") + } + } + def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..5c4b341b73e02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,6 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(Every('booleanField)) + assertSuccess(AnyAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala new file mode 100644 index 0000000000000..2f389e608113a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.BooleanType + +class AnyTestSuite extends SparkFunSuite { + val input = AttributeReference("input", BooleanType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(AnyAgg(input), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(false, false)) + } + + test("update") { + val result = evaluator.update( + InternalRow(true), + InternalRow(false), + InternalRow(true)) + assert(result === InternalRow(true, true)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(false, false)) + + // Single merge + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.merge(p1) === InternalRow(true, true)) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(false), InternalRow(null)) + assert(evaluator.merge(p1, p2) === InternalRow(true, true)) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p0, p2) === InternalRow(false, true)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(true, true)) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(true, false)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(false, false)) === InternalRow(null)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(null)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(true), InternalRow(null)) + assert(evaluator.eval(p1) === InternalRow(true)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(false), InternalRow(false)) + val m1 = evaluator.merge(p0, p2) + assert(evaluator.eval(m1) === InternalRow(false)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(true)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala new file mode 100644 index 0000000000000..109eed85e1208 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.BooleanType + +class EveryTestSuite extends SparkFunSuite { + val input = AttributeReference("input", BooleanType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(Every(input), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(true, false)) + } + + test("update") { + val result = evaluator.update( + InternalRow(true), + InternalRow(false), + InternalRow(true)) + assert(result === InternalRow(false, true)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(true, false)) + + // Single merge + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.merge(p1) === InternalRow(true, true)) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(true), InternalRow(null)) + assert(evaluator.merge(p1, p2) === InternalRow(true, true)) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p1, p0, p2) === InternalRow(true, true)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(true, true)) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(true, false)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(false, false)) === InternalRow(null)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(null)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.eval(p1) === InternalRow(true)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(false), InternalRow(true)) + val m1 = evaluator.merge(p1, p0, p2) + assert(evaluator.eval(m1) === InternalRow(false)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(false)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fa14aa14ee968..0ba0b99d38833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -232,6 +232,21 @@ class Dataset[T] private[sql]( } } + private[sql] def booleanColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[BooleanType]).map { n => + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get + } + } + + private def aggregatableColumns: Seq[Expression] = { + schema.fields + .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) + .map { n => + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) + .get + } + } + /** * Get rows represented in Sequence by specific truncate and vertical requirement. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d4e75b5ebd405..7d2d7144c4360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{BooleanType, NumericType, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -89,7 +89,6 @@ class RelationalGroupedDataset protected[sql]( private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) : DataFrame = { - val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. df.numericColumns @@ -100,7 +99,27 @@ class RelationalGroupedDataset protected[sql]( if (!namedExpr.dataType.isInstanceOf[NumericType]) { throw new AnalysisException( s""""$colName" is not a numeric column. """ + - "Aggregation function can only be applied on a numeric column.") + "Aggregation function can only be applied on a numeric column.") + } + namedExpr + } + } + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) + } + + private[this] def aggregateBooleanColumns(colNames: String*)(f: Expression => AggregateFunction) + : DataFrame = { + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.booleanColumns + } else { + // Make sure all specified columns are numeric. + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[BooleanType]) { + throw new AnalysisException( + s""""$colName" is not a boolean column. """ + + "Aggregation function can only be applied on a boolean column.") } namedExpr } @@ -297,9 +316,44 @@ class RelationalGroupedDataset protected[sql]( } /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. + * Compute the logical and of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.4.0 + */ + @scala.annotation.varargs + def every(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(Every) + } + + /** + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.4.0 + */ + @scala.annotation.varargs + def any(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(AnyAgg) + } + + /** + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. * - * There are two versions of `pivot` function: one that requires the caller to specify the list + * @since 2.4.0 + */ + @scala.annotation.varargs + def some(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(AnyAgg) + } + + /** + * Pivots a column of the current `DataFrame` and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 367ac66dd77f5..f8a8fcf6eaa8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -825,6 +825,53 @@ object functions { */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.3.0 + */ + def every(e: Column): Column = withAggregateFunction { Every(e.expr) } + + /** + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def every(columnName: String): Column = every(Column(columnName)) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def any(e: Column): Column = withAggregateFunction { AnyAgg(e.expr) } + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def any(columnName: String): Column = any(Column(columnName)) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def some(e: Column): Column = any(e) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def some(columnName: String): Column = any(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db2..6303933ebb8de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -727,4 +727,70 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("every") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(every('b)), + Seq(Row(1, false), Row(2, true), Row(3, false))) + } + + test("every null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)]( + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(every('b)), + Seq(Row(1, false), Row(2, true), Row(3, false), Row(4, null))) + } + + test("every empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(every('b)), + Seq(Row(null))) + } + + test("any") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(any('b)), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + + test("any empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(any('b)), + Seq(Row(null))) + } + + test("any/some null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)] ( + (1, true), (1, false), + (2, true), + (3, true), (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(any('b)), + Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null))) + checkAnswer( + df.groupBy("a").agg(some('b)), + Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null))) + } + + test("some") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(some('b)), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 78277d7dcf757..7bf270719c3b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -681,4 +681,37 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("S2", "P2", 300, 300, 500))) } + + test("every/any/some") { + val df = Seq[(String, java.lang.Boolean)]( + ("a", false), + ("a", true), + ("a", false), + ("b", true), + ("b", true), + ("c", false), + ("d", true), + ("d", null) + ).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy(s"value") + .rowsBetween(Long.MinValue, Long.MaxValue) + checkAnswer( + df.select( + $"key", + $"value", + every($"value").over(window), + any($"value").over(window), + some($"value").over(window)), + Seq( + Row("a", false, false, true, true), + Row("a", false, false, true, true), + Row("a", true, false, true, true), + Row("b", true, true, true, true), + Row("b", true, true, true, true), + Row("c", false, false, false, false), + Row("d", true, true, true, true), + Row("d", null, true, true, true) + ) + ) + } }