diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 6471d15b63ab..bf41ada97916 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -1932,6 +1932,14 @@ object functions { */ def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right) + /** + * Returns the remainder of `dividend``/``divisor`. Its result is always null if `divisor` is 0. + * + * @group math_funcs + * @since 4.0.0 + */ + def try_remainder(left: Column, right: Column): Column = Column.fn("try_remainder", left, right) + /** * Returns `left``*``right` and the result is null on overflow. The acceptable input types are * the same with the `*` operator. diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListenerBus.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListenerBus.scala index b62a89b8417e..56a9e19a1b78 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListenerBus.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListenerBus.scala @@ -121,8 +121,10 @@ class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging { } } catch { case e: Exception => - logWarning("StreamingQueryListenerBus Handler thread received exception, all client" + - " side listeners are removed and handler thread is terminated.", e) + logWarning( + "StreamingQueryListenerBus Handler thread received exception, all client" + + " side listeners are removed and handler thread is terminated.", + e) lock.synchronized { executionThread = Option.empty listeners.forEach(remove(_)) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index c89dba03ed69..b6c4c42edf4a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -451,8 +451,7 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.streaming.RemoteStreamingQuery$"), // Skip client side listener specific class ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.streaming.StreamingQueryListenerBus" - ), + "org.apache.spark.sql.streaming.StreamingQueryListenerBus"), // Encoders are in the wrong JAR ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Encoders"), diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 84416ffd5f83..bf9414675846 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -374,6 +374,7 @@ When ANSI mode is on, it throws exceptions for invalid operations. You can use t - `try_subtract`: identical to the add operator `-`, except that it returns `NULL` result instead of throwing an exception on integral value overflow. - `try_multiply`: identical to the add operator `*`, except that it returns `NULL` result instead of throwing an exception on integral value overflow. - `try_divide`: identical to the division operator `/`, except that it returns `NULL` result instead of throwing an exception on dividing 0. + - `try_remainder`: identical to the remainder operator `%`, except that it returns `NULL` result instead of throwing an exception on dividing 0. - `try_sum`: identical to the function `sum`, except that it returns `NULL` result instead of throwing an exception on integral/decimal/interval value overflow. - `try_avg`: identical to the function `avg`, except that it returns `NULL` result instead of throwing an exception on decimal/interval value overflow. - `try_element_at`: identical to the function `element_at`, except that it returns `NULL` result instead of throwing an exception on array's index out of bound. diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index ab3bcfcba4d3..ea4fa2ba967b 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -934,6 +934,13 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column: try_divide.__doc__ = pysparkfuncs.try_divide.__doc__ +def try_remainder(left: "ColumnOrName", right: "ColumnOrName") -> Column: + return _invoke_function_over_columns("try_remainder", left, right) + + +try_remainder.__doc__ = pysparkfuncs.try_remainder.__doc__ + + def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column: return _invoke_function_over_columns("try_multiply", left, right) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 64f39e352642..bca86d907567 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -638,7 +638,7 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column: | 4 months| +--------------------------------------------------+ - Example 3: Exception druing division, resulting in NULL when ANSI mode is on + Example 3: Exception during division, resulting in NULL when ANSI mode is on >>> import pyspark.sql.functions as sf >>> origin = spark.conf.get("spark.sql.ansi.enabled") @@ -657,6 +657,56 @@ def try_divide(left: "ColumnOrName", right: "ColumnOrName") -> Column: return _invoke_function_over_columns("try_divide", left, right) +@_try_remote_functions +def try_remainder(left: "ColumnOrName", right: "ColumnOrName") -> Column: + """ + Returns the remainder after `dividend`/`divisor`. Its result is + always null if `divisor` is 0. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + left : :class:`~pyspark.sql.Column` or str + dividend + right : :class:`~pyspark.sql.Column` or str + divisor + + Examples + -------- + Example 1: Integer divided by Integer. + + >>> import pyspark.sql.functions as sf + >>> spark.createDataFrame( + ... [(6000, 15), (3, 2), (1234, 0)], ["a", "b"] + ... ).select(sf.try_remainder("a", "b")).show() + +-------------------+ + |try_remainder(a, b)| + +-------------------+ + | 0| + | 1| + | NULL| + +-------------------+ + + Example 2: Exception during division, resulting in NULL when ANSI mode is on + + >>> import pyspark.sql.functions as sf + >>> origin = spark.conf.get("spark.sql.ansi.enabled") + >>> spark.conf.set("spark.sql.ansi.enabled", "true") + >>> try: + ... df = spark.range(1) + ... df.select(sf.try_remainder(df.id, sf.lit(0))).show() + ... finally: + ... spark.conf.set("spark.sql.ansi.enabled", origin) + +--------------------+ + |try_remainder(id, 0)| + +--------------------+ + | NULL| + +--------------------+ + """ + return _invoke_function_over_columns("try_remainder", left, right) + + @_try_remote_functions def try_multiply(left: "ColumnOrName", right: "ColumnOrName") -> Column: """ diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 5b4c0275028b..a9e3adb972e9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -772,8 +772,8 @@ def test_column_accessor(self): sdf.select(sdf.z[0], sdf.z[1], sdf["z"][2]).toPandas(), ) self.assert_eq( - cdf.select(CF.col("z")[0], cdf.z[10], CF.col("z")[-10]).toPandas(), - sdf.select(SF.col("z")[0], sdf.z[10], SF.col("z")[-10]).toPandas(), + cdf.select(CF.col("z")[0], CF.get(cdf.z, 10), CF.get(CF.col("z"), -10)).toPandas(), + sdf.select(SF.col("z")[0], SF.get(sdf.z, 10), SF.get(SF.col("z"), -10)).toPandas(), ) self.assert_eq( cdf.select(cdf.z.getItem(0), cdf.z.getItem(1), cdf["z"].getField(2)).toPandas(), @@ -824,8 +824,12 @@ def test_column_arithmetic_ops(self): ) self.assert_eq( - cdf.select(cdf.a % cdf["b"], cdf["a"] % 2, 12 % cdf.c).toPandas(), - sdf.select(sdf.a % sdf["b"], sdf["a"] % 2, 12 % sdf.c).toPandas(), + cdf.select( + cdf.a % cdf["b"], cdf["a"] % 2, CF.try_remainder(CF.lit(12), cdf.c) + ).toPandas(), + sdf.select( + sdf.a % sdf["b"], sdf["a"] % 2, SF.try_remainder(SF.lit(12), sdf.c) + ).toPandas(), ) self.assert_eq( @@ -1022,13 +1026,9 @@ def test_distributed_sequence_id(self): if __name__ == "__main__": - import os import unittest from pyspark.sql.tests.connect.test_connect_column import * # noqa: F401 - # TODO(SPARK-41794): Enable ANSI mode in this file. - os.environ["SPARK_ANSI_SQL_MODE"] = "false" - try: import xmlrunner diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 6565591b7952..a993462b14e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -451,6 +451,7 @@ object FunctionRegistry { // "try_*" function which always return Null instead of runtime error. expression[TryAdd]("try_add"), expression[TryDivide]("try_divide"), + expression[TryRemainder]("try_remainder"), expression[TrySubtract]("try_subtract"), expression[TryMultiply]("try_multiply"), expression[TryElementAt]("try_element_at"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala index 4eacd3442ed5..05eafe01906a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TryEval.scala @@ -132,6 +132,43 @@ case class TryDivide(left: Expression, right: Expression, replacement: Expressio } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(dividend, divisor) - Returns the remainder after `expr1`/`expr2`. " + + "`dividend` must be a numeric. `divisor` must be a numeric.", + examples = """ + Examples: + > SELECT _FUNC_(3, 2); + 1 + > SELECT _FUNC_(2L, 2L); + 0 + > SELECT _FUNC_(3.0, 2.0); + 1.0 + > SELECT _FUNC_(1, 0); + NULL + """, + since = "4.0.0", + group = "math_funcs") +// scalastyle:on line.size.limit +case class TryRemainder(left: Expression, right: Expression, replacement: Expression) + extends RuntimeReplaceable with InheritAnalysisRules { + def this(left: Expression, right: Expression) = this(left, right, + (left.dataType, right.dataType) match { + case (_: NumericType, _: NumericType) => Remainder(left, right, EvalMode.TRY) + // TODO: support TRY eval mode on datetime arithmetic expressions. + case _ => TryEval(Remainder(left, right, EvalMode.ANSI)) + } + ) + + override def prettyName: String = "try_remainder" + + override def parameters: Seq[Expression] = Seq(left, right) + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(replacement = newChild) + } +} + @ExpressionDescription( usage = "_FUNC_(expr1, expr2) - Returns `expr1`-`expr2` and the result is null on overflow. " + "The acceptable input types are the same with the `-` operator.", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index a085a4e3a8a3..65ab42f859ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -906,6 +906,10 @@ case class Remainder( override def inputType: AbstractDataType = NumericType + // `try_remainder` has exactly the same behavior as the legacy divide, so here it only executes + // the error code path when `evalMode` is `ANSI`. + protected override def failOnError: Boolean = evalMode == EvalMode.ANSI + override def symbol: String = "%" override def decimalMethod: String = "remainder" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala index 780a2692e87f..e082f2e3accc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TryEvalSuite.scala @@ -46,6 +46,19 @@ class TryEvalSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("try_remainder") { + Seq( + (3.0, 2.0, 1.0), + (1.0, 0.0, null), + (-1.0, 0.0, null) + ).foreach { case (a, b, expected) => + val left = Literal(a) + val right = Literal(b) + val input = Remainder(left, right, EvalMode.TRY) + checkEvaluation(input, expected) + } + } + test("try_subtract") { Seq( (1, 1, 0), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e6aed3f7a43d..d1b449bf27aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1937,6 +1937,15 @@ object functions { */ def try_divide(left: Column, right: Column): Column = Column.fn("try_divide", left, right) + /** + * Returns the remainder of `dividend``/``divisor`. Its result is + * always null if `divisor` is 0. + * + * @group math_funcs + * @since 4.0.0 + */ + def try_remainder(left: Column, right: Column): Column = Column.fn("try_remainder", left, right) + /** * Returns `left``*``right` and the result is null on overflow. The acceptable input types are * the same with the `*` operator. diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index dd223939a184..ca864dddf19b 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -349,6 +349,7 @@ | org.apache.spark.sql.catalyst.expressions.TryElementAt | try_element_at | SELECT try_element_at(array(1, 2, 3), 2) | struct | | org.apache.spark.sql.catalyst.expressions.TryMultiply | try_multiply | SELECT try_multiply(2, 3) | struct | | org.apache.spark.sql.catalyst.expressions.TryReflect | try_reflect | SELECT try_reflect('java.util.UUID', 'randomUUID') | struct | +| org.apache.spark.sql.catalyst.expressions.TryRemainder | try_remainder | SELECT try_remainder(3, 2) | struct | | org.apache.spark.sql.catalyst.expressions.TrySubtract | try_subtract | SELECT try_subtract(2, 1) | struct | | org.apache.spark.sql.catalyst.expressions.TryToBinary | try_to_binary | SELECT try_to_binary('abc', 'utf-8') | struct | | org.apache.spark.sql.catalyst.expressions.TryToNumber | try_to_number | SELECT try_to_number('454', '999') | struct | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala index ba04e3b691a1..ac14b345a762 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathFunctionsSuite.scala @@ -707,6 +707,17 @@ class MathFunctionsSuite extends QueryTest with SharedSparkSession { df1.select(try_divide(make_interval(col("year"), col("month")), lit(0)))) } + test("try_remainder") { + val df = Seq((10, 3), (5, 5), (5, 0)).toDF("birth", "age") + checkAnswer(df.selectExpr("try_remainder(birth, age)"), Seq(Row(1), Row(0), Row(null))) + + val dfDecimal = Seq( + (BigDecimal(10), BigDecimal(3)), + (BigDecimal(5), BigDecimal(5)), + (BigDecimal(5), BigDecimal(0))).toDF("birth", "age") + checkAnswer(dfDecimal.selectExpr("try_remainder(birth, age)"), Seq(Row(1), Row(0), Row(null))) + } + test("try_element_at") { val df = Seq((Array(1, 2, 3), 2)).toDF("a", "b") checkAnswer(df.selectExpr("try_element_at(a, b)"), Seq(Row(2)))