From 0eb599282cce39245a5c9f047231cf319f44fa31 Mon Sep 17 00:00:00 2001 From: ptkool Date: Tue, 7 Mar 2017 14:09:32 -0500 Subject: [PATCH 01/13] Add new aggregates EVERY and ANY (SOME). --- python/pyspark/sql/functions.py | 11 +++ .../catalyst/analysis/FunctionRegistry.scala | 3 + .../expressions/aggregate/AnyAgg.scala | 62 ++++++++++++++ .../expressions/aggregate/Every.scala | 62 ++++++++++++++ .../spark/sql/catalyst/util/TypeUtils.scala | 8 ++ .../ExpressionTypeCheckingSuite.scala | 2 + .../expressions/aggregate/AnyTestSuite.scala | 80 ++++++++++++++++++ .../aggregate/EveryTestSuite.scala | 80 ++++++++++++++++++ .../scala/org/apache/spark/sql/Dataset.scala | 15 ++++ .../spark/sql/RelationalGroupedDataset.scala | 82 ++++++++++++++++--- .../org/apache/spark/sql/functions.scala | 57 +++++++++++-- .../spark/sql/DataFrameAggregateSuite.scala | 39 +++++++++ .../sql/DataFrameWindowFunctionsSuite.scala | 32 ++++++++ 13 files changed, 515 insertions(+), 18 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3128d5792eead..5f57874af217a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -202,6 +202,15 @@ def _(): """, } +_functions_2_2 = { + 'to_date': 'Converts a string date into a DateType using the (optionally) specified format.', + 'to_timestamp': 'Converts a string timestamp into a timestamp type using the ' + + '(optionally) specified format.', + 'every': 'Aggregate function: returns true if all values in the expression are true.', + 'any': 'Aggregate function: returns true if at least one value in the expression is true.', + 'some': 'Aggregate function: returns true if at least one value in the expression is true.', +} + # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': """ @@ -265,6 +274,8 @@ def _(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) for _name, _doc in _functions_2_1.items(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) +for _name, _doc in _functions_2_2.items(): + globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) for _name, _doc in _functions_2_4.items(): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7dafebff79874..ac9ac7bbc2f66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -300,6 +300,9 @@ object FunctionRegistry { expression[CollectList]("collect_list"), expression[CollectSet]("collect_set"), expression[CountMinSketchAgg]("count_min_sketch"), + expression[Every]("every"), + expression[AnyAgg]("any"), + expression[AnyAgg]("some"), // string functions expression[Ascii]("ascii"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala new file mode 100644 index 0000000000000..8c7f9bcd5f656 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if at least one value of `expr` is true.") +case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBooleanExpr(child.dataType, "function some") + + private lazy val some = AttributeReference("some", BooleanType)() + private lazy val emptySet = AttributeReference("emptySet", BooleanType)() + + override lazy val aggBufferAttributes = some :: emptySet :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + Literal(false), + Literal(true) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + Or(some, Coalesce(Seq(child, Literal(false)))), + Literal(false) + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + Or(some.left, some.right), + And(emptySet.left, emptySet.right) + ) + + override lazy val evaluateExpression: Expression = And(!emptySet, some) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala new file mode 100644 index 0000000000000..e8004e4c134e7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns true if all values of `expr` are true.") +case class Every(child: Expression) extends DeclarativeAggregate with ImplicitCastInputTypes { + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = BooleanType + + override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForBooleanExpr(child.dataType, "function every") + + private lazy val every = AttributeReference("every", BooleanType)() + private lazy val emptySet = AttributeReference("emptySet", BooleanType)() + + override lazy val aggBufferAttributes = every :: emptySet :: Nil + + override lazy val initialValues: Seq[Expression] = Seq( + Literal(true), + Literal(true) + ) + + override lazy val updateExpressions: Seq[Expression] = Seq( + And(every, Coalesce(Seq(child, Literal(false)))), + Literal(false) + ) + + override lazy val mergeExpressions: Seq[Expression] = Seq( + And(every.left, every.right), + And(emptySet.left, emptySet.right) + ) + + override lazy val evaluateExpression: Expression = And(!emptySet, every) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 76218b459ef0d..9b70162b6e8b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -33,6 +33,14 @@ object TypeUtils { } } + def checkForBooleanExpr(dt: DataType, caller: String): TypeCheckResult = { + if (dt.isInstanceOf[BooleanType] || dt == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"$caller requires boolean types, not $dt") + } + } + def checkForOrderingExpr(dt: DataType, caller: String): TypeCheckResult = { if (RowOrdering.isOrderable(dt)) { TypeCheckResult.TypeCheckSuccess diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 8eec14842c7e7..5c4b341b73e02 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -144,6 +144,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(Sum('stringField)) assertSuccess(Average('stringField)) assertSuccess(Min('arrayField)) + assertSuccess(Every('booleanField)) + assertSuccess(AnyAgg('booleanField)) assertError(Min('mapField), "min does not support ordering on type") assertError(Max('mapField), "max does not support ordering on type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala new file mode 100644 index 0000000000000..b0f5962ac27b2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.BooleanType + +class AnyTestSuite extends SparkFunSuite { + val input = AttributeReference("input", BooleanType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(AnyAgg(input), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(false, true)) + } + + test("update") { + val result = evaluator.update( + InternalRow(true), + InternalRow(false), + InternalRow(true)) + assert(result === InternalRow(true, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(false, true)) + + // Single merge + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.merge(p1) === InternalRow(true, false)) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(false), InternalRow(null)) + assert(evaluator.merge(p1, p2) === InternalRow(true, false)) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p0, p2) === InternalRow(false, false)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(true, false)) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(false)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(true), InternalRow(null)) + assert(evaluator.eval(p1) === InternalRow(true)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(false), InternalRow(false)) + val m1 = evaluator.merge(p0, p2) + assert(evaluator.eval(m1) === InternalRow(false)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(true)) + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala new file mode 100644 index 0000000000000..c01fcb11fee78 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.types.BooleanType + +class EveryTestSuite extends SparkFunSuite { + val input = AttributeReference("input", BooleanType, nullable = true)() + val evaluator = DeclarativeAggregateEvaluator(Every(input), Seq(input)) + + test("empty buffer") { + assert(evaluator.initialize() === InternalRow(true, true)) + } + + test("update") { + val result = evaluator.update( + InternalRow(true), + InternalRow(false), + InternalRow(true)) + assert(result === InternalRow(false, false)) + } + + test("merge") { + // Empty merge + val p0 = evaluator.initialize() + assert(evaluator.merge(p0) === InternalRow(true, true)) + + // Single merge + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.merge(p1) === InternalRow(true, false)) + + // Multiple merges. + val p2 = evaluator.update(InternalRow(true), InternalRow(null)) + assert(evaluator.merge(p1, p2) === InternalRow(false, false)) + + // Empty partitions (p0 is empty) + assert(evaluator.merge(p1, p0, p2) === InternalRow(false, false)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(false, false)) + } + + test("eval") { + // Null Eval + assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false)) + + // Empty Eval + val p0 = evaluator.initialize() + assert(evaluator.eval(p0) === InternalRow(false)) + + // Update - Eval + val p1 = evaluator.update(InternalRow(true), InternalRow(true)) + assert(evaluator.eval(p1) === InternalRow(true)) + + // Update - Merge - Eval + val p2 = evaluator.update(InternalRow(false), InternalRow(true)) + val m1 = evaluator.merge(p1, p0, p2) + assert(evaluator.eval(m1) === InternalRow(false)) + + // Update - Merge - Eval (empty partition at the end) + val m2 = evaluator.merge(p2, p1, p0) + assert(evaluator.eval(m2) === InternalRow(false)) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fa14aa14ee968..0ba0b99d38833 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -232,6 +232,21 @@ class Dataset[T] private[sql]( } } + private[sql] def booleanColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[BooleanType]).map { n => + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver).get + } + } + + private def aggregatableColumns: Seq[Expression] = { + schema.fields + .filter(f => f.dataType.isInstanceOf[NumericType] || f.dataType.isInstanceOf[StringType]) + .map { n => + queryExecution.analyzed.resolveQuoted(n.name, sparkSession.sessionState.analyzer.resolver) + .get + } + } + /** * Get rows represented in Sequence by specific truncate and vertical requirement. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d4e75b5ebd405..91dd5a8af3656 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -21,7 +21,6 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.language.implicitConversions - import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast @@ -32,7 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.sql.types.{BooleanType, NumericType, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -88,7 +87,7 @@ class RelationalGroupedDataset protected[sql]( } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { + : DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. @@ -100,7 +99,28 @@ class RelationalGroupedDataset protected[sql]( if (!namedExpr.dataType.isInstanceOf[NumericType]) { throw new AnalysisException( s""""$colName" is not a numeric column. """ + - "Aggregation function can only be applied on a numeric column.") + "Aggregation function can only be applied on a numeric column.") + } + namedExpr + } + } + toDF(columnExprs.map(expr => f(expr).toAggregateExpression())) + } + + private[this] def aggregateBooleanColumns(colNames: String*)(f: Expression => AggregateFunction) + : DataFrame = { + + val columnExprs = if (colNames.isEmpty) { + // No columns specified. Use all numeric columns. + df.booleanColumns + } else { + // Make sure all specified columns are numeric. + colNames.map { colName => + val namedExpr = df.resolve(colName) + if (!namedExpr.dataType.isInstanceOf[BooleanType]) { + throw new AnalysisException( + s""""$colName" is not a boolean column. """ + + "Aggregation function can only be applied on a boolean column.") } namedExpr } @@ -285,21 +305,57 @@ class RelationalGroupedDataset protected[sql]( } /** - * Compute the sum for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 1.3.0 - */ + * Compute the sum for each numeric columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 1.3.0 + */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } /** - * Pivots a column of the current `DataFrame` and performs the specified aggregation. - * - * There are two versions of `pivot` function: one that requires the caller to specify the list + * Compute the logical and of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ + @scala.annotation.varargs + def every(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(Every) + } + + /** + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ + @scala.annotation.varargs + def any(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(AnyAgg) + } + + /** + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ + @scala.annotation.varargs + def some(colNames: String*): DataFrame = { + aggregateBooleanColumns(colNames : _*)(AnyAgg) + } + + /** + * Pivots a column of the current `DataFrame` and perform the specified aggregation. + * There are two versions of pivot function: one that requires the caller to specify the list +>>>>>>> Add new aggregates EVERY and ANY (SOME). * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 367ac66dd77f5..b1380964056f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -818,13 +818,60 @@ object functions { def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } /** - * Aggregate function: returns the population variance of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) + /** + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.3.0 + */ + def every(e: Column): Column = withAggregateFunction { Every(e.expr) } + + /** + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def every(columnName: String): Column = every(Column(columnName)) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def any(e: Column): Column = withAggregateFunction { AnyAgg(e.expr) } + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def any(columnName: String): Column = any(Column(columnName)) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def some(e: Column): Column = any(e) + + /** + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ + def some(columnName: String): Column = any(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// // Window functions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0106c44b7db2..7e290395cecf1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -727,4 +727,43 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("every") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(every('b)), + Seq(Row(1, false), Row(2, true), Row(3, false))) + } + + test("every empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(every('b)), + Seq(Row(false))) + } + + test("any") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(any('b)), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + + test("any empty table") { + val df = Seq.empty[(Int, Boolean)].toDF("a", "b") + checkAnswer( + df.agg(any('b)), + Seq(Row(false))) + } + + test("some") { + val df = Seq((1, true), (1, true), (1, false), (2, true), (2, true), (3, false), (3, false)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(some('b)), + Seq(Row(1, true), Row(2, true), Row(3, false))) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 78277d7dcf757..a26e970e55e1e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -681,4 +681,36 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("S2", "P2", 300, 300, 500))) } + + test("every") { + val nullStr: String = null + val df = Seq( + ("a", 0, nullStr), + ("a", 1, "x"), + ("a", 2, "y"), + ("a", 3, "z"), + ("a", 4, nullStr), + ("b", 1, nullStr), + ("b", 2, nullStr)). + toDF("key", "order", "value") + val window = Window.partitionBy($"key").orderBy($"order") + checkAnswer( + df.select( + $"key", + $"order", + first($"value").over(window), + first($"value", ignoreNulls = false).over(window), + first($"value", ignoreNulls = true).over(window), + last($"value").over(window), + last($"value", ignoreNulls = false).over(window), + last($"value", ignoreNulls = true).over(window)), + Seq( + Row("a", 0, null, null, null, null, null, null), + Row("a", 1, null, null, "x", "x", "x", "x"), + Row("a", 2, null, null, "x", "y", "y", "y"), + Row("a", 3, null, null, "x", "z", "z", "z"), + Row("a", 4, null, null, "x", null, null, "z"), + Row("b", 1, null, null, null, null, null, null), + Row("b", 2, null, null, null, null, null, null))) + } } From a8cf7e16b6723d6e6831bf781c3fa22278563f64 Mon Sep 17 00:00:00 2001 From: ptkool Date: Tue, 7 Mar 2017 17:59:48 -0500 Subject: [PATCH 02/13] Fix Scala style check errors. --- .../spark/sql/RelationalGroupedDataset.scala | 49 ++++++------- .../org/apache/spark/sql/functions.scala | 70 +++++++++---------- 2 files changed, 60 insertions(+), 59 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 91dd5a8af3656..f98437aecfde7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -21,6 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ import scala.language.implicitConversions + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.python.PythonEvalType import org.apache.spark.broadcast.Broadcast @@ -305,48 +306,48 @@ class RelationalGroupedDataset protected[sql]( } /** - * Compute the sum for each numeric columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 1.3.0 - */ + * Compute the sum for each numeric columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 1.3.0 + */ @scala.annotation.varargs def sum(colNames: String*): DataFrame = { aggregateNumericColumns(colNames : _*)(Sum) } /** - * Compute the logical and of all boolean columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 2.2.0 - */ + * Compute the logical and of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ @scala.annotation.varargs def every(colNames: String*): DataFrame = { aggregateBooleanColumns(colNames : _*)(Every) } /** - * Compute the logical or of all boolean columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 2.2.0 - */ + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ @scala.annotation.varargs def any(colNames: String*): DataFrame = { aggregateBooleanColumns(colNames : _*)(AnyAgg) } /** - * Compute the logical or of all boolean columns for each group. - * The resulting `DataFrame` will also contain the grouping columns. - * When specified columns are given, only compute the sum for them. - * - * @since 2.2.0 - */ + * Compute the logical or of all boolean columns for each group. + * The resulting `DataFrame` will also contain the grouping columns. + * When specified columns are given, only compute the sum for them. + * + * @since 2.2.0 + */ @scala.annotation.varargs def some(colNames: String*): DataFrame = { aggregateBooleanColumns(colNames : _*)(AnyAgg) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index b1380964056f9..f8a8fcf6eaa8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -818,59 +818,59 @@ object functions { def var_pop(e: Column): Column = withAggregateFunction { VariancePop(e.expr) } /** - * Aggregate function: returns the population variance of the values in a group. - * - * @group agg_funcs - * @since 1.6.0 - */ + * Aggregate function: returns the population variance of the values in a group. + * + * @group agg_funcs + * @since 1.6.0 + */ def var_pop(columnName: String): Column = var_pop(Column(columnName)) /** - * Aggregate function: returns true if all values in the expression are true. - * - * @group agg_funcs - * @since 2.3.0 - */ + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.3.0 + */ def every(e: Column): Column = withAggregateFunction { Every(e.expr) } /** - * Aggregate function: returns true if all values in the expression are true. - * - * @group agg_funcs - * @since 2.2.0 - */ + * Aggregate function: returns true if all values in the expression are true. + * + * @group agg_funcs + * @since 2.2.0 + */ def every(columnName: String): Column = every(Column(columnName)) /** - * Aggregate function: returns true if at least one value in the expression is true. - * - * @group agg_funcs - * @since 2.2.0 - */ + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ def any(e: Column): Column = withAggregateFunction { AnyAgg(e.expr) } /** - * Aggregate function: returns true if at least one value in the expression is true. - * - * @group agg_funcs - * @since 2.2.0 - */ + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ def any(columnName: String): Column = any(Column(columnName)) /** - * Aggregate function: returns true if at least one value in the expression is true. - * - * @group agg_funcs - * @since 2.2.0 - */ + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ def some(e: Column): Column = any(e) /** - * Aggregate function: returns true if at least one value in the expression is true. - * - * @group agg_funcs - * @since 2.2.0 - */ + * Aggregate function: returns true if at least one value in the expression is true. + * + * @group agg_funcs + * @since 2.2.0 + */ def some(columnName: String): Column = any(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// From 519a4559b1df71b11969754fc0d7854347adf775 Mon Sep 17 00:00:00 2001 From: ptkool Date: Mon, 13 Mar 2017 09:37:02 -0400 Subject: [PATCH 03/13] Resolved issue with Any aggregate and added window function test. --- .../expressions/aggregate/AnyAgg.scala | 2 +- .../sql/DataFrameWindowFunctionsSuite.scala | 53 ++++++++++--------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala index 8c7f9bcd5f656..869e53ac88fa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala @@ -36,7 +36,7 @@ case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitC override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType) override def checkInputDataTypes(): TypeCheckResult = - TypeUtils.checkForBooleanExpr(child.dataType, "function some") + TypeUtils.checkForBooleanExpr(child.dataType, "function any") private lazy val some = AttributeReference("some", BooleanType)() private lazy val emptySet = AttributeReference("emptySet", BooleanType)() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index a26e970e55e1e..3de789e72d7c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -682,35 +682,36 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { } - test("every") { - val nullStr: String = null - val df = Seq( - ("a", 0, nullStr), - ("a", 1, "x"), - ("a", 2, "y"), - ("a", 3, "z"), - ("a", 4, nullStr), - ("b", 1, nullStr), - ("b", 2, nullStr)). - toDF("key", "order", "value") - val window = Window.partitionBy($"key").orderBy($"order") + test("every/any/some") { + val df = Seq[(String, java.lang.Boolean)]( + ("a", false), + ("a", true), + ("a", false), + ("b", true), + ("b", true), + ("c", false), + ("d", true), + ("d", null) + ).toDF("key", "value") + val window = Window.partitionBy($"key").orderBy(s"value") + .rowsBetween(Long.MinValue, Long.MaxValue) checkAnswer( df.select( $"key", - $"order", - first($"value").over(window), - first($"value", ignoreNulls = false).over(window), - first($"value", ignoreNulls = true).over(window), - last($"value").over(window), - last($"value", ignoreNulls = false).over(window), - last($"value", ignoreNulls = true).over(window)), + $"value", + every($"value").over(window), + any($"value").over(window), + some($"value").over(window)), Seq( - Row("a", 0, null, null, null, null, null, null), - Row("a", 1, null, null, "x", "x", "x", "x"), - Row("a", 2, null, null, "x", "y", "y", "y"), - Row("a", 3, null, null, "x", "z", "z", "z"), - Row("a", 4, null, null, "x", null, null, "z"), - Row("b", 1, null, null, null, null, null, null), - Row("b", 2, null, null, null, null, null, null))) + Row("a", false, false, true, true), + Row("a", false, false, true, true), + Row("a", true, false, true, true), + Row("b", true, true, true, true), + Row("b", true, true, true, true), + Row("c", false, false, false, false), + Row("d", true, false, true, true), + Row("d", null, false, true, true) + ) + ) } } From a929914a25baaf1c60f41e5ba35e937d292681a9 Mon Sep 17 00:00:00 2001 From: ptkool Date: Mon, 13 Mar 2017 13:49:21 -0400 Subject: [PATCH 04/13] Added additional pyspark.sql tests. --- python/pyspark/sql/tests.py | 80 +++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 815772d23ceea..4715702a3b2ed 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1532,6 +1532,24 @@ def test_cov(self): cov = df.stat.cov(u"a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) + def test_every_any(self): + from pyspark.sql import functions + data = [ + Row(key="a", value=False), + Row(key="a", value=True), + Row(key="a", value=False), + Row(key="b", value=True), + Row(key="b", value=True), + Row(key="c", value=False), + Row(key="d", value=True), + Row(key="d", value=None) + ] + df = self.sc.parallelize(data).toDF() + df2 = df.select(functions.every(df.value).alias('a'), + functions.any(df.value).alias('b'), + functions.some(df.value).alias('c')) + self.assertEqual([Row(a=False, b=True, c=True)], df2.collect()) + def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() ct = df.stat.crosstab(u"a", "b").collect() @@ -3938,6 +3956,68 @@ def test_window_functions_cumulative_sum(self): for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) + def test_window_functions_every_any(self): + df = self.spark.createDataFrame([ + ("a", False), + ("a", True), + ("a", False), + ("b", True), + ("b", True), + ("c", False), + ("d", True), + ("d", lit(None)) + ], ["key", "value"]) + w = Window.partitionBy("key").orderBy("value") + from pyspark.sql import functions as F + sel = df.select(df.key, + df.value, + F.every().over(w), + F.any().over(w), + F.some().over(w)) + rs = sel.collect() + expected = [ + ("a", False, False, True, True), + ("a", False, False, True, True), + ("a", True, False, True, True), + ("b", True, True, True, True), + ("b", True, True, True, True), + ("c", False, False, False, False), + ("d", True, False, True, True), + ("d", None, False, True, True) + ] + self.assertEqual(rs, expected) + + def test_window_functions_every_any_without_partitionBy(self): + df = self.spark.createDataFrame([ + ("a", False), + ("a", True), + ("a", False), + ("b", True), + ("b", True), + ("c", False), + ("d", True), + ("d", lit(None)) + ], ["key", "value"]) + w = Window.orderBy("value").rowsBetween(Window.unboundedPreceding, 0) + from pyspark.sql import functions as F + sel = df.select(df.key, + df.value, + F.every().over(w), + F.any().over(w), + F.some().over(w)) + rs = sel.collect() + expected = [ + ("a", False, False, False, False), + ("a", False, False, False, False), + ("a", True, False, True, True), + ("b", True, True, True, True), + ("b", True, True, True, True), + ("c", False, False, False, False), + ("d", True, False, True, True), + ("d", None, False, True, True) + ] + self.assertEqual(rs, expected) + def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql import functions From d65ef4a61f43d8b1da5a7f385b37b02c92d7a635 Mon Sep 17 00:00:00 2001 From: ptkool Date: Mon, 13 Mar 2017 21:34:40 -0400 Subject: [PATCH 05/13] Fix pyspark window function tests. --- python/pyspark/sql/tests.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4715702a3b2ed..0ffba0ac01e4e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3965,15 +3965,15 @@ def test_window_functions_every_any(self): ("b", True), ("c", False), ("d", True), - ("d", lit(None)) + ("d", None) ], ["key", "value"]) w = Window.partitionBy("key").orderBy("value") from pyspark.sql import functions as F sel = df.select(df.key, df.value, - F.every().over(w), - F.any().over(w), - F.some().over(w)) + F.every("value").over(w), + F.any("value").over(w), + F.some("value").over(w)) rs = sel.collect() expected = [ ("a", False, False, True, True), @@ -3985,7 +3985,8 @@ def test_window_functions_every_any(self): ("d", True, False, True, True), ("d", None, False, True, True) ] - self.assertEqual(rs, expected) + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) def test_window_functions_every_any_without_partitionBy(self): df = self.spark.createDataFrame([ @@ -3996,15 +3997,15 @@ def test_window_functions_every_any_without_partitionBy(self): ("b", True), ("c", False), ("d", True), - ("d", lit(None)) + ("d", None) ], ["key", "value"]) w = Window.orderBy("value").rowsBetween(Window.unboundedPreceding, 0) from pyspark.sql import functions as F sel = df.select(df.key, df.value, - F.every().over(w), - F.any().over(w), - F.some().over(w)) + F.every("value").over(w), + F.any("value").over(w), + F.some("value").over(w)) rs = sel.collect() expected = [ ("a", False, False, False, False), @@ -4016,7 +4017,8 @@ def test_window_functions_every_any_without_partitionBy(self): ("d", True, False, True, True), ("d", None, False, True, True) ] - self.assertEqual(rs, expected) + for r, ex in zip(rs, expected): + self.assertEqual(tuple(r), ex[:len(r)]) def test_collect_functions(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) From 3217636cf8b80dc59fdd51dc5b64a05325b74b8d Mon Sep 17 00:00:00 2001 From: ptkool Date: Wed, 15 Mar 2017 09:21:47 -0400 Subject: [PATCH 06/13] Resolve several issues with Pyspark tests. --- python/pyspark/sql/functions.py | 25 +++++++-- python/pyspark/sql/tests.py | 55 ++++++++++--------- .../expressions/aggregate/AnyAgg.scala | 1 + .../expressions/aggregate/Every.scala | 1 + 4 files changed, 52 insertions(+), 30 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 5f57874af217a..ab9aa1e6fa810 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -206,9 +206,6 @@ def _(): 'to_date': 'Converts a string date into a DateType using the (optionally) specified format.', 'to_timestamp': 'Converts a string timestamp into a timestamp type using the ' + '(optionally) specified format.', - 'every': 'Aggregate function: returns true if all values in the expression are true.', - 'any': 'Aggregate function: returns true if at least one value in the expression is true.', - 'some': 'Aggregate function: returns true if at least one value in the expression is true.', } # math functions that take two arguments as input @@ -413,6 +410,27 @@ def countDistinct(col, *cols): jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) return Column(jc) +def every(col): + """Aggregate function: returns true if all values in a group are true. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.every(_to_java_column(col)) + return Column(jc) + + +def any(col): + """Aggregate function: returns true if at least one value in the group is true. + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.any(_to_java_column(col)) + return Column(jc) + + +def some(col): + """Aggregate function: returns true if at least one value in the group is true. + """ + return any(col) + @since(1.3) def first(col, ignorenulls=False): @@ -472,7 +490,6 @@ def grouping_id(*cols): jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) return Column(jc) - @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0ffba0ac01e4e..213beba3139dc 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3967,7 +3967,7 @@ def test_window_functions_every_any(self): ("d", True), ("d", None) ], ["key", "value"]) - w = Window.partitionBy("key").orderBy("value") + w = Window.partitionBy("key").orderBy("value").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) from pyspark.sql import functions as F sel = df.select(df.key, df.value, @@ -3982,40 +3982,43 @@ def test_window_functions_every_any(self): ("b", True, True, True, True), ("b", True, True, True, True), ("c", False, False, False, False), - ("d", True, False, True, True), - ("d", None, False, True, True) + ("d", None, False, True, True), + ("d", True, False, True, True) ] for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) def test_window_functions_every_any_without_partitionBy(self): df = self.spark.createDataFrame([ - ("a", False), - ("a", True), - ("a", False), - ("b", True), - ("b", True), - ("c", False), - ("d", True), - ("d", None) - ], ["key", "value"]) - w = Window.orderBy("value").rowsBetween(Window.unboundedPreceding, 0) + (False,), + (True,), + (False,), + (True,), + (True,), + (False,), + (True,), + (None,) + ], ["value"]) + w1 = Window.orderBy("value").rowsBetween(Window.unboundedPreceding, 0) + w2 = Window.orderBy("value").rowsBetween(-1, 0) from pyspark.sql import functions as F - sel = df.select(df.key, - df.value, - F.every("value").over(w), - F.any("value").over(w), - F.some("value").over(w)) + sel = df.select(df.value, + F.every("value").over(w1), + F.any("value").over(w1), + F.some("value").over(w1), + F.every("value").over(w2), + F.any("value").over(w2), + F.some("value").over(w2)) rs = sel.collect() expected = [ - ("a", False, False, False, False), - ("a", False, False, False, False), - ("a", True, False, True, True), - ("b", True, True, True, True), - ("b", True, True, True, True), - ("c", False, False, False, False), - ("d", True, False, True, True), - ("d", None, False, True, True) + (None, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (False, False, False, False, False, False, False), + (True, False, True, True, False, True, True), + (True, False, True, True, True, True, True), + (True, False, True, True, True, True, True), + (True, False, True, True, True, True, True) ] for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala index 869e53ac88fa1..fd3a492848194 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala @@ -39,6 +39,7 @@ case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitC TypeUtils.checkForBooleanExpr(child.dataType, "function any") private lazy val some = AttributeReference("some", BooleanType)() + private lazy val emptySet = AttributeReference("emptySet", BooleanType)() override lazy val aggBufferAttributes = some :: emptySet :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala index e8004e4c134e7..508fc3dae4939 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala @@ -39,6 +39,7 @@ case class Every(child: Expression) extends DeclarativeAggregate with ImplicitCa TypeUtils.checkForBooleanExpr(child.dataType, "function every") private lazy val every = AttributeReference("every", BooleanType)() + private lazy val emptySet = AttributeReference("emptySet", BooleanType)() override lazy val aggBufferAttributes = every :: emptySet :: Nil From 309d4b64a61427c6b0311b3e9ef98f18df246fc3 Mon Sep 17 00:00:00 2001 From: ptkool Date: Wed, 15 Mar 2017 09:27:53 -0400 Subject: [PATCH 07/13] Resolve Scala style issues. --- .../spark/sql/catalyst/expressions/aggregate/AnyAgg.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala index fd3a492848194..bf599fb23f4a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala @@ -39,7 +39,7 @@ case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitC TypeUtils.checkForBooleanExpr(child.dataType, "function any") private lazy val some = AttributeReference("some", BooleanType)() - + private lazy val emptySet = AttributeReference("emptySet", BooleanType)() override lazy val aggBufferAttributes = some :: emptySet :: Nil From b5e9afbe60167870b92be7a6a83620e43898f24f Mon Sep 17 00:00:00 2001 From: ptkool Date: Wed, 15 Mar 2017 09:42:11 -0400 Subject: [PATCH 08/13] Fix Python style errors. --- python/pyspark/sql/functions.py | 2 ++ python/pyspark/sql/tests.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index ab9aa1e6fa810..6fb1e0532f2ed 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -410,6 +410,7 @@ def countDistinct(col, *cols): jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, cols, _to_java_column)) return Column(jc) + def every(col): """Aggregate function: returns true if all values in a group are true. """ @@ -490,6 +491,7 @@ def grouping_id(*cols): jc = sc._jvm.functions.grouping_id(_to_seq(sc, cols, _to_java_column)) return Column(jc) + @since(1.6) def input_file_name(): """Creates a string column for the file name of the current Spark task. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 213beba3139dc..0c35db931500a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3967,7 +3967,9 @@ def test_window_functions_every_any(self): ("d", True), ("d", None) ], ["key", "value"]) - w = Window.partitionBy("key").orderBy("value").rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) + w = Window \ + .partitionBy("key").orderBy("value") \ + .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) from pyspark.sql import functions as F sel = df.select(df.key, df.value, From d05ca690c4a35c4e8382eee9da7e269955bf38e8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Wed, 8 Aug 2018 15:05:27 -0700 Subject: [PATCH 09/13] generatedoc fix --- .../scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index f98437aecfde7..e2d0cdab6007f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -356,7 +356,6 @@ class RelationalGroupedDataset protected[sql]( /** * Pivots a column of the current `DataFrame` and perform the specified aggregation. * There are two versions of pivot function: one that requires the caller to specify the list ->>>>>>> Add new aggregates EVERY and ANY (SOME). * of distinct values to pivot on, and one that does not. The latter is more concise but less * efficient, because Spark needs to first compute the list of distinct values internally. * From d050193e144f3cc1f1e5ef08b684f72fffdb24f0 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 10 Aug 2018 13:50:55 -0700 Subject: [PATCH 10/13] code review --- python/pyspark/sql/functions.py | 6 ------ .../apache/spark/sql/RelationalGroupedDataset.scala | 12 +++++------- 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 6fb1e0532f2ed..14041dd59bfa4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -202,12 +202,6 @@ def _(): """, } -_functions_2_2 = { - 'to_date': 'Converts a string date into a DateType using the (optionally) specified format.', - 'to_timestamp': 'Converts a string timestamp into a timestamp type using the ' + - '(optionally) specified format.', -} - # math functions that take two arguments as input _binary_mathfunctions = { 'atan2': """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index e2d0cdab6007f..7d2d7144c4360 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -88,8 +88,7 @@ class RelationalGroupedDataset protected[sql]( } private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { - + : DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. df.numericColumns @@ -109,8 +108,7 @@ class RelationalGroupedDataset protected[sql]( } private[this] def aggregateBooleanColumns(colNames: String*)(f: Expression => AggregateFunction) - : DataFrame = { - + : DataFrame = { val columnExprs = if (colNames.isEmpty) { // No columns specified. Use all numeric columns. df.booleanColumns @@ -322,7 +320,7 @@ class RelationalGroupedDataset protected[sql]( * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the sum for them. * - * @since 2.2.0 + * @since 2.4.0 */ @scala.annotation.varargs def every(colNames: String*): DataFrame = { @@ -334,7 +332,7 @@ class RelationalGroupedDataset protected[sql]( * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the sum for them. * - * @since 2.2.0 + * @since 2.4.0 */ @scala.annotation.varargs def any(colNames: String*): DataFrame = { @@ -346,7 +344,7 @@ class RelationalGroupedDataset protected[sql]( * The resulting `DataFrame` will also contain the grouping columns. * When specified columns are given, only compute the sum for them. * - * @since 2.2.0 + * @since 2.4.0 */ @scala.annotation.varargs def some(colNames: String*): DataFrame = { From 291a13d53a62a414020d4b5a9723255626218863 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 10 Aug 2018 14:40:48 -0700 Subject: [PATCH 11/13] python codestyle --- python/pyspark/sql/functions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 14041dd59bfa4..2e5a9b79e89c2 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -265,8 +265,6 @@ def _(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) for _name, _doc in _functions_2_1.items(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) -for _name, _doc in _functions_2_2.items(): - globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) for _name, _doc in _functions_2_4.items(): From b378fffd160a24d337fb3acbd63ce0d4db784afe Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Thu, 4 Oct 2018 22:01:02 -0700 Subject: [PATCH 12/13] Add null filtering logic to the aggregate function along with tests --- .../expressions/aggregate/AnyAgg.scala | 19 ++++++------ .../expressions/aggregate/Every.scala | 19 ++++++------ .../expressions/aggregate/AnyTestSuite.scala | 20 ++++++------ .../aggregate/EveryTestSuite.scala | 19 ++++++------ .../spark/sql/DataFrameAggregateSuite.scala | 31 +++++++++++++++++-- .../sql/DataFrameWindowFunctionsSuite.scala | 4 +-- 6 files changed, 71 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala index bf599fb23f4a2..1eac1e4ca3ffc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyAgg.scala @@ -40,24 +40,25 @@ case class AnyAgg(child: Expression) extends DeclarativeAggregate with ImplicitC private lazy val some = AttributeReference("some", BooleanType)() - private lazy val emptySet = AttributeReference("emptySet", BooleanType)() + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() - override lazy val aggBufferAttributes = some :: emptySet :: Nil + override lazy val aggBufferAttributes = some :: valueSet :: Nil override lazy val initialValues: Seq[Expression] = Seq( - Literal(false), - Literal(true) + /* some = */ Literal.create(false, BooleanType), + /* valueSet = */ Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[Expression] = Seq( - Or(some, Coalesce(Seq(child, Literal(false)))), - Literal(false) + /* some = */ Or(some, If (child.isNull, some, child)), + /* valueSet = */ valueSet || child.isNotNull ) override lazy val mergeExpressions: Seq[Expression] = Seq( - Or(some.left, some.right), - And(emptySet.left, emptySet.right) + /* some = */ Or(some.left, some.right), + /* valueSet */ valueSet.right || valueSet.left ) - override lazy val evaluateExpression: Expression = And(!emptySet, some) + override lazy val evaluateExpression: Expression = + If (valueSet, some, Literal.create(null, BooleanType)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala index 508fc3dae4939..ad0301a3200e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Every.scala @@ -40,24 +40,25 @@ case class Every(child: Expression) extends DeclarativeAggregate with ImplicitCa private lazy val every = AttributeReference("every", BooleanType)() - private lazy val emptySet = AttributeReference("emptySet", BooleanType)() + private lazy val valueSet = AttributeReference("valueSet", BooleanType)() - override lazy val aggBufferAttributes = every :: emptySet :: Nil + override lazy val aggBufferAttributes = every :: valueSet :: Nil override lazy val initialValues: Seq[Expression] = Seq( - Literal(true), - Literal(true) + /* every = */ Literal.create(true, BooleanType), + /* valueSet = */ Literal.create(false, BooleanType) ) override lazy val updateExpressions: Seq[Expression] = Seq( - And(every, Coalesce(Seq(child, Literal(false)))), - Literal(false) + /* every = */ And(every, If (child.isNull, every, child)), + /* valueSet = */ valueSet || child.isNotNull ) override lazy val mergeExpressions: Seq[Expression] = Seq( - And(every.left, every.right), - And(emptySet.left, emptySet.right) + /* every = */ And(every.left, every.right), + /* valueSet */ valueSet.right || valueSet.left ) - override lazy val evaluateExpression: Expression = And(!emptySet, every) + override lazy val evaluateExpression: Expression = + If (valueSet, every, Literal.create(null, BooleanType)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala index b0f5962ac27b2..2f389e608113a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/AnyTestSuite.scala @@ -26,7 +26,7 @@ class AnyTestSuite extends SparkFunSuite { val evaluator = DeclarativeAggregateEvaluator(AnyAgg(input), Seq(input)) test("empty buffer") { - assert(evaluator.initialize() === InternalRow(false, true)) + assert(evaluator.initialize() === InternalRow(false, false)) } test("update") { @@ -34,34 +34,35 @@ class AnyTestSuite extends SparkFunSuite { InternalRow(true), InternalRow(false), InternalRow(true)) - assert(result === InternalRow(true, false)) + assert(result === InternalRow(true, true)) } test("merge") { // Empty merge val p0 = evaluator.initialize() - assert(evaluator.merge(p0) === InternalRow(false, true)) + assert(evaluator.merge(p0) === InternalRow(false, false)) // Single merge val p1 = evaluator.update(InternalRow(true), InternalRow(true)) - assert(evaluator.merge(p1) === InternalRow(true, false)) + assert(evaluator.merge(p1) === InternalRow(true, true)) // Multiple merges. val p2 = evaluator.update(InternalRow(false), InternalRow(null)) - assert(evaluator.merge(p1, p2) === InternalRow(true, false)) + assert(evaluator.merge(p1, p2) === InternalRow(true, true)) // Empty partitions (p0 is empty) - assert(evaluator.merge(p0, p2) === InternalRow(false, false)) - assert(evaluator.merge(p2, p1, p0) === InternalRow(true, false)) + assert(evaluator.merge(p0, p2) === InternalRow(false, true)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(true, true)) } test("eval") { // Null Eval - assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false)) + assert(evaluator.eval(InternalRow(true, false)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(false, false)) === InternalRow(null)) // Empty Eval val p0 = evaluator.initialize() - assert(evaluator.eval(p0) === InternalRow(false)) + assert(evaluator.eval(p0) === InternalRow(null)) // Update - Eval val p1 = evaluator.update(InternalRow(true), InternalRow(null)) @@ -76,5 +77,4 @@ class AnyTestSuite extends SparkFunSuite { val m2 = evaluator.merge(p2, p1, p0) assert(evaluator.eval(m2) === InternalRow(true)) } - } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala index c01fcb11fee78..109eed85e1208 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/EveryTestSuite.scala @@ -26,7 +26,7 @@ class EveryTestSuite extends SparkFunSuite { val evaluator = DeclarativeAggregateEvaluator(Every(input), Seq(input)) test("empty buffer") { - assert(evaluator.initialize() === InternalRow(true, true)) + assert(evaluator.initialize() === InternalRow(true, false)) } test("update") { @@ -34,34 +34,35 @@ class EveryTestSuite extends SparkFunSuite { InternalRow(true), InternalRow(false), InternalRow(true)) - assert(result === InternalRow(false, false)) + assert(result === InternalRow(false, true)) } test("merge") { // Empty merge val p0 = evaluator.initialize() - assert(evaluator.merge(p0) === InternalRow(true, true)) + assert(evaluator.merge(p0) === InternalRow(true, false)) // Single merge val p1 = evaluator.update(InternalRow(true), InternalRow(true)) - assert(evaluator.merge(p1) === InternalRow(true, false)) + assert(evaluator.merge(p1) === InternalRow(true, true)) // Multiple merges. val p2 = evaluator.update(InternalRow(true), InternalRow(null)) - assert(evaluator.merge(p1, p2) === InternalRow(false, false)) + assert(evaluator.merge(p1, p2) === InternalRow(true, true)) // Empty partitions (p0 is empty) - assert(evaluator.merge(p1, p0, p2) === InternalRow(false, false)) - assert(evaluator.merge(p2, p1, p0) === InternalRow(false, false)) + assert(evaluator.merge(p1, p0, p2) === InternalRow(true, true)) + assert(evaluator.merge(p2, p1, p0) === InternalRow(true, true)) } test("eval") { // Null Eval - assert(evaluator.eval(InternalRow(null, true)) === InternalRow(false)) + assert(evaluator.eval(InternalRow(true, false)) === InternalRow(null)) + assert(evaluator.eval(InternalRow(false, false)) === InternalRow(null)) // Empty Eval val p0 = evaluator.initialize() - assert(evaluator.eval(p0) === InternalRow(false)) + assert(evaluator.eval(p0) === InternalRow(null)) // Update - Eval val p1 = evaluator.update(InternalRow(true), InternalRow(true)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 7e290395cecf1..6303933ebb8de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -736,11 +736,23 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, false), Row(2, true), Row(3, false))) } + test("every null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)]( + (1, true), (1, false), + (2, true), + (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(every('b)), + Seq(Row(1, false), Row(2, true), Row(3, false), Row(4, null))) + } + test("every empty table") { val df = Seq.empty[(Int, Boolean)].toDF("a", "b") checkAnswer( df.agg(every('b)), - Seq(Row(false))) + Seq(Row(null))) } test("any") { @@ -755,7 +767,22 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { val df = Seq.empty[(Int, Boolean)].toDF("a", "b") checkAnswer( df.agg(any('b)), - Seq(Row(false))) + Seq(Row(null))) + } + + test("any/some null values") { + val df = Seq[(java.lang.Integer, java.lang.Boolean)] ( + (1, true), (1, false), + (2, true), + (3, true), (3, false), (3, null), + (4, null), (4, null)) + .toDF("a", "b") + checkAnswer( + df.groupBy("a").agg(any('b)), + Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null))) + checkAnswer( + df.groupBy("a").agg(some('b)), + Seq(Row(1, true), Row(2, true), Row(3, true), Row(4, null))) } test("some") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 3de789e72d7c2..7bf270719c3b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -709,8 +709,8 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { Row("b", true, true, true, true), Row("b", true, true, true, true), Row("c", false, false, false, false), - Row("d", true, false, true, true), - Row("d", null, false, true, true) + Row("d", true, true, true, true), + Row("d", null, true, true, true) ) ) } From e1764dfe284fe21f8ece17239cd964db64b0dc92 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 5 Oct 2018 06:05:10 -0700 Subject: [PATCH 13/13] Fix --- python/pyspark/sql/tests.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0c35db931500a..6c07013596ffa 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3984,8 +3984,8 @@ def test_window_functions_every_any(self): ("b", True, True, True, True), ("b", True, True, True, True), ("c", False, False, False, False), - ("d", None, False, True, True), - ("d", True, False, True, True) + ("d", None, True, True, True), + ("d", True, True, True, True) ] for r, ex in zip(rs, expected): self.assertEqual(tuple(r), ex[:len(r)]) @@ -4013,7 +4013,7 @@ def test_window_functions_every_any_without_partitionBy(self): F.some("value").over(w2)) rs = sel.collect() expected = [ - (None, False, False, False, False, False, False), + (None, None, None, None, None, None, None), (False, False, False, False, False, False, False), (False, False, False, False, False, False, False), (False, False, False, False, False, False, False),