From b86e5d2ab1fb17f8dcbb5b4d50f3361494270438 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 18 Sep 2024 07:44:42 -0700 Subject: [PATCH 001/250] [SPARK-49495][DOCS][FOLLOWUP] Enable GitHub Pages settings via .asf.yml ### What changes were proposed in this pull request? A followup of SPARK-49495 to enable GitHub Pages settings via [.asf.yaml](https://cwiki.apache.org/confluence/pages/viewpage.action?spaceKey=INFRA&title=git+-+.asf.yaml+features#Git.asf.yamlfeatures-GitHubPages) ### Why are the changes needed? Meet the requirement for `actions/configure-pagesv5` action ``` Run actions/configure-pagesv5 with: token: *** enablement: false env: SPARK_TESTING: 1 RELEASE_VERSION: In-Progress JAVA_HOME: /opt/hostedtoolcache/Java_Zulu_jdk/17.0.1[2](https://github.com/apache/spark/actions/runs/10916383676/job/30297716064#step:10:2)-7/x64 JAVA_HOME_17_X64: /opt/hostedtoolcache/Java_Zulu_jdk/17.0.12-7/x64 pythonLocation: /opt/hostedtoolcache/Python/[3](https://github.com/apache/spark/actions/runs/10916383676/job/30297716064#step:10:3).9.19/x64 PKG_CONFIG_PATH: /opt/hostedtoolcache/Python/3.9.19/x6[4](https://github.com/apache/spark/actions/runs/10916383676/job/30297716064#step:10:4)/lib/pkgconfig Python_ROOT_DIR: /opt/hostedtoolcache/Python/3.9.19/x[6](https://github.com/apache/spark/actions/runs/10916383676/job/30297716064#step:10:6)4 Python2_ROOT_DIR: /opt/hostedtoolcache/Python/3.9.19/x64 Python3_ROOT_DIR: /opt/hostedtoolcache/Python/3.[9](https://github.com/apache/spark/actions/runs/10916383676/job/30297716064#step:10:9).19/x64 LD_LIBRARY_PATH: /opt/hostedtoolcache/Python/3.9.19/x64/lib Error: Get Pages site failed. Please verify that the repository has Pages enabled and configured to build using GitHub Actions, or consider exploring the `enablement` parameter for this action. Error: Not Found - https://docs.github.com/rest/pages/pages#get-a-apiname-pages-site Error: HttpError: Not Found - https://docs.github.com/rest/pages/pages#get-a-apiname-pages-site ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? NA ### Was this patch authored or co-authored using generative AI tooling? no Closes #48141 from yaooqinn/SPARK-49495-FF. Authored-by: Kent Yao Signed-off-by: Dongjoon Hyun --- .asf.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.asf.yaml b/.asf.yaml index 22042b355b2fa..91a5f9b2bb1a2 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,6 +31,8 @@ github: merge: false squash: true rebase: true + ghp_branch: master + ghp_path: /docs/_site notifications: pullrequests: reviews@spark.apache.org From ed3a9b1aa92957015592b399167a960b68b73beb Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 18 Sep 2024 09:28:09 -0700 Subject: [PATCH 002/250] [SPARK-49691][PYTHON][CONNECT] Function `substring` should accept column names ### What changes were proposed in this pull request? Function `substring` should accept column names ### Why are the changes needed? Bug fix: ``` In [1]: >>> import pyspark.sql.functions as sf ...: >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) ...: >>> df.select('*', sf.substring('s', 'p', 'l')).show() ``` works in PySpark Classic, but fail in Connect with: ``` NumberFormatException Traceback (most recent call last) Cell In[2], line 1 ----> 1 df.select('*', sf.substring('s', 'p', 'l')).show() File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1170, in DataFrame.show(self, n, truncate, vertical) 1169 def show(self, n: int = 20, truncate: Union[bool, int] = True, vertical: bool = False) -> None: -> 1170 print(self._show_string(n, truncate, vertical)) File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:927, in DataFrame._show_string(self, n, truncate, vertical) 910 except ValueError: 911 raise PySparkTypeError( 912 errorClass="NOT_BOOL", 913 messageParameters={ (...) 916 }, 917 ) 919 table, _ = DataFrame( 920 plan.ShowString( 921 child=self._plan, 922 num_rows=n, 923 truncate=_truncate, 924 vertical=vertical, 925 ), 926 session=self._session, --> 927 )._to_table() 928 return table[0][0].as_py() File ~/Dev/spark/python/pyspark/sql/connect/dataframe.py:1844, in DataFrame._to_table(self) 1842 def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]: 1843 query = self._plan.to_proto(self._session.client) -> 1844 table, schema, self._execution_info = self._session.client.to_table( 1845 query, self._plan.observations 1846 ) 1847 assert table is not None 1848 return (table, schema) File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:892, in SparkConnectClient.to_table(self, plan, observations) 890 req = self._execute_plan_request_with_metadata() 891 req.plan.CopyFrom(plan) --> 892 table, schema, metrics, observed_metrics, _ = self._execute_and_fetch(req, observations) 894 # Create a query execution object. 895 ei = ExecutionInfo(metrics, observed_metrics) File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1517, in SparkConnectClient._execute_and_fetch(self, req, observations, self_destruct) 1514 properties: Dict[str, Any] = {} 1516 with Progress(handlers=self._progress_handlers, operation_id=req.operation_id) as progress: -> 1517 for response in self._execute_and_fetch_as_iterator( 1518 req, observations, progress=progress 1519 ): 1520 if isinstance(response, StructType): 1521 schema = response File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1494, in SparkConnectClient._execute_and_fetch_as_iterator(self, req, observations, progress) 1492 raise kb 1493 except Exception as error: -> 1494 self._handle_error(error) File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1764, in SparkConnectClient._handle_error(self, error) 1762 self.thread_local.inside_error_handling = True 1763 if isinstance(error, grpc.RpcError): -> 1764 self._handle_rpc_error(error) 1765 elif isinstance(error, ValueError): 1766 if "Cannot invoke RPC" in str(error) and "closed" in str(error): File ~/Dev/spark/python/pyspark/sql/connect/client/core.py:1840, in SparkConnectClient._handle_rpc_error(self, rpc_error) 1837 if info.metadata["errorClass"] == "INVALID_HANDLE.SESSION_CHANGED": 1838 self._closed = True -> 1840 raise convert_exception( 1841 info, 1842 status.message, 1843 self._fetch_enriched_error(info), 1844 self._display_server_stack_trace(), 1845 ) from None 1847 raise SparkConnectGrpcException(status.message) from None 1848 else: NumberFormatException: [CAST_INVALID_INPUT] The value 'p' of the type "STRING" cannot be cast to "INT" because it is malformed. Correct the value as per the syntax, or change its target type. Use `try_cast` to tolerate malformed input and return NULL instead. SQLSTATE: 22018 ... ``` ### Does this PR introduce _any_ user-facing change? yes, Function `substring` in Connect can properly handle column names ### How was this patch tested? new doctests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48135 from zhengruifeng/py_substring_fix. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- .../pyspark/sql/connect/functions/builtin.py | 10 ++- python/pyspark/sql/functions/builtin.py | 63 ++++++++++++++++--- 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 031e7c22542d2..2870d9c408b6b 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2488,8 +2488,14 @@ def sentences( sentences.__doc__ = pysparkfuncs.sentences.__doc__ -def substring(str: "ColumnOrName", pos: int, len: int) -> Column: - return _invoke_function("substring", _to_col(str), lit(pos), lit(len)) +def substring( + str: "ColumnOrName", + pos: Union["ColumnOrName", int], + len: Union["ColumnOrName", int], +) -> Column: + _pos = lit(pos) if isinstance(pos, int) else _to_col(pos) + _len = lit(len) if isinstance(len, int) else _to_col(len) + return _invoke_function("substring", _to_col(str), _pos, _len) substring.__doc__ = pysparkfuncs.substring.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 781bf3d9f83a2..c0730b193bc72 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11309,7 +11309,9 @@ def sentences( @_try_remote_functions def substring( - str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int] + str: "ColumnOrName", + pos: Union["ColumnOrName", int], + len: Union["ColumnOrName", int], ) -> Column: """ Substring starts at `pos` and is of length `len` when str is String type or @@ -11348,16 +11350,59 @@ def substring( Examples -------- + Example 1: Using literal integers as arguments + + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('abcd',)], ['s',]) - >>> df.select(substring(df.s, 1, 2).alias('s')).collect() - [Row(s='ab')] + >>> df.select('*', sf.substring(df.s, 1, 2)).show() + +----+------------------+ + | s|substring(s, 1, 2)| + +----+------------------+ + |abcd| ab| + +----+------------------+ + + Example 2: Using columns as arguments + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) + >>> df.select('*', sf.substring(df.s, 2, df.l)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, 2, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring(df.s, df.p, 3)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, 3)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring(df.s, df.p, df.l)).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + Example 3: Using column names as arguments + + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) - >>> df.select(substring(df.s, 2, df.l).alias('s')).collect() - [Row(s='par')] - >>> df.select(substring(df.s, df.p, 3).alias('s')).collect() - [Row(s='par')] - >>> df.select(substring(df.s, df.p, df.l).alias('s')).collect() - [Row(s='par')] + >>> df.select('*', sf.substring(df.s, 2, 'l')).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, 2, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ + + >>> df.select('*', sf.substring('s', 'p', 'l')).show() + +-----+---+---+------------------+ + | s| p| l|substring(s, p, l)| + +-----+---+---+------------------+ + |Spark| 2| 3| par| + +-----+---+---+------------------+ """ pos = _enum_to_value(pos) pos = lit(pos) if isinstance(pos, int) else pos From fbf81ebaef49baa4c19a936fb3884c2e62e6a49b Mon Sep 17 00:00:00 2001 From: xuping <13289341606@163.com> Date: Wed, 18 Sep 2024 22:06:00 +0200 Subject: [PATCH 003/250] [SPARK-47263][SQL] Assign names to the legacy conditions _LEGACY_ERROR_TEMP_13[44-46] ### What changes were proposed in this pull request? rename err class _LEGACY_ERROR_TEMP_13[44-46]: 44 removed, 45 to DEFAULT_UNSUPPORTED, 46 to ADD_DEFAULT_UNSUPPORTED ### Why are the changes needed? replace legacy err class name with semantically explicits. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Re run the UT class modified in the PR (org.apache.spark.sql.sources.InsertSuite & org.apache.spark.sql.types.StructTypeSuite) ### Was this patch authored or co-authored using generative AI tooling? No Closes #46320 from PaysonXu/SPARK-47263. Authored-by: xuping <13289341606@163.com> Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 27 +++++++++---------- .../util/ResolveDefaultColumnsUtil.scala | 9 ++++--- .../sql/errors/QueryCompilationErrors.scala | 16 +++-------- .../spark/sql/types/StructTypeSuite.scala | 23 +++++++++++----- .../spark/sql/sources/InsertSuite.scala | 6 ++--- 5 files changed, 41 insertions(+), 40 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 25dd676c4aff9..6463cc2c12da7 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1,4 +1,10 @@ { + "ADD_DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "AGGREGATE_FUNCTION_WITH_NONDETERMINISTIC_EXPRESSION" : { "message" : [ "Non-deterministic expression should not appear in the arguments of an aggregate function." @@ -1096,6 +1102,12 @@ ], "sqlState" : "42608" }, + "DEFAULT_UNSUPPORTED" : { + "message" : [ + "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." + ], + "sqlState" : "42623" + }, "DISTINCT_WINDOW_FUNCTION_UNSUPPORTED" : { "message" : [ "Distinct window functions are not supported: ." @@ -6673,21 +6685,6 @@ "Sinks cannot request distribution and ordering in continuous execution mode." ] }, - "_LEGACY_ERROR_TEMP_1344" : { - "message" : [ - "Invalid DEFAULT value for column : fails to parse as a valid literal value." - ] - }, - "_LEGACY_ERROR_TEMP_1345" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported for target data source with table provider: \"\"." - ] - }, - "_LEGACY_ERROR_TEMP_1346" : { - "message" : [ - "Failed to execute command because DEFAULT values are not supported when adding new columns to previously existing target data source with table provider: \"\"." - ] - }, "_LEGACY_ERROR_TEMP_2000" : { "message" : [ ". If necessary set to false to bypass this error." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 8b7392e71249e..693ac8d94dbcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkThrowable, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.AnalysisException @@ -412,8 +412,11 @@ object ResolveDefaultColumns extends QueryErrorsBase case _: ExprLiteral | _: Cast => expr } } catch { - case _: AnalysisException | _: MatchError => - throw QueryCompilationErrors.failedToParseExistenceDefaultAsLiteral(field.name, text) + // AnalysisException thrown from analyze is already formatted, throw it directly. + case ae: AnalysisException => throw ae + case _: MatchError => + throw SparkException.internalError(s"parse existence default as literal err," + + s" field name: ${field.name}, value: $text") } // The expression should be a literal value by this point, possibly wrapped in a cast // function. This is enforced by the execution of commands that assign default values. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index f268ef85ef1dd..e324d4e9d2edb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3516,29 +3516,21 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "cond" -> toSQLExpr(cond))) } - def failedToParseExistenceDefaultAsLiteral(fieldName: String, defaultValue: String): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1344", - messageParameters = Map( - "fieldName" -> fieldName, - "defaultValue" -> defaultValue)) - } - def defaultReferencesNotAllowedInDataSource( statementType: String, dataSource: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1345", + errorClass = "DEFAULT_UNSUPPORTED", messageParameters = Map( - "statementType" -> statementType, + "statementType" -> toSQLStmt(statementType), "dataSource" -> dataSource)) } def addNewDefaultColumnToExistingTableNotAllowed( statementType: String, dataSource: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1346", + errorClass = "ADD_DEFAULT_UNSUPPORTED", messageParameters = Map( - "statementType" -> statementType, + "statementType" -> toSQLStmt(statementType), "dataSource" -> dataSource)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index 5ec1525bf9b61..6a67525dd02d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -564,7 +564,6 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "1 + 1") .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "1 + 1") .build()))) - val error = "fails to parse as a valid literal value" assert(ResolveDefaultColumns.existenceDefaultValues(source2).length == 1) assert(ResolveDefaultColumns.existenceDefaultValues(source2)(0) == 2) @@ -576,9 +575,13 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { .putString(ResolveDefaultColumns.EXISTS_DEFAULT_COLUMN_METADATA_KEY, "invalid") .putString(ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "invalid") .build()))) - assert(intercept[AnalysisException] { - ResolveDefaultColumns.existenceDefaultValues(source3) - }.getMessage.contains(error)) + + checkError( + exception = intercept[AnalysisException]{ + ResolveDefaultColumns.existenceDefaultValues(source3) + }, + condition = "INVALID_DEFAULT_VALUE.UNRESOLVED_EXPRESSION", + parameters = Map("statement" -> "", "colName" -> "`c1`", "defaultValue" -> "invalid")) // Negative test: StructType.defaultValues fails because the existence default value fails to // resolve. @@ -592,9 +595,15 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY, "(SELECT 'abc' FROM missingtable)") .build()))) - assert(intercept[AnalysisException] { - ResolveDefaultColumns.existenceDefaultValues(source4) - }.getMessage.contains(error)) + + checkError( + exception = intercept[AnalysisException]{ + ResolveDefaultColumns.existenceDefaultValues(source4) + }, + condition = "INVALID_DEFAULT_VALUE.SUBQUERY_EXPRESSION", + parameters = Map("statement" -> "", + "colName" -> "`c1`", + "defaultValue" -> "(SELECT 'abc' FROM missingtable)")) } test("SPARK-46629: Test STRUCT DDL with NOT NULL round trip") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 57655a58a694d..41447d8af5740 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -1998,7 +1998,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql(s"create table t(a string default 'abc') using parquet") }, - condition = "_LEGACY_ERROR_TEMP_1345", + condition = "DEFAULT_UNSUPPORTED", parameters = Map("statementType" -> "CREATE TABLE", "dataSource" -> "parquet")) withTable("t") { sql(s"create table t(a string, b int) using parquet") @@ -2006,7 +2006,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { exception = intercept[AnalysisException] { sql("alter table t add column s bigint default 42") }, - condition = "_LEGACY_ERROR_TEMP_1345", + condition = "DEFAULT_UNSUPPORTED", parameters = Map( "statementType" -> "ALTER TABLE ADD COLUMNS", "dataSource" -> "parquet")) @@ -2314,7 +2314,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { // provider is now in the denylist. sql(s"alter table t1 add column (b string default 'abc')") }, - condition = "_LEGACY_ERROR_TEMP_1346", + condition = "ADD_DEFAULT_UNSUPPORTED", parameters = Map( "statementType" -> "ALTER TABLE ADD COLUMNS", "dataSource" -> provider)) From a6f6e07b70311fb843670b89f6546ae675359feb Mon Sep 17 00:00:00 2001 From: Yuchen Liu Date: Wed, 18 Sep 2024 15:45:17 -0700 Subject: [PATCH 004/250] [SPARK-48939][AVRO] Support reading Avro with recursive schema reference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Continue the discussion from https://github.com/apache/spark/pull/47425 to this PR because I can't push to Yuchen's account ### What changes were proposed in this pull request? The builtin ProtoBuf connector first supports recursive schema reference. It is approached by letting users specify an option “recursive.fields.max.depth”, and at the start of the execution, unroll the recursive field by this level. It converts a problem of dynamic schema for each row to a fixed schema which is supported by Spark. Avro can just adopt a similar method. This PR defines an option "recursiveFieldMaxDepth" to both Avro data source and from_avro function. With this option, Spark can support Avro recursive schema up to certain depth. ### Why are the changes needed? Recursive reference denotes the case that the type of a field can be defined before in the parent nodes. A simple example is: ``` { "type": "record", "name": "LongList", "fields" : [ {"name": "value", "type": "long"}, {"name": "next", "type": ["null", "LongList"]} ] } ``` This is written in Avro Schema DSL and represents a linked list data structure. Spark currently will throw an error on this schema. Many users used schema like this, so we should support it. ### Does this PR introduce any user-facing change? Yes. Previously, it will throw error on recursive schemas like above. With this change, it will still throw the same error by default but when users specify the option to a number greater than 0, the schema will be unrolled to that depth. ### How was this patch tested? Added new unit tests and integration tests to AvroSuite and AvroFunctionSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Co-authored-by: Wei Liu Closes #48043 from WweiL/yuchen-avro-recursive-schema. Lead-authored-by: Yuchen Liu Co-authored-by: Wei Liu Co-authored-by: Yuchen Liu <170372783+eason-yuchen-liu@users.noreply.github.com> Signed-off-by: Gengliang Wang --- .../org/apache/spark/internal/LogKey.scala | 2 + .../spark/sql/avro/AvroDataToCatalyst.scala | 6 +- .../spark/sql/avro/AvroDeserializer.scala | 12 +- .../spark/sql/avro/AvroFileFormat.scala | 3 +- .../apache/spark/sql/avro/AvroOptions.scala | 31 +++ .../org/apache/spark/sql/avro/AvroUtils.scala | 3 +- .../spark/sql/avro/SchemaConverters.scala | 198 +++++++++++----- .../v2/avro/AvroPartitionReaderFactory.scala | 3 +- .../AvroCatalystDataConversionSuite.scala | 3 +- .../spark/sql/avro/AvroFunctionsSuite.scala | 33 ++- .../spark/sql/avro/AvroRowReaderSuite.scala | 3 +- .../spark/sql/avro/AvroSerdeSuite.scala | 3 +- .../org/apache/spark/sql/avro/AvroSuite.scala | 223 +++++++++++++++++- docs/sql-data-sources-avro.md | 45 ++++ .../sql/errors/QueryCompilationErrors.scala | 7 + 15 files changed, 488 insertions(+), 87 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala index a7e4f186000b5..12d456a371d07 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/LogKey.scala @@ -266,6 +266,7 @@ private[spark] object LogKeys { case object FEATURE_NAME extends LogKey case object FETCH_SIZE extends LogKey case object FIELD_NAME extends LogKey + case object FIELD_TYPE extends LogKey case object FILES extends LogKey case object FILE_ABSOLUTE_PATH extends LogKey case object FILE_END_OFFSET extends LogKey @@ -652,6 +653,7 @@ private[spark] object LogKeys { case object RECEIVER_IDS extends LogKey case object RECORDS extends LogKey case object RECOVERY_STATE extends LogKey + case object RECURSIVE_DEPTH extends LogKey case object REDACTED_STATEMENT extends LogKey case object REDUCE_ID extends LogKey case object REGEX extends LogKey diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 7d80998d96eb1..0b85b208242cb 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -42,7 +42,8 @@ private[sql] case class AvroDataToCatalyst( val dt = SchemaConverters.toSqlType( expectedSchema, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType).dataType + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth).dataType parseMode match { // With PermissiveMode, the output Catalyst row might contain columns of null values for // corrupt records, even if some of the columns are not nullable in the user-provided schema. @@ -69,7 +70,8 @@ private[sql] case class AvroDataToCatalyst( dataType, avroOptions.datetimeRebaseModeInRead, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) @transient private var decoder: BinaryDecoder = _ diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 877c3f89e88c0..ac20614553ca2 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -51,14 +51,16 @@ private[sql] class AvroDeserializer( datetimeRebaseSpec: RebaseSpec, filters: StructFilters, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) { def this( rootAvroType: Schema, rootCatalystType: DataType, datetimeRebaseMode: String, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String) = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int) = { this( rootAvroType, rootCatalystType, @@ -66,7 +68,8 @@ private[sql] class AvroDeserializer( RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), new NoopFilters, useStableIdForUnionType, - stableIdPrefixForUnionType) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) } private lazy val decimalConversions = new DecimalConversion() @@ -128,7 +131,8 @@ private[sql] class AvroDeserializer( s"schema is incompatible (avroType = $avroType, sqlType = ${catalystType.sql})" val realDataType = SchemaConverters.toSqlType( - avroType, useStableIdForUnionType, stableIdPrefixForUnionType).dataType + avroType, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth).dataType (avroType.getType, catalystType) match { case (NULL, NullType) => (updater, ordinal, _) => diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala index 372f24b54f5c4..264c3a1f48abe 100755 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroFileFormat.scala @@ -145,7 +145,8 @@ private[sql] class AvroFileFormat extends FileFormat datetimeRebaseMode, avroFilters, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType) + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth) override val stopPosition = file.start + file.length override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index 4332904339f19..e0c6ad3ee69d3 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** @@ -136,6 +137,15 @@ private[sql] class AvroOptions( val stableIdPrefixForUnionType: String = parameters .getOrElse(STABLE_ID_PREFIX_FOR_UNION_TYPE, "member_") + + val recursiveFieldMaxDepth: Int = + parameters.get(RECURSIVE_FIELD_MAX_DEPTH).map(_.toInt).getOrElse(-1) + + if (recursiveFieldMaxDepth > RECURSIVE_FIELD_MAX_DEPTH_LIMIT) { + throw QueryCompilationErrors.avroOptionsException( + RECURSIVE_FIELD_MAX_DEPTH, + s"Should not be greater than $RECURSIVE_FIELD_MAX_DEPTH_LIMIT.") + } } private[sql] object AvroOptions extends DataSourceOptions { @@ -170,4 +180,25 @@ private[sql] object AvroOptions extends DataSourceOptions { // When STABLE_ID_FOR_UNION_TYPE is enabled, the option allows to configure the prefix for fields // of Avro Union type. val STABLE_ID_PREFIX_FOR_UNION_TYPE = newOption("stableIdentifierPrefixForUnionType") + + /** + * Adds support for recursive fields. If this option is not specified or is set to 0, recursive + * fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive + * fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. + * Values larger than 15 are not allowed in order to avoid inadvertently creating very large + * schemas. If an avro message has depth beyond this limit, the Spark struct returned is + * truncated after the recursion limit. + * + * Examples: Consider an Avro schema with a recursive field: + * {"type" : "record", "name" : "Node", "fields" : [{"name": "Id", "type": "int"}, + * {"name": "Next", "type": ["null", "Node"]}]} + * The following lists the parsed schema with different values for this setting. + * 1: `struct` + * 2: `struct>` + * 3: `struct>>` + * and so on. + */ + val RECURSIVE_FIELD_MAX_DEPTH = newOption("recursiveFieldMaxDepth") + + val RECURSIVE_FIELD_MAX_DEPTH_LIMIT: Int = 15 } diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala index 7cbc30f1fb3dc..594ebb4716c41 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroUtils.scala @@ -65,7 +65,8 @@ private[sql] object AvroUtils extends Logging { SchemaConverters.toSqlType( avroSchema, parsedOptions.useStableIdForUnionType, - parsedOptions.stableIdPrefixForUnionType).dataType match { + parsedOptions.stableIdPrefixForUnionType, + parsedOptions.recursiveFieldMaxDepth).dataType match { case t: StructType => Some(t) case _ => throw new RuntimeException( s"""Avro schema cannot be converted to a Spark SQL StructType: diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala index b2285aa966ddb..1168a887abd8e 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/SchemaConverters.scala @@ -27,6 +27,10 @@ import org.apache.avro.LogicalTypes.{Date, Decimal, LocalTimestampMicros, LocalT import org.apache.avro.Schema.Type._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.internal.LogKeys.{FIELD_NAME, FIELD_TYPE, RECURSIVE_DEPTH} +import org.apache.spark.internal.MDC +import org.apache.spark.sql.avro.AvroOptions.RECURSIVE_FIELD_MAX_DEPTH_LIMIT import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.types._ import org.apache.spark.sql.types.Decimal.minBytesForPrecision @@ -36,7 +40,7 @@ import org.apache.spark.sql.types.Decimal.minBytesForPrecision * versa. */ @DeveloperApi -object SchemaConverters { +object SchemaConverters extends Logging { private lazy val nullSchema = Schema.create(Schema.Type.NULL) /** @@ -48,14 +52,27 @@ object SchemaConverters { /** * Converts an Avro schema to a corresponding Spark SQL schema. - * + * + * @param avroSchema The Avro schema to convert. + * @param useStableIdForUnionType If true, Avro schema is deserialized into Spark SQL schema, + * and the Avro Union type is transformed into a structure where + * the field names remain consistent with their respective types. + * @param stableIdPrefixForUnionType The prefix to use to configure the prefix for fields of + * Avro Union type + * @param recursiveFieldMaxDepth The maximum depth to recursively process fields in Avro schema. + * -1 means not supported. * @since 4.0.0 */ def toSqlType( avroSchema: Schema, useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { - toSqlTypeHelper(avroSchema, Set.empty, useStableIdForUnionType, stableIdPrefixForUnionType) + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int = -1): SchemaType = { + val schema = toSqlTypeHelper(avroSchema, Map.empty, useStableIdForUnionType, + stableIdPrefixForUnionType, recursiveFieldMaxDepth) + // the top level record should never return null + assert(schema != null) + schema } /** * Converts an Avro schema to a corresponding Spark SQL schema. @@ -63,17 +80,17 @@ object SchemaConverters { * @since 2.4.0 */ def toSqlType(avroSchema: Schema): SchemaType = { - toSqlType(avroSchema, false, "") + toSqlType(avroSchema, false, "", -1) } @deprecated("using toSqlType(..., useStableIdForUnionType: Boolean) instead", "4.0.0") def toSqlType(avroSchema: Schema, options: Map[String, String]): SchemaType = { val avroOptions = AvroOptions(options) - toSqlTypeHelper( + toSqlType( avroSchema, - Set.empty, avroOptions.useStableIdForUnionType, - avroOptions.stableIdPrefixForUnionType) + avroOptions.stableIdPrefixForUnionType, + avroOptions.recursiveFieldMaxDepth) } // The property specifies Catalyst type of the given field @@ -81,9 +98,10 @@ object SchemaConverters { private def toSqlTypeHelper( avroSchema: Schema, - existingRecordNames: Set[String], + existingRecordNames: Map[String, Int], useStableIdForUnionType: Boolean, - stableIdPrefixForUnionType: String): SchemaType = { + stableIdPrefixForUnionType: String, + recursiveFieldMaxDepth: Int): SchemaType = { avroSchema.getType match { case INT => avroSchema.getLogicalType match { case _: Date => SchemaType(DateType, nullable = false) @@ -128,62 +146,110 @@ object SchemaConverters { case NULL => SchemaType(NullType, nullable = true) case RECORD => - if (existingRecordNames.contains(avroSchema.getFullName)) { + val recursiveDepth: Int = existingRecordNames.getOrElse(avroSchema.getFullName, 0) + if (recursiveDepth > 0 && recursiveFieldMaxDepth <= 0) { throw new IncompatibleSchemaException(s""" - |Found recursive reference in Avro schema, which can not be processed by Spark: - |${avroSchema.toString(true)} + |Found recursive reference in Avro schema, which can not be processed by Spark by + | default: ${avroSchema.toString(true)}. Try setting the option `recursiveFieldMaxDepth` + | to 1 - $RECURSIVE_FIELD_MAX_DEPTH_LIMIT. """.stripMargin) - } - val newRecordNames = existingRecordNames + avroSchema.getFullName - val fields = avroSchema.getFields.asScala.map { f => - val schemaType = toSqlTypeHelper( - f.schema(), - newRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType) - StructField(f.name, schemaType.dataType, schemaType.nullable) - } + } else if (recursiveDepth > 0 && recursiveDepth >= recursiveFieldMaxDepth) { + logInfo( + log"The field ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} is dropped at recursive depth " + + log"${MDC(RECURSIVE_DEPTH, recursiveDepth)}." + ) + null + } else { + val newRecordNames = + existingRecordNames + (avroSchema.getFullName -> (recursiveDepth + 1)) + val fields = avroSchema.getFields.asScala.map { f => + val schemaType = toSqlTypeHelper( + f.schema(), + newRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null + } + else { + StructField(f.name, schemaType.dataType, schemaType.nullable) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) + } case ARRAY => val schemaType = toSqlTypeHelper( avroSchema.getElementType, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - SchemaType( - ArrayType(schemaType.dataType, containsNull = schemaType.nullable), - nullable = false) + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + ArrayType(schemaType.dataType, containsNull = schemaType.nullable), + nullable = false) + } case MAP => val schemaType = toSqlTypeHelper(avroSchema.getValueType, - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) - SchemaType( - MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), - nullable = false) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null + } else { + SchemaType( + MapType(StringType, schemaType.dataType, valueContainsNull = schemaType.nullable), + nullable = false) + } case UNION => if (avroSchema.getTypes.asScala.exists(_.getType == NULL)) { // In case of a union with null, eliminate it and make a recursive call val remainingUnionTypes = AvroUtils.nonNullUnionBranches(avroSchema) - if (remainingUnionTypes.size == 1) { - toSqlTypeHelper( - remainingUnionTypes.head, - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + val remainingSchema = + if (remainingUnionTypes.size == 1) { + remainingUnionTypes.head + } else { + Schema.createUnion(remainingUnionTypes.asJava) + } + val schemaType = toSqlTypeHelper( + remainingSchema, + existingRecordNames, + useStableIdForUnionType, + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + + if (schemaType == null) { + logInfo( + log"Dropping ${MDC(FIELD_NAME, avroSchema.getFullName)} of type " + + log"${MDC(FIELD_TYPE, avroSchema.getType.getName)} as it does not have any " + + log"fields left likely due to recursive depth limit." + ) + null } else { - toSqlTypeHelper( - Schema.createUnion(remainingUnionTypes.asJava), - existingRecordNames, - useStableIdForUnionType, - stableIdPrefixForUnionType).copy(nullable = true) + schemaType.copy(nullable = true) } } else avroSchema.getTypes.asScala.map(_.getType).toSeq match { case Seq(t1) => toSqlTypeHelper(avroSchema.getTypes.get(0), - existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType) + existingRecordNames, useStableIdForUnionType, stableIdPrefixForUnionType, + recursiveFieldMaxDepth) case Seq(t1, t2) if Set(t1, t2) == Set(INT, LONG) => SchemaType(LongType, nullable = false) case Seq(t1, t2) if Set(t1, t2) == Set(FLOAT, DOUBLE) => @@ -201,29 +267,33 @@ object SchemaConverters { s, existingRecordNames, useStableIdForUnionType, - stableIdPrefixForUnionType) - - val fieldName = if (useStableIdForUnionType) { - // Avro's field name may be case sensitive, so field names for two named type - // could be "a" and "A" and we need to distinguish them. In this case, we throw - // an exception. - // Stable id prefix can be empty so the name of the field can be just the type. - val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" - if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { - throw new IncompatibleSchemaException( - "Cannot generate stable identifier for Avro union type due to name " + - s"conflict of type name ${s.getName}") - } - tempFieldName + stableIdPrefixForUnionType, + recursiveFieldMaxDepth) + if (schemaType == null) { + null } else { - s"member$i" - } + val fieldName = if (useStableIdForUnionType) { + // Avro's field name may be case sensitive, so field names for two named type + // could be "a" and "A" and we need to distinguish them. In this case, we throw + // an exception. + // Stable id prefix can be empty so the name of the field can be just the type. + val tempFieldName = s"${stableIdPrefixForUnionType}${s.getName}" + if (!fieldNameSet.add(tempFieldName.toLowerCase(Locale.ROOT))) { + throw new IncompatibleSchemaException( + "Cannot generate stable identifier for Avro union type due to name " + + s"conflict of type name ${s.getName}") + } + tempFieldName + } else { + s"member$i" + } - // All fields are nullable because only one of them is set at a time - StructField(fieldName, schemaType.dataType, nullable = true) - } + // All fields are nullable because only one of them is set at a time + StructField(fieldName, schemaType.dataType, nullable = true) + } + }.filter(_ != null).toSeq - SchemaType(StructType(fields.toArray), nullable = false) + SchemaType(StructType(fields), nullable = false) } case other => throw new IncompatibleSchemaException(s"Unsupported type $other") diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index 1083c99160724..a13faf3b51560 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -105,7 +105,8 @@ case class AvroPartitionReaderFactory( datetimeRebaseMode, avroFilters, options.useStableIdForUnionType, - options.stableIdPrefixForUnionType) + options.stableIdPrefixForUnionType, + options.recursiveFieldMaxDepth) override val stopPosition = partitionedFile.start + partitionedFile.length override def next(): Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala index 388347537a4d6..311eda3a1b6ae 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroCatalystDataConversionSuite.scala @@ -291,7 +291,8 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite RebaseSpec(LegacyBehaviorPolicy.CORRECTED), filters, false, - "") + "", + -1) val deserialized = deserializer.deserialize(data) expected match { case None => assert(deserialized == None) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index 47faaf7662a50..a7f7abadcf485 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{col, lit, struct} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} class AvroFunctionsSuite extends QueryTest with SharedSparkSession { import testImplicits._ @@ -374,6 +374,37 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { } } + + test("roundtrip in to_avro and from_avro - recursive schema") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4))), Row(1, null))), + catalystSchema).select(struct("Id", "Name").as("struct")) + + val avroStructDF = df.select(functions.to_avro($"struct", avroSchema).as("avro")) + checkAnswer(avroStructDF.select( + functions.from_avro($"avro", avroSchema, Map( + "recursiveFieldMaxDepth" -> "3").asJava)), df) + } + private def serialize(record: GenericRecord, avroSchema: String): Array[Byte] = { val schema = new Schema.Parser().parse(avroSchema) val datumWriter = new GenericDatumWriter[GenericRecord](schema) diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 9b3bb929a700d..c1ab96a63eb26 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -77,7 +77,8 @@ class AvroRowReaderSuite RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) override val stopPosition = fileSize override def hasNext: Boolean = hasNextRow diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala index cbcbc2e7e76a6..3643a95abe19c 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSerdeSuite.scala @@ -228,7 +228,8 @@ object AvroSerdeSuite { RebaseSpec(CORRECTED), new NoopFilters, false, - "") + "", + -1) } /** diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 14ed6c43e4c0f..be887bd5237b0 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2220,7 +2220,8 @@ abstract class AvroSuite } } - private def checkSchemaWithRecursiveLoop(avroSchema: String): Unit = { + private def checkSchemaWithRecursiveLoop(avroSchema: String, recursiveFieldMaxDepth: Int): + Unit = { val message = intercept[IncompatibleSchemaException] { SchemaConverters.toSqlType(new Schema.Parser().parse(avroSchema), false, "") }.getMessage @@ -2229,7 +2230,79 @@ abstract class AvroSuite } test("Detect recursive loop") { - checkSchemaWithRecursiveLoop(""" + for (recursiveFieldMaxDepth <- Seq(-1, 0)) { + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, // each element has a long + | {"name": "next", "type": ["null", "LongList"]} // optional next element + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields": [ + | { + | "name": "value", + | "type": { + | "type": "record", + | "name": "foo", + | "fields": [ + | { + | "name": "parent", + | "type": "LongList" + | } + | ] + | } + | } + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "array", "type": {"type": "array", "items": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + + checkSchemaWithRecursiveLoop( + """ + |{ + | "type": "record", + | "name": "LongList", + | "fields" : [ + | {"name": "value", "type": "long"}, + | {"name": "map", "type": {"type": "map", "values": "LongList"}} + | ] + |} + """.stripMargin, recursiveFieldMaxDepth) + } + } + + private def checkSparkSchemaEquals( + avroSchema: String, expectedSchema: StructType, recursiveFieldMaxDepth: Int): Unit = { + val sparkSchema = + SchemaConverters.toSqlType( + new Schema.Parser().parse(avroSchema), false, "", recursiveFieldMaxDepth).dataType + + assert(sparkSchema === expectedSchema) + } + + test("Translate recursive schema - union") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2238,9 +2311,57 @@ abstract class AvroSuite | {"name": "next", "type": ["null", "LongList"]} // optional next element | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("next", expectedSchema) + } + } + + test("Translate recursive schema - union - 2 non-null fields") { + val avroSchema = """ + |{ + | "type": "record", + | "name": "TreeNode", + | "fields": [ + | { + | "name": "name", + | "type": "string" + | }, + | { + | "name": "value", + | "type": [ + | "long" + | ] + | }, + | { + | "name": "children", + | "type": [ + | "null", + | { + | "type": "array", + | "items": "TreeNode" + | } + | ], + | "default": null + | } + | ] + |} + """.stripMargin + val nonRecursiveFields = new StructType().add("name", StringType, nullable = false) + .add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = nonRecursiveFields.add("children", + new ArrayType(expectedSchema, false), nullable = true) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - record") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2260,9 +2381,18 @@ abstract class AvroSuite | } | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", StructType(Seq()), nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = new StructType().add("value", + new StructType().add("parent", expectedSchema, nullable = false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - array") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2271,9 +2401,18 @@ abstract class AvroSuite | {"name": "array", "type": {"type": "array", "items": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("array", new ArrayType(expectedSchema, false), nullable = false) + } + } - checkSchemaWithRecursiveLoop(""" + test("Translate recursive schema - map") { + val avroSchema = """ |{ | "type": "record", | "name": "LongList", @@ -2282,7 +2421,70 @@ abstract class AvroSuite | {"name": "map", "type": {"type": "map", "values": "LongList"}} | ] |} - """.stripMargin) + """.stripMargin + val nonRecursiveFields = new StructType().add("value", LongType, nullable = false) + var expectedSchema = nonRecursiveFields + for (i <- 1 to 5) { + checkSparkSchemaEquals(avroSchema, expectedSchema, i) + expectedSchema = + nonRecursiveFields.add("map", + new MapType(StringType, expectedSchema, false), nullable = false) + } + } + + test("recursive schema integration test") { + val catalystSchema = + StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", StructType(Seq( + StructField("Id", IntegerType), + StructField("Name", NullType))))))))) + + val avroSchema = s""" + |{ + | "type" : "record", + | "name" : "test_schema", + | "fields" : [ + | {"name": "Id", "type": "int"}, + | {"name": "Name", "type": ["null", "test_schema"]} + | ] + |} + """.stripMargin + + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(2, Row(3, Row(4, null))), Row(1, null))), + catalystSchema) + + withTempPath { tempDir => + df.write.format("avro").save(tempDir.getPath) + + val exc = intercept[AnalysisException] { + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 16) + .load(tempDir.getPath) + } + assert(exc.getMessage.contains("Should not be greater than 15.")) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 10) + .load(tempDir.getPath), + df) + + checkAnswer( + spark.read + .format("avro") + .option("avroSchema", avroSchema) + .option("recursiveFieldMaxDepth", 1) + .load(tempDir.getPath), + df.select("Id")) + } } test("log a warning of ignoreExtension deprecation") { @@ -2777,7 +2979,7 @@ abstract class AvroSuite } test("SPARK-40667: validate Avro Options") { - assert(AvroOptions.getAllOptions.size == 11) + assert(AvroOptions.getAllOptions.size == 12) // Please add validation on any new Avro options here assert(AvroOptions.isValidOption("ignoreExtension")) assert(AvroOptions.isValidOption("mode")) @@ -2790,6 +2992,7 @@ abstract class AvroSuite assert(AvroOptions.isValidOption("datetimeRebaseMode")) assert(AvroOptions.isValidOption("enableStableIdentifiersForUnionType")) assert(AvroOptions.isValidOption("stableIdentifierPrefixForUnionType")) + assert(AvroOptions.isValidOption("recursiveFieldMaxDepth")) } test("SPARK-46633: read file with empty blocks") { diff --git a/docs/sql-data-sources-avro.md b/docs/sql-data-sources-avro.md index 3721f92d93266..c06e1fd46d2da 100644 --- a/docs/sql-data-sources-avro.md +++ b/docs/sql-data-sources-avro.md @@ -353,6 +353,13 @@ Data source options of Avro can be set via: read 4.0.0 + + recursiveFieldMaxDepth + -1 + If this option is specified to negative or is set to 0, recursive fields are not permitted. Setting it to 1 drops all recursive fields, 2 allows recursive fields to be recursed once, and 3 allows it to be recursed twice and so on, up to 15. Values larger than 15 are not allowed in order to avoid inadvertently creating very large schemas. If an avro message has depth beyond this limit, the Spark struct returned is truncated after the recursion limit. An example of usage can be found in section Handling circular references of Avro fields + read + 4.0.0 + ## Configuration @@ -628,3 +635,41 @@ You can also specify the whole output Avro schema with the option `avroSchema`, decimal + +## Handling circular references of Avro fields +In Avro, a circular reference occurs when the type of a field is defined in one of the parent records. This can cause issues when parsing the data, as it can result in infinite loops or other unexpected behavior. +To read Avro data with schema that has circular reference, users can use the `recursiveFieldMaxDepth` option to specify the maximum number of levels of recursion to allow when parsing the schema. By default, Spark Avro data source will not permit recursive fields by setting `recursiveFieldMaxDepth` to -1. However, you can set this option to 1 to 15 if needed. + +Setting `recursiveFieldMaxDepth` to 1 drops all recursive fields, setting it to 2 allows it to be recursed once, and setting it to 3 allows it to be recursed twice. A `recursiveFieldMaxDepth` value greater than 15 is not allowed, as it can lead to performance issues and even stack overflows. + +SQL Schema for the below Avro message will vary based on the value of `recursiveFieldMaxDepth`. + +
+
+This div is only used to make markdown editor/viewer happy and does not display on web + +```avro +
+ +{% highlight avro %} +{ + "type": "record", + "name": "Node", + "fields": [ + {"name": "Id", "type": "int"}, + {"name": "Next", "type": ["null", "Node"]} + ] +} + +// The Avro schema defined above, would be converted into a Spark SQL columns with the following +// structure based on `recursiveFieldMaxDepth` value. + +1: struct +2: struct> +3: struct>> + +{% endhighlight %} +
+``` +
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e324d4e9d2edb..ad0e1d07bf93d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -4090,6 +4090,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def avroOptionsException(optionName: String, message: String): Throwable = { + new AnalysisException( + errorClass = "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", + messageParameters = Map("optionName" -> optionName, "message" -> message) + ) + } + def protobufNotLoadedSqlFunctionsUnusable(functionName: String): Throwable = { new AnalysisException( errorClass = "PROTOBUF_NOT_LOADED_SQL_FUNCTIONS_UNUSABLE", From 25d6b7a280f690c1a467f65143115cce846a732a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 19 Sep 2024 07:46:18 +0800 Subject: [PATCH 005/250] [SPARK-49692][PYTHON][CONNECT] Refine the string representation of literal date and datetime ### What changes were proposed in this pull request? Refine the string representation of literal date and datetime ### Why are the changes needed? 1, we should not represent those literals with internal values; 2, the string representation should be consistent with PySpark Classic if possible (we cannot make sure the representations are always the same because we only hold an unresolved expression in connect, but we can try our best to do so) ### Does this PR introduce _any_ user-facing change? yes before: ``` In [3]: lit(datetime.date(2024, 7, 10)) Out[3]: Column<'19914'> In [4]: lit(datetime.datetime(2024, 7, 10, 1, 2, 3, 456)) Out[4]: Column<'1720544523000456'> ``` after: ``` In [3]: lit(datetime.date(2024, 7, 10)) Out[3]: Column<'2024-07-10'> In [4]: lit(datetime.datetime(2024, 7, 10, 1, 2, 3, 456)) Out[4]: Column<'2024-07-10 01:02:03.000456'> ``` ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48137 from zhengruifeng/py_connect_lit_dt. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/connect/expressions.py | 16 ++++++++++++++-- python/pyspark/sql/tests/test_column.py | 9 +++++++++ 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index db1cd1c013be5..63128ef48e389 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -477,8 +477,20 @@ def to_plan(self, session: "SparkConnectClient") -> "proto.Expression": def __repr__(self) -> str: if self._value is None: return "NULL" - else: - return f"{self._value}" + elif isinstance(self._dataType, DateType): + dt = DateType().fromInternal(self._value) + if dt is not None and isinstance(dt, datetime.date): + return dt.strftime("%Y-%m-%d") + elif isinstance(self._dataType, TimestampType): + ts = TimestampType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + elif isinstance(self._dataType, TimestampNTZType): + ts = TimestampNTZType().fromInternal(self._value) + if ts is not None and isinstance(ts, datetime.datetime): + return ts.strftime("%Y-%m-%d %H:%M:%S.%f") + # TODO(SPARK-49693): Refine the string representation of timedelta + return f"{self._value}" class ColumnReference(Expression): diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 2bd66baaa2bfe..220ecd387f7ee 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -18,6 +18,8 @@ from enum import Enum from itertools import chain +import datetime + from pyspark.sql import Column, Row from pyspark.sql import functions as sf from pyspark.sql.types import StructType, StructField, IntegerType, LongType @@ -280,6 +282,13 @@ def test_expr_str_representation(self): when_cond = sf.when(expression, sf.lit(None)) self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") + def test_lit_time_representation(self): + dt = datetime.date(2021, 3, 4) + self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>") + + ts = datetime.datetime(2021, 3, 4, 12, 34, 56, 1234) + self.assertEqual(str(sf.lit(ts)), "Column<'2021-03-04 12:34:56.001234'>") + def test_enum_literals(self): class IntEnum(Enum): X = 1 From 669e63a34012404d8d864cd6294f799b672f6f9a Mon Sep 17 00:00:00 2001 From: Robert Dillitz Date: Thu, 19 Sep 2024 08:54:20 +0900 Subject: [PATCH 006/250] [SPARK-49673][CONNECT] Increase CONNECT_GRPC_ARROW_MAX_BATCH_SIZE to 0.7 * CONNECT_GRPC_MAX_MESSAGE_SIZE ### What changes were proposed in this pull request? Increases the default `maxBatchSize` from 4MiB * 0.7 to 128MiB (= CONNECT_GRPC_MAX_MESSAGE_SIZE) * 0.7. This makes better use of the allowed maximum message size. This limit is used when creating Arrow batches for the `SqlCommandResult` in the `SparkConnectPlanner` and for `ExecutePlanResponse.ArrowBatch` in `processAsArrowBatches`. This, for example, lets us return much larger `LocalRelations` in the `SqlCommandResult` (i.e., for the `SHOW PARTITIONS` command) while still staying within the GRPC message size limit. ### Why are the changes needed? There are `SqlCommandResults` that exceed 0.7 * 4MiB. ### Does this PR introduce _any_ user-facing change? Now support `SqlCommandResults` <= 0.7 * 128 MiB instead of only <= 0.7 * 4MiB and ExecutePlanResponses will now better use the limit of 128MiB. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48122 from dillitz/increase-sql-command-batch-size. Authored-by: Robert Dillitz Signed-off-by: Hyukjin Kwon --- .../apache/spark/sql/ClientE2ETestSuite.scala | 23 +++++++++++++++++-- .../spark/sql/test/RemoteSparkSession.scala | 2 ++ .../spark/sql/connect/config/Connect.scala | 2 +- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 52cdbd47357f3..b47231948dc98 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} -import scala.concurrent.duration.DurationInt +import scala.concurrent.duration.{DurationInt, FiniteDuration} import scala.jdk.CollectionConverters._ import org.apache.commons.io.FileUtils @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.catalyst.parser.ParseException -import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult} +import org.apache.spark.sql.connect.client.{RetryPolicy, SparkConnectClient, SparkResult} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.test.{ConnectFunSuite, IntegrationTestUtils, RemoteSparkSession, SQLHelper} @@ -1566,6 +1566,25 @@ class ClientE2ETestSuite val result = df.select(trim(col("col"), " ").as("trimmed_col")).collect() assert(result sameElements Array(Row("a"), Row("b"), Row("c"))) } + + test("SPARK-49673: new batch size, multiple batches") { + val maxBatchSize = spark.conf.get("spark.connect.grpc.arrow.maxBatchSize").dropRight(1).toInt + // Adjust client grpcMaxMessageSize to maxBatchSize (10MiB; set in RemoteSparkSession config) + val sparkWithLowerMaxMessageSize = SparkSession + .builder() + .client( + SparkConnectClient + .builder() + .userId("test") + .port(port) + .grpcMaxMessageSize(maxBatchSize) + .retryPolicy(RetryPolicy + .defaultPolicy() + .copy(maxRetries = Some(10), maxBackoff = Some(FiniteDuration(30, "s")))) + .build()) + .create() + assert(sparkWithLowerMaxMessageSize.range(maxBatchSize).collect().length == maxBatchSize) + } } private[sql] case class ClassData(a: String, b: Int) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala index e0de73e496d95..36aaa2cc7fbf6 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/test/RemoteSparkSession.scala @@ -124,6 +124,8 @@ object SparkConnectServerUtils { // to make the tests exercise reattach. "spark.connect.execute.reattachable.senderMaxStreamDuration=1s", "spark.connect.execute.reattachable.senderMaxStreamSize=123", + // Testing SPARK-49673, setting maxBatchSize to 10MiB + s"spark.connect.grpc.arrow.maxBatchSize=${10 * 1024 * 1024}", // Disable UI "spark.ui.enabled=false") Seq("--jars", catalystTestJar) ++ confs.flatMap(v => "--conf" :: v :: Nil) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 92709ff29a1ca..b64637f7d2472 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -63,7 +63,7 @@ object Connect { "conservatively use 70% of it because the size is not accurate but estimated.") .version("3.4.0") .bytesConf(ByteUnit.BYTE) - .createWithDefault(4 * 1024 * 1024) + .createWithDefault(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE) val CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE = buildStaticConf("spark.connect.grpc.maxInboundMessageSize") From 5c48806a2941070e23a81b4e7e4f3225fe341535 Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Thu, 19 Sep 2024 09:08:59 +0900 Subject: [PATCH 007/250] [SPARK-49688][CONNECT][TESTS] Fix a sporadic `SparkConnectServiceSuite` failure ### What changes were proposed in this pull request? Add a short wait loop to ensure that the test pre-condition is met. To be specific, VerifyEvents.executeHolder is set asynchronously by MockSparkListener.onOtherEvent whereas the test assumes that VerifyEvents.executeHolder is always available. ### Why are the changes needed? For smoother development experience. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? SparkConnectServiceSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48142 from changgyoopark-db/SPARK-49688. Authored-by: Changgyoo Park Signed-off-by: Hyukjin Kwon --- .../planner/SparkConnectServiceSuite.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 579fdb47aef3c..62146f19328a8 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -871,10 +871,16 @@ class SparkConnectServiceSuite class VerifyEvents(val sparkContext: SparkContext) { val listener: MockSparkListener = new MockSparkListener() val listenerBus = sparkContext.listenerBus + val EVENT_WAIT_TIMEOUT = timeout(10.seconds) val LISTENER_BUS_TIMEOUT = 30000 def executeHolder: ExecuteHolder = { - assert(listener.executeHolder.isDefined) - listener.executeHolder.get + // An ExecuteHolder shall be set eventually through MockSparkListener + Eventually.eventually(EVENT_WAIT_TIMEOUT) { + assert( + listener.executeHolder.isDefined, + s"No events have been posted in $EVENT_WAIT_TIMEOUT") + listener.executeHolder.get + } } def onNext(v: proto.ExecutePlanResponse): Unit = { if (v.hasSchema) { @@ -891,8 +897,10 @@ class SparkConnectServiceSuite def onCompleted(producedRowCount: Option[Long] = None): Unit = { assert(executeHolder.eventsManager.getProducedRowCount == producedRowCount) // The eventsManager is closed asynchronously - Eventually.eventually(timeout(1.seconds)) { - assert(executeHolder.eventsManager.status == ExecuteStatus.Closed) + Eventually.eventually(EVENT_WAIT_TIMEOUT) { + assert( + executeHolder.eventsManager.status == ExecuteStatus.Closed, + s"Execution has not been completed in $EVENT_WAIT_TIMEOUT") } } def onCanceled(): Unit = { From db8010b4c8be6f1c50f35cbde3efa44cd5d45adf Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 18 Sep 2024 20:10:18 -0400 Subject: [PATCH 008/250] [SPARK-49568][CONNECT][SQL] Remove self type from Dataset ### What changes were proposed in this pull request? This PR removes the self type parameter from Dataset. This turned out to be a bit noisy. The self type is replaced by a combination of covariant return types and abstract types. Abstract types are used when a method takes a Dataset (or a KeyValueGroupedDataset) as an argument. ### Why are the changes needed? The self type made using the classes in sql/api a bit noisy. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48146 from hvanhovell/SPARK-49568. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../spark/sql/DataFrameNaFunctions.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 5 +- .../spark/sql/DataFrameStatFunctions.scala | 3 +- .../scala/org/apache/spark/sql/Dataset.scala | 5 +- .../spark/sql/KeyValueGroupedDataset.scala | 4 +- .../spark/sql/RelationalGroupedDataset.scala | 4 +- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../apache/spark/sql/catalog/Catalog.scala | 3 +- .../sql/connect/ConnectConversions.scala | 51 +++ .../spark/sql/streaming/StreamingQuery.scala | 4 +- .../CheckConnectJvmClientCompatibility.scala | 1 + project/MimaExcludes.scala | 2 + project/SparkBuild.scala | 1 + .../org/apache/spark/sql/api/Catalog.scala | 58 ++-- .../spark/sql/api/DataFrameNaFunctions.scala | 65 ++-- .../spark/sql/api/DataFrameReader.scala | 51 +-- .../sql/api/DataFrameStatFunctions.scala | 22 +- .../org/apache/spark/sql/api/Dataset.scala | 299 +++++++++--------- .../sql/api/KeyValueGroupedDataset.scala | 109 +++---- .../sql/api/RelationalGroupedDataset.scala | 44 ++- .../apache/spark/sql/api/SparkSession.scala | 40 +-- .../apache/spark/sql/api/StreamingQuery.scala | 4 +- .../org/apache/spark/sql/functions.scala | 2 +- .../spark/sql/DataFrameNaFunctions.scala | 3 +- .../apache/spark/sql/DataFrameReader.scala | 5 +- .../spark/sql/DataFrameStatFunctions.scala | 3 +- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../spark/sql/KeyValueGroupedDataset.scala | 3 +- .../spark/sql/RelationalGroupedDataset.scala | 5 +- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../apache/spark/sql/catalog/Catalog.scala | 3 +- .../sql/classic/ClassicConversions.scala | 50 +++ .../spark/sql/streaming/StreamingQuery.scala | 4 +- 33 files changed, 500 insertions(+), 364 deletions(-) create mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index c06cbbc0cdb42..3777f82594aae 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -22,6 +22,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{NAReplace, Relation} import org.apache.spark.connect.proto.Expression.{Literal => GLiteral} import org.apache.spark.connect.proto.NAReplace.Replacement +import org.apache.spark.sql.connect.ConnectConversions._ /** * Functionality for working with missing data in `DataFrame`s. @@ -29,7 +30,7 @@ import org.apache.spark.connect.proto.NAReplace.Replacement * @since 3.4.0 */ final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation) - extends api.DataFrameNaFunctions[Dataset] { + extends api.DataFrameNaFunctions { import sparkSession.RichColumn override protected def drop(minNonNulls: Option[Int]): Dataset[Row] = diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c3ee7030424eb..60bacd4e18ede 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.connect.proto.Parse.ParseFormat +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.StructType @@ -33,8 +34,8 @@ import org.apache.spark.sql.types.StructType * @since 3.4.0 */ @Stable -class DataFrameReader private[sql] (sparkSession: SparkSession) - extends api.DataFrameReader[Dataset] { +class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader { + type DS[U] = Dataset[U] /** @inheritdoc */ override def format(source: String): this.type = super.format(source) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 9f5ada0d7ec35..bb7cfa75a9ab9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -22,6 +22,7 @@ import java.{lang => jl, util => ju} import org.apache.spark.connect.proto.{Relation, StatSampleBy} import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.functions.lit /** @@ -30,7 +31,7 @@ import org.apache.spark.sql.functions.lit * @since 3.4.0 */ final class DataFrameStatFunctions private[sql] (protected val df: DataFrame) - extends api.DataFrameStatFunctions[Dataset] { + extends api.DataFrameStatFunctions { private def root: Relation = df.plan.getRoot private val sparkSession: SparkSession = df.sparkSession diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 519193ebd9c74..161a0d9d265f0 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.OrderUtils +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkResult import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter} import org.apache.spark.sql.errors.DataTypeErrors.toSQLId @@ -134,8 +135,8 @@ class Dataset[T] private[sql] ( val sparkSession: SparkSession, @DeveloperApi val plan: proto.Plan, val encoder: Encoder[T]) - extends api.Dataset[T, Dataset] { - type RGD = RelationalGroupedDataset + extends api.Dataset[T] { + type DS[U] = Dataset[U] import sparkSession.RichColumn diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index aef7efb08a254..6bf2518901470 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -26,6 +26,7 @@ import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.SparkUserDefinedFunction import org.apache.spark.sql.functions.col @@ -40,8 +41,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * * @since 3.5.0 */ -class KeyValueGroupedDataset[K, V] private[sql] () - extends api.KeyValueGroupedDataset[K, V, Dataset] { +class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] { type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] private def unsupported(): Nothing = throw new UnsupportedOperationException() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index ea13635fc2eaa..14ceb3f4bb144 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.ConnectConversions._ /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -39,8 +40,7 @@ class RelationalGroupedDataset private[sql] ( groupType: proto.Aggregate.GroupType, pivot: Option[proto.Aggregate.Pivot] = None, groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None) - extends api.RelationalGroupedDataset[Dataset] { - type RGD = RelationalGroupedDataset + extends api.RelationalGroupedDataset { import df.sparkSession.RichColumn protected def toDF(aggExprs: Seq[Column]): DataFrame = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index aa6258a14b811..04f8eeb5c6d46 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -69,7 +69,7 @@ import org.apache.spark.util.ArrayImplicits._ class SparkSession private[sql] ( private[sql] val client: SparkConnectClient, private val planIdGenerator: AtomicLong) - extends api.SparkSession[Dataset] + extends api.SparkSession with Logging { private[this] val allocator = new RootAllocator() diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 11a4a044d20e5..86b1dbe4754e6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.catalog import java.util import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.types.StructType /** @inheritdoc */ -abstract class Catalog extends api.Catalog[Dataset] { +abstract class Catalog extends api.Catalog { /** @inheritdoc */ override def listDatabases(): Dataset[Database] diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala new file mode 100644 index 0000000000000..7d81f4ead7857 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql._ + +/** + * Conversions from sql interfaces to the Connect specific implementation. + * + * This class is mainly used by the implementation. In the case of connect it should be extremely + * rare that a developer needs these classes. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They + * can create a shim for these conversions, the Spark 4+ version of the shim implements this + * trait, and shims for older versions do not. + */ +@DeveloperApi +trait ConnectConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V]( + kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = + kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] +} + +object ConnectConversions extends ConnectConversions diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 3b47269875f4a..29fbcc443deb9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -26,10 +26,10 @@ import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.connect.proto.StreamingQueryCommand import org.apache.spark.connect.proto.StreamingQueryCommandResult import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance -import org.apache.spark.sql.{api, Dataset, SparkSession} +import org.apache.spark.sql.{api, SparkSession} /** @inheritdoc */ -trait StreamingQuery extends api.StreamingQuery[Dataset] { +trait StreamingQuery extends api.StreamingQuery { /** @inheritdoc */ override def sparkSession: SparkSession diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index abf03cfbc6722..16f6983efb187 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -158,6 +158,7 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.columnar.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.jdbc.*"), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dfe7b14e2ec66..ece4504395f12 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -201,6 +201,8 @@ object MimaExcludes { ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.classic.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connect.*"), // DSv2 catalog and expression APIs are unstable yet. We should enable this back. ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4a8214b2e20a3..d93a52985b772 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1352,6 +1352,7 @@ trait SharedUnidocSettings { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalyst"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/connect/"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/classic/"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/catalog/v2/utils"))) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala index fbb665b7f1b1f..a0f51d30dc572 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Catalog.scala @@ -33,7 +33,7 @@ import org.apache.spark.storage.StorageLevel * @since 2.0.0 */ @Stable -abstract class Catalog[DS[U] <: Dataset[U, DS]] { +abstract class Catalog { /** * Returns the current database (namespace) in this session. @@ -54,7 +54,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * * @since 2.0.0 */ - def listDatabases(): DS[Database] + def listDatabases(): Dataset[Database] /** * Returns a list of databases (namespaces) which name match the specify pattern and available @@ -62,7 +62,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * * @since 3.5.0 */ - def listDatabases(pattern: String): DS[Database] + def listDatabases(pattern: String): Dataset[Database] /** * Returns a list of tables/views in the current database (namespace). This includes all @@ -70,7 +70,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * * @since 2.0.0 */ - def listTables(): DS[Table] + def listTables(): Dataset[Table] /** * Returns a list of tables/views in the specified database (namespace) (the name can be @@ -79,7 +79,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @throws[AnalysisException]("database does not exist") - def listTables(dbName: String): DS[Table] + def listTables(dbName: String): Dataset[Table] /** * Returns a list of tables/views in the specified database (namespace) which name match the @@ -88,7 +88,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 3.5.0 */ @throws[AnalysisException]("database does not exist") - def listTables(dbName: String, pattern: String): DS[Table] + def listTables(dbName: String, pattern: String): Dataset[Table] /** * Returns a list of functions registered in the current database (namespace). This includes all @@ -96,7 +96,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * * @since 2.0.0 */ - def listFunctions(): DS[Function] + def listFunctions(): Dataset[Function] /** * Returns a list of functions registered in the specified database (namespace) (the name can be @@ -105,7 +105,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String): DS[Function] + def listFunctions(dbName: String): Dataset[Function] /** * Returns a list of functions registered in the specified database (namespace) which name match @@ -115,7 +115,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 3.5.0 */ @throws[AnalysisException]("database does not exist") - def listFunctions(dbName: String, pattern: String): DS[Function] + def listFunctions(dbName: String, pattern: String): Dataset[Function] /** * Returns a list of columns for the given table/view or temporary view. @@ -127,7 +127,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @throws[AnalysisException]("table does not exist") - def listColumns(tableName: String): DS[Column] + def listColumns(tableName: String): Dataset[Column] /** * Returns a list of columns for the given table/view in the specified database under the Hive @@ -143,7 +143,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @throws[AnalysisException]("database or table does not exist") - def listColumns(dbName: String, tableName: String): DS[Column] + def listColumns(dbName: String, tableName: String): Dataset[Column] /** * Get the database (namespace) with the specified name (can be qualified with catalog). This @@ -280,7 +280,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String): DS[Row] = { + def createExternalTable(tableName: String, path: String): Dataset[Row] = { createTable(tableName, path) } @@ -293,7 +293,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * identifier is provided, it refers to a table in the current database. * @since 2.2.0 */ - def createTable(tableName: String, path: String): DS[Row] + def createTable(tableName: String, path: String): Dataset[Row] /** * Creates a table from the given path based on a data source and returns the corresponding @@ -305,7 +305,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @deprecated("use createTable instead.", "2.2.0") - def createExternalTable(tableName: String, path: String, source: String): DS[Row] = { + def createExternalTable(tableName: String, path: String, source: String): Dataset[Row] = { createTable(tableName, path, source) } @@ -318,7 +318,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * identifier is provided, it refers to a table in the current database. * @since 2.2.0 */ - def createTable(tableName: String, path: String, source: String): DS[Row] + def createTable(tableName: String, path: String, source: String): Dataset[Row] /** * Creates a table from the given path based on a data source and a set of options. Then, @@ -333,7 +333,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { def createExternalTable( tableName: String, source: String, - options: util.Map[String, String]): DS[Row] = { + options: util.Map[String, String]): Dataset[Row] = { createTable(tableName, source, options) } @@ -349,7 +349,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { def createTable( tableName: String, source: String, - options: util.Map[String, String]): DS[Row] = { + options: util.Map[String, String]): Dataset[Row] = { createTable(tableName, source, options.asScala.toMap) } @@ -366,7 +366,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { def createExternalTable( tableName: String, source: String, - options: Map[String, String]): DS[Row] = { + options: Map[String, String]): Dataset[Row] = { createTable(tableName, source, options) } @@ -379,7 +379,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * identifier is provided, it refers to a table in the current database. * @since 2.2.0 */ - def createTable(tableName: String, source: String, options: Map[String, String]): DS[Row] + def createTable(tableName: String, source: String, options: Map[String, String]): Dataset[Row] /** * Create a table from the given path based on a data source, a schema and a set of options. @@ -395,7 +395,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { tableName: String, source: String, schema: StructType, - options: util.Map[String, String]): DS[Row] = { + options: util.Map[String, String]): Dataset[Row] = { createTable(tableName, source, schema, options) } @@ -412,7 +412,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { tableName: String, source: String, description: String, - options: util.Map[String, String]): DS[Row] = { + options: util.Map[String, String]): Dataset[Row] = { createTable( tableName, source = source, @@ -433,7 +433,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { tableName: String, source: String, description: String, - options: Map[String, String]): DS[Row] + options: Map[String, String]): Dataset[Row] /** * Create a table based on the dataset in a data source, a schema and a set of options. Then, @@ -448,7 +448,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { tableName: String, source: String, schema: StructType, - options: util.Map[String, String]): DS[Row] = { + options: util.Map[String, String]): Dataset[Row] = { createTable(tableName, source, schema, options.asScala.toMap) } @@ -466,7 +466,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { tableName: String, source: String, schema: StructType, - options: Map[String, String]): DS[Row] = { + options: Map[String, String]): Dataset[Row] = { createTable(tableName, source, schema, options) } @@ -483,7 +483,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { tableName: String, source: String, schema: StructType, - options: Map[String, String]): DS[Row] + options: Map[String, String]): Dataset[Row] /** * Create a table based on the dataset in a data source, a schema and a set of options. Then, @@ -499,7 +499,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { source: String, schema: StructType, description: String, - options: util.Map[String, String]): DS[Row] = { + options: util.Map[String, String]): Dataset[Row] = { createTable( tableName, source = source, @@ -522,7 +522,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { source: String, schema: StructType, description: String, - options: Map[String, String]): DS[Row] + options: Map[String, String]): Dataset[Row] /** * Drops the local temporary view with the given view name in the catalog. If the view has been @@ -670,7 +670,7 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * * @since 3.4.0 */ - def listCatalogs(): DS[CatalogMetadata] + def listCatalogs(): Dataset[CatalogMetadata] /** * Returns a list of catalogs which name match the specify pattern and available in this @@ -678,5 +678,5 @@ abstract class Catalog[DS[U] <: Dataset[U, DS]] { * * @since 3.5.0 */ - def listCatalogs(pattern: String): DS[CatalogMetadata] + def listCatalogs(pattern: String): Dataset[CatalogMetadata] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala index 12d3d41aa5546..ef6cc64c058a4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameNaFunctions.scala @@ -30,14 +30,14 @@ import org.apache.spark.util.ArrayImplicits._ * @since 1.3.1 */ @Stable -abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { +abstract class DataFrameNaFunctions { /** * Returns a new `DataFrame` that drops rows containing any null or NaN values. * * @since 1.3.1 */ - def drop(): DS[Row] = drop("any") + def drop(): Dataset[Row] = drop("any") /** * Returns a new `DataFrame` that drops rows containing null or NaN values. @@ -47,7 +47,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(how: String): DS[Row] = drop(toMinNonNulls(how)) + def drop(how: String): Dataset[Row] = drop(toMinNonNulls(how)) /** * Returns a new `DataFrame` that drops rows containing any null or NaN values in the specified @@ -55,7 +55,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(cols: Array[String]): DS[Row] = drop(cols.toImmutableArraySeq) + def drop(cols: Array[String]): Dataset[Row] = drop(cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that drops rows containing any null or NaN values @@ -63,7 +63,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(cols: Seq[String]): DS[Row] = drop(cols.size, cols) + def drop(cols: Seq[String]): Dataset[Row] = drop(cols.size, cols) /** * Returns a new `DataFrame` that drops rows containing null or NaN values in the specified @@ -74,7 +74,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(how: String, cols: Array[String]): DS[Row] = drop(how, cols.toImmutableArraySeq) + def drop(how: String, cols: Array[String]): Dataset[Row] = drop(how, cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that drops rows containing null or NaN values in @@ -85,7 +85,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(how: String, cols: Seq[String]): DS[Row] = drop(toMinNonNulls(how), cols) + def drop(how: String, cols: Seq[String]): Dataset[Row] = drop(toMinNonNulls(how), cols) /** * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and @@ -93,7 +93,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(minNonNulls: Int): DS[Row] = drop(Option(minNonNulls)) + def drop(minNonNulls: Int): Dataset[Row] = drop(Option(minNonNulls)) /** * Returns a new `DataFrame` that drops rows containing less than `minNonNulls` non-null and @@ -101,7 +101,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Array[String]): DS[Row] = + def drop(minNonNulls: Int, cols: Array[String]): Dataset[Row] = drop(minNonNulls, cols.toImmutableArraySeq) /** @@ -110,7 +110,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def drop(minNonNulls: Int, cols: Seq[String]): DS[Row] = drop(Option(minNonNulls), cols) + def drop(minNonNulls: Int, cols: Seq[String]): Dataset[Row] = drop(Option(minNonNulls), cols) private def toMinNonNulls(how: String): Option[Int] = { how.toLowerCase(util.Locale.ROOT) match { @@ -120,29 +120,29 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { } } - protected def drop(minNonNulls: Option[Int]): DS[Row] + protected def drop(minNonNulls: Option[Int]): Dataset[Row] - protected def drop(minNonNulls: Option[Int], cols: Seq[String]): DS[Row] + protected def drop(minNonNulls: Option[Int], cols: Seq[String]): Dataset[Row] /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * * @since 2.2.0 */ - def fill(value: Long): DS[Row] + def fill(value: Long): Dataset[Row] /** * Returns a new `DataFrame` that replaces null or NaN values in numeric columns with `value`. * @since 1.3.1 */ - def fill(value: Double): DS[Row] + def fill(value: Double): Dataset[Row] /** * Returns a new `DataFrame` that replaces null values in string columns with `value`. * * @since 1.3.1 */ - def fill(value: String): DS[Row] + def fill(value: String): Dataset[Row] /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a @@ -150,7 +150,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 2.2.0 */ - def fill(value: Long, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Long, cols: Array[String]): Dataset[Row] = fill(value, cols.toImmutableArraySeq) /** * Returns a new `DataFrame` that replaces null or NaN values in specified numeric columns. If a @@ -158,7 +158,8 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(value: Double, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Double, cols: Array[String]): Dataset[Row] = + fill(value, cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -166,7 +167,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 2.2.0 */ - def fill(value: Long, cols: Seq[String]): DS[Row] + def fill(value: Long, cols: Seq[String]): Dataset[Row] /** * (Scala-specific) Returns a new `DataFrame` that replaces null or NaN values in specified @@ -174,7 +175,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(value: Double, cols: Seq[String]): DS[Row] + def fill(value: Double, cols: Seq[String]): Dataset[Row] /** * Returns a new `DataFrame` that replaces null values in specified string columns. If a @@ -182,7 +183,8 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(value: String, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: String, cols: Array[String]): Dataset[Row] = + fill(value, cols.toImmutableArraySeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified string @@ -190,14 +192,14 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(value: String, cols: Seq[String]): DS[Row] + def fill(value: String, cols: Seq[String]): Dataset[Row] /** * Returns a new `DataFrame` that replaces null values in boolean columns with `value`. * * @since 2.3.0 */ - def fill(value: Boolean): DS[Row] + def fill(value: Boolean): Dataset[Row] /** * (Scala-specific) Returns a new `DataFrame` that replaces null values in specified boolean @@ -205,7 +207,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Seq[String]): DS[Row] + def fill(value: Boolean, cols: Seq[String]): Dataset[Row] /** * Returns a new `DataFrame` that replaces null values in specified boolean columns. If a @@ -213,7 +215,8 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 2.3.0 */ - def fill(value: Boolean, cols: Array[String]): DS[Row] = fill(value, cols.toImmutableArraySeq) + def fill(value: Boolean, cols: Array[String]): Dataset[Row] = + fill(value, cols.toImmutableArraySeq) /** * Returns a new `DataFrame` that replaces null values. @@ -231,7 +234,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(valueMap: util.Map[String, Any]): DS[Row] = fillMap(valueMap.asScala.toSeq) + def fill(valueMap: util.Map[String, Any]): Dataset[Row] = fillMap(valueMap.asScala.toSeq) /** * (Scala-specific) Returns a new `DataFrame` that replaces null values. @@ -251,9 +254,9 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def fill(valueMap: Map[String, Any]): DS[Row] = fillMap(valueMap.toSeq) + def fill(valueMap: Map[String, Any]): Dataset[Row] = fillMap(valueMap.toSeq) - protected def fillMap(values: Seq[(String, Any)]): DS[Row] + protected def fillMap(values: Seq[(String, Any)]): Dataset[Row] /** * Replaces values matching keys in `replacement` map with the corresponding values. @@ -280,7 +283,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def replace[T](col: String, replacement: util.Map[T, T]): DS[Row] = { + def replace[T](col: String, replacement: util.Map[T, T]): Dataset[Row] = { replace[T](col, replacement.asScala.toMap) } @@ -306,7 +309,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def replace[T](cols: Array[String], replacement: util.Map[T, T]): DS[Row] = { + def replace[T](cols: Array[String], replacement: util.Map[T, T]): Dataset[Row] = { replace(cols.toImmutableArraySeq, replacement.asScala.toMap) } @@ -333,7 +336,7 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def replace[T](col: String, replacement: Map[T, T]): DS[Row] + def replace[T](col: String, replacement: Map[T, T]): Dataset[Row] /** * (Scala-specific) Replaces values matching keys in `replacement` map. @@ -355,5 +358,5 @@ abstract class DataFrameNaFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.3.1 */ - def replace[T](cols: Seq[String], replacement: Map[T, T]): DS[Row] + def replace[T](cols: Seq[String], replacement: Map[T, T]): Dataset[Row] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala index 6e6ab7b9d95a4..c101c52fd0662 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.types.StructType * @since 1.4.0 */ @Stable -abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { +abstract class DataFrameReader { + type DS[U] <: Dataset[U] /** * Specifies the input data source format. @@ -149,7 +150,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def load(): DS[Row] + def load(): Dataset[Row] /** * Loads input in as a `DataFrame`, for data sources that require a path (e.g. data backed by a @@ -157,7 +158,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def load(path: String): DS[Row] + def load(path: String): Dataset[Row] /** * Loads input in as a `DataFrame`, for data sources that support multiple paths. Only works if @@ -166,7 +167,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 1.6.0 */ @scala.annotation.varargs - def load(paths: String*): DS[Row] + def load(paths: String*): Dataset[Row] /** * Construct a `DataFrame` representing the database table accessible via JDBC URL url named @@ -179,7 +180,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def jdbc(url: String, table: String, properties: util.Properties): DS[Row] = { + def jdbc(url: String, table: String, properties: util.Properties): Dataset[Row] = { assertNoSpecifiedSchema("jdbc") // properties should override settings in extraOptions. this.extraOptions ++= properties.asScala @@ -223,7 +224,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { lowerBound: Long, upperBound: Long, numPartitions: Int, - connectionProperties: util.Properties): DS[Row] = { + connectionProperties: util.Properties): Dataset[Row] = { // columnName, lowerBound, upperBound and numPartitions override settings in extraOptions. this.extraOptions ++= Map( "partitionColumn" -> columnName, @@ -260,7 +261,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { url: String, table: String, predicates: Array[String], - connectionProperties: util.Properties): DS[Row] + connectionProperties: util.Properties): Dataset[Row] /** * Loads a JSON file and returns the results as a `DataFrame`. @@ -269,7 +270,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def json(path: String): DS[Row] = { + def json(path: String): Dataset[Row] = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 json(Seq(path): _*) } @@ -290,7 +291,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @scala.annotation.varargs - def json(paths: String*): DS[Row] = { + def json(paths: String*): Dataset[Row] = { validateJsonSchema() format("json").load(paths: _*) } @@ -306,7 +307,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * input Dataset with one JSON object per record * @since 2.2.0 */ - def json(jsonDataset: DS[String]): DS[Row] + def json(jsonDataset: DS[String]): Dataset[Row] /** * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other @@ -314,7 +315,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 2.0.0 */ - def csv(path: String): DS[Row] = { + def csv(path: String): Dataset[Row] = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 csv(Seq(path): _*) } @@ -340,7 +341,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * input Dataset with one CSV row per record * @since 2.2.0 */ - def csv(csvDataset: DS[String]): DS[Row] + def csv(csvDataset: DS[String]): Dataset[Row] /** * Loads CSV files and returns the result as a `DataFrame`. @@ -356,7 +357,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @scala.annotation.varargs - def csv(paths: String*): DS[Row] = format("csv").load(paths: _*) + def csv(paths: String*): Dataset[Row] = format("csv").load(paths: _*) /** * Loads a XML file and returns the result as a `DataFrame`. See the documentation on the other @@ -364,7 +365,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 4.0.0 */ - def xml(path: String): DS[Row] = { + def xml(path: String): Dataset[Row] = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 xml(Seq(path): _*) } @@ -383,7 +384,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 4.0.0 */ @scala.annotation.varargs - def xml(paths: String*): DS[Row] = { + def xml(paths: String*): Dataset[Row] = { validateXmlSchema() format("xml").load(paths: _*) } @@ -398,7 +399,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * input Dataset with one XML object per record * @since 4.0.0 */ - def xml(xmlDataset: DS[String]): DS[Row] + def xml(xmlDataset: DS[String]): Dataset[Row] /** * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation on the @@ -406,7 +407,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 2.0.0 */ - def parquet(path: String): DS[Row] = { + def parquet(path: String): Dataset[Row] = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 parquet(Seq(path): _*) } @@ -421,7 +422,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 1.4.0 */ @scala.annotation.varargs - def parquet(paths: String*): DS[Row] = format("parquet").load(paths: _*) + def parquet(paths: String*): Dataset[Row] = format("parquet").load(paths: _*) /** * Loads an ORC file and returns the result as a `DataFrame`. @@ -430,7 +431,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * input path * @since 1.5.0 */ - def orc(path: String): DS[Row] = { + def orc(path: String): Dataset[Row] = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 orc(Seq(path): _*) } @@ -447,7 +448,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @scala.annotation.varargs - def orc(paths: String*): DS[Row] = format("orc").load(paths: _*) + def orc(paths: String*): Dataset[Row] = format("orc").load(paths: _*) /** * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch @@ -462,7 +463,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * database. Note that, the global temporary view database is also valid here. * @since 1.4.0 */ - def table(tableName: String): DS[Row] + def table(tableName: String): Dataset[Row] /** * Loads text files and returns a `DataFrame` whose schema starts with a string column named @@ -471,7 +472,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * * @since 2.0.0 */ - def text(path: String): DS[Row] = { + def text(path: String): Dataset[Row] = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 text(Seq(path): _*) } @@ -499,14 +500,14 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 1.6.0 */ @scala.annotation.varargs - def text(paths: String*): DS[Row] = format("text").load(paths: _*) + def text(paths: String*): Dataset[Row] = format("text").load(paths: _*) /** * Loads text files and returns a [[Dataset]] of String. See the documentation on the other * overloaded `textFile()` method for more details. * @since 2.0.0 */ - def textFile(path: String): DS[String] = { + def textFile(path: String): Dataset[String] = { // This method ensures that calls that explicit need single argument works, see SPARK-16009 textFile(Seq(path): _*) } @@ -534,7 +535,7 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] { * @since 2.0.0 */ @scala.annotation.varargs - def textFile(paths: String*): DS[String] = { + def textFile(paths: String*): Dataset[String] = { assertNoSpecifiedSchema("textFile") text(paths: _*).select("value").as(StringEncoder) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala index fc1680231be5b..ae7c256b30ace 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameStatFunctions.scala @@ -34,8 +34,8 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch} * @since 1.4.0 */ @Stable -abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { - protected def df: DS[Row] +abstract class DataFrameStatFunctions { + protected def df: Dataset[Row] /** * Calculates the approximate quantiles of a numerical column of a DataFrame. @@ -202,7 +202,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def crosstab(col1: String, col2: String): DS[Row] + def crosstab(col1: String, col2: String): Dataset[Row] /** * Finding frequent items for columns, possibly with false positives. Using the frequent element @@ -246,7 +246,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * }}} * @since 1.4.0 */ - def freqItems(cols: Array[String], support: Double): DS[Row] = + def freqItems(cols: Array[String], support: Double): Dataset[Row] = freqItems(cols.toImmutableArraySeq, support) /** @@ -263,7 +263,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Array[String]): DS[Row] = freqItems(cols, 0.01) + def freqItems(cols: Array[String]): Dataset[Row] = freqItems(cols, 0.01) /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the @@ -307,7 +307,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.4.0 */ - def freqItems(cols: Seq[String], support: Double): DS[Row] + def freqItems(cols: Seq[String], support: Double): Dataset[Row] /** * (Scala-specific) Finding frequent items for columns, possibly with false positives. Using the @@ -324,7 +324,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * A Local DataFrame with the Array of frequent items for each column. * @since 1.4.0 */ - def freqItems(cols: Seq[String]): DS[Row] = freqItems(cols, 0.01) + def freqItems(cols: Seq[String]): Dataset[Row] = freqItems(cols, 0.01) /** * Returns a stratified sample without replacement based on the fraction given on each stratum. @@ -356,7 +356,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): DS[Row] = { + def sampleBy[T](col: String, fractions: Map[T, Double], seed: Long): Dataset[Row] = { sampleBy(Column(col), fractions, seed) } @@ -376,7 +376,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 1.5.0 */ - def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = { + def sampleBy[T](col: String, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } @@ -413,7 +413,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): DS[Row] + def sampleBy[T](col: Column, fractions: Map[T, Double], seed: Long): Dataset[Row] /** * (Java-specific) Returns a stratified sample without replacement based on the fraction given @@ -432,7 +432,7 @@ abstract class DataFrameStatFunctions[DS[U] <: Dataset[U, DS]] { * a new `DataFrame` that represents the stratified sample * @since 3.0.0 */ - def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): DS[Row] = { + def sampleBy[T](col: Column, fractions: ju.Map[T, jl.Double], seed: Long): Dataset[Row] = { sampleBy(col, fractions.asScala.toMap.asInstanceOf[Map[T, Double]], seed) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index fb8b6f2f483a1..284a69fe6ee3e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -119,10 +119,10 @@ import org.apache.spark.util.SparkClassUtils * @since 1.6.0 */ @Stable -abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { - type RGD <: RelationalGroupedDataset[DS] +abstract class Dataset[T] extends Serializable { + type DS[U] <: Dataset[U] - def sparkSession: SparkSession[DS] + def sparkSession: SparkSession val encoder: Encoder[T] @@ -136,7 +136,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { */ // This is declared with parentheses to prevent the Scala compiler from treating // `ds.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DS[Row] + def toDF(): Dataset[Row] /** * Returns a new Dataset where each record has been mapped on to the specified type. The method @@ -157,7 +157,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def as[U: Encoder]: DS[U] + def as[U: Encoder]: Dataset[U] /** * Returns a new DataFrame where each row is reconciled to match the specified schema. Spark @@ -175,7 +175,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 3.4.0 */ - def to(schema: StructType): DS[Row] + def to(schema: StructType): Dataset[Row] /** * Converts this strongly typed collection of data to generic `DataFrame` with columns renamed. @@ -191,7 +191,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def toDF(colNames: String*): DS[Row] + def toDF(colNames: String*): Dataset[Row] /** * Returns the schema of this Dataset. @@ -312,7 +312,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 2.1.0 */ - def checkpoint(): DS[T] = checkpoint(eager = true, reliableCheckpoint = true) + def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) /** * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the @@ -331,7 +331,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 2.1.0 */ - def checkpoint(eager: Boolean): DS[T] = checkpoint(eager = eager, reliableCheckpoint = true) + def checkpoint(eager: Boolean): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = true) /** * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used @@ -342,7 +343,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 2.3.0 */ - def localCheckpoint(): DS[T] = checkpoint(eager = true, reliableCheckpoint = false) + def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) /** * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to @@ -361,7 +362,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 2.3.0 */ - def localCheckpoint(eager: Boolean): DS[T] = + def localCheckpoint(eager: Boolean): Dataset[T] = checkpoint(eager = eager, reliableCheckpoint = false) /** @@ -373,7 +374,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If * false creates a local checkpoint using the caching subsystem */ - protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): DS[T] + protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] /** * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time @@ -400,7 +401,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { */ // We only accept an existing column name, not a derived column here as a watermark that is // defined on a derived column cannot referenced elsewhere in the plan. - def withWatermark(eventTime: String, delayThreshold: String): DS[T] + def withWatermark(eventTime: String, delayThreshold: String): Dataset[T] /** * Displays the Dataset in a tabular form. Strings more than 20 characters will be truncated, @@ -551,7 +552,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 1.6.0 */ - def na: DataFrameNaFunctions[DS] + def na: DataFrameNaFunctions /** * Returns a [[DataFrameStatFunctions]] for working statistic functions support. @@ -563,7 +564,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 1.6.0 */ - def stat: DataFrameStatFunctions[DS] + def stat: DataFrameStatFunctions /** * Join with another `DataFrame`. @@ -575,7 +576,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_]): DS[Row] + def join(right: DS[_]): Dataset[Row] /** * Inner equi-join with another `DataFrame` using the given column. @@ -601,7 +602,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumn: String): DS[Row] = { + def join(right: DS[_], usingColumn: String): Dataset[Row] = { join(right, Seq(usingColumn)) } @@ -617,7 +618,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String]): DS[Row] = { + def join(right: DS[_], usingColumns: Array[String]): Dataset[Row] = { join(right, usingColumns.toImmutableArraySeq) } @@ -645,7 +646,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String]): DS[Row] = { + def join(right: DS[_], usingColumns: Seq[String]): Dataset[Row] = { join(right, usingColumns, "inner") } @@ -675,7 +676,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumn: String, joinType: String): DS[Row] = { + def join(right: DS[_], usingColumn: String, joinType: String): Dataset[Row] = { join(right, Seq(usingColumn), joinType) } @@ -696,7 +697,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def join(right: DS[_], usingColumns: Array[String], joinType: String): DS[Row] = { + def join(right: DS[_], usingColumns: Array[String], joinType: String): Dataset[Row] = { join(right, usingColumns.toImmutableArraySeq, joinType) } @@ -726,7 +727,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], usingColumns: Seq[String], joinType: String): DS[Row] + def join(right: DS[_], usingColumns: Seq[String], joinType: String): Dataset[Row] /** * Inner join with another `DataFrame`, using the given join expression. @@ -740,7 +741,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column): DS[Row] = + def join(right: DS[_], joinExprs: Column): Dataset[Row] = join(right, joinExprs, "inner") /** @@ -770,7 +771,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def join(right: DS[_], joinExprs: Column, joinType: String): DS[Row] + def join(right: DS[_], joinExprs: Column, joinType: String): Dataset[Row] /** * Explicit cartesian join with another `DataFrame`. @@ -782,7 +783,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.1.0 */ - def crossJoin(right: DS[_]): DS[Row] + def crossJoin(right: DS[_]): Dataset[Row] /** * Joins this Dataset returning a `Tuple2` for each pair where `condition` evaluates to true. @@ -806,7 +807,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column, joinType: String): DS[(T, U)] + def joinWith[U](other: DS[U], condition: Column, joinType: String): Dataset[(T, U)] /** * Using inner equi-join to join this Dataset returning a `Tuple2` for each pair where @@ -819,11 +820,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def joinWith[U](other: DS[U], condition: Column): DS[(T, U)] = { + def joinWith[U](other: DS[U], condition: Column): Dataset[(T, U)] = { joinWith(other, condition, "inner") } - protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): DS[T] + protected def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] /** * Returns a new Dataset with each partition sorted by the given expressions. @@ -834,7 +835,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sortWithinPartitions(sortCol: String, sortCols: String*): DS[T] = { + def sortWithinPartitions(sortCol: String, sortCols: String*): Dataset[T] = { sortWithinPartitions((sortCol +: sortCols).map(Column(_)): _*) } @@ -847,7 +848,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sortWithinPartitions(sortExprs: Column*): DS[T] = { + def sortWithinPartitions(sortExprs: Column*): Dataset[T] = { sortInternal(global = false, sortExprs) } @@ -864,7 +865,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sort(sortCol: String, sortCols: String*): DS[T] = { + def sort(sortCol: String, sortCols: String*): Dataset[T] = { sort((sortCol +: sortCols).map(Column(_)): _*) } @@ -878,7 +879,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def sort(sortExprs: Column*): DS[T] = { + def sort(sortExprs: Column*): Dataset[T] = { sortInternal(global = true, sortExprs) } @@ -890,7 +891,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def orderBy(sortCol: String, sortCols: String*): DS[T] = sort(sortCol, sortCols: _*) + def orderBy(sortCol: String, sortCols: String*): Dataset[T] = sort(sortCol, sortCols: _*) /** * Returns a new Dataset sorted by the given expressions. This is an alias of the `sort` @@ -900,7 +901,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def orderBy(sortExprs: Column*): DS[T] = sort(sortExprs: _*) + def orderBy(sortExprs: Column*): Dataset[T] = sort(sortExprs: _*) /** * Specifies some hint on the current Dataset. As an example, the following code specifies that @@ -926,7 +927,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.2.0 */ @scala.annotation.varargs - def hint(name: String, parameters: Any*): DS[T] + def hint(name: String, parameters: Any*): Dataset[T] /** * Selects column based on the column name and returns it as a [[org.apache.spark.sql.Column]]. @@ -975,7 +976,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def as(alias: String): DS[T] + def as(alias: String): Dataset[T] /** * (Scala-specific) Returns a new Dataset with an alias set. @@ -983,7 +984,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def as(alias: Symbol): DS[T] = as(alias.name) + def as(alias: Symbol): Dataset[T] = as(alias.name) /** * Returns a new Dataset with an alias set. Same as `as`. @@ -991,7 +992,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def alias(alias: String): DS[T] = as(alias) + def alias(alias: String): Dataset[T] = as(alias) /** * (Scala-specific) Returns a new Dataset with an alias set. Same as `as`. @@ -999,7 +1000,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def alias(alias: Symbol): DS[T] = as(alias) + def alias(alias: Symbol): Dataset[T] = as(alias) /** * Selects a set of column based expressions. @@ -1011,7 +1012,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(cols: Column*): DS[Row] + def select(cols: Column*): Dataset[Row] /** * Selects a set of columns. This is a variant of `select` that can only select existing columns @@ -1027,7 +1028,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def select(col: String, cols: String*): DS[Row] = select((col +: cols).map(Column(_)): _*) + def select(col: String, cols: String*): Dataset[Row] = select((col +: cols).map(Column(_)): _*) /** * Selects a set of SQL expressions. This is a variant of `select` that accepts SQL expressions. @@ -1042,7 +1043,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def selectExpr(exprs: String*): DS[Row] = select(exprs.map(functions.expr): _*) + def selectExpr(exprs: String*): Dataset[Row] = select(exprs.map(functions.expr): _*) /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expression for @@ -1056,14 +1057,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def select[U1](c1: TypedColumn[T, U1]): DS[U1] + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] /** * Internal helper function for building typed selects that return tuples. For simplicity and * code reuse, we do this without the help of the type system and then use helper functions that * cast appropriately for the user facing interface. */ - protected def selectUntyped(columns: TypedColumn[_, _]*): DS[_] + protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1072,8 +1073,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): DS[(U1, U2)] = - selectUntyped(c1, c2).asInstanceOf[DS[(U1, U2)]] + def select[U1, U2](c1: TypedColumn[T, U1], c2: TypedColumn[T, U2]): Dataset[(U1, U2)] = + selectUntyped(c1, c2).asInstanceOf[Dataset[(U1, U2)]] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1085,8 +1086,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { def select[U1, U2, U3]( c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], - c3: TypedColumn[T, U3]): DS[(U1, U2, U3)] = - selectUntyped(c1, c2, c3).asInstanceOf[DS[(U1, U2, U3)]] + c3: TypedColumn[T, U3]): Dataset[(U1, U2, U3)] = + selectUntyped(c1, c2, c3).asInstanceOf[Dataset[(U1, U2, U3)]] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1099,8 +1100,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { c1: TypedColumn[T, U1], c2: TypedColumn[T, U2], c3: TypedColumn[T, U3], - c4: TypedColumn[T, U4]): DS[(U1, U2, U3, U4)] = - selectUntyped(c1, c2, c3, c4).asInstanceOf[DS[(U1, U2, U3, U4)]] + c4: TypedColumn[T, U4]): Dataset[(U1, U2, U3, U4)] = + selectUntyped(c1, c2, c3, c4).asInstanceOf[Dataset[(U1, U2, U3, U4)]] /** * Returns a new Dataset by computing the given [[org.apache.spark.sql.Column]] expressions for @@ -1114,8 +1115,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { c2: TypedColumn[T, U2], c3: TypedColumn[T, U3], c4: TypedColumn[T, U4], - c5: TypedColumn[T, U5]): DS[(U1, U2, U3, U4, U5)] = - selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[DS[(U1, U2, U3, U4, U5)]] + c5: TypedColumn[T, U5]): Dataset[(U1, U2, U3, U4, U5)] = + selectUntyped(c1, c2, c3, c4, c5).asInstanceOf[Dataset[(U1, U2, U3, U4, U5)]] /** * Filters rows using the given condition. @@ -1128,7 +1129,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def filter(condition: Column): DS[T] + def filter(condition: Column): Dataset[T] /** * Filters rows using the given SQL expression. @@ -1139,7 +1140,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def filter(conditionExpr: String): DS[T] = + def filter(conditionExpr: String): Dataset[T] = filter(functions.expr(conditionExpr)) /** @@ -1149,7 +1150,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def filter(func: T => Boolean): DS[T] + def filter(func: T => Boolean): Dataset[T] /** * (Java-specific) Returns a new Dataset that only contains elements where `func` returns @@ -1158,7 +1159,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def filter(func: FilterFunction[T]): DS[T] + def filter(func: FilterFunction[T]): Dataset[T] /** * Filters rows using the given condition. This is an alias for `filter`. @@ -1171,7 +1172,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def where(condition: Column): DS[T] = filter(condition) + def where(condition: Column): Dataset[T] = filter(condition) /** * Filters rows using the given SQL expression. @@ -1182,7 +1183,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def where(conditionExpr: String): DS[T] = filter(conditionExpr) + def where(conditionExpr: String): Dataset[T] = filter(conditionExpr) /** * Groups the Dataset using the specified columns, so we can run aggregation on them. See @@ -1203,7 +1204,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(cols: Column*): RGD + def groupBy(cols: Column*): RelationalGroupedDataset /** * Groups the Dataset using the specified columns, so that we can run aggregation on them. See @@ -1227,7 +1228,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def groupBy(col1: String, cols: String*): RGD = groupBy((col1 +: cols).map(col): _*) + def groupBy(col1: String, cols: String*): RelationalGroupedDataset = groupBy( + (col1 +: cols).map(col): _*) /** * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we @@ -1249,7 +1251,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def rollup(cols: Column*): RGD + def rollup(cols: Column*): RelationalGroupedDataset /** * Create a multi-dimensional rollup for the current Dataset using the specified columns, so we @@ -1274,7 +1276,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def rollup(col1: String, cols: String*): RGD = rollup((col1 +: cols).map(col): _*) + def rollup(col1: String, cols: String*): RelationalGroupedDataset = rollup( + (col1 +: cols).map(col): _*) /** * Create a multi-dimensional cube for the current Dataset using the specified columns, so we @@ -1296,7 +1299,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def cube(cols: Column*): RGD + def cube(cols: Column*): RelationalGroupedDataset /** * Create a multi-dimensional cube for the current Dataset using the specified columns, so we @@ -1321,7 +1324,8 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def cube(col1: String, cols: String*): RGD = cube((col1 +: cols).map(col): _*) + def cube(col1: String, cols: String*): RelationalGroupedDataset = cube( + (col1 +: cols).map(col): _*) /** * Create multi-dimensional aggregation for the current Dataset using the specified grouping @@ -1343,7 +1347,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 4.0.0 */ @scala.annotation.varargs - def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RGD + def groupingSets(groupingSets: Seq[Seq[Column]], cols: Column*): RelationalGroupedDataset /** * (Scala-specific) Aggregates on the entire Dataset without groups. @@ -1356,7 +1360,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] = { + def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = { groupBy().agg(aggExpr, aggExprs: _*) } @@ -1371,7 +1375,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: Map[String, String]): DS[Row] = groupBy().agg(exprs) + def agg(exprs: Map[String, String]): Dataset[Row] = groupBy().agg(exprs) /** * (Java-specific) Aggregates on the entire Dataset without groups. @@ -1384,7 +1388,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def agg(exprs: util.Map[String, String]): DS[Row] = groupBy().agg(exprs) + def agg(exprs: util.Map[String, String]): Dataset[Row] = groupBy().agg(exprs) /** * Aggregates on the entire Dataset without groups. @@ -1398,7 +1402,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DS[Row] = groupBy().agg(expr, exprs: _*) + def agg(expr: Column, exprs: Column*): Dataset[Row] = groupBy().agg(expr, exprs: _*) /** * (Scala-specific) Reduces the elements of this Dataset using the specified binary function. @@ -1479,7 +1483,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DS[Row] + valueColumnName: String): Dataset[Row] /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns @@ -1502,7 +1506,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def unpivot(ids: Array[Column], variableColumnName: String, valueColumnName: String): DS[Row] + def unpivot( + ids: Array[Column], + variableColumnName: String, + valueColumnName: String): Dataset[Row] /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns @@ -1526,7 +1533,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { ids: Array[Column], values: Array[Column], variableColumnName: String, - valueColumnName: String): DS[Row] = + valueColumnName: String): Dataset[Row] = unpivot(ids, values, variableColumnName, valueColumnName) /** @@ -1548,7 +1555,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def melt(ids: Array[Column], variableColumnName: String, valueColumnName: String): DS[Row] = + def melt( + ids: Array[Column], + variableColumnName: String, + valueColumnName: String): Dataset[Row] = unpivot(ids, variableColumnName, valueColumnName) /** @@ -1611,7 +1621,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def transpose(indexColumn: Column): DS[Row] + def transpose(indexColumn: Column): Dataset[Row] /** * Transposes a DataFrame, switching rows to columns. This function transforms the DataFrame @@ -1630,7 +1640,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 4.0.0 */ - def transpose(): DS[Row] + def transpose(): Dataset[Row] /** * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset @@ -1651,7 +1661,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 3.0.0 */ @scala.annotation.varargs - def observe(name: String, expr: Column, exprs: Column*): DS[T] + def observe(name: String, expr: Column, exprs: Column*): Dataset[T] /** * Observe (named) metrics through an `org.apache.spark.sql.Observation` instance. This method @@ -1674,7 +1684,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 3.3.0 */ @scala.annotation.varargs - def observe(observation: Observation, expr: Column, exprs: Column*): DS[T] + def observe(observation: Observation, expr: Column, exprs: Column*): Dataset[T] /** * Returns a new Dataset by taking the first `n` rows. The difference between this function and @@ -1684,7 +1694,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def limit(n: Int): DS[T] + def limit(n: Int): Dataset[T] /** * Returns a new Dataset by skipping the first `n` rows. @@ -1692,7 +1702,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.4.0 */ - def offset(n: Int): DS[T] + def offset(n: Int): Dataset[T] /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1724,7 +1734,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def union(other: DS[T]): DS[T] + def union(other: DS[T]): Dataset[T] /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. This is @@ -1738,7 +1748,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def unionAll(other: DS[T]): DS[T] = union(other) + def unionAll(other: DS[T]): Dataset[T] = union(other) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1769,7 +1779,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.3.0 */ - def unionByName(other: DS[T]): DS[T] = unionByName(other, allowMissingColumns = false) + def unionByName(other: DS[T]): Dataset[T] = unionByName(other, allowMissingColumns = false) /** * Returns a new Dataset containing union of rows in this Dataset and another Dataset. @@ -1813,7 +1823,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.1.0 */ - def unionByName(other: DS[T], allowMissingColumns: Boolean): DS[T] + def unionByName(other: DS[T], allowMissingColumns: Boolean): Dataset[T] /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset. This is @@ -1825,7 +1835,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def intersect(other: DS[T]): DS[T] + def intersect(other: DS[T]): Dataset[T] /** * Returns a new Dataset containing rows only in both this Dataset and another Dataset while @@ -1838,7 +1848,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.4.0 */ - def intersectAll(other: DS[T]): DS[T] + def intersectAll(other: DS[T]): Dataset[T] /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset. This is @@ -1850,7 +1860,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def except(other: DS[T]): DS[T] + def except(other: DS[T]): Dataset[T] /** * Returns a new Dataset containing rows in this Dataset but not in another Dataset while @@ -1863,7 +1873,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.4.0 */ - def exceptAll(other: DS[T]): DS[T] + def exceptAll(other: DS[T]): Dataset[T] /** * Returns a new [[Dataset]] by sampling a fraction of rows (without replacement), using a @@ -1879,7 +1889,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.3.0 */ - def sample(fraction: Double, seed: Long): DS[T] = { + def sample(fraction: Double, seed: Long): Dataset[T] = { sample(withReplacement = false, fraction = fraction, seed = seed) } @@ -1895,7 +1905,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.3.0 */ - def sample(fraction: Double): DS[T] = { + def sample(fraction: Double): Dataset[T] = { sample(withReplacement = false, fraction = fraction) } @@ -1914,7 +1924,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double, seed: Long): DS[T] + def sample(withReplacement: Boolean, fraction: Double, seed: Long): Dataset[T] /** * Returns a new [[Dataset]] by sampling a fraction of rows, using a random seed. @@ -1931,7 +1941,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def sample(withReplacement: Boolean, fraction: Double): DS[T] = { + def sample(withReplacement: Boolean, fraction: Double): Dataset[T] = { sample(withReplacement, fraction, SparkClassUtils.random.nextLong) } @@ -1948,7 +1958,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double], seed: Long): Array[_ <: DS[T]] + def randomSplit(weights: Array[Double], seed: Long): Array[_ <: Dataset[T]] /** * Returns a Java list that contains randomly split Dataset with the provided weights. @@ -1960,7 +1970,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: DS[T]] + def randomSplitAsList(weights: Array[Double], seed: Long): util.List[_ <: Dataset[T]] /** * Randomly splits this Dataset with the provided weights. @@ -1970,7 +1980,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def randomSplit(weights: Array[Double]): Array[_ <: DS[T]] + def randomSplit(weights: Array[Double]): Array[_ <: Dataset[T]] /** * (Scala-specific) Returns a new Dataset where each row has been expanded to zero or more rows @@ -1983,7 +1993,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * * {{{ * case class Book(title: String, words: String) - * val ds: DS[Book] + * val ds: Dataset[Book] * * val allWords = ds.select($"title", explode(split($"words", " ")).as("word")) * @@ -2000,7 +2010,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") - def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): DS[Row] + def explode[A <: Product: TypeTag](input: Column*)(f: Row => IterableOnce[A]): Dataset[Row] /** * (Scala-specific) Returns a new Dataset where a single column has been expanded to zero or @@ -2026,7 +2036,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { */ @deprecated("use flatMap() or select() with functions.explode() instead", "2.0.0") def explode[A, B: TypeTag](inputColumn: String, outputColumn: String)( - f: A => IterableOnce[B]): DS[Row] + f: A => IterableOnce[B]): Dataset[Row] /** * Returns a new Dataset by adding a column or replacing the existing column that has the same @@ -2043,7 +2053,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): DS[Row] = withColumns(Seq(colName), Seq(col)) + def withColumn(colName: String, col: Column): Dataset[Row] = withColumns(Seq(colName), Seq(col)) /** * (Scala-specific) Returns a new Dataset by adding columns or replacing the existing columns @@ -2055,7 +2065,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: Map[String, Column]): DS[Row] = { + def withColumns(colsMap: Map[String, Column]): Dataset[Row] = { val (colNames, newCols) = colsMap.toSeq.unzip withColumns(colNames, newCols) } @@ -2070,13 +2080,14 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withColumns(colsMap: util.Map[String, Column]): DS[Row] = withColumns(colsMap.asScala.toMap) + def withColumns(colsMap: util.Map[String, Column]): Dataset[Row] = withColumns( + colsMap.asScala.toMap) /** * Returns a new Dataset by adding columns or replacing the existing columns that has the same * names. */ - protected def withColumns(colNames: Seq[String], cols: Seq[Column]): DS[Row] + protected def withColumns(colNames: Seq[String], cols: Seq[Column]): Dataset[Row] /** * Returns a new Dataset with a column renamed. This is a no-op if schema doesn't contain @@ -2085,7 +2096,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def withColumnRenamed(existingName: String, newName: String): DS[Row] = + def withColumnRenamed(existingName: String, newName: String): Dataset[Row] = withColumnsRenamed(Seq(existingName), Seq(newName)) /** @@ -2100,7 +2111,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 3.4.0 */ @throws[AnalysisException] - def withColumnsRenamed(colsMap: Map[String, String]): DS[Row] = { + def withColumnsRenamed(colsMap: Map[String, String]): Dataset[Row] = { val (colNames, newColNames) = colsMap.toSeq.unzip withColumnsRenamed(colNames, newColNames) } @@ -2114,10 +2125,10 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.4.0 */ - def withColumnsRenamed(colsMap: util.Map[String, String]): DS[Row] = + def withColumnsRenamed(colsMap: util.Map[String, String]): Dataset[Row] = withColumnsRenamed(colsMap.asScala.toMap) - protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): DS[Row] + protected def withColumnsRenamed(colNames: Seq[String], newColNames: Seq[String]): Dataset[Row] /** * Returns a new Dataset by updating an existing column with metadata. @@ -2125,7 +2136,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 3.3.0 */ - def withMetadata(columnName: String, metadata: Metadata): DS[Row] + def withMetadata(columnName: String, metadata: Metadata): Dataset[Row] /** * Returns a new Dataset with a column dropped. This is a no-op if schema doesn't contain column @@ -2198,7 +2209,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def drop(colName: String): DS[Row] = drop(colName :: Nil: _*) + def drop(colName: String): Dataset[Row] = drop(colName :: Nil: _*) /** * Returns a new Dataset with columns dropped. This is a no-op if schema doesn't contain column @@ -2211,7 +2222,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def drop(colNames: String*): DS[Row] + def drop(colNames: String*): Dataset[Row] /** * Returns a new Dataset with column dropped. @@ -2226,7 +2237,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group untypedrel * @since 2.0.0 */ - def drop(col: Column): DS[Row] = drop(col, Nil: _*) + def drop(col: Column): Dataset[Row] = drop(col, Nil: _*) /** * Returns a new Dataset with columns dropped. @@ -2238,7 +2249,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 3.4.0 */ @scala.annotation.varargs - def drop(col: Column, cols: Column*): DS[Row] + def drop(col: Column, cols: Column*): Dataset[Row] /** * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias @@ -2253,7 +2264,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def dropDuplicates(): DS[T] + def dropDuplicates(): Dataset[T] /** * (Scala-specific) Returns a new Dataset with duplicate rows removed, considering only the @@ -2268,7 +2279,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Seq[String]): DS[T] + def dropDuplicates(colNames: Seq[String]): Dataset[T] /** * Returns a new Dataset with duplicate rows removed, considering only the subset of columns. @@ -2282,7 +2293,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def dropDuplicates(colNames: Array[String]): DS[T] = + def dropDuplicates(colNames: Array[String]): Dataset[T] = dropDuplicates(colNames.toImmutableArraySeq) /** @@ -2299,7 +2310,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def dropDuplicates(col1: String, cols: String*): DS[T] = { + def dropDuplicates(col1: String, cols: String*): Dataset[T] = { val colNames: Seq[String] = col1 +: cols dropDuplicates(colNames) } @@ -2321,7 +2332,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(): DS[T] + def dropDuplicatesWithinWatermark(): Dataset[T] /** * Returns a new Dataset with duplicates rows removed, considering only the subset of columns, @@ -2341,7 +2352,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Seq[String]): DS[T] + def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] /** * Returns a new Dataset with duplicates rows removed, considering only the subset of columns, @@ -2361,7 +2372,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 3.5.0 */ - def dropDuplicatesWithinWatermark(colNames: Array[String]): DS[T] = { + def dropDuplicatesWithinWatermark(colNames: Array[String]): Dataset[T] = { dropDuplicatesWithinWatermark(colNames.toImmutableArraySeq) } @@ -2384,7 +2395,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 3.5.0 */ @scala.annotation.varargs - def dropDuplicatesWithinWatermark(col1: String, cols: String*): DS[T] = { + def dropDuplicatesWithinWatermark(col1: String, cols: String*): Dataset[T] = { val colNames: Seq[String] = col1 +: cols dropDuplicatesWithinWatermark(colNames) } @@ -2418,7 +2429,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 1.6.0 */ @scala.annotation.varargs - def describe(cols: String*): DS[Row] + def describe(cols: String*): Dataset[Row] /** * Computes specified statistics for numeric and string columns. Available statistics are:
    @@ -2488,7 +2499,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.3.0 */ @scala.annotation.varargs - def summary(statistics: String*): DS[Row] + def summary(statistics: String*): Dataset[Row] /** * Returns the first `n` rows. @@ -2520,7 +2531,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { /** * Concise syntax for chaining custom transformations. * {{{ - * def featurize(ds: DS[T]): DS[U] = ... + * def featurize(ds: Dataset[T]): Dataset[U] = ... * * ds * .transform(featurize) @@ -2530,7 +2541,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def transform[U](t: DS[T] => DS[U]): DS[U] = t(this.asInstanceOf[DS[T]]) + def transform[U](t: Dataset[T] => Dataset[U]): Dataset[U] = t(this.asInstanceOf[Dataset[T]]) /** * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each @@ -2539,7 +2550,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def map[U: Encoder](func: T => U): DS[U] + def map[U: Encoder](func: T => U): Dataset[U] /** * (Java-specific) Returns a new Dataset that contains the result of applying `func` to each @@ -2548,7 +2559,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): DS[U] + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] /** * (Scala-specific) Returns a new Dataset that contains the result of applying `func` to each @@ -2557,7 +2568,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): DS[U] + def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] /** * (Java-specific) Returns a new Dataset that contains the result of applying `f` to each @@ -2566,7 +2577,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): DS[U] = + def mapPartitions[U](f: MapPartitionsFunction[T, U], encoder: Encoder[U]): Dataset[U] = mapPartitions(ToScalaUDF(f))(encoder) /** @@ -2576,7 +2587,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def flatMap[U: Encoder](func: T => IterableOnce[U]): DS[U] = + def flatMap[U: Encoder](func: T => IterableOnce[U]): Dataset[U] = mapPartitions(UDFAdaptors.flatMapToMapPartitions[T, U](func)) /** @@ -2586,7 +2597,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): DS[U] = { + def flatMap[U](f: FlatMapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { mapPartitions(UDFAdaptors.flatMapToMapPartitions(f))(encoder) } @@ -2713,11 +2724,11 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def repartition(numPartitions: Int): DS[T] + def repartition(numPartitions: Int): Dataset[T] protected def repartitionByExpression( numPartitions: Option[Int], - partitionExprs: Seq[Column]): DS[T] + partitionExprs: Seq[Column]): Dataset[T] /** * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`. @@ -2729,7 +2740,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def repartition(numPartitions: Int, partitionExprs: Column*): DS[T] = { + def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { repartitionByExpression(Some(numPartitions), partitionExprs) } @@ -2744,11 +2755,13 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.0.0 */ @scala.annotation.varargs - def repartition(partitionExprs: Column*): DS[T] = { + def repartition(partitionExprs: Column*): Dataset[T] = { repartitionByExpression(None, partitionExprs) } - protected def repartitionByRange(numPartitions: Option[Int], partitionExprs: Seq[Column]): DS[T] + protected def repartitionByRange( + numPartitions: Option[Int], + partitionExprs: Seq[Column]): Dataset[T] /** * Returns a new Dataset partitioned by the given partitioning expressions into `numPartitions`. @@ -2766,7 +2779,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.3.0 */ @scala.annotation.varargs - def repartitionByRange(numPartitions: Int, partitionExprs: Column*): DS[T] = { + def repartitionByRange(numPartitions: Int, partitionExprs: Column*): Dataset[T] = { repartitionByRange(Some(numPartitions), partitionExprs) } @@ -2787,7 +2800,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @since 2.3.0 */ @scala.annotation.varargs - def repartitionByRange(partitionExprs: Column*): DS[T] = { + def repartitionByRange(partitionExprs: Column*): Dataset[T] = { repartitionByRange(None, partitionExprs) } @@ -2807,7 +2820,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 1.6.0 */ - def coalesce(numPartitions: Int): DS[T] + def coalesce(numPartitions: Int): Dataset[T] /** * Returns a new Dataset that contains only the unique rows from this Dataset. This is an alias @@ -2823,7 +2836,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group typedrel * @since 2.0.0 */ - def distinct(): DS[T] = dropDuplicates() + def distinct(): Dataset[T] = dropDuplicates() /** * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). @@ -2831,7 +2844,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def persist(): DS[T] + def persist(): Dataset[T] /** * Persist this Dataset with the default storage level (`MEMORY_AND_DISK`). @@ -2839,7 +2852,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def cache(): DS[T] + def cache(): Dataset[T] /** * Persist this Dataset with the given storage level. @@ -2850,7 +2863,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def persist(newLevel: StorageLevel): DS[T] + def persist(newLevel: StorageLevel): Dataset[T] /** * Get the Dataset's current storage level, or StorageLevel.NONE if not persisted. @@ -2869,7 +2882,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def unpersist(blocking: Boolean): DS[T] + def unpersist(blocking: Boolean): Dataset[T] /** * Mark the Dataset as non-persistent, and remove all blocks for it from memory and disk. This @@ -2878,7 +2891,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * @group basic * @since 1.6.0 */ - def unpersist(): DS[T] + def unpersist(): Dataset[T] /** * Registers this Dataset as a temporary table using the given name. The lifetime of this @@ -3008,7 +3021,7 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable { * * @since 2.0.0 */ - def toJSON: DS[String] + def toJSON: Dataset[String] /** * Returns a best-effort snapshot of the files that compose this Dataset. This method simply diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala index 50dfbff81dd3e..81f999430a128 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/KeyValueGroupedDataset.scala @@ -30,8 +30,8 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode * * @since 2.0.0 */ -abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Serializable { - type KVDS[KY, VL] <: KeyValueGroupedDataset[KY, VL, DS] +abstract class KeyValueGroupedDataset[K, V] extends Serializable { + type KVDS[KL, VL] <: KeyValueGroupedDataset[KL, VL] /** * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the @@ -40,7 +40,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def keyAs[L: Encoder]: KVDS[L, V] + def keyAs[L: Encoder]: KeyValueGroupedDataset[L, V] /** * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to @@ -53,7 +53,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 2.1.0 */ - def mapValues[W: Encoder](func: V => W): KVDS[K, W] + def mapValues[W: Encoder](func: V => W): KeyValueGroupedDataset[K, W] /** * Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied to @@ -68,7 +68,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 2.1.0 */ - def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KVDS[K, W] = { + def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = { mapValues(ToScalaUDF(func))(encoder) } @@ -78,7 +78,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def keys: DS[K] + def keys: Dataset[K] /** * (Scala-specific) Applies the given function to each group of data. For each unique group, the @@ -98,7 +98,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): DS[U] = { + def flatMapGroups[U: Encoder](f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { flatMapSortedGroups(Nil: _*)(f) } @@ -120,7 +120,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): DS[U] = { + def flatMapGroups[U](f: FlatMapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { flatMapGroups(ToScalaUDF(f))(encoder) } @@ -149,7 +149,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * @since 3.4.0 */ def flatMapSortedGroups[U: Encoder](sortExprs: Column*)( - f: (K, Iterator[V]) => IterableOnce[U]): DS[U] + f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] /** * (Java-specific) Applies the given function to each group of data. For each unique group, the @@ -178,7 +178,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser def flatMapSortedGroups[U]( SortExprs: Array[Column], f: FlatMapGroupsFunction[K, V, U], - encoder: Encoder[U]): DS[U] = { + encoder: Encoder[U]): Dataset[U] = { import org.apache.spark.util.ArrayImplicits._ flatMapSortedGroups(SortExprs.toImmutableArraySeq: _*)(ToScalaUDF(f))(encoder) } @@ -201,7 +201,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): DS[U] = { + def mapGroups[U: Encoder](f: (K, Iterator[V]) => U): Dataset[U] = { flatMapGroups(UDFAdaptors.mapGroupsToFlatMapGroups(f)) } @@ -223,7 +223,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): DS[U] = { + def mapGroups[U](f: MapGroupsFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = { mapGroups(ToScalaUDF(f))(encoder) } @@ -247,7 +247,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * @since 2.2.0 */ def mapGroupsWithState[S: Encoder, U: Encoder]( - func: (K, Iterator[V], GroupState[S]) => U): DS[U] + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] /** * (Scala-specific) Applies the given function to each group of data, while maintaining a @@ -271,7 +271,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * @since 2.2.0 */ def mapGroupsWithState[S: Encoder, U: Encoder](timeoutConf: GroupStateTimeout)( - func: (K, Iterator[V], GroupState[S]) => U): DS[U] + func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] /** * (Scala-specific) Applies the given function to each group of data, while maintaining a @@ -301,7 +301,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser */ def mapGroupsWithState[S: Encoder, U: Encoder]( timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => U): DS[U] + initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => U): Dataset[U] /** * (Java-specific) Applies the given function to each group of data, while maintaining a @@ -329,7 +329,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser def mapGroupsWithState[S, U]( func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], - outputEncoder: Encoder[U]): DS[U] = { + outputEncoder: Encoder[U]): Dataset[U] = { mapGroupsWithState[S, U](ToScalaUDF(func))(stateEncoder, outputEncoder) } @@ -362,7 +362,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser func: MapGroupsWithStateFunction[K, V, S, U], stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): DS[U] = { + timeoutConf: GroupStateTimeout): Dataset[U] = { mapGroupsWithState[S, U](timeoutConf)(ToScalaUDF(func))(stateEncoder, outputEncoder) } @@ -400,7 +400,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S]): DS[U] = { + initialState: KVDS[K, S]): Dataset[U] = { val f = ToScalaUDF(func) mapGroupsWithState[S, U](timeoutConf, initialState)(f)(stateEncoder, outputEncoder) } @@ -430,7 +430,8 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser */ def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, - timeoutConf: GroupStateTimeout)(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): DS[U] + timeoutConf: GroupStateTimeout)( + func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] /** * (Scala-specific) Applies the given function to each group of data, while maintaining a @@ -462,7 +463,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser def flatMapGroupsWithState[S: Encoder, U: Encoder]( outputMode: OutputMode, timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): DS[U] + initialState: KVDS[K, S])(func: (K, Iterator[V], GroupState[S]) => Iterator[U]): Dataset[U] /** * (Java-specific) Applies the given function to each group of data, while maintaining a @@ -496,7 +497,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser outputMode: OutputMode, stateEncoder: Encoder[S], outputEncoder: Encoder[U], - timeoutConf: GroupStateTimeout): DS[U] = { + timeoutConf: GroupStateTimeout): Dataset[U] = { val f = ToScalaUDF(func) flatMapGroupsWithState[S, U](outputMode, timeoutConf)(f)(stateEncoder, outputEncoder) } @@ -540,7 +541,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser stateEncoder: Encoder[S], outputEncoder: Encoder[U], timeoutConf: GroupStateTimeout, - initialState: KVDS[K, S]): DS[U] = { + initialState: KVDS[K, S]): Dataset[U] = { flatMapGroupsWithState[S, U](outputMode, timeoutConf, initialState)(ToScalaUDF(func))( stateEncoder, outputEncoder) @@ -568,7 +569,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, - outputMode: OutputMode): DS[U] + outputMode: OutputMode): Dataset[U] /** * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state @@ -597,7 +598,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser private[sql] def transformWithState[U: Encoder]( statefulProcessor: StatefulProcessor[K, V, U], eventTimeColumnName: String, - outputMode: OutputMode): DS[U] + outputMode: OutputMode): Dataset[U] /** * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API @@ -624,7 +625,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser statefulProcessor: StatefulProcessor[K, V, U], timeMode: TimeMode, outputMode: OutputMode, - outputEncoder: Encoder[U]): DS[U] = { + outputEncoder: Encoder[U]): Dataset[U] = { transformWithState(statefulProcessor, timeMode, outputMode)(outputEncoder) } @@ -660,7 +661,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser statefulProcessor: StatefulProcessor[K, V, U], eventTimeColumnName: String, outputMode: OutputMode, - outputEncoder: Encoder[U]): DS[U] = { + outputEncoder: Encoder[U]): Dataset[U] = { transformWithState(statefulProcessor, eventTimeColumnName, outputMode)(outputEncoder) } @@ -689,7 +690,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], timeMode: TimeMode, outputMode: OutputMode, - initialState: KVDS[K, S]): DS[U] + initialState: KVDS[K, S]): Dataset[U] /** * (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state @@ -722,7 +723,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S], eventTimeColumnName: String, outputMode: OutputMode, - initialState: KVDS[K, S]): DS[U] + initialState: KVDS[K, S]): Dataset[U] /** * (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API @@ -756,7 +757,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser outputMode: OutputMode, initialState: KVDS[K, S], outputEncoder: Encoder[U], - initialStateEncoder: Encoder[S]): DS[U] = { + initialStateEncoder: Encoder[S]): Dataset[U] = { transformWithState(statefulProcessor, timeMode, outputMode, initialState)( outputEncoder, initialStateEncoder) @@ -798,7 +799,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser initialState: KVDS[K, S], eventTimeColumnName: String, outputEncoder: Encoder[U], - initialStateEncoder: Encoder[S]): DS[U] = { + initialStateEncoder: Encoder[S]): Dataset[U] = { transformWithState(statefulProcessor, eventTimeColumnName, outputMode, initialState)( outputEncoder, initialStateEncoder) @@ -811,7 +812,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def reduceGroups(f: (V, V) => V): DS[(K, V)] + def reduceGroups(f: (V, V) => V): Dataset[(K, V)] /** * (Java-specific) Reduces the elements of each group of data using the specified binary @@ -820,7 +821,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def reduceGroups(f: ReduceFunction[V]): DS[(K, V)] = { + def reduceGroups(f: ReduceFunction[V]): Dataset[(K, V)] = { reduceGroups(ToScalaUDF(f)) } @@ -829,7 +830,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * and code reuse, we do this without the help of the type system and then use helper functions * that cast appropriately for the user facing interface. */ - protected def aggUntyped(columns: TypedColumn[_, _]*): DS[_] + protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] /** * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key and the @@ -837,8 +838,8 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def agg[U1](col1: TypedColumn[V, U1]): DS[(K, U1)] = - aggUntyped(col1).asInstanceOf[DS[(K, U1)]] + def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] = + aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and @@ -846,8 +847,8 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): DS[(K, U1, U2)] = - aggUntyped(col1, col2).asInstanceOf[DS[(K, U1, U2)]] + def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] = + aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and @@ -858,8 +859,8 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser def agg[U1, U2, U3]( col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], - col3: TypedColumn[V, U3]): DS[(K, U1, U2, U3)] = - aggUntyped(col1, col2, col3).asInstanceOf[DS[(K, U1, U2, U3)]] + col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] = + aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and @@ -871,8 +872,8 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser col1: TypedColumn[V, U1], col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], - col4: TypedColumn[V, U4]): DS[(K, U1, U2, U3, U4)] = - aggUntyped(col1, col2, col3, col4).asInstanceOf[DS[(K, U1, U2, U3, U4)]] + col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] = + aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and @@ -885,8 +886,8 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser col2: TypedColumn[V, U2], col3: TypedColumn[V, U3], col4: TypedColumn[V, U4], - col5: TypedColumn[V, U5]): DS[(K, U1, U2, U3, U4, U5)] = - aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[DS[(K, U1, U2, U3, U4, U5)]] + col5: TypedColumn[V, U5]): Dataset[(K, U1, U2, U3, U4, U5)] = + aggUntyped(col1, col2, col3, col4, col5).asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and @@ -900,9 +901,9 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser col3: TypedColumn[V, U3], col4: TypedColumn[V, U4], col5: TypedColumn[V, U5], - col6: TypedColumn[V, U6]): DS[(K, U1, U2, U3, U4, U5, U6)] = + col6: TypedColumn[V, U6]): Dataset[(K, U1, U2, U3, U4, U5, U6)] = aggUntyped(col1, col2, col3, col4, col5, col6) - .asInstanceOf[DS[(K, U1, U2, U3, U4, U5, U6)]] + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and @@ -917,9 +918,9 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser col4: TypedColumn[V, U4], col5: TypedColumn[V, U5], col6: TypedColumn[V, U6], - col7: TypedColumn[V, U7]): DS[(K, U1, U2, U3, U4, U5, U6, U7)] = + col7: TypedColumn[V, U7]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7)] = aggUntyped(col1, col2, col3, col4, col5, col6, col7) - .asInstanceOf[DS[(K, U1, U2, U3, U4, U5, U6, U7)]] + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7)]] /** * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key and @@ -935,9 +936,9 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser col5: TypedColumn[V, U5], col6: TypedColumn[V, U6], col7: TypedColumn[V, U7], - col8: TypedColumn[V, U8]): DS[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = + col8: TypedColumn[V, U8]): Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)] = aggUntyped(col1, col2, col3, col4, col5, col6, col7, col8) - .asInstanceOf[DS[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] + .asInstanceOf[Dataset[(K, U1, U2, U3, U4, U5, U6, U7, U8)]] /** * Returns a [[Dataset]] that contains a tuple with each key and the number of items present for @@ -945,7 +946,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * * @since 1.6.0 */ - def count(): DS[(K, Long)] = agg(cnt(lit(1)).as(PrimitiveLongEncoder)) + def count(): Dataset[(K, Long)] = agg(cnt(lit(1)).as(PrimitiveLongEncoder)) /** * (Scala-specific) Applies the given function to each cogrouped data. For each unique group, @@ -956,7 +957,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * @since 1.6.0 */ def cogroup[U, R: Encoder](other: KVDS[K, U])( - f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): DS[R] = { + f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] = { cogroupSorted(other)(Nil: _*)(Nil: _*)(f) } @@ -971,7 +972,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser def cogroup[U, R]( other: KVDS[K, U], f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): DS[R] = { + encoder: Encoder[R]): Dataset[R] = { cogroup(other)(ToScalaUDF(f))(encoder) } @@ -991,7 +992,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser * @since 3.4.0 */ def cogroupSorted[U, R: Encoder](other: KVDS[K, U])(thisSortExprs: Column*)( - otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): DS[R] + otherSortExprs: Column*)(f: (K, Iterator[V], Iterator[U]) => IterableOnce[R]): Dataset[R] /** * (Java-specific) Applies the given function to each sorted cogrouped data. For each unique @@ -1013,7 +1014,7 @@ abstract class KeyValueGroupedDataset[K, V, DS[U] <: Dataset[U, DS]] extends Ser thisSortExprs: Array[Column], otherSortExprs: Array[Column], f: CoGroupFunction[K, V, U, R], - encoder: Encoder[R]): DS[R] = { + encoder: Encoder[R]): Dataset[R] = { import org.apache.spark.util.ArrayImplicits._ cogroupSorted(other)(thisSortExprs.toImmutableArraySeq: _*)( otherSortExprs.toImmutableArraySeq: _*)(ToScalaUDF(f))(encoder) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala index 7dd5f46beb316..118b8f1ecd488 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/RelationalGroupedDataset.scala @@ -35,15 +35,13 @@ import org.apache.spark.sql.{functions, Column, Encoder, Row} * @since 2.0.0 */ @Stable -abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { - type RGD <: RelationalGroupedDataset[DS] - - protected def df: DS[Row] +abstract class RelationalGroupedDataset { + protected def df: Dataset[Row] /** * Create a aggregation based on the grouping column, the grouping type, and the aggregations. */ - protected def toDF(aggCols: Seq[Column]): DS[Row] + protected def toDF(aggCols: Seq[Column]): Dataset[Row] protected def selectNumericColumns(colNames: Seq[String]): Seq[Column] @@ -62,7 +60,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { private def aggregateNumericColumns( colNames: Seq[String], - function: Column => Column): DS[Row] = { + function: Column => Column): Dataset[Row] = { toDF(selectNumericColumns(colNames).map(function)) } @@ -72,7 +70,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 3.0.0 */ - def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T, DS] + def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] /** * (Scala-specific) Compute aggregates by specifying the column names and aggregate methods. The @@ -89,7 +87,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 1.3.0 */ - def agg(aggExpr: (String, String), aggExprs: (String, String)*): DS[Row] = + def agg(aggExpr: (String, String), aggExprs: (String, String)*): Dataset[Row] = toDF((aggExpr +: aggExprs).map(toAggCol)) /** @@ -107,7 +105,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 1.3.0 */ - def agg(exprs: Map[String, String]): DS[Row] = toDF(exprs.map(toAggCol).toSeq) + def agg(exprs: Map[String, String]): Dataset[Row] = toDF(exprs.map(toAggCol).toSeq) /** * (Java-specific) Compute aggregates by specifying a map from column name to aggregate methods. @@ -122,7 +120,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 1.3.0 */ - def agg(exprs: util.Map[String, String]): DS[Row] = { + def agg(exprs: util.Map[String, String]): Dataset[Row] = { agg(exprs.asScala.toMap) } @@ -158,7 +156,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * @since 1.3.0 */ @scala.annotation.varargs - def agg(expr: Column, exprs: Column*): DS[Row] = toDF(expr +: exprs) + def agg(expr: Column, exprs: Column*): Dataset[Row] = toDF(expr +: exprs) /** * Count the number of rows for each group. The resulting `DataFrame` will also contain the @@ -166,7 +164,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * * @since 1.3.0 */ - def count(): DS[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil) + def count(): Dataset[Row] = toDF(functions.count(functions.lit(1)).as("count") :: Nil) /** * Compute the average value for each numeric columns for each group. This is an alias for @@ -176,7 +174,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * @since 1.3.0 */ @scala.annotation.varargs - def mean(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg) + def mean(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg) /** * Compute the max value for each numeric columns for each group. The resulting `DataFrame` will @@ -186,7 +184,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * @since 1.3.0 */ @scala.annotation.varargs - def max(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.max) + def max(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.max) /** * Compute the mean value for each numeric columns for each group. The resulting `DataFrame` @@ -196,7 +194,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * @since 1.3.0 */ @scala.annotation.varargs - def avg(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.avg) + def avg(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.avg) /** * Compute the min value for each numeric column for each group. The resulting `DataFrame` will @@ -206,7 +204,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * @since 1.3.0 */ @scala.annotation.varargs - def min(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.min) + def min(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.min) /** * Compute the sum for each numeric columns for each group. The resulting `DataFrame` will also @@ -216,7 +214,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * @since 1.3.0 */ @scala.annotation.varargs - def sum(colNames: String*): DS[Row] = aggregateNumericColumns(colNames, functions.sum) + def sum(colNames: String*): Dataset[Row] = aggregateNumericColumns(colNames, functions.sum) /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. @@ -237,7 +235,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * Name of the column to pivot. * @since 1.6.0 */ - def pivot(pivotColumn: String): RGD = pivot(df.col(pivotColumn)) + def pivot(pivotColumn: String): RelationalGroupedDataset = pivot(df.col(pivotColumn)) /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. There are @@ -271,7 +269,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: Seq[Any]): RGD = + def pivot(pivotColumn: String, values: Seq[Any]): RelationalGroupedDataset = pivot(df.col(pivotColumn), values) /** @@ -299,7 +297,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 */ - def pivot(pivotColumn: String, values: util.List[Any]): RGD = + def pivot(pivotColumn: String, values: util.List[Any]): RelationalGroupedDataset = pivot(df.col(pivotColumn), values) /** @@ -316,7 +314,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 */ - def pivot(pivotColumn: Column, values: util.List[Any]): RGD = + def pivot(pivotColumn: Column, values: util.List[Any]): RelationalGroupedDataset = pivot(pivotColumn, values.asScala.toSeq) /** @@ -338,7 +336,7 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * he column to pivot. * @since 2.4.0 */ - def pivot(pivotColumn: Column): RGD + def pivot(pivotColumn: Column): RelationalGroupedDataset /** * Pivots a column of the current `DataFrame` and performs the specified aggregation. This is an @@ -358,5 +356,5 @@ abstract class RelationalGroupedDataset[DS[U] <: Dataset[U, DS]] { * List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 */ - def pivot(pivotColumn: Column, values: Seq[Any]): RGD + def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 63d4a12e11839..41d16b16ab1c5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.types.StructType * .getOrCreate() * }}} */ -abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with Closeable { +abstract class SparkSession extends Serializable with Closeable { /** * The version of Spark on which this application is running. @@ -103,7 +103,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * will initialize the metastore, which may take some time. * @since 2.0.0 */ - def newSession(): SparkSession[DS] + def newSession(): SparkSession /* --------------------------------- * | Methods for creating DataFrames | @@ -115,14 +115,14 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * @since 2.0.0 */ @transient - def emptyDataFrame: DS[Row] + def emptyDataFrame: Dataset[Row] /** * Creates a `DataFrame` from a local Seq of Product. * * @since 2.0.0 */ - def createDataFrame[A <: Product: TypeTag](data: Seq[A]): DS[Row] + def createDataFrame[A <: Product: TypeTag](data: Seq[A]): Dataset[Row] /** * :: DeveloperApi :: Creates a `DataFrame` from a `java.util.List` containing @@ -133,7 +133,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * @since 2.0.0 */ @DeveloperApi - def createDataFrame(rows: util.List[Row], schema: StructType): DS[Row] + def createDataFrame(rows: util.List[Row], schema: StructType): Dataset[Row] /** * Applies a schema to a List of Java Beans. @@ -143,7 +143,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 1.6.0 */ - def createDataFrame(data: util.List[_], beanClass: Class[_]): DS[Row] + def createDataFrame(data: util.List[_], beanClass: Class[_]): Dataset[Row] /* ------------------------------- * | Methods for creating DataSets | @@ -154,7 +154,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def emptyDataset[T: Encoder]: DS[T] + def emptyDataset[T: Encoder]: Dataset[T] /** * Creates a [[Dataset]] from a local Seq of data of a given type. This method requires an @@ -183,7 +183,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def createDataset[T: Encoder](data: Seq[T]): DS[T] + def createDataset[T: Encoder](data: Seq[T]): Dataset[T] /** * Creates a [[Dataset]] from a `java.util.List` of a given type. This method requires an @@ -200,7 +200,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def createDataset[T: Encoder](data: util.List[T]): DS[T] + def createDataset[T: Encoder](data: util.List[T]): Dataset[T] /** * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a @@ -208,7 +208,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def range(end: Long): DS[lang.Long] + def range(end: Long): Dataset[lang.Long] /** * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a @@ -216,7 +216,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def range(start: Long, end: Long): DS[lang.Long] + def range(start: Long, end: Long): Dataset[lang.Long] /** * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a @@ -224,7 +224,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def range(start: Long, end: Long, step: Long): DS[lang.Long] + def range(start: Long, end: Long, step: Long): Dataset[lang.Long] /** * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a @@ -232,7 +232,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def range(start: Long, end: Long, step: Long, numPartitions: Int): DS[lang.Long] + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[lang.Long] /* ------------------------- * | Catalog-related methods | @@ -244,7 +244,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def catalog: Catalog[DS] + def catalog: Catalog /** * Returns the specified table/view as a `DataFrame`. If it's a table, it must support batch @@ -259,7 +259,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * database. Note that, the global temporary view database is also valid here. * @since 2.0.0 */ - def table(tableName: String): DS[Row] + def table(tableName: String): Dataset[Row] /* ----------------- * | Everything else | @@ -281,7 +281,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * @since 3.5.0 */ @Experimental - def sql(sqlText: String, args: Array[_]): DS[Row] + def sql(sqlText: String, args: Array[_]): Dataset[Row] /** * Executes a SQL query substituting named parameters by the given arguments, returning the @@ -299,7 +299,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * @since 3.4.0 */ @Experimental - def sql(sqlText: String, args: Map[String, Any]): DS[Row] + def sql(sqlText: String, args: Map[String, Any]): Dataset[Row] /** * Executes a SQL query substituting named parameters by the given arguments, returning the @@ -317,7 +317,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * @since 3.4.0 */ @Experimental - def sql(sqlText: String, args: util.Map[String, Any]): DS[Row] = { + def sql(sqlText: String, args: util.Map[String, Any]): Dataset[Row] = { sql(sqlText, args.asScala.toMap) } @@ -327,7 +327,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def sql(sqlText: String): DS[Row] = sql(sqlText, Map.empty[String, Any]) + def sql(sqlText: String): Dataset[Row] = sql(sqlText, Map.empty[String, Any]) /** * Add a single artifact to the current session. @@ -503,7 +503,7 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C * * @since 2.0.0 */ - def read: DataFrameReader[DS] + def read: DataFrameReader /** * Executes some code block and prints to stdout the time taken to execute the block. This is diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala index 16cd45339f051..0aeb3518facd8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQuery.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryPr * @since 2.0.0 */ @Evolving -trait StreamingQuery[DS[U] <: Dataset[U, DS]] { +trait StreamingQuery { /** * Returns the user-specified name of the query, or null if not specified. This name can be @@ -62,7 +62,7 @@ trait StreamingQuery[DS[U] <: Dataset[U, DS]] { * * @since 2.0.0 */ - def sparkSession: SparkSession[DS] + def sparkSession: SparkSession /** * Returns `true` if this query is actively running. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 86f8923f36b40..02669270c8acf 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1714,7 +1714,7 @@ object functions { * @group normal_funcs * @since 1.5.0 */ - def broadcast[DS[U] <: api.Dataset[U, DS]](df: DS[_]): df.type = { + def broadcast[DS[U] <: api.Dataset[U]](df: DS[_]): df.type = { df.hint("broadcast").asInstanceOf[df.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 53640f513fc81..b356751083fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -21,6 +21,7 @@ import java.{lang => jl} import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.ExpressionUtils.column @@ -33,7 +34,7 @@ import org.apache.spark.sql.types._ */ @Stable final class DataFrameNaFunctions private[sql](df: DataFrame) - extends api.DataFrameNaFunctions[Dataset] { + extends api.DataFrameNaFunctions { import df.sparkSession.RichColumn protected def drop(minNonNulls: Option[Int]): Dataset[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index f105a77cf253b..78cc65bb7a298 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} import org.apache.spark.sql.catalyst.util.FailureSafeParser import org.apache.spark.sql.catalyst.xml.{StaxXmlParser, XmlOptions} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource @@ -54,7 +55,9 @@ import org.apache.spark.unsafe.types.UTF8String */ @Stable class DataFrameReader private[sql](sparkSession: SparkSession) - extends api.DataFrameReader[Dataset] { + extends api.DataFrameReader { + override type DS[U] = Dataset[U] + format(sparkSession.sessionState.conf.defaultDataSourceName) /** @inheritdoc */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index a5ab237bb7041..9f7180d8dfd6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.stat._ import org.apache.spark.sql.functions.col import org.apache.spark.util.ArrayImplicits._ @@ -34,7 +35,7 @@ import org.apache.spark.util.ArrayImplicits._ */ @Stable final class DataFrameStatFunctions private[sql](protected val df: DataFrame) - extends api.DataFrameStatFunctions[Dataset] { + extends api.DataFrameStatFunctions { /** @inheritdoc */ def approxQuantile( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c147b6a56e024..61f9e6ff7c042 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -52,6 +52,7 @@ import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression @@ -215,7 +216,8 @@ private[sql] object Dataset { class Dataset[T] private[sql]( @DeveloperApi @Unstable @transient val queryExecution: QueryExecution, @DeveloperApi @Unstable @transient val encoder: Encoder[T]) - extends api.Dataset[T, Dataset] { + extends api.Dataset[T] { + type DS[U] = Dataset[U] type RGD = RelationalGroupedDataset @transient lazy val sparkSession: SparkSession = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index fcad1b721eaca..c645ba57e8f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderF import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.internal.TypedAggUtils.{aggKeyColumn, withInputType} @@ -41,7 +42,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( @transient val queryExecution: QueryExecution, private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) - extends api.KeyValueGroupedDataset[K, V, Dataset] { + extends api.KeyValueGroupedDataset[K, V] { type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL] private implicit def kEncoderImpl: Encoder[K] = kEncoder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index da4609135fd63..bd47a21a1e09b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.internal.ExpressionUtils.{column, generateAlias} @@ -52,8 +53,8 @@ class RelationalGroupedDataset protected[sql]( protected[sql] val df: DataFrame, private[sql] val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) - extends api.RelationalGroupedDataset[Dataset] { - type RGD = RelationalGroupedDataset + extends api.RelationalGroupedDataset { + import RelationalGroupedDataset._ import df.sparkSession._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 720b77b0b9fe5..137dbaed9f00a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -96,7 +96,7 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions, @transient private[sql] val initialSessionOptions: Map[String, String], @transient private val parentManagedJobTags: Map[String, String]) - extends api.SparkSession[Dataset] with Logging { self => + extends api.SparkSession with Logging { self => // The call site where this SparkSession was constructed. private val creationSite: CallSite = Utils.getCallSite() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 661e43fe73cae..c39018ff06fca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.catalog import java.util import org.apache.spark.sql.{api, DataFrame, Dataset} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.types.StructType /** @inheritdoc */ -abstract class Catalog extends api.Catalog[Dataset] { +abstract class Catalog extends api.Catalog { /** @inheritdoc */ override def listDatabases(): Dataset[Database] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala new file mode 100644 index 0000000000000..af91b57a6848b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql._ + +/** + * Conversions from sql interfaces to the Classic specific implementation. + * + * This class is mainly used by the implementation, but is also meant to be used by extension + * developers. + * + * We provide both a trait and an object. The trait is useful in situations where an extension + * developer needs to use these conversions in a project covering multiple Spark versions. They can + * create a shim for these conversions, the Spark 4+ version of the shim implements this trait, and + * shims for older versions do not. + */ +@DeveloperApi +trait ClassicConversions { + implicit def castToImpl(session: api.SparkSession): SparkSession = + session.asInstanceOf[SparkSession] + + implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] = + ds.asInstanceOf[Dataset[T]] + + implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset = + rgds.asInstanceOf[RelationalGroupedDataset] + + implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V]) + : KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] +} + +object ClassicConversions extends ClassicConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala index 653e1df4af679..7cf92db59067c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala @@ -16,10 +16,10 @@ */ package org.apache.spark.sql.streaming -import org.apache.spark.sql.{api, Dataset, SparkSession} +import org.apache.spark.sql.{api, SparkSession} /** @inheritdoc */ -trait StreamingQuery extends api.StreamingQuery[Dataset] { +trait StreamingQuery extends api.StreamingQuery { /** @inheritdoc */ override def sparkSession: SparkSession } From 3b34891e5b9c2694b7ffdc265290e25847dc3437 Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Thu, 19 Sep 2024 09:10:51 +0900 Subject: [PATCH 009/250] [SPARK-49684][CONNECT] Remove global locks from session and execution managers ### What changes were proposed in this pull request? Eliminate the use of global locks in the session and execution managers. Those locks residing in the streaming query manager cannot be easily removed because the tag and query maps seemingly need to be synchronised. ### Why are the changes needed? In order to achieve true scalability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48131 from changgyoopark-db/SPARK-49684. Authored-by: Changgyoo Park Signed-off-by: Hyukjin Kwon --- .../SparkConnectExecutionManager.scala | 59 ++++++++---------- .../service/SparkConnectSessionManager.scala | 60 ++++++++----------- .../SparkConnectStreamingQueryCache.scala | 22 +++---- 3 files changed, 61 insertions(+), 80 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 61b41f932199e..d66964b8d34bd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} import scala.collection.mutable import scala.concurrent.duration.FiniteDuration @@ -66,7 +66,6 @@ private[connect] class SparkConnectExecutionManager() extends Logging { /** Concurrent hash table containing all the current executions. */ private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] = new ConcurrentHashMap[ExecuteKey, ExecuteHolder]() - private val executionsLock = new Object /** Graveyard of tombstones of executions that were abandoned and removed. */ private val abandonedTombstones = CacheBuilder @@ -74,13 +73,12 @@ private[connect] class SparkConnectExecutionManager() extends Logging { .maximumSize(SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_ABANDONED_TOMBSTONES_SIZE)) .build[ExecuteKey, ExecuteInfo]() - /** None if there are no executions. Otherwise, the time when the last execution was removed. */ - @GuardedBy("executionsLock") - private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis()) + /** The time when the last execution was removed. */ + private var lastExecutionTimeMs: AtomicLong = new AtomicLong(System.currentTimeMillis()) /** Executor for the periodic maintenance */ - @GuardedBy("executionsLock") - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() /** * Create a new ExecuteHolder and register it with this global manager and with its session. @@ -118,11 +116,6 @@ private[connect] class SparkConnectExecutionManager() extends Logging { sessionHolder.addExecuteHolder(executeHolder) - executionsLock.synchronized { - if (!executions.isEmpty()) { - lastExecutionTimeMs = None - } - } logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.") schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started. @@ -151,11 +144,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { executions.remove(key) executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId) - executionsLock.synchronized { - if (executions.isEmpty) { - lastExecutionTimeMs = Some(System.currentTimeMillis()) - } - } + updateLastExecutionTime() logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.") @@ -197,7 +186,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { */ def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = { if (executions.isEmpty) { - Left(lastExecutionTimeMs.get) + Left(lastExecutionTimeMs.getAcquire()) } else { Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq) } @@ -212,22 +201,23 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } private[connect] def shutdown(): Unit = { - executionsLock.synchronized { - scheduledExecutor.foreach { executor => - ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) - } - scheduledExecutor = None + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { + ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } // note: this does not cleanly shut down the executions, but the server is shutting down. executions.clear() abandonedTombstones.invalidateAll() - executionsLock.synchronized { - if (lastExecutionTimeMs.isEmpty) { - lastExecutionTimeMs = Some(System.currentTimeMillis()) - } - } + updateLastExecutionTime() + } + + /** + * Updates the last execution time after the last execution has been removed. + */ + private def updateLastExecutionTime(): Unit = { + lastExecutionTimeMs.getAndUpdate(prev => prev.max(System.currentTimeMillis())) } /** @@ -235,16 +225,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * for executions that have not been closed, but are left with no RPC attached to them, and * removes them after a timeout. */ - private def schedulePeriodicChecks(): Unit = executionsLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { val interval = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_MAINTENANCE_INTERVAL) logInfo( log"Starting thread for cleanup of abandoned executions every " + log"${MDC(LogKeys.INTERVAL, interval)} ms") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try { val timeout = SparkEnv.get.conf.get(CONNECT_EXECUTE_MANAGER_DETACHED_TIMEOUT) @@ -256,6 +246,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { interval, interval, TimeUnit.MILLISECONDS) + } } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index fec01813de6e2..4ca3a80bfb985 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connect.service import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} -import javax.annotation.concurrent.GuardedBy +import java.util.concurrent.atomic.AtomicReference import scala.collection.mutable import scala.concurrent.duration.FiniteDuration @@ -40,8 +40,6 @@ import org.apache.spark.util.ThreadUtils */ class SparkConnectSessionManager extends Logging { - private val sessionsLock = new Object - private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] = new ConcurrentHashMap[SessionKey, SessionHolder]() @@ -52,8 +50,8 @@ class SparkConnectSessionManager extends Logging { .build[SessionKey, SessionHolderInfo]() /** Executor for the periodic maintenance */ - @GuardedBy("sessionsLock") - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() private def validateSessionId( key: SessionKey, @@ -75,8 +73,6 @@ class SparkConnectSessionManager extends Logging { val holder = getSession( key, Some(() => { - // Executed under sessionsState lock in getSession, to guard against concurrent removal - // and insertion into closedSessionsCache. validateSessionCreate(key) val holder = SessionHolder(key.userId, key.sessionId, newIsolatedSession()) holder.initializeSession() @@ -168,17 +164,14 @@ class SparkConnectSessionManager extends Logging { def closeSession(key: SessionKey): Unit = { val sessionHolder = removeSessionHolder(key) - // Rest of the cleanup outside sessionLock - the session cannot be accessed anymore by - // getOrCreateIsolatedSession. + // Rest of the cleanup: the session cannot be accessed anymore by getOrCreateIsolatedSession. sessionHolder.foreach(shutdownSessionHolder(_)) } private[connect] def shutdown(): Unit = { - sessionsLock.synchronized { - scheduledExecutor.foreach { executor => - ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) - } - scheduledExecutor = None + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { + ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } // note: this does not cleanly shut down the sessions, but the server is shutting down. @@ -199,16 +192,16 @@ class SparkConnectSessionManager extends Logging { * * The checks are looking to remove sessions that expired. */ - private def schedulePeriodicChecks(): Unit = sessionsLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { val interval = SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_MAINTENANCE_INTERVAL) logInfo( log"Starting thread for cleanup of expired sessions every " + log"${MDC(INTERVAL, interval)} ms") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try { val defaultInactiveTimeoutMs = @@ -221,6 +214,7 @@ class SparkConnectSessionManager extends Logging { interval, interval, TimeUnit.MILLISECONDS) + } } } @@ -255,24 +249,18 @@ class SparkConnectSessionManager extends Logging { // .. and remove them. toRemove.foreach { sessionHolder => - // This doesn't use closeSession to be able to do the extra last chance check under lock. - val removedSession = { - // Last chance - check expiration time and remove under lock if expired. - val info = sessionHolder.getSessionHolderInfo - if (shouldExpire(info, System.currentTimeMillis())) { - logInfo( - log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + - log"and will be closed.") - removeSessionHolder(info.key) - } else { - None + val info = sessionHolder.getSessionHolderInfo + if (shouldExpire(info, System.currentTimeMillis())) { + logInfo( + log"Found session ${MDC(SESSION_HOLD_INFO, info)} that expired " + + log"and will be closed.") + removeSessionHolder(info.key) + try { + shutdownSessionHolder(sessionHolder) + } catch { + case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) } } - // do shutdown and cleanup outside of lock. - try removedSession.foreach(shutdownSessionHolder(_)) - catch { - case NonFatal(ex) => logWarning("Unexpected exception closing session", ex) - } } logInfo("Finished periodic run of SparkConnectSessionManager maintenance.") } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 03719ddd87419..8241672d5107b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.Executors import java.util.concurrent.ScheduledExecutorService import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -185,10 +186,10 @@ private[connect] class SparkConnectStreamingQueryCache( // Visible for testing. private[service] def shutdown(): Unit = queryCacheLock.synchronized { - scheduledExecutor.foreach { executor => + val executor = scheduledExecutor.getAndSet(null) + if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } - scheduledExecutor = None } @GuardedBy("queryCacheLock") @@ -199,19 +200,19 @@ private[connect] class SparkConnectStreamingQueryCache( private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] private val taggedQueriesLock = new Object - @GuardedBy("queryCacheLock") - private var scheduledExecutor: Option[ScheduledExecutorService] = None + private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = + new AtomicReference[ScheduledExecutorService]() /** Schedules periodic checks if it is not already scheduled */ - private def schedulePeriodicChecks(): Unit = queryCacheLock.synchronized { - scheduledExecutor match { - case Some(_) => // Already running. - case None => + private def schedulePeriodicChecks(): Unit = { + var executor = scheduledExecutor.getAcquire() + if (executor == null) { + executor = Executors.newSingleThreadScheduledExecutor() + if (scheduledExecutor.compareAndExchangeRelease(null, executor) == null) { logInfo( log"Starting thread for polling streaming sessions " + log"every ${MDC(DURATION, sessionPollingPeriod.toMillis)}") - scheduledExecutor = Some(Executors.newSingleThreadScheduledExecutor()) - scheduledExecutor.get.scheduleAtFixedRate( + executor.scheduleAtFixedRate( () => { try periodicMaintenance() catch { @@ -221,6 +222,7 @@ private[connect] class SparkConnectStreamingQueryCache( sessionPollingPeriod.toMillis, sessionPollingPeriod.toMillis, TimeUnit.MILLISECONDS) + } } } From af45902d33c4d8e38a6427ac1d0c46fe057bb45a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 18 Sep 2024 20:11:21 -0400 Subject: [PATCH 010/250] [SPARK-49422][CONNECT][SQL] Add groupByKey to sql/api ### What changes were proposed in this pull request? This PR adds `Dataset.groupByKey(..)` to the shared interface. I forgot to add in the previous PR. ### Why are the changes needed? The shared interface needs to support all functionality. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48147 from hvanhovell/SPARK-49422-follow-up. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Dataset.scala | 24 ++----- .../org/apache/spark/sql/api/Dataset.scala | 22 ++++++ .../scala/org/apache/spark/sql/Dataset.scala | 68 +++---------------- 3 files changed, 39 insertions(+), 75 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 161a0d9d265f0..accfff9f2b073 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -524,27 +524,11 @@ class Dataset[T] private[sql] ( result(0) } - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) } - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { @@ -1480,4 +1464,10 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 284a69fe6ee3e..7a3d6b0e03877 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -1422,6 +1422,28 @@ abstract class Dataset[T] extends Serializable { */ def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func)) + /** + * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T, DS] + + /** + * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T, DS] = { + groupByKey(ToScalaUDF(func))(encoder) + } + /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 61f9e6ff7c042..ef628ca612b49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF} +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter @@ -865,24 +865,7 @@ class Dataset[T] private[sql]( Filter(condition.expr, logicalPlan) } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) @@ -914,13 +897,7 @@ class Dataset[T] private[sql]( rdd.reduce(func) } - /** - * (Scala-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -933,16 +910,6 @@ class Dataset[T] private[sql]( withGroupingKey.newColumns) } - /** - * (Java-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -1640,28 +1607,7 @@ class Dataset[T] private[sql]( new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into - * a target table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -2024,6 +1970,12 @@ class Dataset[T] private[sql]( @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// From 58d73fe8e7cbff9878539d31430f819eff9fc7a1 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 19 Sep 2024 09:16:23 +0900 Subject: [PATCH 011/250] Revert "[SPARK-49495][DOCS][FOLLOWUP] Enable GitHub Pages settings via .asf.yml" This reverts commit b86e5d2ab1fb17f8dcbb5b4d50f3361494270438. --- .asf.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.asf.yaml b/.asf.yaml index 91a5f9b2bb1a2..22042b355b2fa 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,8 +31,6 @@ github: merge: false squash: true rebase: true - ghp_branch: master - ghp_path: /docs/_site notifications: pullrequests: reviews@spark.apache.org From 376382711e200aa978008b25630cc54271fd419b Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 19 Sep 2024 09:16:28 +0900 Subject: [PATCH 012/250] Revert "[SPARK-49495][DOCS][FOLLOWUP] Fix Pandoc installation for GitHub Pages publication action" This reverts commit 7de71a2ec78d985c2a045f13c1275101b126cec4. --- .github/workflows/pages.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index f10dadf315a1b..083620427c015 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -63,9 +63,9 @@ jobs: ruby-version: '3.3' bundler-cache: true - name: Install Pandoc - run: | - sudo apt-get update -y - sudo apt-get install pandoc + uses: pandoc/actions/setup@d6abb76f6c8a1a9a5e15a5190c96a02aabffd1ee + with: + version: 3.3 - name: Install dependencies for documentation generation run: | cd docs From 8861f0f9af3f397921ba1204cf4f76f4e20680bb Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 19 Sep 2024 09:16:33 +0900 Subject: [PATCH 013/250] Revert "[SPARK-49495][DOCS] Document and Feature Preview on the master branch via Live GitHub Pages Updates" This reverts commit b1807095bef9c6d98e60bdc2669c8af93bc68ad4. --- .github/workflows/pages.yml | 90 ------------------------------------- 1 file changed, 90 deletions(-) delete mode 100644 .github/workflows/pages.yml diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml deleted file mode 100644 index 083620427c015..0000000000000 --- a/.github/workflows/pages.yml +++ /dev/null @@ -1,90 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -name: GitHub Pages deployment - -on: - push: - branches: - - master - -concurrency: - group: 'docs preview' - cancel-in-progress: true - -jobs: - docs: - name: Build and deploy documentation - runs-on: ubuntu-latest - permissions: - id-token: write - pages: write - env: - SPARK_TESTING: 1 # Reduce some noise in the logs - RELEASE_VERSION: 'In-Progress' - steps: - - name: Checkout Spark repository - uses: actions/checkout@v4 - with: - repository: apache/spark - ref: 'master' - - name: Install Java 17 - uses: actions/setup-java@v4 - with: - distribution: zulu - java-version: 17 - - name: Install Python 3.9 - uses: actions/setup-python@v5 - with: - python-version: '3.9' - architecture: x64 - cache: 'pip' - - name: Install Python dependencies - run: pip install --upgrade -r dev/requirements.txt - - name: Install Ruby for documentation generation - uses: ruby/setup-ruby@v1 - with: - ruby-version: '3.3' - bundler-cache: true - - name: Install Pandoc - uses: pandoc/actions/setup@d6abb76f6c8a1a9a5e15a5190c96a02aabffd1ee - with: - version: 3.3 - - name: Install dependencies for documentation generation - run: | - cd docs - gem install bundler -v 2.4.22 -n /usr/local/bin - bundle install --retry=100 - - name: Run documentation build - run: | - sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml - sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml - sed -i".tmp3" "s/'facetFilters':.*$/'facetFilters': [\"version:$RELEASE_VERSION\"]/g" docs/_config.yml - sed -i".tmp4" 's/__version__: str = .*$/__version__: str = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py - cd docs - SKIP_RDOC=1 bundle exec jekyll build - - name: Setup Pages - uses: actions/configure-pages@v5 - - name: Upload artifact - uses: actions/upload-pages-artifact@v3 - with: - path: 'docs/_site' - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v4 From f3c8d26eb0c3fd7f77950eb08c70bb2a9ab6493c Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 19 Sep 2024 10:36:03 +0900 Subject: [PATCH 014/250] Revert "[SPARK-49422][CONNECT][SQL] Add groupByKey to sql/api" This reverts commit af45902d33c4d8e38a6427ac1d0c46fe057bb45a. --- .../scala/org/apache/spark/sql/Dataset.scala | 24 +++++-- .../org/apache/spark/sql/api/Dataset.scala | 22 ------ .../scala/org/apache/spark/sql/Dataset.scala | 68 ++++++++++++++++--- 3 files changed, 75 insertions(+), 39 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index accfff9f2b073..161a0d9d265f0 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -524,11 +524,27 @@ class Dataset[T] private[sql] ( result(0) } - /** @inheritdoc */ + /** + * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 3.5.0 + */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) } + /** + * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 3.5.0 + */ + def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + groupByKey(ToScalaUDF(func))(encoder) + /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { @@ -1464,10 +1480,4 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) - - /** @inheritdoc */ - override def groupByKey[K]( - func: MapFunction[T, K], - encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 7a3d6b0e03877..284a69fe6ee3e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -1422,28 +1422,6 @@ abstract class Dataset[T] extends Serializable { */ def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func)) - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 2.0.0 - */ - def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T, DS] - - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 2.0.0 - */ - def groupByKey[K]( - func: MapFunction[T, K], - encoder: Encoder[K]): KeyValueGroupedDataset[K, T, DS] = { - groupByKey(ToScalaUDF(func))(encoder) - } - /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef628ca612b49..61f9e6ff7c042 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter @@ -865,7 +865,24 @@ class Dataset[T] private[sql]( Filter(condition.expr, logicalPlan) } - /** @inheritdoc */ + /** + * Groups the Dataset using the specified columns, so we can run aggregation on them. See + * [[RelationalGroupedDataset]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * ds.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * ds.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + * + * @group untypedrel + * @since 2.0.0 + */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) @@ -897,7 +914,13 @@ class Dataset[T] private[sql]( rdd.reduce(func) } - /** @inheritdoc */ + /** + * (Scala-specific) + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. + * + * @group typedrel + * @since 2.0.0 + */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -910,6 +933,16 @@ class Dataset[T] private[sql]( withGroupingKey.newColumns) } + /** + * (Java-specific) + * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + groupByKey(ToScalaUDF(func))(encoder) + /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -1607,7 +1640,28 @@ class Dataset[T] private[sql]( new DataFrameWriterV2Impl[T](table, this) } - /** @inheritdoc */ + /** + * Merges a set of updates, insertions, and deletions based on a source table into + * a target table. + * + * Scala Examples: + * {{{ + * spark.table("source") + * .mergeInto("target", $"source.id" === $"target.id") + * .whenMatched($"salary" === 100) + * .delete() + * .whenNotMatched() + * .insertAll() + * .whenNotMatchedBySource($"salary" === 100) + * .update(Map( + * "salary" -> lit(200) + * )) + * .merge() + * }}} + * + * @group basic + * @since 4.0.0 + */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -1970,12 +2024,6 @@ class Dataset[T] private[sql]( @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) - /** @inheritdoc */ - override def groupByKey[K]( - func: MapFunction[T, K], - encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] - //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// From 3bdf146bbee58d207afaadc92024d9f6c4b941dd Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 19 Sep 2024 09:27:38 +0200 Subject: [PATCH 015/250] [SPARK-49611][SQL][FOLLOW-UP] Fix wrong results of collations() TVF ### What changes were proposed in this pull request? Fix of accent sensitive and case sensitive column results. ### Why are the changes needed? When initial PR was introduced, ICU collation listing ended up with different order of generating columns so results were wrong. ### Does this PR introduce _any_ user-facing change? No, as spark 4.0 was not released yet. ### How was this patch tested? Existing test in CollationSuite.scala, which was wrong in the first place. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48152 from mihailom-db/tvf-collations-followup. Authored-by: Mihailo Milosevic Signed-off-by: Max Gekk --- .../sql/catalyst/util/CollationFactory.java | 4 ++-- .../org/apache/spark/sql/CollationSuite.scala | 24 +++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 4b88e15e8ed72..87558971042e0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -773,8 +773,8 @@ protected CollationMeta buildCollationMeta() { ICULocaleMap.get(locale).getDisplayCountry(), VersionInfo.ICU_VERSION.toString(), COLLATION_PAD_ATTRIBUTE, - caseSensitivity == CaseSensitivity.CS, - accentSensitivity == AccentSensitivity.AS); + accentSensitivity == AccentSensitivity.AS, + caseSensitivity == CaseSensitivity.CS); } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index d5d18b1ab081c..73fd897e91f53 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1661,17 +1661,17 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Row("SYSTEM", "BUILTIN", "UNICODE", "", "", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "UNICODE_AI", "", "", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "UNICODE_CI", "", "", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "UNICODE_CI", "", "", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "UNICODE_CI_AI", "", "", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "af", "Afrikaans", "", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "af_AI", "Afrikaans", "", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "af_CI", "Afrikaans", "", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "af_CI", "Afrikaans", "", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "af_CI_AI", "Afrikaans", "", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) @@ -1683,9 +1683,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("SYSTEM", "BUILTIN", "zh_Hant_HKG", "Chinese", "Hong Kong SAR China", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_AI", "Chinese", "Hong Kong SAR China", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI", "Chinese", "Hong Kong SAR China", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI", "Chinese", "Hong Kong SAR China", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hant_HKG_CI_AI", "Chinese", "Hong Kong SAR China", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) @@ -1693,9 +1693,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("SYSTEM", "BUILTIN", "zh_Hans_SGP", "Chinese", "Singapore", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_AI", "Chinese", "Singapore", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI", "Chinese", "Singapore", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI", "Chinese", "Singapore", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "zh_Hans_SGP_CI_AI", "Chinese", "Singapore", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) @@ -1704,17 +1704,17 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq(Row("SYSTEM", "BUILTIN", "en_USA", "English", "United States", "ACCENT_SENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "en_USA_AI", "English", "United States", - "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), - Row("SYSTEM", "BUILTIN", "en_USA_CI", "English", "United States", "ACCENT_INSENSITIVE", "CASE_SENSITIVE", "NO_PAD", "75.1.0.0"), + Row("SYSTEM", "BUILTIN", "en_USA_CI", "English", "United States", + "ACCENT_SENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"), Row("SYSTEM", "BUILTIN", "en_USA_CI_AI", "English", "United States", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE", "NO_PAD", "75.1.0.0"))) checkAnswer(sql("SELECT NAME, LANGUAGE, ACCENT_SENSITIVITY, CASE_SENSITIVITY " + "FROM collations() WHERE COUNTRY = 'United States'"), Seq(Row("en_USA", "English", "ACCENT_SENSITIVE", "CASE_SENSITIVE"), - Row("en_USA_AI", "English", "ACCENT_SENSITIVE", "CASE_INSENSITIVE"), - Row("en_USA_CI", "English", "ACCENT_INSENSITIVE", "CASE_SENSITIVE"), + Row("en_USA_AI", "English", "ACCENT_INSENSITIVE", "CASE_SENSITIVE"), + Row("en_USA_CI", "English", "ACCENT_SENSITIVE", "CASE_INSENSITIVE"), Row("en_USA_CI_AI", "English", "ACCENT_INSENSITIVE", "CASE_INSENSITIVE"))) checkAnswer(sql("SELECT NAME FROM collations() WHERE ICU_VERSION is null"), From 492d1b14c0d19fa89b9ce9c0e48fc0e4c120b70c Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Thu, 19 Sep 2024 11:09:40 +0200 Subject: [PATCH 016/250] [SPARK-48782][SQL] Add support for executing procedures in catalogs ### What changes were proposed in this pull request? This PR adds support for executing procedures in catalogs. ### Why are the changes needed? These changes are needed per [discussed and voted](https://lists.apache.org/thread/w586jr53fxwk4pt9m94b413xyjr1v25m) SPIP tracked in [SPARK-44167](https://issues.apache.org/jira/browse/SPARK-44167). ### Does this PR introduce _any_ user-facing change? Yes. This PR adds CALL commands. ### How was this patch tested? This PR comes with tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47943 from aokolnychyi/spark-48782. Authored-by: Anton Okolnychyi Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 6 + docs/sql-ref-ansi-compliance.md | 1 + .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 + .../sql/catalyst/parser/SqlBaseParser.g4 | 5 + .../procedures/ProcedureParameter.java | 5 + .../catalog/procedures/UnboundProcedure.java | 6 + .../sql/catalyst/analysis/Analyzer.scala | 65 +- .../catalyst/analysis/AnsiTypeCoercion.scala | 1 + .../sql/catalyst/analysis/CheckAnalysis.scala | 8 + .../sql/catalyst/analysis/TypeCoercion.scala | 16 + .../spark/sql/catalyst/analysis/package.scala | 6 +- .../catalyst/analysis/v2ResolutionPlans.scala | 17 +- .../sql/catalyst/parser/AstBuilder.scala | 22 + .../logical/ExecutableDuringAnalysis.scala | 28 + .../plans/logical/FunctionBuilderBase.scala | 36 +- .../catalyst/plans/logical/MultiResult.scala | 30 + .../catalyst/plans/logical/v2Commands.scala | 67 +- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../sql/catalyst/trees/TreePatterns.scala | 1 + .../catalog/CatalogV2Implicits.scala | 7 + .../sql/errors/QueryCompilationErrors.scala | 7 + .../connector/catalog/InMemoryCatalog.scala | 19 +- .../catalyst/analysis/InvokeProcedures.scala | 71 ++ .../spark/sql/execution/MultiResultExec.scala | 36 + .../spark/sql/execution/SparkStrategies.scala | 2 + .../sql/execution/command/commands.scala | 11 +- .../datasources/v2/DataSourceV2Strategy.scala | 6 +- .../datasources/v2/ExplainOnlySparkPlan.scala | 38 + .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql-tests/results/ansi/keywords.sql.out | 2 + .../sql-tests/results/keywords.sql.out | 1 + .../spark/sql/connector/ProcedureSuite.scala | 654 ++++++++++++++++++ .../ThriftServerWithSparkContextSuite.scala | 2 +- .../sql/hive/HiveSessionStateBuilder.scala | 3 +- 34 files changed, 1162 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 6463cc2c12da7..72985de6631f0 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1456,6 +1456,12 @@ ], "sqlState" : "2203G" }, + "FAILED_TO_LOAD_ROUTINE" : { + "message" : [ + "Failed to load routine ." + ], + "sqlState" : "38000" + }, "FAILED_TO_PARSE_TOO_COMPLEX" : { "message" : [ "The statement, including potential SQL functions and referenced views, was too complex to parse.", diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index fff6906457f7d..12dff1e325c49 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL. |BY|non-reserved|non-reserved|reserved| |BYTE|non-reserved|non-reserved|non-reserved| |CACHE|non-reserved|non-reserved|non-reserved| +|CALL|reserved|non-reserved|reserved| |CALLED|non-reserved|non-reserved|non-reserved| |CASCADE|non-reserved|non-reserved|non-reserved| |CASE|reserved|non-reserved|reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index e704f9f58b964..de28041acd41f 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS'; BY: 'BY'; BYTE: 'BYTE'; CACHE: 'CACHE'; +CALL: 'CALL'; CALLED: 'CALLED'; CASCADE: 'CASCADE'; CASE: 'CASE'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f13dde773496a..e591a43b84d1a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -298,6 +298,10 @@ statement LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN (OPTIONS options=propertyList)? #createIndex | DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference #dropIndex + | CALL identifierReference + LEFT_PAREN + (functionArgument (COMMA functionArgument)*)? + RIGHT_PAREN #call | unsupportedHiveNativeCommands .*? #failNativeCommand ; @@ -1851,6 +1855,7 @@ nonReserved | BY | BYTE | CACHE + | CALL | CALLED | CASCADE | CASE diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java index 90d531ae21892..18c76833c5879 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java @@ -32,6 +32,11 @@ */ @Evolving public interface ProcedureParameter { + /** + * A field metadata key that indicates whether an argument is passed by name. + */ + String BY_NAME_METADATA_KEY = "BY_NAME"; + /** * Creates a builder for an IN procedure parameter. * diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java index ee9a09055243b..1a91fd21bf07e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java @@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure { * validate if the input types are compatible while binding or delegate that to Spark. Regardless, * Spark will always perform the final validation of the arguments and rearrange them as needed * based on {@link BoundProcedure#parameters() reported parameters}. + *

    + * The provided {@code inputType} is based on the procedure arguments. If an argument is passed + * by name, its metadata will indicate this with {@link ProcedureParameter#BY_NAME_METADATA_KEY} + * set to {@code true}. In such cases, the field name will match the name of the target procedure + * parameter. If the argument is not named, {@link ProcedureParameter#BY_NAME_METADATA_KEY} will + * not be set and the name will be assigned randomly. * * @param inputType the input types to bind to * @return the bound procedure that is most suitable for the given input types diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0164af945ca28..9e5b1d1254c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.{Failure, Random, Success, Try} -import org.apache.spark.{SparkException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkException, SparkThrowable, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ @@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction => V2AggregateFunction, ScalarFunction, UnboundFunction} +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: + ResolveProcedures :: + BindProcedures :: ResolveTableSpec :: ResolveAliases :: ResolveSubquery :: @@ -2611,6 +2614,66 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } + /** + * A rule that resolves procedures. + */ + object ResolveProcedures extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(UNRESOLVED_PROCEDURE), ruleId) { + case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)), args, execute) => + val procedureCatalog = catalog.asProcedureCatalog + val procedure = load(procedureCatalog, ident) + Call(ResolvedProcedure(procedureCatalog, ident, procedure), args, execute) + } + + private def load(catalog: ProcedureCatalog, ident: Identifier): UnboundProcedure = { + try { + catalog.loadProcedure(ident) + } catch { + case e: Exception if !e.isInstanceOf[SparkThrowable] => + val nameParts = catalog.name +: ident.asMultipartIdentifier + throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e) + } + } + } + + /** + * A rule that binds procedures to the input types and rearranges arguments as needed. + */ + object BindProcedures extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure), args, execute) + if args.forall(_.resolved) => + val inputType = extractInputType(args) + val bound = unbound.bind(inputType) + validateParameterModes(bound) + val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound, args) + Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute) + } + + private def extractInputType(args: Seq[Expression]): StructType = { + val fields = args.zipWithIndex.map { + case (NamedArgumentExpression(name, value), _) => + StructField(name, value.dataType, value.nullable, byNameMetadata) + case (arg, index) => + StructField(s"param$index", arg.dataType, arg.nullable) + } + StructType(fields) + } + + private def byNameMetadata: Metadata = { + new MetadataBuilder() + .putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, value = true) + .build() + } + + private def validateParameterModes(procedure: BoundProcedure): Unit = { + procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach { param => + throw SparkException.internalError(s"Unsupported parameter mode: ${param.mode}") + } + } + } + /** * This rule resolves and rewrites subqueries inside expressions. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index 17b1c4e249f57..3afe0ec8e9a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = UnpivotCoercion :: WidenSetOperationTypes :: + ProcedureArgumentCoercion :: new AnsiCombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 752ff49e1f90d..5a9d5cd87ecc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -676,6 +676,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB varName, c.defaultExpr.originalSQL) + case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure => + c.checkArgTypes() match { + case mismatch: TypeCheckResult.DataTypeMismatch => + c.dataTypeMismatch("CALL", mismatch) + case _ => + throw SparkException.internalError("Invalid input for procedure") + } + case _ => // Falls back to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 08c5b3531b4c8..5983346ff1e27 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation} @@ -202,6 +203,20 @@ abstract class TypeCoercionBase { } } + /** + * A type coercion rule that implicitly casts procedure arguments to expected types. + */ + object ProcedureArgumentCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) if c.resolved => + val expectedDataTypes = procedure.parameters.map(_.dataType) + val coercedArgs = args.zip(expectedDataTypes).map { + case (arg, expectedType) => implicitCast(arg, expectedType).getOrElse(arg) + } + c.copy(args = coercedArgs) + } + } + /** * Widens the data types of the [[Unpivot]] values. */ @@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = UnpivotCoercion :: WidenSetOperationTypes :: + ProcedureArgumentCoercion :: new CombinedTypeCoercionRule( CollationTypeCasts :: InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala index c0689eb121679..daab9e4d78bf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala @@ -67,9 +67,13 @@ package object analysis { } def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch): Nothing = { + dataTypeMismatch(toSQLExpr(expr), mismatch) + } + + def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing = { throw new AnalysisException( errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}", - messageParameters = mismatch.messageParameters + ("sqlExpr" -> toSQLExpr(expr)), + messageParameters = mismatch.messageParameters + ("sqlExpr" -> sqlExpr), origin = t.origin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index ecdf40e87a894..dee78b8f03af4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC} +import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE} import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, ProcedureCatalog, Table, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.connector.catalog.procedures.Procedure import org.apache.spark.sql.types.{DataType, StructField} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ArrayImplicits._ @@ -135,6 +136,12 @@ case class UnresolvedFunctionName( case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean = false) extends UnresolvedLeafNode +/** + * A procedure identifier that should be resolved into [[ResolvedProcedure]]. + */ +case class UnresolvedProcedure(nameParts: Seq[String]) extends UnresolvedLeafNode { + final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE) +} /** * A resolved leaf node whose statistics has no meaning. @@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field: StructField) extends Fiel case class ResolvedFieldPosition(position: ColumnPosition) extends FieldPosition +case class ResolvedProcedure( + catalog: ProcedureCatalog, + ident: Identifier, + procedure: Procedure) extends LeafNodeWithoutStats { + override def output: Seq[Attribute] = Nil +} /** * A plan containing resolved persistent views. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index cb0e0e35c3704..52529bb4b789b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5697,6 +5697,28 @@ class AstBuilder extends DataTypeAstBuilder ctx.EXISTS != null) } + /** + * Creates a plan for invoking a procedure. + * + * For example: + * {{{ + * CALL multi_part_name(v1, v2, ...); + * CALL multi_part_name(v1, param2 => v2, ...); + * CALL multi_part_name(param1 => v1, param2 => v2, ...); + * }}} + */ + override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) { + val procedure = withIdentClause(ctx.identifierReference, UnresolvedProcedure) + val args = ctx.functionArgument.asScala.map { + case expr if expr.namedArgumentExpression != null => + val namedExpr = expr.namedArgumentExpression + NamedArgumentExpression(namedExpr.key.getText, expression(namedExpr.value)) + case expr => + expression(expr) + }.toSeq + Call(procedure, args) + } + /** * Create a TimestampAdd expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala new file mode 100644 index 0000000000000..dc8dbf701f6a9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +/** + * A logical plan node that requires execution during analysis. + */ +trait ExecutableDuringAnalysis extends LogicalPlan { + /** + * Returns the logical plan node that should be used for EXPLAIN. + */ + def stageForExplain(): LogicalPlan +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala index 4701f4ea1e172..75b2fcd3a5f34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Expression, NamedArgumentExpression} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.util.ArrayImplicits._ @@ -122,12 +124,32 @@ object NamedParametersSupport { functionSignature: FunctionSignature, args: Seq[Expression], functionName: String): Seq[Expression] = { - val parameters: Seq[InputParameter] = functionSignature.parameters + defaultRearrange(functionName, functionSignature.parameters, args) + } + + final def defaultRearrange(procedure: BoundProcedure, args: Seq[Expression]): Seq[Expression] = { + defaultRearrange( + procedure.name, + procedure.parameters.map(toInputParameter).toSeq, + args) + } + + private def toInputParameter(param: ProcedureParameter): InputParameter = { + val defaultValue = Option(param.defaultValueExpression).map { expr => + ResolveDefaultColumns.analyze(param.name, param.dataType, expr, "CALL") + } + InputParameter(param.name, defaultValue) + } + + private def defaultRearrange( + routineName: String, + parameters: Seq[InputParameter], + args: Seq[Expression]): Seq[Expression] = { if (parameters.dropWhile(_.default.isEmpty).exists(_.default.isEmpty)) { - throw QueryCompilationErrors.unexpectedRequiredParameter(functionName, parameters) + throw QueryCompilationErrors.unexpectedRequiredParameter(routineName, parameters) } - val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, functionName) + val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args, routineName) val namedParameters: Seq[InputParameter] = parameters.drop(positionalArgs.size) // The following loop checks for the following: @@ -140,12 +162,12 @@ object NamedParametersSupport { namedArgs.foreach { namedArg => val parameterName = namedArg.key if (!parameterNamesSet.contains(parameterName)) { - throw QueryCompilationErrors.unrecognizedParameterName(functionName, namedArg.key, + throw QueryCompilationErrors.unrecognizedParameterName(routineName, namedArg.key, parameterNamesSet.toSeq) } if (positionalParametersSet.contains(parameterName)) { throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference( - functionName, namedArg.key) + routineName, namedArg.key) } } @@ -154,7 +176,7 @@ object NamedParametersSupport { val validParameterSizes = Array.range(parameters.count(_.default.isEmpty), parameters.size + 1).toImmutableArraySeq throw QueryCompilationErrors.wrongNumArgsError( - functionName, validParameterSizes, args.length) + routineName, validParameterSizes, args.length) } // This constructs a map from argument name to value for argument rearrangement. @@ -168,7 +190,7 @@ object NamedParametersSupport { namedArgMap.getOrElse( param.name, if (param.default.isEmpty) { - throw QueryCompilationErrors.requiredParameterNotFound(functionName, param.name, index) + throw QueryCompilationErrors.requiredParameterNotFound(routineName, param.name, index) } else { param.default.get } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala new file mode 100644 index 0000000000000..f249e5c87eba2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class MultiResult(children: Seq[LogicalPlan]) extends LogicalPlan { + + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[LogicalPlan]): MultiResult = { + copy(children = newChildren) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index fdd43404e1d98..b465e0e11612f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -19,17 +19,22 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, UnresolvedException, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.FunctionResource import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, MetadataAttribute, NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema import org.apache.spark.sql.catalyst.trees.BinaryLike -import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections} +import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr} import org.apache.spark.sql.connector.catalog._ +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -1571,3 +1576,61 @@ case class SetVariable( override protected def withNewChildInternal(newChild: LogicalPlan): SetVariable = copy(sourceQuery = newChild) } + +/** + * The logical plan of the CALL statement. + */ +case class Call( + procedure: LogicalPlan, + args: Seq[Expression], + execute: Boolean = true) + extends UnaryNode with ExecutableDuringAnalysis { + + override def output: Seq[Attribute] = Nil + + override def child: LogicalPlan = procedure + + def bound: Boolean = procedure match { + case ResolvedProcedure(_, _, _: BoundProcedure) => true + case _ => false + } + + def checkArgTypes(): TypeCheckResult = { + require(resolved && bound, "can check arg types only after resolution and binding") + + val params = procedure match { + case ResolvedProcedure(_, _, bound: BoundProcedure) => bound.parameters + } + require(args.length == params.length, "number of args and params must match after binding") + + args.zip(params).zipWithIndex.collectFirst { + case ((arg, param), idx) + if !DataType.equalsIgnoreCompatibleNullability(arg.dataType, param.dataType) => + DataTypeMismatch( + errorSubClass = "UNEXPECTED_INPUT_TYPE", + messageParameters = Map( + "paramIndex" -> ordinalNumber(idx), + "requiredType" -> toSQLType(param.dataType), + "inputSql" -> toSQLExpr(arg), + "inputType" -> toSQLType(arg.dataType))) + }.getOrElse(TypeCheckSuccess) + } + + override def simpleString(maxFields: Int): String = { + val name = procedure match { + case ResolvedProcedure(catalog, ident, _) => + s"${quoteIfNeeded(catalog.name)}.${ident.quoted}" + case UnresolvedProcedure(nameParts) => + nameParts.quoted + } + val argsString = truncatedString(args, ", ", maxFields) + s"Call $name($argsString)" + } + + override def stageForExplain(): Call = { + copy(execute = false) + } + + override protected def withNewChildInternal(newChild: LogicalPlan): Call = + copy(procedure = newChild) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index c70b43f0db173..b5556cbae7cd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -54,6 +54,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveProcedures" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGenerate" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics" :: "org.apache.spark.sql.catalyst.analysis.ResolveHigherOrderFunctions" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 826ac52c2b817..0f1c98b53e0b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -157,6 +157,7 @@ object TreePattern extends Enumeration { // Unresolved Plan patterns (Alphabetically ordered) val UNRESOLVED_FUNC: Value = Value + val UNRESOLVED_PROCEDURE: Value = Value val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value val UNRESOLVED_TRANSPOSE: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 65bdae85be12a..282350dda67d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -126,6 +126,13 @@ private[sql] object CatalogV2Implicits { case _ => throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "functions") } + + def asProcedureCatalog: ProcedureCatalog = plugin match { + case procedureCatalog: ProcedureCatalog => + procedureCatalog + case _ => + throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "procedures") + } } implicit class NamespaceHelper(namespace: Array[String]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ad0e1d07bf93d..0b5255e95f073 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -853,6 +853,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat origin = origin) } + def failedToLoadRoutineError(nameParts: Seq[String], e: Exception): Throwable = { + new AnalysisException( + errorClass = "FAILED_TO_LOAD_ROUTINE", + messageParameters = Map("routineName" -> toSQLId(nameParts)), + cause = Some(e)) + } + def unresolvedRoutineError(name: FunctionIdentifier, searchPath: Seq[String]): Throwable = { new AnalysisException( errorClass = "UNRESOLVED_ROUTINE", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala index 8d8d2317f0986..411a88b8765f6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala @@ -24,10 +24,13 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException, NoSuchNamespaceException} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction +import org.apache.spark.sql.connector.catalog.procedures.UnboundProcedure -class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { +class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog with ProcedureCatalog { protected val functions: util.Map[Identifier, UnboundFunction] = new ConcurrentHashMap[Identifier, UnboundFunction]() + protected val procedures: util.Map[Identifier, UnboundProcedure] = + new ConcurrentHashMap[Identifier, UnboundProcedure]() override protected def allNamespaces: Seq[Seq[String]] = { (tables.keySet.asScala.map(_.namespace.toSeq) ++ @@ -63,4 +66,18 @@ class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog { def clearFunctions(): Unit = { functions.clear() } + + override def loadProcedure(ident: Identifier): UnboundProcedure = { + val procedure = procedures.get(ident) + if (procedure == null) throw new RuntimeException("Procedure not found: " + ident) + procedure + } + + def createProcedure(ident: Identifier, procedure: UnboundProcedure): UnboundProcedure = { + procedures.put(ident, procedure) + } + + def clearProcedures(): Unit = { + procedures.clear() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala new file mode 100644 index 0000000000000..c7320d350a7ff --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.jdk.CollectionConverters.IteratorHasAsScala + +import org.apache.spark.SparkException +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow} +import org.apache.spark.sql.catalyst.plans.logical.{Call, LocalRelation, LogicalPlan, MultiResult} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.util.ArrayImplicits._ + +class InvokeProcedures(session: SparkSession) extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case c: Call if c.resolved && c.bound && c.execute && c.checkArgTypes().isSuccess => + session.sessionState.optimizer.execute(c) match { + case Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _) => + invoke(procedure, args) + case _ => + throw SparkException.internalError("Unexpected plan for optimized CALL statement") + } + } + + private def invoke(procedure: BoundProcedure, args: Seq[Expression]): LogicalPlan = { + val input = toInternalRow(args) + val scanIterator = procedure.call(input) + val relations = scanIterator.asScala.map(toRelation).toSeq + relations match { + case Nil => LocalRelation(Nil) + case Seq(relation) => relation + case _ => MultiResult(relations) + } + } + + private def toRelation(scan: Scan): LogicalPlan = scan match { + case s: LocalScan => + val attrs = DataTypeUtils.toAttributes(s.readSchema) + val data = s.rows.toImmutableArraySeq + LocalRelation(attrs, data) + case _ => + throw SparkException.internalError( + s"Only local scans are temporarily supported as procedure output: ${scan.getClass.getName}") + } + + private def toInternalRow(args: Seq[Expression]): InternalRow = { + require(args.forall(_.foldable), "args must be foldable") + val values = args.map(_.eval()).toArray + new GenericInternalRow(values) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala new file mode 100644 index 0000000000000..c2b12b053c927 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class MultiResultExec(children: Seq[SparkPlan]) extends SparkPlan { + + override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Nil) + + override protected def doExecute(): RDD[InternalRow] = { + children.lastOption.map(_.execute()).getOrElse(sparkContext.emptyRDD) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[SparkPlan]): MultiResultExec = { + copy(children = newChildren) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6d940a30619fb..aee735e48fc5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -1041,6 +1041,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) => WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket, options, staticPartitions) :: Nil + case MultiResult(children) => + MultiResultExec(children.map(planLater)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index ea2736b2c1266..ea9d53190546e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, SupervisingCommand} +import org.apache.spark.sql.catalyst.plans.logical.{Command, ExecutableDuringAnalysis, LogicalPlan, SupervisingCommand} import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike} import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode, LeafExecNode, SparkPlan, UnaryExecNode} @@ -165,14 +165,19 @@ case class ExplainCommand( // Run through the optimizer to generate the physical plan. override def run(sparkSession: SparkSession): Seq[Row] = try { - val outputString = sparkSession.sessionState.executePlan(logicalPlan, CommandExecutionMode.SKIP) - .explainString(mode) + val stagedLogicalPlan = stageForAnalysis(logicalPlan) + val qe = sparkSession.sessionState.executePlan(stagedLogicalPlan, CommandExecutionMode.SKIP) + val outputString = qe.explainString(mode) Seq(Row(outputString)) } catch { case NonFatal(cause) => ("Error occurred during query planning: \n" + cause.getMessage).split("\n") .map(Row(_)).toImmutableArraySeq } + private def stageForAnalysis(plan: LogicalPlan): LogicalPlan = plan transform { + case p: ExecutableDuringAnalysis => p.stageForExplain() + } + def withTransformedSupervisedPlan(transformer: LogicalPlan => LogicalPlan): LogicalPlan = copy(logicalPlan = transformer(logicalPlan)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d7f46c32f99a0..76cd33b815edd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -32,8 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, - IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog, TruncatableTable} import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} @@ -554,6 +553,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat systemScope, pattern) :: Nil + case c: Call => + ExplainOnlySparkPlan(c) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala new file mode 100644 index 0000000000000..bbf56eaa71184 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.LeafLike +import org.apache.spark.sql.execution.SparkPlan + +case class ExplainOnlySparkPlan(toExplain: LogicalPlan) extends SparkPlan with LeafLike[SparkPlan] { + + override def output: Seq[Attribute] = Nil + + override def simpleString(maxFields: Int): String = { + toExplain.simpleString(maxFields) + } + + override protected def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index a2539828733fc..0d0258f11efb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.artifact.ArtifactManager -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -206,6 +206,7 @@ abstract class BaseSessionStateBuilder( ResolveWriteToStream +: new EvalSubqueriesForTimeTravel +: new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index 6497a46c68ccd..7c694503056ab 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL true CALLED false CASCADE false CASE true @@ -378,6 +379,7 @@ ANY AS AUTHORIZATION BOTH +CALL CASE CAST CHECK diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 0dfd62599afa6..2c16d961b1313 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -32,6 +32,7 @@ BUCKETS false BY false BYTE false CACHE false +CALL false CALLED false CASCADE false CASE false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala new file mode 100644 index 0000000000000..e39a1b7ea340a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import java.util.Collections + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.{SPARK_DOC_ROOT, SparkException, SparkNumberFormatException} +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog} +import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure, ProcedureParameter, UnboundProcedure} +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode +import org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode.{IN, INOUT, OUT} +import org.apache.spark.sql.connector.read.{LocalScan, Scan} +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLType, toSQLValue} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { + + before { + spark.conf.set(s"spark.sql.catalog.cat", classOf[InMemoryCatalog].getName) + } + + after { + spark.sessionState.catalogManager.reset() + spark.sessionState.conf.unsetConf(s"spark.sql.catalog.cat") + } + + private def catalog: InMemoryCatalog = { + val catalog = spark.sessionState.catalogManager.catalog("cat") + catalog.asInstanceOf[InMemoryCatalog] + } + + test("position arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(5, 5)"), Row(10) :: Nil) + } + + test("named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(in2 => 3, in1 => 5)"), Row(8) :: Nil) + } + + test("position and named arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(3, in2 => 1)"), Row(4) :: Nil) + } + + test("foldable expressions") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer(sql("CALL cat.ns.sum(1 + 1, in2 => 2)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(in2 => 1, in1 => 2 + 1)"), Row(4) :: Nil) + checkAnswer(sql("CALL cat.ns.sum((1 + 1) * 2, in2 => (2 + 1) / 3)"), Row(5) :: Nil) + } + + test("type coercion") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundLongSum) + checkAnswer(sql("CALL cat.ns.sum(1, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1L, 2)"), Row(3) :: Nil) + checkAnswer(sql("CALL cat.ns.sum(1, 2L)"), Row(3) :: Nil) + } + + test("multiple output rows") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer( + sql("CALL cat.ns.complex('X', 'Y', 3)"), + Row(1, "X1", "Y1") :: Row(2, "X2", "Y2") :: Row(3, "X3", "Y3") :: Nil) + } + + test("parameters with default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "complex"), UnboundComplexProcedure) + checkAnswer(sql("CALL cat.ns.complex()"), Row(1, "A1", "B1") :: Nil) + checkAnswer(sql("CALL cat.ns.complex('X', 'Y')"), Row(1, "X1", "Y1") :: Nil) + } + + test("parameters with invalid default values") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidDefaultProcedure) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum()") + ), + condition = "INVALID_DEFAULT_VALUE.DATA_TYPE", + parameters = Map( + "statement" -> "CALL", + "colName" -> toSQLId("in2"), + "defaultValue" -> toSQLValue("B"), + "expectedType" -> toSQLType("INT"), + "actualType" -> toSQLType("STRING"))) + } + + test("IDENTIFIER") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL IDENTIFIER(:p1)(1, 2)", Map("p1" -> "cat.ns.sum")), + Row(3) :: Nil) + } + + test("parameterized statements") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkAnswer( + spark.sql("CALL cat.ns.sum(?, ?)", Array(2, 3)), + Row(5) :: Nil) + } + + test("undefined procedure") { + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.non_exist(1, 2)") + ), + sqlState = Some("38000"), + condition = "FAILED_TO_LOAD_ROUTINE", + parameters = Map("routineName" -> "`cat`.`non_exist`") + ) + } + + test("non-procedure catalog") { + withSQLConf("spark.sql.catalog.testcat" -> classOf[BasicInMemoryTableCatalog].getName) { + checkError( + exception = intercept[AnalysisException]( + sql("CALL testcat.procedure(1, 2)") + ), + condition = "_LEGACY_ERROR_TEMP_1184", + parameters = Map("plugin" -> "testcat", "ability" -> "procedures") + ) + } + } + + test("too many arguments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException]( + sql("CALL cat.ns.sum(1, 2, 3)") + ), + condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + parameters = Map( + "functionName" -> toSQLId("sum"), + "expectedNum" -> "2", + "actualNum" -> "3", + "docroot" -> SPARK_DOC_ROOT)) + } + + test("custom default catalog") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val df = sql("CALL ns.sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("custom default catalog and namespace") { + withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") { + catalog.createNamespace(Array("ns"), Collections.emptyMap) + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + sql("USE ns") + val df = sql("CALL sum(1, 2)") + checkAnswer(df, Row(3) :: Nil) + } + } + + test("required parameter not found") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum()") + }, + condition = "REQUIRED_PARAMETER_NOT_FOUND", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"), + "index" -> "0")) + } + + test("conflicting position and named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("duplicate named parameter assignments") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in1 => 2)") + }, + condition = "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("unknown parameter name") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, in5 => 2)") + }, + condition = "UNRECOGNIZED_PARAMETER_NAME", + parameters = Map( + "routineName" -> toSQLId("sum"), + "argumentName" -> toSQLId("in5"), + "proposal" -> (toSQLId("in1") + " " + toSQLId("in2")))) + } + + test("position parameter after named parameter") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + checkError( + exception = intercept[AnalysisException] { + sql("CALL cat.ns.sum(in1 => 1, 2)") + }, + condition = "UNEXPECTED_POSITIONAL_ARGUMENT", + parameters = Map( + "routineName" -> toSQLId("sum"), + "parameterName" -> toSQLId("in1"))) + } + + test("invalid argument type") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum(1, TIMESTAMP '2016-11-15 20:54:00.000')" + checkError( + exception = intercept[AnalysisException] { + sql(call) + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "second", + "inputSql" -> "\"TIMESTAMP '2016-11-15 20:54:00'\"", + "inputType" -> toSQLType("TIMESTAMP"), + "requiredType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("malformed input to implicit cast") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum('A', 2)" + checkError( + exception = intercept[SparkNumberFormatException]( + sql(call) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> toSQLValue("A"), + "sourceType" -> toSQLType("STRING"), + "targetType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("required parameters after optional") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundInvalidSum) + val e = intercept[SparkException] { + sql("CALL cat.ns.sum(in2 => 1)") + } + assert(e.getMessage.contains("required arguments should come before optional arguments")) + } + + test("INOUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundInoutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains(" Unsupported parameter mode: INOUT")) + } + + test("OUT parameters are not supported") { + catalog.createProcedure(Identifier.of(Array("ns"), "procedure"), UnboundOutProcedure) + val e = intercept[SparkException] { + sql("CALL cat.ns.procedure(1)") + } + assert(e.getMessage.contains("Unsupported parameter mode: OUT")) + } + + test("EXPLAIN") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundNonExecutableSum) + val explain1 = sql("EXPLAIN CALL cat.ns.sum(5, 5)").head().get(0) + assert(explain1.toString.contains("cat.ns.sum(5, 5)")) + val explain2 = sql("EXPLAIN EXTENDED CALL cat.ns.sum(10, 10)").head().get(0) + assert(explain2.toString.contains("cat.ns.sum(10, 10)")) + } + + test("void procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundVoidProcedure) + checkAnswer(sql("CALL cat.ns.proc('A', 'B')"), Nil) + } + + test("multi-result procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundMultiResultProcedure) + checkAnswer(sql("CALL cat.ns.proc()"), Row("last") :: Nil) + } + + test("invalid input to struct procedure") { + catalog.createProcedure(Identifier.of(Array("ns"), "proc"), UnboundStructProcedure) + val actualType = + StructType(Seq( + StructField("X", DataTypes.DateType, nullable = false), + StructField("Y", DataTypes.IntegerType, nullable = false))) + val expectedType = StructProcedure.parameters.head.dataType + val call = "CALL cat.ns.proc(named_struct('X', DATE '2011-11-11', 'Y', 2), 'VALUE')" + checkError( + exception = intercept[AnalysisException](sql(call)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "CALL", + "paramIndex" -> "first", + "inputSql" -> "\"named_struct(X, DATE '2011-11-11', Y, 2)\"", + "inputType" -> toSQLType(actualType), + "requiredType" -> toSQLType(expectedType)), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } + + test("save execution summary") { + withTable("summary") { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val result = sql("CALL cat.ns.sum(1, 2)") + result.write.saveAsTable("summary") + checkAnswer(spark.table("summary"), Row(3) :: Nil) + } + } + + object UnboundVoidProcedure extends UnboundProcedure { + override def name: String = "void" + override def description: String = "void procedure" + override def bind(inputType: StructType): BoundProcedure = VoidProcedure + } + + object VoidProcedure extends BoundProcedure { + override def name: String = "void" + + override def description: String = "void procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundMultiResultProcedure extends UnboundProcedure { + override def name: String = "multi" + override def description: String = "multi-result procedure" + override def bind(inputType: StructType): BoundProcedure = MultiResultProcedure + } + + object MultiResultProcedure extends BoundProcedure { + override def name: String = "multi" + + override def description: String = "multi-result procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array() + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val scans = java.util.Arrays.asList[Scan]( + Result( + new StructType().add("out", DataTypes.IntegerType), + Array(InternalRow(1))), + Result( + new StructType().add("out", DataTypes.StringType), + Array(InternalRow(UTF8String.fromString("last")))) + ) + scans.iterator() + } + } + + object UnboundNonExecutableSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object NonExecutableSum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundSum extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = Sum + } + + object Sum extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getInt(0) + val in2 = input.getInt(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundLongSum extends UnboundProcedure { + override def name: String = "long_sum" + override def description: String = "sum longs" + override def bind(inputType: StructType): BoundProcedure = LongSum + } + + object LongSum extends BoundProcedure { + override def name: String = "long_sum" + + override def description: String = "sum longs" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.LongType).build(), + ProcedureParameter.in("in2", DataTypes.LongType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.LongType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getLong(0) + val in2 = input.getLong(1) + val result = Result(outputType, Array(InternalRow(in1 + in2))) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundInvalidSum extends UnboundProcedure { + override def name: String = "invalid" + override def description: String = "sum integers" + override def bind(inputType: StructType): BoundProcedure = InvalidSum + } + + object InvalidSum extends BoundProcedure { + override def name: String = "invalid" + + override def description: String = "sum integers" + + override def isDeterministic: Boolean = false + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("1").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundInvalidDefaultProcedure extends UnboundProcedure { + override def name: String = "sum" + override def description: String = "invalid default value procedure" + override def bind(inputType: StructType): BoundProcedure = InvalidDefaultProcedure + } + + object InvalidDefaultProcedure extends BoundProcedure { + override def name: String = "sum" + + override def description: String = "invalid default value procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.IntegerType).defaultValue("10").build(), + ProcedureParameter.in("in2", DataTypes.IntegerType).defaultValue("'B'").build() + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundComplexProcedure extends UnboundProcedure { + override def name: String = "complex" + override def description: String = "complex procedure" + override def bind(inputType: StructType): BoundProcedure = ComplexProcedure + } + + object ComplexProcedure extends BoundProcedure { + override def name: String = "complex" + + override def description: String = "complex procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter.in("in1", DataTypes.StringType).defaultValue("'A'").build(), + ProcedureParameter.in("in2", DataTypes.StringType).defaultValue("'B'").build(), + ProcedureParameter.in("in3", DataTypes.IntegerType).defaultValue("1 + 1 - 1").build() + ) + + def outputType: StructType = new StructType() + .add("out1", DataTypes.IntegerType) + .add("out2", DataTypes.StringType) + .add("out3", DataTypes.StringType) + + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + val in1 = input.getString(0) + val in2 = input.getString(1) + val in3 = input.getInt(2) + + val rows = (1 to in3).map { index => + val v1 = UTF8String.fromString(s"$in1$index") + val v2 = UTF8String.fromString(s"$in2$index") + InternalRow(index, v1, v2) + }.toArray + + val result = Result(outputType, rows) + Collections.singleton[Scan](result).iterator() + } + } + + object UnboundStructProcedure extends UnboundProcedure { + override def name: String = "struct_input" + override def description: String = "struct procedure" + override def bind(inputType: StructType): BoundProcedure = StructProcedure + } + + object StructProcedure extends BoundProcedure { + override def name: String = "struct_input" + + override def description: String = "struct procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + ProcedureParameter + .in( + "in1", + StructType(Seq( + StructField("nested1", DataTypes.IntegerType), + StructField("nested2", DataTypes.StringType)))) + .build(), + ProcedureParameter.in("in2", DataTypes.StringType).build() + ) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + Collections.emptyIterator + } + } + + object UnboundInoutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "inout procedure" + override def bind(inputType: StructType): BoundProcedure = InoutProcedure + } + + object InoutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "inout procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(INOUT, "in1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + object UnboundOutProcedure extends UnboundProcedure { + override def name: String = "procedure" + override def description: String = "out procedure" + override def bind(inputType: StructType): BoundProcedure = OutProcedure + } + + object OutProcedure extends BoundProcedure { + override def name: String = "procedure" + + override def description: String = "out procedure" + + override def isDeterministic: Boolean = true + + override def parameters: Array[ProcedureParameter] = Array( + CustomParameterImpl(IN, "in1", DataTypes.IntegerType), + CustomParameterImpl(OUT, "out1", DataTypes.IntegerType) + ) + + def outputType: StructType = new StructType().add("out", DataTypes.IntegerType) + + override def call(input: InternalRow): java.util.Iterator[Scan] = { + throw new UnsupportedOperationException() + } + } + + case class Result(readSchema: StructType, rows: Array[InternalRow]) extends LocalScan + + case class CustomParameterImpl( + mode: Mode, + name: String, + dataType: DataType) extends ProcedureParameter { + override def defaultValueExpression: String = null + override def comment: String = null + } +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index 4bc4116a23da7..dcf3bd8c71731 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 44c1ecd6902ce..dbeb8607facc2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose} import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -95,6 +95,7 @@ class HiveSessionStateBuilder( new EvalSubqueriesForTimeTravel +: new DetermineTableStats(session) +: new ResolveTranspose(session) +: + new InvokeProcedures(session) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = From ac34f1de92c6f5cb53d799f00e550a0a204d9eb2 Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Thu, 19 Sep 2024 11:56:10 +0200 Subject: [PATCH 017/250] [SPARK-48280][SQL][FOLLOW-UP] Add expressions that are built via expressionBuilder to Expression Walker ### What changes were proposed in this pull request? Addition of new expressions to expression walker. This PR also improves descriptions of methods in the Suite. ### Why are the changes needed? It was noticed while debugging that startsWith, endsWith and contains are not tested with this suite and these expressions represent core of collation testing. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Test only. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48162 from mihailom-db/expressionwalkerfollowup. Authored-by: Mihailo Milosevic Signed-off-by: Wenchen Fan --- .../sql/CollationExpressionWalkerSuite.scala | 148 ++++++++++++++---- 1 file changed, 121 insertions(+), 27 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 2342722c0bb14..1d23774a51692 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.{SparkFunSuite, SparkRuntimeException} +import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.variant.ParseJson import org.apache.spark.sql.internal.SqlApiConf @@ -46,7 +47,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputEntry - List of all input entries that need to be generated * @param collationType - Flag defining collation type to use - * @return + * @return - List of data generated for expression instance creation */ def generateData( inputEntry: Seq[Any], @@ -54,23 +55,11 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi inputEntry.map(generateSingleEntry(_, collationType)) } - /** - * Helper function to generate single entry of data as a string. - * @param inputEntry - Single input entry that requires generation - * @param collationType - Flag defining collation type to use - * @return - */ - def generateDataAsStrings( - inputEntry: Seq[AbstractDataType], - collationType: CollationType): Seq[Any] = { - inputEntry.map(generateInputAsString(_, collationType)) - } - /** * Helper function to generate single entry of data. * @param inputEntry - Single input entry that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Single input entry data */ def generateSingleEntry( inputEntry: Any, @@ -100,7 +89,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input literal type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - Literal/Expression containing expression ready for evaluation */ def generateLiterals( inputType: AbstractDataType, @@ -116,6 +105,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => Literal(true) case _: DatetimeType => Literal(Timestamp.valueOf("2009-07-30 12:58:59")) + case DecimalType => Literal((new Decimal).set(5)) case _: DecimalType => Literal((new Decimal).set(5)) case _: DoubleType => Literal(5.0) case IntegerType | NumericType | IntegralType => Literal(5) @@ -158,11 +148,15 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType => val key = generateLiterals(StringTypeAnyCollation, collationType) val value = generateLiterals(StringTypeAnyCollation, collationType) - Literal.create(Map(key -> value)) + CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) val value = generateLiterals(valueType, collationType) - Literal.create(Map(key -> value)) + CreateMap(Seq(key, value)) + case AbstractMapType(keyType, valueType) => + val key = generateLiterals(keyType, collationType) + val value = generateLiterals(valueType, collationType) + CreateMap(Seq(key, value)) case StructType => CreateNamedStruct( Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), @@ -174,7 +168,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation of a input ready for SQL query */ def generateInputAsString( inputType: AbstractDataType, @@ -189,6 +183,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi } case BooleanType => "True" case _: DatetimeType => "date'2016-04-08'" + case DecimalType => "5.0" case _: DecimalType => "5.0" case _: DoubleType => "5.0" case IntegerType | NumericType | IntegralType => "5" @@ -221,6 +216,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" + case AbstractMapType(keyType, valueType) => + "map(" + generateInputAsString(keyType, collationType) + ", " + + generateInputAsString(valueType, collationType) + ")" case StructType => "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" @@ -234,7 +232,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * * @param inputType - Single input type that requires generation * @param collationType - Flag defining collation type to use - * @return + * @return - String representation for SQL query of a inputType */ def generateInputTypeAsStrings( inputType: AbstractDataType, @@ -244,6 +242,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case BinaryType => "BINARY" case BooleanType => "BOOLEAN" case _: DatetimeType => "DATE" + case DecimalType => "DECIMAL(2, 1)" case _: DecimalType => "DECIMAL(2, 1)" case _: DoubleType => "DOUBLE" case IntegerType | NumericType | IntegralType => "INT" @@ -275,6 +274,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" + case AbstractMapType(keyType, valueType) => + "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + + generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => "struct hasStringType(elementType) case TypeCollection(typeCollection) => typeCollection.exists(hasStringType) - case StructType => true case StructType(fields) => fields.exists(sf => hasStringType(sf.dataType)) case _ => false } @@ -310,7 +311,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * Helper function to replace expected parameters with expected input types. * @param inputTypes - Input types generated by ExpectsInputType.inputTypes * @param params - Parameters that are read from expression info - * @return + * @return - List of parameters where Expressions are replaced with input types */ def replaceExpressions(inputTypes: Seq[AbstractDataType], params: Seq[Class[_]]): Seq[Any] = { (inputTypes, params) match { @@ -325,7 +326,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi /** * Helper method to extract relevant expressions that can be walked over. - * @return + * @return - (List of relevant expressions that expect input, List of expressions to skip) */ def extractRelevantExpressions(): (Array[ExpressionInfo], List[String]) = { var expressionCounter = 0 @@ -384,6 +385,47 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi (funInfos, toSkip) } + /** + * Helper method to extract relevant expressions that can be walked over but are built with + * expression builder. + * + * @return - (List of expressions that are relevant builders, List of expressions to skip) + */ + def extractRelevantBuilders(): (Array[ExpressionInfo], List[String]) = { + var builderExpressionCounter = 0 + val funInfos = spark.sessionState.functionRegistry.listFunction().map { funcId => + spark.sessionState.catalog.lookupFunctionInfo(funcId) + }.filter(funInfo => { + // make sure that there is a constructor. + val cl = Utils.classForName(funInfo.getClassName) + cl.isAssignableFrom(classOf[ExpressionBuilder]) + }).filter(funInfo => { + builderExpressionCounter = builderExpressionCounter + 1 + val cl = Utils.classForName(funInfo.getClassName) + val method = cl.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + var input: Seq[Expression] = Seq.empty + var i = 0 + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] + } + catch { + case _: Exception => i = i + 1 + } + } + if (i == 10) false + else true + }).toArray + + logInfo("Total number of expression that are built: " + builderExpressionCounter) + logInfo("Number of extracted expressions of relevance: " + funInfos.length) + + (funInfos, List()) + } + /** * Helper function to generate string of an expression suitable for execution. * @param expr - Expression that needs to be converted @@ -441,10 +483,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for expression evaluation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) @@ -526,10 +594,36 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi * 5) Otherwise, check if exceptions are the same */ test("SPARK-48280: Expression Walker for codeGen generation") { - val (funInfos, toSkip) = extractRelevantExpressions() + val (funInfosExpr, toSkip) = extractRelevantExpressions() + val (funInfosBuild, _) = extractRelevantBuilders() + val funInfos = funInfosExpr ++ funInfosBuild for (f <- funInfos.filter(f => !toSkip.contains(f.getName))) { - val cl = Utils.classForName(f.getClassName) + val TempCl = Utils.classForName(f.getClassName) + val cl = if (TempCl.isAssignableFrom(classOf[ExpressionBuilder])) { + val clTemp = Utils.classForName(f.getClassName) + val method = clTemp.getMethod("build", + Utils.classForName("java.lang.String"), + Utils.classForName("scala.collection.Seq")) + val instance = { + var input: Seq[Expression] = Seq.empty + var result: Expression = null + for (_ <- 1 to 10) { + input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + try { + val tempResult = method.invoke(null, f.getClassName, input) + if (result == null) result = tempResult.asInstanceOf[Expression] + } + catch { + case _: Exception => + } + } + result + } + instance.getClass + } + else Utils.classForName(f.getClassName) + val headConstructor = cl.getConstructors .zip(cl.getConstructors.map(c => c.getParameters.length)).minBy(a => a._2)._1 val params = headConstructor.getParameters.map(p => p.getType) From a060c236d314bd2facc73ad26926b59401e5f7aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Thu, 19 Sep 2024 14:25:53 +0200 Subject: [PATCH 018/250] [SPARK-49667][SQL] Disallowed CS_AI collators with expressions that use StringSearch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? In this PR, I propose to disallow `CS_AI` collated strings in expressions that use `StringsSearch` in their implementation. These expressions are `trim`, `startswith`, `endswith`, `locate`, `instr`, `str_to_map`, `contains`, `replace`, `split_part` and `substring_index`. Currently, these expressions support all possible collations, however, they do not work properly with `CS_AI` collators. This is because there is no support for `CS_AI` search in the ICU's `StringSearch` class which is used to implement these expressions. Therefore, the expressions are not behaving correctly when used with `CS_AI` collators (e.g. currently `startswith('hOtEl' collate unicode_ai, 'Hotel' collate unicode_ai)` returns `true`). ### Why are the changes needed? Proposed changes are necessary in order to achieve correct behavior of the expressions mentioned above. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This patch was tested by adding a test in the `CollationSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48121 from vladanvasi-db/vladanvasi-db/cs-ai-collations-expressions-disablement. Authored-by: Vladan Vasić Signed-off-by: Wenchen Fan --- .../sql/catalyst/util/CollationFactory.java | 12 + .../internal/types/AbstractStringType.scala | 9 + .../apache/spark/sql/types/StringType.scala | 3 + .../expressions/complexTypeCreator.scala | 4 +- .../expressions/stringExpressions.scala | 33 +- .../analyzer-results/collations.sql.out | 336 ++++++++++++++++ .../resources/sql-tests/inputs/collations.sql | 14 + .../sql-tests/results/collations.sql.out | 364 ++++++++++++++++++ .../sql/CollationSQLExpressionsSuite.scala | 24 ++ .../sql/CollationStringExpressionsSuite.scala | 251 ++++++++++++ 10 files changed, 1041 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 87558971042e0..d5dbca7eb89bc 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -921,6 +921,18 @@ public static int collationNameToId(String collationName) throws SparkException return Collation.CollationSpec.collationNameToId(collationName); } + /** + * Returns whether the ICU collation is not Case Sensitive Accent Insensitive + * for the given collation id. + * This method is used in expressions which do not support CS_AI collations. + */ + public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { + return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity == + Collation.CollationSpecICU.CaseSensitivity.CS && + Collation.CollationSpecICU.fromCollationId(collationId).accentSensitivity == + Collation.CollationSpecICU.AccentSensitivity.AI; + } + public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index 05d1701eff74d..dc4ee013fd189 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -51,3 +51,12 @@ case object StringTypeBinaryLcase extends AbstractStringType { case object StringTypeAnyCollation extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } + +/** + * Use StringTypeNonCSAICollation for expressions supporting all possible collation types except + * CS_AI collation types. + */ +case object StringTypeNonCSAICollation extends AbstractStringType { + override private[sql] def acceptsType(other: DataType): Boolean = + other.isInstanceOf[StringType] && other.asInstanceOf[StringType].isNonCSAI +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index eba12c4ff4875..c2dd6cec7ba74 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -44,6 +44,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def supportsLowercaseEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsLowercaseEquality + private[sql] def isNonCSAI: Boolean = + !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) + private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index ba1beab28d9a7..b8b47f2763f5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeNonCSAICollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -579,7 +579,7 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def third: Expression = keyValueDelim override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def dataType: DataType = MapType(first.dataType, first.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index e75df87994f0e..da6d786efb4e3 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeNonCSAICollation} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -609,6 +609,8 @@ case class Contains(left: Expression, right: Expression) extends StringPredicate defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.Contains.genCode(c1, c2, collationId)) } + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) } @@ -650,6 +652,10 @@ case class StartsWith(left: Expression, right: Expression) extends StringPredica defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.StartsWith.genCode(c1, c2, collationId)) } + + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) } @@ -691,6 +697,10 @@ case class EndsWith(left: Expression, right: Expression) extends StringPredicate defineCodeGen(ctx, ev, (c1, c2) => CollationSupport.EndsWith.genCode(c1, c2, collationId)) } + + override def inputTypes : Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) } @@ -919,7 +929,7 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def first: Expression = srcExpr override def second: Expression = searchExpr override def third: Expression = replaceExpr @@ -1167,7 +1177,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac override def dataType: DataType = srcExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def first: Expression = srcExpr override def second: Expression = matchingExpr override def third: Expression = replaceExpr @@ -1394,6 +1404,9 @@ case class StringTrim(srcStr: Expression, trimStr: Option[Expression] = None) override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrim.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy( srcStr = newChildren.head, @@ -1501,6 +1514,9 @@ case class StringTrimLeft(srcStr: Expression, trimStr: Option[Expression] = None override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrimLeft.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimLeft = copy( @@ -1561,6 +1577,9 @@ case class StringTrimRight(srcStr: Expression, trimStr: Option[Expression] = Non override def doEval(srcString: UTF8String, trimString: UTF8String): UTF8String = CollationSupport.StringTrimRight.exec(srcString, trimString, collationId) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): StringTrimRight = copy( @@ -1595,7 +1614,7 @@ case class StringInstr(str: Expression, substr: Expression) override def right: Expression = substr override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation) override def nullSafeEval(string: Any, sub: Any): Any = { CollationSupport.StringInstr. @@ -1643,7 +1662,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: override def dataType: DataType = strExpr.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) override def first: Expression = strExpr override def second: Expression = delimExpr override def third: Expression = countExpr @@ -1701,7 +1720,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) override def eval(input: InternalRow): Any = { val s = start.eval(input) @@ -3463,7 +3482,7 @@ case class SplitPart ( false) override def nodeName: String = "split_part" override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) + Seq(StringTypeNonCSAICollation, StringTypeNonCSAICollation, IntegerType) def children: Seq[Expression] = Seq(str, delimiter, partNum) protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = { copy(str = newChildren.apply(0), delimiter = newChildren.apply(1), diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out index 83c9ebfef4b25..eed7fa73ab698 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/collations.sql.out @@ -436,6 +436,30 @@ Project [str_to_map(collate(text#x, utf8_binary), collate(pairDelim#x, utf8_bina +- Relation spark_catalog.default.t4[text#x,pairDelim#x,keyValueDelim#x] parquet +-- !query +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query analysis @@ -820,6 +844,30 @@ Project [split_part(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query analysis @@ -883,6 +931,30 @@ Project [Contains(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query analysis @@ -946,6 +1018,30 @@ Project [substring_index(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase# +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query analysis @@ -1009,6 +1105,30 @@ Project [instr(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lc +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query analysis @@ -1135,6 +1255,30 @@ Project [StartsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis @@ -1190,6 +1334,30 @@ Project [translate(cast(utf8_binary#x as string collate UTF8_LCASE), collate(SQL +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query analysis @@ -1253,6 +1421,30 @@ Project [replace(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_ +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query analysis @@ -1316,6 +1508,30 @@ Project [EndsWith(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8 +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query analysis @@ -2039,6 +2255,30 @@ Project [locate(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_l +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query analysis @@ -2102,6 +2342,30 @@ Project [trim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, utf +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2165,6 +2429,30 @@ Project [btrim(collate(utf8_binary#x, utf8_lcase), collate(utf8_lcase#x, utf8_lc +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2228,6 +2516,30 @@ Project [ltrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query analysis @@ -2291,6 +2603,30 @@ Project [rtrim(collate(utf8_lcase#x, utf8_lcase), Some(collate(utf8_binary#x, ut +- Relation spark_catalog.default.t5[s#x,utf8_binary#x,utf8_lcase#x] parquet +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/collations.sql b/sql/core/src/test/resources/sql-tests/inputs/collations.sql index 183577b83971b..f3a42fd3e1f12 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/collations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/collations.sql @@ -99,6 +99,7 @@ insert into t4 values('a:1,b:2,c:3', ',', ':'); select str_to_map(text, pairDelim, keyValueDelim) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_lcase, keyValueDelim collate utf8_binary) from t4; select str_to_map(text collate utf8_binary, pairDelim collate utf8_binary, keyValueDelim collate utf8_binary) from t4; +select str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai) from t4; drop table t4; @@ -159,6 +160,7 @@ select split_part(s, utf8_binary, 1) from t5; select split_part(utf8_binary collate utf8_binary, s collate utf8_lcase, 1) from t5; select split_part(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; select split_part(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5; select split_part(utf8_binary, 'a' collate utf8_lcase, 3), split_part(utf8_lcase, 'a' collate utf8_binary, 3) from t5; @@ -168,6 +170,7 @@ select contains(s, utf8_binary) from t5; select contains(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select contains(utf8_binary, utf8_lcase collate utf8_binary) from t5; select contains(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5; select contains(utf8_binary, 'AaAA' collate utf8_lcase), contains(utf8_lcase, 'AAa' collate utf8_binary) from t5; @@ -177,6 +180,7 @@ select substring_index(s, utf8_binary,1) from t5; select substring_index(utf8_binary collate utf8_binary, s collate utf8_lcase, 3) from t5; select substring_index(utf8_binary, utf8_lcase collate utf8_binary, 2) from t5; select substring_index(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 2) from t5; +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5; select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5; select substring_index(utf8_binary, 'AaAA' collate utf8_lcase, 2), substring_index(utf8_lcase, 'AAa' collate utf8_binary, 2) from t5; @@ -186,6 +190,7 @@ select instr(s, utf8_binary) from t5; select instr(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select instr(utf8_binary, utf8_lcase collate utf8_binary) from t5; select instr(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5; select instr(utf8_binary, 'AaAA' collate utf8_lcase), instr(utf8_lcase, 'AAa' collate utf8_binary) from t5; @@ -204,6 +209,7 @@ select startswith(s, utf8_binary) from t5; select startswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select startswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; select startswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5; select startswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), startswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; @@ -212,6 +218,7 @@ select translate(utf8_lcase, utf8_lcase, '12345') from t5; select translate(utf8_binary, utf8_lcase, '12345') from t5; select translate(utf8_binary, 'aBc' collate utf8_lcase, '12345' collate utf8_binary) from t5; select translate(utf8_binary, 'SQL' collate utf8_lcase, '12345' collate utf8_lcase) from t5; +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5; select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5; select translate(utf8_lcase, 'aBc' collate utf8_binary, '12345'), translate(utf8_binary, 'aBc' collate utf8_lcase, '12345') from t5; @@ -221,6 +228,7 @@ select replace(s, utf8_binary, 'abc') from t5; select replace(utf8_binary collate utf8_binary, s collate utf8_lcase, 'abc') from t5; select replace(utf8_binary, utf8_lcase collate utf8_binary, 'abc') from t5; select replace(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 'abc') from t5; +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5; select replace(utf8_binary, 'aaAaaAaA' collate utf8_lcase, 'abc'), replace(utf8_lcase, 'aaAaaAaA' collate utf8_binary, 'abc') from t5; @@ -230,6 +238,7 @@ select endswith(s, utf8_binary) from t5; select endswith(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select endswith(utf8_binary, utf8_lcase collate utf8_binary) from t5; select endswith(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5; select endswith(utf8_binary, 'aaAaaAaA' collate utf8_lcase), endswith(utf8_lcase, 'aaAaaAaA' collate utf8_binary) from t5; @@ -364,6 +373,7 @@ select locate(s, utf8_binary) from t5; select locate(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select locate(utf8_binary, utf8_lcase collate utf8_binary) from t5; select locate(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase, 3) from t5; +select locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3) from t5; select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5; select locate(utf8_binary, 'AaAA' collate utf8_lcase, 4), locate(utf8_lcase, 'AAa' collate utf8_binary, 4) from t5; @@ -373,6 +383,7 @@ select TRIM(s, utf8_binary) from t5; select TRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select TRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select TRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5; select TRIM('ABc' collate utf8_lcase, utf8_binary), TRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimBoth @@ -381,6 +392,7 @@ select BTRIM(s, utf8_binary) from t5; select BTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select BTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select BTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5; select BTRIM('ABc' collate utf8_lcase, utf8_binary), BTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimLeft @@ -389,6 +401,7 @@ select LTRIM(s, utf8_binary) from t5; select LTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select LTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select LTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5; select LTRIM('ABc' collate utf8_lcase, utf8_binary), LTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; -- StringTrimRight @@ -397,6 +410,7 @@ select RTRIM(s, utf8_binary) from t5; select RTRIM(utf8_binary collate utf8_binary, s collate utf8_lcase) from t5; select RTRIM(utf8_binary, utf8_lcase collate utf8_binary) from t5; select RTRIM(utf8_binary collate utf8_lcase, utf8_lcase collate utf8_lcase) from t5; +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5; select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5; select RTRIM('ABc' collate utf8_lcase, utf8_binary), RTRIM('AAa' collate utf8_binary, utf8_lcase) from t5; diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index ea5564aafe96f..5999bf20f6884 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -480,6 +480,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(text, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"str_to_map(collate(text, unicode_ai), collate(pairDelim, unicode_ai), collate(keyValueDelim, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 106, + "fragment" : "str_to_map(text collate unicode_ai, pairDelim collate unicode_ai, keyValueDelim collate unicode_ai)" + } ] +} + + -- !query drop table t4 -- !query schema @@ -1021,6 +1047,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"split_part(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "split_part(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select split_part(utf8_binary, 'a', 3), split_part(utf8_lcase, 'a', 3) from t5 -- !query schema @@ -1148,6 +1200,32 @@ true true +-- !query +select contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"contains(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "contains(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select contains(utf8_binary, 'a'), contains(utf8_lcase, 'a') from t5 -- !query schema @@ -1275,6 +1353,32 @@ kitten İo +-- !query +select substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"substring_index(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 2)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 88, + "fragment" : "substring_index(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 2)" + } ] +} + + -- !query select substring_index(utf8_binary, 'a', 2), substring_index(utf8_lcase, 'a', 2) from t5 -- !query schema @@ -1402,6 +1506,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"instr(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "instr(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select instr(utf8_binary, 'a'), instr(utf8_lcase, 'a') from t5 -- !query schema @@ -1656,6 +1786,32 @@ true true +-- !query +select startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"startswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 80, + "fragment" : "startswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select startswith(utf8_binary, 'aaAaaAaA'), startswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query schema @@ -1763,6 +1919,32 @@ kitten İo +-- !query +select translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"utf8_binary\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"translate(utf8_binary, collate(SQL, unicode_ai), collate(12345, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 83, + "fragment" : "translate(utf8_binary, 'SQL' collate unicode_ai, '12345' collate unicode_ai)" + } ] +} + + -- !query select translate(utf8_lcase, 'aaAaaAaA', '12345'), translate(utf8_binary, 'aaAaaAaA', '12345') from t5 -- !query schema @@ -1890,6 +2072,32 @@ bbabcbabcabcbabc kitten +-- !query +select replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc') from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"replace(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), abc)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 84, + "fragment" : "replace(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 'abc')" + } ] +} + + -- !query select replace(utf8_binary, 'aaAaaAaA', 'abc'), replace(utf8_lcase, 'aaAaaAaA', 'abc') from t5 -- !query schema @@ -2017,6 +2225,32 @@ true true +-- !query +select endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"endswith(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 78, + "fragment" : "endswith(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select endswith(utf8_binary, 'aaAaaAaA'), endswith(utf8_lcase, 'aaAaaAaA') from t5 -- !query schema @@ -3570,6 +3804,32 @@ struct +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"locate(collate(utf8_binary, unicode_ai), collate(utf8_lcase, unicode_ai), 3)\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 79, + "fragment" : "locate(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai, 3)" + } ] +} + + -- !query select locate(utf8_binary, 'a'), locate(utf8_lcase, 'a') from t5 -- !query schema @@ -3685,6 +3945,32 @@ QL sitTing +-- !query +select TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 74, + "fragment" : "TRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select TRIM('ABc', utf8_binary), TRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -3812,6 +4098,32 @@ park İ +-- !query +select BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_binary, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(BOTH collate(utf8_lcase, unicode_ai) FROM collate(utf8_binary, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "BTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select BTRIM('ABc', utf8_binary), BTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -3927,6 +4239,32 @@ QL sitTing +-- !query +select LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(LEADING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "LTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select LTRIM('ABc', utf8_binary), LTRIM('ABc', utf8_lcase) from t5 -- !query schema @@ -4042,6 +4380,32 @@ SQL sitTing +-- !query +select RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai) from t5 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + "sqlState" : "42K09", + "messageParameters" : { + "inputSql" : "\"collate(utf8_lcase, unicode_ai)\"", + "inputType" : "\"STRING COLLATE UNICODE_AI\"", + "paramIndex" : "first", + "requiredType" : "\"STRING\"", + "sqlExpr" : "\"TRIM(TRAILING collate(utf8_binary, unicode_ai) FROM collate(utf8_lcase, unicode_ai))\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 8, + "stopIndex" : 75, + "fragment" : "RTRIM(utf8_binary collate unicode_ai, utf8_lcase collate unicode_ai)" + } ] +} + + -- !query select RTRIM('ABc', utf8_binary), RTRIM('ABc', utf8_lcase) from t5 -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f8cd840ecdbb9..941d5cd31db40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -982,6 +982,7 @@ class CollationSQLExpressionsSuite StringToMapTestCase("1/AX2/BX3/C", "x", "/", "UNICODE_CI", Map("1" -> "A", "2" -> "B", "3" -> "C")) ) + val unsupportedTestCase = StringToMapTestCase("a:1,b:2,c:3", "?", "?", "UNICODE_AI", null) testCases.foreach(t => { // Unit test. val text = Literal.create(t.text, StringType(t.collation)) @@ -996,6 +997,29 @@ class CollationSQLExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(dataType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select str_to_map('${unsupportedTestCase.text}', '${unsupportedTestCase.pairDelim}', " + + s"'${unsupportedTestCase.keyValueDelim}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"str_to_map('a:1,b:2,c:3' collate UNICODE_AI, " + + "'?' collate UNICODE_AI, '?' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'a:1,b:2,c:3' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "str_to_map('a:1,b:2,c:3', '?', '?')", + start = 7, + stop = 41)) + } } test("Support RaiseError misc expression with collation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala index 6804411d470b9..fe9872ddaf575 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala @@ -98,6 +98,7 @@ class CollationStringExpressionsSuite SplitPartTestCase("1a2", "A", 2, "UTF8_LCASE", "2"), SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2") ) + val unsupportedTestCase = SplitPartTestCase("1a2", "a", 2, "UNICODE_AI", "2") testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -111,6 +112,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select split_part('${unsupportedTestCase.str}', '${unsupportedTestCase.delimiter}', " + + s"${unsupportedTestCase.partNum})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"split_part('1a2' collate UNICODE_AI, 'a' collate UNICODE_AI, 2)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'1a2' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "split_part('1a2', 'a', 2)", start = 7, stop = 31) + ) + } } test("Support `StringSplitSQL` string expression with collation") { @@ -166,6 +187,7 @@ class CollationStringExpressionsSuite ContainsTestCase("abcde", "FGH", "UTF8_LCASE", false), ContainsTestCase("abcde", "BCD", "UNICODE_CI", true) ) + val unsupportedTestCase = ContainsTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -178,6 +200,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select contains('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"contains('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "contains('abcde', 'A')", start = 7, stop = 28) + ) + } } test("Support `SubstringIndex` expression with collation") { @@ -194,6 +235,7 @@ class CollationStringExpressionsSuite SubstringIndexTestCase("aaaaaaaaaa", "aa", 2, "UNICODE", "a"), SubstringIndexTestCase("wwwmapacheMorg", "M", -2, "UNICODE_CI", "apacheMorg") ) + val unsupportedTestCase = SubstringIndexTestCase("abacde", "a", 2, "UNICODE_AI", "cde") testCases.foreach(t => { // Unit test. val strExpr = Literal.create(t.strExpr, StringType(t.collation)) @@ -207,6 +249,29 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select substring_index('${unsupportedTestCase.strExpr}', " + + s"'${unsupportedTestCase.delimExpr}', ${unsupportedTestCase.countExpr})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"substring_index('abacde' collate UNICODE_AI, " + + "'a' collate UNICODE_AI, 2)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abacde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = "substring_index('abacde', 'a', 2)", + start = 7, + stop = 39)) + } } test("Support `StringInStr` string expression with collation") { @@ -219,6 +284,7 @@ class CollationStringExpressionsSuite StringInStrTestCase("test大千世界X大千世界", "界x", "UNICODE_CI", 8), StringInStrTestCase("abİo12", "i̇o", "UNICODE_CI", 3) ) + val unsupportedTestCase = StringInStrTestCase("a", "abcde", "UNICODE_AI", 0) testCases.foreach(t => { // Unit test. val str = Literal.create(t.str, StringType(t.collation)) @@ -231,6 +297,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select instr('${unsupportedTestCase.str}', '${unsupportedTestCase.substr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"instr('a' collate UNICODE_AI, 'abcde' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'a' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "instr('a', 'abcde')", start = 7, stop = 25) + ) + } } test("Support `FindInSet` string expression with collation") { @@ -264,6 +349,7 @@ class CollationStringExpressionsSuite StartsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), StartsWithTestCase("abcde", "ABC", "UNICODE_CI", true) ) + val unsupportedTestCase = StartsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -276,6 +362,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select startswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"startswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "startswith('abcde', 'A')", start = 7, stop = 30) + ) + } } test("Support `StringTranslate` string expression with collation") { @@ -291,6 +396,7 @@ class CollationStringExpressionsSuite StringTranslateTestCase("Translate", "Rn", "\u0000\u0000", "UNICODE", "Traslate"), StringTranslateTestCase("Translate", "Rn", "1234", "UNICODE_CI", "T1a2slate") ) + val unsupportedTestCase = StringTranslateTestCase("ABC", "AB", "12", "UNICODE_AI", "12C") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -304,6 +410,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select translate('${unsupportedTestCase.srcExpr}', " + + s"'${unsupportedTestCase.matchingExpr}', '${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"translate('ABC' collate UNICODE_AI, 'AB' collate UNICODE_AI, " + + "'12' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'ABC' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "translate('ABC', 'AB', '12')", start = 7, stop = 34) + ) + } } test("Support `StringReplace` string expression with collation") { @@ -321,6 +448,7 @@ class CollationStringExpressionsSuite StringReplaceTestCase("abi̇o12i̇o", "İo", "yy", "UNICODE_CI", "abyy12yy"), StringReplaceTestCase("abİo12i̇o", "i̇o", "xx", "UNICODE_CI", "abxx12xx") ) + val unsupportedTestCase = StringReplaceTestCase("abcde", "A", "B", "UNICODE_AI", "abcde") testCases.foreach(t => { // Unit test. val srcExpr = Literal.create(t.srcExpr, StringType(t.collation)) @@ -334,6 +462,27 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select replace('${unsupportedTestCase.srcExpr}', '${unsupportedTestCase.searchExpr}', " + + s"'${unsupportedTestCase.replaceExpr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"replace('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI, " + + "'B' collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "replace('abcde', 'A', 'B')", start = 7, stop = 32) + ) + } } test("Support `EndsWith` string expression with collation") { @@ -344,6 +493,7 @@ class CollationStringExpressionsSuite EndsWithTestCase("abcde", "FGH", "UTF8_LCASE", false), EndsWithTestCase("abcde", "CDE", "UNICODE_CI", true) ) + val unsupportedTestCase = EndsWithTestCase("abcde", "A", "UNICODE_AI", false) testCases.foreach(t => { // Unit test. val left = Literal.create(t.left, StringType(t.collation)) @@ -355,6 +505,25 @@ class CollationStringExpressionsSuite checkAnswer(sql(query), Row(t.result)) assert(sql(query).schema.fields.head.dataType.sameType(BooleanType)) } + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select endswith('${unsupportedTestCase.left}', '${unsupportedTestCase.right}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"endswith('abcde' collate UNICODE_AI, 'A' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'abcde' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "endswith('abcde', 'A')", start = 7, stop = 28) + ) + } }) } @@ -1097,6 +1266,7 @@ class CollationStringExpressionsSuite StringLocateTestCase("aa", "Aaads", 0, "UNICODE_CI", 0), StringLocateTestCase("界x", "test大千世界X大千世界", 1, "UNICODE_CI", 8) ) + val unsupportedTestCase = StringLocateTestCase("aa", "Aaads", 0, "UNICODE_AI", 1) testCases.foreach(t => { // Unit test. val substr = Literal.create(t.substr, StringType(t.collation)) @@ -1110,6 +1280,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(IntegerType)) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val query = + s"select locate('${unsupportedTestCase.substr}', '${unsupportedTestCase.str}', " + + s"${unsupportedTestCase.start})" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"locate('aa' collate UNICODE_AI, 'Aaads' collate UNICODE_AI, 0)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'aa' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "locate('aa', 'Aaads', 0)", start = 7, stop = 30) + ) + } } test("Support `StringTrimLeft` string expression with collation") { @@ -1124,6 +1314,7 @@ class CollationStringExpressionsSuite StringTrimLeftTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimLeftTestCase(" asd ", None, "UNICODE_CI", "asd ") ) + val unsupportedTestCase = StringTrimLeftTestCase("xxasdxx", Some("x"), "UNICODE_AI", null) testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1137,6 +1328,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select ltrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(LEADING 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "ltrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrimRight` string expression with collation") { @@ -1151,6 +1361,7 @@ class CollationStringExpressionsSuite StringTrimRightTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimRightTestCase(" asd ", None, "UNICODE_CI", " asd") ) + val unsupportedTestCase = StringTrimRightTestCase("xxasdxx", Some("x"), "UNICODE_AI", "xxasd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1164,6 +1375,26 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select rtrim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> ("\"TRIM(TRAILING 'x' collate UNICODE_AI FROM 'xxasdxx'" + + " collate UNICODE_AI)\""), + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "rtrim('x', 'xxasdxx')", start = 7, stop = 27) + ) + } } test("Support `StringTrim` string expression with collation") { @@ -1178,6 +1409,7 @@ class CollationStringExpressionsSuite StringTrimTestCase("xxasdxx", Some("y"), "UNICODE", "xxasdxx"), StringTrimTestCase(" asd ", None, "UNICODE_CI", "asd") ) + val unsupportedTestCase = StringTrimTestCase("xxasdxx", Some("x"), "UNICODE_AI", "asd") testCases.foreach(t => { // Unit test. val srcStr = Literal.create(t.srcStr, StringType(t.collation)) @@ -1191,6 +1423,25 @@ class CollationStringExpressionsSuite assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.collation))) } }) + // Test unsupported collation. + withSQLConf(SQLConf.DEFAULT_COLLATION.key -> unsupportedTestCase.collation) { + val trimString = s"'${unsupportedTestCase.trimStr.get}', " + val query = s"select trim($trimString'${unsupportedTestCase.srcStr}')" + checkError( + exception = intercept[AnalysisException] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "sqlExpr" -> "\"TRIM(BOTH 'x' collate UNICODE_AI FROM 'xxasdxx' collate UNICODE_AI)\"", + "paramIndex" -> "first", + "inputSql" -> "\"'xxasdxx' collate UNICODE_AI\"", + "inputType" -> "\"STRING COLLATE UNICODE_AI\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = "trim('x', 'xxasdxx')", start = 7, stop = 26) + ) + } } test("Support `StringTrimBoth` string expression with collation") { From 4068fbcc0de59154db9bdeb1296bd24059db9f42 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 19 Sep 2024 21:00:57 +0800 Subject: [PATCH 019/250] [SPARK-49717][SQL][TESTS] Function parity test ignore private[xxx] functions ### What changes were proposed in this pull request? Function parity test ignore private functions ### Why are the changes needed? existing test is based on `java.lang.reflect.Modifier` which cannot properly handle `private[xxx]` ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48163 from zhengruifeng/df_func_test. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index f16171940df21..0842b92e5d53c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql -import java.lang.reflect.Modifier import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import scala.reflect.runtime.universe.runtimeMirror import scala.util.Random import org.apache.spark.{QueryContextType, SPARK_DOC_ROOT, SparkException, SparkRuntimeException} @@ -82,7 +82,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { "bucket", "days", "hours", "months", "years", // Datasource v2 partition transformations "product", // Discussed in https://github.com/apache/spark/pull/30745 "unwrap_udt", - "collect_top_k", "timestamp_add", "timestamp_diff" ) @@ -92,10 +91,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { val word_pattern = """\w*""" // Set of DataFrame functions in org.apache.spark.sql.functions - val dataFrameFunctions = functions.getClass - .getDeclaredMethods - .filter(m => Modifier.isPublic(m.getModifiers)) - .map(_.getName) + val dataFrameFunctions = runtimeMirror(getClass.getClassLoader) + .reflect(functions) + .symbol + .typeSignature + .decls + .filter(s => s.isMethod && s.isPublic) + .map(_.name.toString) .toSet .filter(_.matches(word_pattern)) .diff(excludedDataFrameFunctions) From 398457af59875120ea8b3ed44468a51597e6a441 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 19 Sep 2024 09:02:34 -0400 Subject: [PATCH 020/250] [SPARK-49422][CONNECT][SQL] Add groupByKey to sql/api ### What changes were proposed in this pull request? This PR adds `Dataset.groupByKey(..)` to the shared interface. I forgot to add in the previous PR. ### Why are the changes needed? The shared interface needs to support all functionality. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48147 from hvanhovell/SPARK-49422-follow-up. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Dataset.scala | 24 ++----- .../org/apache/spark/sql/api/Dataset.scala | 22 ++++++ .../scala/org/apache/spark/sql/Dataset.scala | 68 +++---------------- 3 files changed, 39 insertions(+), 75 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 161a0d9d265f0..accfff9f2b073 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -524,27 +524,11 @@ class Dataset[T] private[sql] ( result(0) } - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) } - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { @@ -1480,4 +1464,10 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 284a69fe6ee3e..6eef034aa5157 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -1422,6 +1422,28 @@ abstract class Dataset[T] extends Serializable { */ def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func)) + /** + * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] + + /** + * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = { + groupByKey(ToScalaUDF(func))(encoder) + } + /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 61f9e6ff7c042..ef628ca612b49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF} +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter @@ -865,24 +865,7 @@ class Dataset[T] private[sql]( Filter(condition.expr, logicalPlan) } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) @@ -914,13 +897,7 @@ class Dataset[T] private[sql]( rdd.reduce(func) } - /** - * (Scala-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -933,16 +910,6 @@ class Dataset[T] private[sql]( withGroupingKey.newColumns) } - /** - * (Java-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -1640,28 +1607,7 @@ class Dataset[T] private[sql]( new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into - * a target table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -2024,6 +1970,12 @@ class Dataset[T] private[sql]( @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + //////////////////////////////////////////////////////////////////////////// // For Python API //////////////////////////////////////////////////////////////////////////// From 94dca78c128ff3d1571326629b4100ee092afb54 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 19 Sep 2024 21:10:52 +0800 Subject: [PATCH 021/250] [SPARK-49693][PYTHON][CONNECT] Refine the string representation of `timedelta` ### What changes were proposed in this pull request? Refine the string representation of `timedelta`, by following the ISO format. Note that the used units in JVM side (`Duration`) and Pandas are different. ### Why are the changes needed? We should not leak the raw data ### Does this PR introduce _any_ user-facing change? yes PySpark Classic: ``` In [1]: from pyspark.sql import functions as sf In [2]: import datetime In [3]: sf.lit(datetime.timedelta(1, 1)) Out[3]: Column<'PT24H1S'> ``` PySpark Connect (before): ``` In [1]: from pyspark.sql import functions as sf In [2]: import datetime In [3]: sf.lit(datetime.timedelta(1, 1)) Out[3]: Column<'86401000000'> ``` PySpark Connect (after): ``` In [1]: from pyspark.sql import functions as sf In [2]: import datetime In [3]: sf.lit(datetime.timedelta(1, 1)) Out[3]: Column<'P1DT0H0M1S'> ``` ### How was this patch tested? added test ### Was this patch authored or co-authored using generative AI tooling? no Closes #48159 from zhengruifeng/pc_lit_delta. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/connect/expressions.py | 12 +++++++++++- python/pyspark/sql/tests/test_column.py | 23 ++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 63128ef48e389..0b5512b61925c 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -489,7 +489,17 @@ def __repr__(self) -> str: ts = TimestampNTZType().fromInternal(self._value) if ts is not None and isinstance(ts, datetime.datetime): return ts.strftime("%Y-%m-%d %H:%M:%S.%f") - # TODO(SPARK-49693): Refine the string representation of timedelta + elif isinstance(self._dataType, DayTimeIntervalType): + delta = DayTimeIntervalType().fromInternal(self._value) + if delta is not None and isinstance(delta, datetime.timedelta): + import pandas as pd + + # Note: timedelta itself does not provide isoformat method. + # Both Pandas and java.time.Duration provide it, but the format + # is sightly different: + # java.time.Duration only applies HOURS, MINUTES, SECONDS units, + # while Pandas applies all supported units. + return pd.Timedelta(delta).isoformat() # type: ignore[attr-defined] return f"{self._value}" diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 220ecd387f7ee..1972dd2804d98 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -19,12 +19,13 @@ from enum import Enum from itertools import chain import datetime +import unittest from pyspark.sql import Column, Row from pyspark.sql import functions as sf from pyspark.sql.types import StructType, StructField, IntegerType, LongType from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError -from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, pandas_requirement_message class ColumnTestsMixin: @@ -289,6 +290,26 @@ def test_lit_time_representation(self): ts = datetime.datetime(2021, 3, 4, 12, 34, 56, 1234) self.assertEqual(str(sf.lit(ts)), "Column<'2021-03-04 12:34:56.001234'>") + @unittest.skipIf(not have_pandas, pandas_requirement_message) + def test_lit_delta_representation(self): + for delta in [ + datetime.timedelta(days=1), + datetime.timedelta(hours=2), + datetime.timedelta(minutes=3), + datetime.timedelta(seconds=4), + datetime.timedelta(microseconds=5), + datetime.timedelta(days=2, hours=21, microseconds=908), + datetime.timedelta(days=1, minutes=-3, microseconds=-1001), + datetime.timedelta(days=1, hours=2, minutes=3, seconds=4, microseconds=5), + ]: + import pandas as pd + + # Column<'PT69H0.000908S'> or Column<'P2DT21H0M0.000908S'> + s = str(sf.lit(delta)) + + # Parse the ISO string representation and compare + self.assertTrue(pd.Timedelta(s[8:-2]).to_pytimedelta() == delta) + def test_enum_literals(self): class IntEnum(Enum): X = 1 From f0fb0c89ec29b587569d68a824c4ce7543721c06 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 19 Sep 2024 10:06:45 -0700 Subject: [PATCH 022/250] [SPARK-49719][SQL] Make `UUID` and `SHUFFLE` accept integer `seed` ### What changes were proposed in this pull request? Make `UUID` and `SHUFFLE` accept integer `seed` ### Why are the changes needed? In most cases, `seed` accept both int and long, but `UUID` and `SHUFFLE` only accept long seed ```py In [1]: spark.sql("SELECT RAND(1L), RAND(1), SHUFFLE(array(1, 20, 3, 5), 1L), UUID(1L)").show() +------------------+------------------+---------------------------+--------------------+ | rand(1)| rand(1)|shuffle(array(1, 20, 3, 5))| uuid()| +------------------+------------------+---------------------------+--------------------+ |0.6363787615254752|0.6363787615254752| [20, 1, 3, 5]|1ced31d7-59ef-4bb...| +------------------+------------------+---------------------------+--------------------+ In [2]: spark.sql("SELECT UUID(1)").show() ... AnalysisException: [INVALID_PARAMETER_VALUE.LONG] The value of parameter(s) `seed` in `UUID` is invalid: expects a long literal, but got "1". SQLSTATE: 22023; line 1 pos 7 ... In [3]: spark.sql("SELECT SHUFFLE(array(1, 20, 3, 5), 1)").show() ... AnalysisException: [INVALID_PARAMETER_VALUE.LONG] The value of parameter(s) `seed` in `shuffle` is invalid: expects a long literal, but got "1". SQLSTATE: 22023; line 1 pos 7 ... ``` ### Does this PR introduce _any_ user-facing change? yes after this fix: ```py In [2]: spark.sql("SELECT SHUFFLE(array(1, 20, 3, 5), 1L), SHUFFLE(array(1, 20, 3, 5), 1), UUID(1L), UUID(1)").show() +---------------------------+---------------------------+--------------------+--------------------+ |shuffle(array(1, 20, 3, 5))|shuffle(array(1, 20, 3, 5))| uuid()| uuid()| +---------------------------+---------------------------+--------------------+--------------------+ | [20, 1, 3, 5]| [20, 1, 3, 5]|1ced31d7-59ef-4bb...|1ced31d7-59ef-4bb...| +---------------------------+---------------------------+--------------------+--------------------+ ``` ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48166 from zhengruifeng/int_seed. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- .../sql/catalyst/expressions/randomExpressions.scala | 1 + .../catalyst/expressions/CollectionExpressionsSuite.scala | 8 ++++++++ .../sql/catalyst/expressions/MiscExpressionsSuite.scala | 7 +++++++ 3 files changed, 16 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index ea9ca451c2cb1..f329f8346b0de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -81,6 +81,7 @@ trait ExpressionWithRandomSeed extends Expression { private[catalyst] object ExpressionWithRandomSeed { def expressionToSeed(e: Expression, source: String): Option[Long] = e match { + case IntegerLiteral(seed) => Some(seed) case LongLiteral(seed) => Some(seed) case Literal(null, _) => None case _ => throw QueryCompilationErrors.invalidRandomSeedParameter(source, e) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index e9de59b3ec48c..55148978fa005 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2293,6 +2293,14 @@ class CollectionExpressionsSuite evaluateWithMutableProjection(Shuffle(ai0, seed2))) assert(evaluateWithUnsafeProjection(Shuffle(ai0, seed1)) !== evaluateWithUnsafeProjection(Shuffle(ai0, seed2))) + + val seed3 = Literal.create(r.nextInt()) + assert(evaluateWithoutCodegen(new Shuffle(ai0, seed3)) === + evaluateWithoutCodegen(new Shuffle(ai0, seed3))) + assert(evaluateWithMutableProjection(new Shuffle(ai0, seed3)) === + evaluateWithMutableProjection(new Shuffle(ai0, seed3))) + assert(evaluateWithUnsafeProjection(new Shuffle(ai0, seed3)) === + evaluateWithUnsafeProjection(new Shuffle(ai0, seed3))) } test("Array Except") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala index 1f37886f44258..40e6fe1a90a63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscExpressionsSuite.scala @@ -71,6 +71,13 @@ class MiscExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { evaluateWithMutableProjection(Uuid(seed2))) assert(evaluateWithUnsafeProjection(Uuid(seed1)) !== evaluateWithUnsafeProjection(Uuid(seed2))) + + val seed3 = Literal.create(r.nextInt()) + assert(evaluateWithoutCodegen(new Uuid(seed3)) === evaluateWithoutCodegen(new Uuid(seed3))) + assert(evaluateWithMutableProjection(new Uuid(seed3)) === + evaluateWithMutableProjection(new Uuid(seed3))) + assert(evaluateWithUnsafeProjection(new Uuid(seed3)) === + evaluateWithUnsafeProjection(new Uuid(seed3))) } test("PrintToStderr") { From 92cad2abd54e775259dc36d2f90242460d72a174 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Thu, 19 Sep 2024 10:09:36 -0700 Subject: [PATCH 023/250] [SPARK-49716][PS][DOCS][TESTS] Fix documentation and add test of barh plot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? - Update the documentation for barh plot to clarify the difference between axis interpretation in Plotly and Matplotlib. - Test multiple columns as value axis. The parameter difference is demonstrated as below. ```py >>> df = ps.DataFrame({'lab': ['A', 'B', 'C'], 'val': [10, 30, 20]}) >>> df.plot.barh(x='val', y='lab').show() # plot1 >>> ps.set_option('plotting.backend', 'matplotlib') >>> import matplotlib.pyplot as plt >>> df.plot.barh(x='lab', y='val') >>> plt.show() # plot2 ``` plot1 ![newplot (5)](https://github.com/user-attachments/assets/f1b6fabe-9509-41bb-8cfb-0733f65f1643) plot2 ![Figure_1](https://github.com/user-attachments/assets/10e1b65f-6116-4490-9956-29e1fbf0c053) ### Why are the changes needed? The barh plot’s x and y axis behavior differs between Plotly and Matplotlib, which may confuse users. The updated documentation and tests help ensure clarity and prevent misinterpretation. ### Does this PR introduce _any_ user-facing change? No. Doc change only. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48161 from xinrong-meng/ps_barh. Authored-by: Xinrong Meng Signed-off-by: Dongjoon Hyun --- python/pyspark/pandas/plot/core.py | 13 ++++++++++--- .../pandas/tests/plot/test_frame_plot_plotly.py | 5 +++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 7630ecc398954..429e97ecf07bb 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -756,10 +756,10 @@ def barh(self, x=None, y=None, **kwargs): Parameters ---------- - x : label or position, default DataFrame.index - Column to be used for categories. - y : label or position, default All numeric columns in dataframe + x : label or position, default All numeric columns in dataframe Columns to be plotted from the DataFrame. + y : label or position, default DataFrame.index + Column to be used for categories. **kwds Keyword arguments to pass on to :meth:`pyspark.pandas.DataFrame.plot` or :meth:`pyspark.pandas.Series.plot`. @@ -770,6 +770,13 @@ def barh(self, x=None, y=None, **kwargs): Return an custom object when ``backend!=plotly``. Return an ndarray when ``subplots=True`` (matplotlib-only). + Notes + ----- + In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. + In Plotly, `x` refers to the values and `y` refers to the categories. + In Matplotlib, `x` refers to the categories and `y` refers to the values. + Ensure correct axis labeling based on the backend used. + See Also -------- plotly.express.bar : Plot a vertical bar plot using plotly. diff --git a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py index 37469db2c8f51..8d197649aaebe 100644 --- a/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/pandas/tests/plot/test_frame_plot_plotly.py @@ -105,9 +105,10 @@ def check_barh_plot_with_x_y(pdf, psdf, x, y): self.assertEqual(pdf.plot.barh(x=x, y=y), psdf.plot.barh(x=x, y=y)) # this is testing plot with specified x and y - pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20]}) + pdf1 = pd.DataFrame({"lab": ["A", "B", "C"], "val": [10, 30, 20], "val2": [1.1, 2.2, 3.3]}) psdf1 = ps.from_pandas(pdf1) - check_barh_plot_with_x_y(pdf1, psdf1, x="lab", y="val") + check_barh_plot_with_x_y(pdf1, psdf1, x="val", y="lab") + check_barh_plot_with_x_y(pdf1, psdf1, x=["val", "val2"], y="lab") def test_barh_plot(self): def check_barh_plot(pdf, psdf): From 6d1815eceea2003de2e3602f0f64e8188e8288d8 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 19 Sep 2024 12:31:48 -0700 Subject: [PATCH 024/250] [SPARK-49718][PS] Switch `Scatter` plot to sampled data ### What changes were proposed in this pull request? Switch `Scatter` plot to sampled data ### Why are the changes needed? when the data distribution has relationship with the order, the first n rows will not be representative of the whole dataset for example: ``` import pandas as pd import numpy as np import pyspark.pandas as ps # ps.set_option("plotting.max_rows", 10000) np.random.seed(123) pdf = pd.DataFrame(np.random.randn(10000, 4), columns=list('ABCD')).sort_values("A") psdf = ps.DataFrame(pdf) psdf.plot.scatter(x='B', y='A') ``` all 10k datapoints: ![image](https://github.com/user-attachments/assets/72cf7e97-ad10-41e0-a8a6-351747d5285f) before (first 1k datapoints): ![image](https://github.com/user-attachments/assets/1ed50d2c-7772-4579-a84c-6062542d9367) after (sampled 1k datapoints): ![image](https://github.com/user-attachments/assets/6c684cba-4119-4c38-8228-2bedcdeb9e59) ### Does this PR introduce _any_ user-facing change? yes ### How was this patch tested? ci and manually test ### Was this patch authored or co-authored using generative AI tooling? no Closes #48164 from zhengruifeng/ps_scatter_sampling. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/pandas/plot/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 429e97ecf07bb..6f036b7669246 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -479,7 +479,7 @@ class PandasOnSparkPlotAccessor(PandasObject): "pie": TopNPlotBase().get_top_n, "bar": TopNPlotBase().get_top_n, "barh": TopNPlotBase().get_top_n, - "scatter": TopNPlotBase().get_top_n, + "scatter": SampledPlotBase().get_sampled, "area": SampledPlotBase().get_sampled, "line": SampledPlotBase().get_sampled, } From 04455797bfb3631b13b41cfa5d2604db3bf8acc2 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 19 Sep 2024 12:32:30 -0700 Subject: [PATCH 025/250] [SPARK-49720][PYTHON][INFRA] Add a script to clean up PySpark temp files ### What changes were proposed in this pull request? Add a script to clean up PySpark temp files ### Why are the changes needed? Sometimes I encounter weird issues due to the out-dated `pyspark.zip` file, and removing it can result in expected behavior. So I think we can add such a script. ### Does this PR introduce _any_ user-facing change? no, dev-only ### How was this patch tested? manually test ### Was this patch authored or co-authored using generative AI tooling? no Closes #48167 from zhengruifeng/py_infra_cleanup. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- dev/py-cleanup | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100755 dev/py-cleanup diff --git a/dev/py-cleanup b/dev/py-cleanup new file mode 100755 index 0000000000000..6a2edd1040171 --- /dev/null +++ b/dev/py-cleanup @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility for temporary files cleanup in 'python'. +# usage: ./dev/py-cleanup + +set -ex + +SPARK_HOME="$(cd "`dirname $0`"/..; pwd)" +cd "$SPARK_HOME" + +rm -rf python/target +rm -rf python/lib/pyspark.zip +rm -rf python/docs/build +rm -rf python/docs/source/reference/*/api From ca726c10925a3677bf057f65ecf415e608c63cd5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 19 Sep 2024 17:16:25 -0700 Subject: [PATCH 026/250] [SPARK-49721][BUILD] Upgrade `protobuf-java` to 3.25.5 ### What changes were proposed in this pull request? This PR aims to upgrade `protobuf-java` to 3.25.5. ### Why are the changes needed? To bring the latest bug fixes. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48170 Closes #48171 from dongjoon-hyun/SPARK-49721. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- pom.xml | 2 +- project/SparkBuild.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pom.xml b/pom.xml index 694ea31e6f377..ddabc82d2ad13 100644 --- a/pom.xml +++ b/pom.xml @@ -124,7 +124,7 @@ 3.4.0 - 3.25.4 + 3.25.5 3.11.4 ${hadoop.version} 3.9.2 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index d93a52985b772..2f390cb70baa8 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -89,7 +89,7 @@ object BuildCommons { // Google Protobuf version used for generating the protobuf. // SPARK-41247: needs to be consistent with `protobuf.version` in `pom.xml`. - val protoVersion = "3.25.4" + val protoVersion = "3.25.5" // GRPC version used for Spark Connect. val grpcVersion = "1.62.2" } From a5ac80af8e94afe56105c265a94d02ef878e1de9 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 20 Sep 2024 08:29:48 +0800 Subject: [PATCH 027/250] [SPARK-49713][PYTHON][CONNECT] Make function `count_min_sketch` accept number arguments ### What changes were proposed in this pull request? 1, Make function `count_min_sketch` accept number arguments; 2, Make argument `seed` optional; 3, fix the type hints of `eps/confidence/seed` from `ColumnOrName` to `Column`, because they require a foldable value and actually do not accept column name: ``` In [3]: from pyspark.sql import functions as sf In [4]: df = spark.range(10000).withColumn("seed", sf.lit(1).cast("int")) In [5]: df.select(sf.hex(sf.count_min_sketch("id", sf.lit(0.5), sf.lit(0.5), "seed"))) ... AnalysisException: [DATATYPE_MISMATCH.NON_FOLDABLE_INPUT] Cannot resolve "count_min_sketch(id, 0.5, 0.5, seed)" due to data type mismatch: the input `seed` should be a foldable "INT" expression; however, got "seed". SQLSTATE: 42K09; 'Aggregate [unresolvedalias('hex(count_min_sketch(id#1L, 0.5, 0.5, seed#2, 0, 0)))] +- Project [id#1L, cast(1 as int) AS seed#2] +- Range (0, 10000, step=1, splits=Some(12)) ... ``` ### Why are the changes needed? 1, seed is optional in other similar functions; 2, existing type hint is `ColumnOrName` which is misleading since column name is not actually supported ### Does this PR introduce _any_ user-facing change? yes, it support number arguments ### How was this patch tested? updated doctests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48157 from zhengruifeng/py_fix_count_min_sketch. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../pyspark/sql/connect/functions/builtin.py | 10 +-- python/pyspark/sql/functions/builtin.py | 71 +++++++++++++++---- .../org/apache/spark/sql/functions.scala | 12 ++++ 3 files changed, 77 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 2870d9c408b6b..7fed175cbc8ea 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -71,6 +71,7 @@ StringType, ) from pyspark.sql.utils import enum_to_value as _enum_to_value +from pyspark.util import JVM_INT_MAX # The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf # for code reuse. @@ -1126,11 +1127,12 @@ def grouping_id(*cols: "ColumnOrName") -> Column: def count_min_sketch( col: "ColumnOrName", - eps: "ColumnOrName", - confidence: "ColumnOrName", - seed: "ColumnOrName", + eps: Union[Column, float], + confidence: Union[Column, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: - return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed) + _seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed) + return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed) count_min_sketch.__doc__ = pysparkfuncs.count_min_sketch.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index c0730b193bc72..5f8d1c21a24f1 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -6015,9 +6015,9 @@ def grouping_id(*cols: "ColumnOrName") -> Column: @_try_remote_functions def count_min_sketch( col: "ColumnOrName", - eps: "ColumnOrName", - confidence: "ColumnOrName", - seed: "ColumnOrName", + eps: Union[Column, float], + confidence: Union[Column, float], + seed: Optional[Union[Column, int]] = None, ) -> Column: """ Returns a count-min sketch of a column with the given esp, confidence and seed. @@ -6031,13 +6031,24 @@ def count_min_sketch( ---------- col : :class:`~pyspark.sql.Column` or str target column to compute on. - eps : :class:`~pyspark.sql.Column` or str + eps : :class:`~pyspark.sql.Column` or float relative error, must be positive - confidence : :class:`~pyspark.sql.Column` or str + + .. versionchanged:: 4.0.0 + `eps` now accepts float value. + + confidence : :class:`~pyspark.sql.Column` or float confidence, must be positive and less than 1.0 - seed : :class:`~pyspark.sql.Column` or str + + .. versionchanged:: 4.0.0 + `confidence` now accepts float value. + + seed : :class:`~pyspark.sql.Column` or int, optional random seed + .. versionchanged:: 4.0.0 + `seed` now accepts int value. + Returns ------- :class:`~pyspark.sql.Column` @@ -6045,12 +6056,48 @@ def count_min_sketch( Examples -------- - >>> df = spark.createDataFrame([[1], [2], [1]], ['data']) - >>> df = df.agg(count_min_sketch(df.data, lit(0.5), lit(0.5), lit(1)).alias('sketch')) - >>> df.select(hex(df.sketch).alias('r')).collect() - [Row(r='0000000100000000000000030000000100000004000000005D8D6AB90000000000000000000000000000000200000000000000010000000000000000')] - """ - return _invoke_function_over_columns("count_min_sketch", col, eps, confidence, seed) + Example 1: Using columns as arguments + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch(sf.col("id"), sf.lit(3.0), sf.lit(0.1), sf.lit(1))) + ... ).show(truncate=False) + +------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 3.0, 0.1, 1)) | + +------------------------------------------------------------------------+ + |0000000100000000000000640000000100000001000000005D8D6AB90000000000000064| + +------------------------------------------------------------------------+ + + Example 2: Using numbers as arguments + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", 1.0, 0.3, 2)) + ... ).show(truncate=False) + +----------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.0, 0.3, 2)) | + +----------------------------------------------------------------------------------------+ + |0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032| + +----------------------------------------------------------------------------------------+ + + Example 3: Using a random seed + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6)) + ... ).show(truncate=False) # doctest: +SKIP + +----------------------------------------------------------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.5, 0.6, 2120704260)) | + +----------------------------------------------------------------------------------------------------------------------------------------+ + |0000000100000000000000640000000200000002000000005ADECCEE00000000153EBE090000000000000033000000000000003100000000000000320000000000000032| + +----------------------------------------------------------------------------------------------------------------------------------------+ + """ # noqa: E501 + _eps = lit(eps) + _conf = lit(confidence) + if seed is None: + return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf) + else: + return _invoke_function_over_columns("count_min_sketch", col, _eps, _conf, lit(seed)) @_try_remote_functions diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 02669270c8acf..0662b8f2b271f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -389,6 +389,18 @@ object functions { def count_min_sketch(e: Column, eps: Column, confidence: Column, seed: Column): Column = Column.fn("count_min_sketch", e, eps, confidence, seed) + /** + * Returns a count-min sketch of a column with the given esp, confidence and seed. The result is + * an array of bytes, which can be deserialized to a `CountMinSketch` before usage. Count-min + * sketch is a probabilistic data structure used for cardinality estimation using sub-linear + * space. + * + * @group agg_funcs + * @since 4.0.0 + */ + def count_min_sketch(e: Column, eps: Column, confidence: Column): Column = + count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt)) + private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) From d4665fa1df716305acb49912d41c396b39343c93 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Fri, 20 Sep 2024 14:11:14 +0900 Subject: [PATCH 028/250] [SPARK-49677][SS] Ensure that changelog files are written on commit and forceSnapshot flag is also reset ### What changes were proposed in this pull request? Ensure that changelog files are written on commit and forceSnapshot flag is also reset ### Why are the changes needed? Without these changes, we are not writing the changelog files per batch and we are also trying to upload full snapshot each time since the flag is not being reset correctly ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests Before: ``` [info] Run completed in 3 seconds, 438 milliseconds. [info] Total number of tests run: 1 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 0, failed 1, canceled 0, ignored 0, pending 0 [info] *** 1 TEST FAILED *** ``` After: ``` [info] Run completed in 4 seconds, 155 milliseconds. [info] Total number of tests run: 1 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 1, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48125 from anishshri-db/task/SPARK-49677. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../execution/streaming/state/RocksDB.scala | 16 ++++---- .../streaming/state/RocksDBSuite.scala | 41 +++++++++++++++++++ 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 81e80629092a0..4a2aac43b3331 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -646,15 +646,15 @@ class RocksDB( // is enabled. if (shouldForceSnapshot.get()) { uploadSnapshot() + shouldForceSnapshot.set(false) + } + + // ensure that changelog files are always written + try { + assert(changelogWriter.isDefined) + changelogWriter.foreach(_.commit()) + } finally { changelogWriter = None - changelogWriter.foreach(_.abort()) - } else { - try { - assert(changelogWriter.isDefined) - changelogWriter.foreach(_.commit()) - } finally { - changelogWriter = None - } } } else { assert(changelogWriter.isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 691f18451af22..608a22a284b6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -811,6 +811,47 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared } } + testWithChangelogCheckpointingEnabled("RocksDB: ensure that changelog files are written " + + "and snapshots uploaded optionally with changelog format v2") { + withTempDir { dir => + val remoteDir = Utils.createTempDir().toString + val conf = dbConf.copy(minDeltasForSnapshot = 5, compactOnCommit = false) + new File(remoteDir).delete() // to make sure that the directory gets created + withDB(remoteDir, conf = conf, useColumnFamilies = true) { db => + db.createColFamilyIfAbsent("test") + db.load(0) + db.put("a", "1") + db.put("b", "2") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + db.load(1) + db.put("a", "3") + db.put("c", "4") + db.commit() + + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1)) + + db.removeColFamilyIfExists("test") + db.load(2) + db.remove("a") + db.put("d", "5") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2, 3)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1, 3)) + + db.load(3) + db.put("e", "6") + db.remove("b") + db.commit() + assert(changelogVersionsPresent(remoteDir) == Seq(1, 2, 3, 4)) + assert(snapshotVersionsPresent(remoteDir) == Seq(1, 3)) + } + } + } + test("RocksDB: ensure merge operation correctness") { withTempDir { dir => val remoteDir = Utils.createTempDir().toString From 6352c12f607bc092c33f1f29174d6699f8312380 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 20 Sep 2024 15:29:08 +0900 Subject: [PATCH 029/250] [MINOR][INFRA] Disable 'pages build and deployment' action ### What changes were proposed in this pull request? Disable https://github.com/apache/spark/actions/runs/10951008649/ via: > adding a .nojekyll file to the root of your source branch will bypass the Jekyll build process and deploy the content directly. https://docs.github.com/en/pages/quickstart ### Why are the changes needed? restore ci ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? no ### Was this patch authored or co-authored using generative AI tooling? no Closes #48176 from yaooqinn/action. Authored-by: Kent Yao Signed-off-by: Hyukjin Kwon --- .nojekyll | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 .nojekyll diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000000000..e69de29bb2d1d From c009cd061c4923955a1e7ec9bf6c045f93d27ef7 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Fri, 20 Sep 2024 09:16:04 +0200 Subject: [PATCH 030/250] [SPARK-49392][SQL][FOLLOWUP] Catch errors when failing to write to external data source ### What changes were proposed in this pull request? Change `sqlState` to KD010. ### Why are the changes needed? Necessary modification for the Databricks error class space. ### Does this PR introduce _any_ user-facing change? Yes, the new error message is now updated to KD010. ### How was this patch tested? Existing tests (updated). ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48165 from uros-db/external-data-source-fix. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- common/utils/src/main/resources/error/error-conditions.json | 2 +- common/utils/src/main/resources/error/error-states.json | 2 +- .../apache/spark/sql/errors/QueryCompilationErrorsSuite.scala | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 72985de6631f0..e83202d9e5ee3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1055,7 +1055,7 @@ "message" : [ "Encountered error when saving to external data source." ], - "sqlState" : "KD00F" + "sqlState" : "KD010" }, "DATA_SOURCE_NOT_EXIST" : { "message" : [ diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index edba6e1d43216..87811fef9836e 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -7417,7 +7417,7 @@ "standard": "N", "usedBy": ["Databricks"] }, - "KD00F": { + "KD010": { "description": "external data source failure", "origin": "Databricks", "standard": "N", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 370c118de9a93..832e1873af6a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -941,7 +941,7 @@ class QueryCompilationErrorsSuite cmd.run(spark) }, condition = "DATA_SOURCE_EXTERNAL_ERROR", - sqlState = "KD00F", + sqlState = "KD010", parameters = Map.empty ) } From b37863d2327131c670fe791576a907bcb5243cd6 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 20 Sep 2024 16:40:36 +0900 Subject: [PATCH 031/250] [MINOR][FOLLOWUP] Fix rat check for .nojekyll ### What changes were proposed in this pull request? Fix rat check for .nojekyll ### Why are the changes needed? CI fix ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? dev/check-license Ignored 1 lines in your exclusion files as comments or empty lines. RAT checks passed. ### Was this patch authored or co-authored using generative AI tooling? no Closes #48178 from yaooqinn/f. Authored-by: Kent Yao Signed-off-by: Hyukjin Kwon --- dev/.rat-excludes | 1 + 1 file changed, 1 insertion(+) diff --git a/dev/.rat-excludes b/dev/.rat-excludes index f38fd7e2012a5..b82cb7078c9f3 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -140,3 +140,4 @@ ui-test/package.json ui-test/package-lock.json core/src/main/resources/org/apache/spark/ui/static/package.json .*\.har +.nojekyll From 46b0210edb4ef8490ee4bbc4a40baf202a531b33 Mon Sep 17 00:00:00 2001 From: Nick Young Date: Fri, 20 Sep 2024 18:05:28 +0900 Subject: [PATCH 032/250] [SPARK-49699][SS] Disable PruneFilters for streaming workloads ### What changes were proposed in this pull request? The PR proposes to disable PruneFilters if the predicate of the filter is evaluated to `null` / `false` and the filter (and subtree) is streaming. ### Why are the changes needed? PruneFilters replaces the `null` / `false` filter with an empty relation, which means the subtree of the filter is also lost. The optimization does not care about whichever operator is in the subtree, hence some important operators like stateful operator, watermark node, observe node could be lost. The filter could be evaluated to `null` / `false` selectively among microbatches in various reasons (one simple example is the modification of the query during restart), which means stateful operator might not be available for batch N and be available for batch N + 1. For this case, streaming query will fail as batch N + 1 cannot load the state from batch N, and it's not recoverable in most cases. See new tests in StreamingQueryOptimizationCorrectnessSuite for details. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48149 from n-young-db/n-young-db/disable-streaming-prune-filters. Lead-authored-by: Nick Young Co-authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../sql/catalyst/optimizer/Optimizer.scala | 7 +- .../apache/spark/sql/internal/SQLConf.scala | 9 +++ .../PropagateEmptyRelationSuite.scala | 27 ++++++-- .../optimizer/PruneFiltersSuite.scala | 34 ++++++++++ .../sql/execution/streaming/OffsetSeq.scala | 7 +- ...ingQueryOptimizationCorrectnessSuite.scala | 64 ++++++++++++++++++- 6 files changed, 137 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6ceeeb9bfdf38..8e14537c6a5b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1723,15 +1723,18 @@ object EliminateSorts extends Rule[LogicalPlan] { * 3) by eliminating the always-true conditions given the constraints on the child's output. */ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { + private def shouldApply(child: LogicalPlan): Boolean = + SQLConf.get.getConf(SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN) || !child.isStreaming + def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(FILTER), ruleId) { // If the filter condition always evaluate to true, remove the filter. case Filter(Literal(true, BooleanType), child) => child // If the filter condition always evaluate to null or false, // replace the input with an empty relation. - case Filter(Literal(null, _), child) => + case Filter(Literal(null, _), child) if shouldApply(child) => LocalRelation(child.output, data = Seq.empty, isStreaming = child.isStreaming) - case Filter(Literal(false, BooleanType), child) => + case Filter(Literal(false, BooleanType), child) if shouldApply(child) => LocalRelation(child.output, data = Seq.empty, isStreaming = child.isStreaming) // If any deterministic condition is guaranteed to be true given the constraints on the child's // output, remove the condition diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 094fb8f050bc8..2eaafde52228b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3827,6 +3827,15 @@ object SQLConf { .intConf .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + val PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN = + buildConf("spark.databricks.sql.optimizer.pruneFiltersCanPruneStreamingSubplan") + .internal() + .doc("Allow PruneFilters to remove streaming subplans when we encounter a false filter. " + + "This flag is to restore prior buggy behavior for broken pipelines.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala index 5aeb27f7ee6b4..451236162343b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelationSuite.scala @@ -27,12 +27,13 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{Expand, Filter, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegerType, MetadataBuilder} -class PropagateEmptyRelationSuite extends PlanTest { +class PropagateEmptyRelationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("PropagateEmptyRelation", Once, + Batch("PropagateEmptyRelation", FixedPoint(1), CombineUnions, ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, @@ -45,7 +46,7 @@ class PropagateEmptyRelationSuite extends PlanTest { object OptimizeWithoutPropagateEmptyRelation extends RuleExecutor[LogicalPlan] { val batches = - Batch("OptimizeWithoutPropagateEmptyRelation", Once, + Batch("OptimizeWithoutPropagateEmptyRelation", FixedPoint(1), CombineUnions, ReplaceDistinctWithAggregate, ReplaceExceptWithAntiJoin, @@ -216,10 +217,24 @@ class PropagateEmptyRelationSuite extends PlanTest { .where($"a" =!= 200) .orderBy($"a".asc) - val optimized = Optimize.execute(query.analyze) - val correctAnswer = LocalRelation(output, isStreaming = true) + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true") { + val optimized = Optimize.execute(query.analyze) + val correctAnswer = LocalRelation(output, isStreaming = true) + comparePlans(optimized, correctAnswer) + } - comparePlans(optimized, correctAnswer) + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "false") { + val optimized = Optimize.execute(query.analyze) + val correctAnswer = relation + .where(false) + .where($"a" > 1) + .select($"a") + .where($"a" =!= 200) + .orderBy($"a".asc).analyze + comparePlans(optimized, correctAnswer) + } } test("SPARK-47305 correctly tag isStreaming when propagating empty relation " + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index b81a57f4f8cd5..66ded338340f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -174,4 +174,38 @@ class PruneFiltersSuite extends PlanTest { testRelation.where(!$"a".attr.in(1, 3, 5) && $"a".attr === 7 && $"b".attr === 1) .where(Rand(10) > 0.1 && Rand(10) < 1.1).analyze) } + + test("Streaming relation is not lost under true filter") { + Seq("true", "false").foreach(x => withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> x) { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 > 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.select($"a").analyze + comparePlans(optimized, correctAnswer) + }) + } + + test("Streaming relation is not lost under false filter") { + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true") { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 < 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.select($"a").analyze + comparePlans(optimized, correctAnswer) + } + + withSQLConf( + SQLConf.PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "false") { + val streamingRelation = + LocalRelation(Seq($"a".int, $"b".int, $"c".int), Nil, isStreaming = true) + val originalQuery = streamingRelation.where(10 < 5).select($"a").analyze + val optimized = Optimize.execute(originalQuery) + val correctAnswer = streamingRelation.where(10 < 5).select($"a").analyze + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index d5facc245e72f..e1e5b3a7ef88e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -101,7 +101,9 @@ object OffsetSeqMetadata extends Logging { SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY, FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION, STREAMING_JOIN_STATE_FORMAT_VERSION, STATE_STORE_COMPRESSION_CODEC, - STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION) + STATE_STORE_ROCKSDB_FORMAT_VERSION, STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION, + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN + ) /** * Default values of relevant configurations that are used for backward compatibility. @@ -122,7 +124,8 @@ object OffsetSeqMetadata extends Logging { STREAMING_JOIN_STATE_FORMAT_VERSION.key -> SymmetricHashJoinStateManager.legacyVersion.toString, STATE_STORE_COMPRESSION_CODEC.key -> CompressionCodec.LZ4, - STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false" + STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION.key -> "false", + PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN.key -> "true" ) def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala index 782badaef924f..f651bfb7f3c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryOptimizationCorrectnessSuite.scala @@ -21,7 +21,7 @@ import java.sql.Timestamp import org.apache.spark.sql.Row import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.functions.{expr, lit, window} +import org.apache.spark.sql.functions.{count, expr, lit, timestamp_seconds, window} import org.apache.spark.sql.internal.SQLConf /** @@ -524,4 +524,66 @@ class StreamingQueryOptimizationCorrectnessSuite extends StreamTest { doTest(numExpectedStatefulOperatorsForOneEmptySource = 1) } } + + test("SPARK-49699: observe node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .observe("observation", count(lit(1)).as("rows")) + // Enforce PruneFilters to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + val observeRow = qe.lastExecution.observedMetrics.get("observation") + assert(observeRow.get.getAs[Long]("rows") == 3L) + } + ) + } + + test("SPARK-49699: watermark node is not pruned out from PruneFilters") { + // NOTE: The test actually passes without SPARK-49699, because of the trickiness of + // filter pushdown and PruneFilters. Unlike observe node, the `false` filter is pushed down + // below to watermark node, hence PruneFilters rule does not prune out watermark node even + // before SPARK-49699. Propagate empty relation does not also propagate emptiness into + // watermark node, so the node is retained. The test is added for preventing regression. + + val input1 = MemoryStream[Int] + val df = input1.toDF() + .withColumn("eventTime", timestamp_seconds($"value")) + .withWatermark("eventTime", "0 second") + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df)( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + // If the watermark node is pruned out, this would be null. + assert(qe.lastProgress.eventTime.get("watermark") != null) + } + ) + } + + test("SPARK-49699: stateful operator node is not pruned out from PruneFilters") { + val input1 = MemoryStream[Int] + val df = input1.toDF() + .groupBy("value") + .count() + // Enforce PruneFilter to come into play and prune subtree. We could do the same + // with the reproducer of SPARK-48267, but let's just be simpler. + .filter(expr("false")) + + testStream(df, OutputMode.Complete())( + AddData(input1, 1, 2, 3), + CheckNewAnswer(), + Execute { qe => + assert(qe.lastProgress.stateOperators.length == 1) + } + ) + } } From 4d97574425e603a7c6ac42a419747922bb1f83f9 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 20 Sep 2024 15:15:14 +0200 Subject: [PATCH 033/250] [SPARK-49733][SQL][DOCS] Delete `ExpressionInfo[between]` from `gen-sql-api-docs.py` to avoid duplication ### What changes were proposed in this pull request? The pr aims to delete `ExpressionInfo[between]` from `gen-sql-api-docs.py` to avoid duplication. ### Why are the changes needed? - In the following doc, `between` is repeatedly displayed `twice` https://spark.apache.org/docs/preview/api/sql/index.html#between image After the pr: image - After https://github.com/apache/spark/pull/44299, the expression 'between' has been added to `Spark 4.0`. ### Does this PR introduce _any_ user-facing change? Yes, only for docs. ### How was this patch tested? Manually check. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48183 from panbingkun/SPARK-49733. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../spark/sql/catalyst/expressions/Between.scala | 2 +- sql/gen-sql-api-docs.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala index de1122da646b7..deec1ab51ad98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.internal.SQLConf * lower - Lower bound of the between check. * upper - Upper bound of the between check. """, - since = "4.0.0", + since = "1.0.0", group = "conditional_funcs") case class Between private(input: Expression, lower: Expression, upper: Expression, replacement: Expression) extends RuntimeReplaceable with InheritAnalysisRules { diff --git a/sql/gen-sql-api-docs.py b/sql/gen-sql-api-docs.py index 17631a7352a02..3d19da01b3938 100644 --- a/sql/gen-sql-api-docs.py +++ b/sql/gen-sql-api-docs.py @@ -69,19 +69,6 @@ note="", since="1.0.0", deprecated=""), - ExpressionInfo( - className="", - name="between", - usage="expr1 [NOT] BETWEEN expr2 AND expr3 - " + - "evaluate if `expr1` is [not] in between `expr2` and `expr3`.", - arguments="", - examples="\n Examples:\n " + - "> SELECT col1 FROM VALUES 1, 3, 5, 7 WHERE col1 BETWEEN 2 AND 5;\n " + - " 3\n " + - " 5", - note="", - since="1.0.0", - deprecated=""), ExpressionInfo( className="", name="case", From bb8294c649909702e9086203b2726c6f51971c9c Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 20 Sep 2024 16:09:00 +0200 Subject: [PATCH 034/250] [SPARK-49729][SQL][DOCS] Forcefully check `usage` and correct the non-standard writing of 4 expressions ### What changes were proposed in this pull request? The pr aims to - forcefully check `usage` - correct the non-standard writing of 4 expressions (`shiftleft`, `shiftright`, `shiftrightunsigned`, `between`) ### Why are the changes needed? 1.When some expressions have non-standard `usage` writing, corresponding explanations may be omitted in our documentation, such as `shiftleft` https://spark.apache.org/docs/preview/sql-ref-functions-builtin.html - Before (Note: It looks very weird to only appear in `examples` and not in the `Conditional Functions` catalog) image - After image 2.When there is an `non-standard` writing format, it fails directly in GA and can be corrected in a timely manner to avoid omissions. Refer to `Manually check` below. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA. - Manually check: ```python The usage of between is not standardized, please correct it. Refer to: `AesDecrypt` ------------------------------------------------ Jekyll 4.3.3 Please append `--trace` to the `build` command for any additional information or backtrace. ------------------------------------------------ /Users/panbingkun/Developer/spark/spark-community/docs/_plugins/build_api_docs.rb:184:in `build_sql_docs': SQL doc generation failed (RuntimeError) from /Users/panbingkun/Developer/spark/spark-community/docs/_plugins/build_api_docs.rb:225:in `' from :37:in `require' from :37:in `require' from /Users/panbingkun/Developer/spark/spark-community/docs/.local_ruby_bundle/ruby/3.3.0/gems/jekyll-4.3.3/lib/jekyll/external.rb:57:in `block in require_with_graceful_fail' from /Users/panbingkun/Developer/spark/spark-community/docs/.local_ruby_bundle/ruby/3.3.0/gems/jekyll-4.3.3/lib/jekyll/external.rb:55:in `each' from /Users/panbingkun/Developer/spark/spark-community/docs/.local_ruby_bundle/ruby/3.3.0/gems/jekyll-4.3.3/lib/jekyll/external.rb:55:in `require_with_graceful_fail' from /Users/panbingkun/Developer/spark/spark-community/docs/.local_ruby_bundle/ruby/3.3.0/gems/jekyll-4.3.3/lib/jekyll/plugin_manager.rb:96:in `block in require_plugin_files' ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48179 from panbingkun/SPARK-49729. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../spark/sql/catalyst/expressions/Between.scala | 2 +- .../sql/catalyst/expressions/mathExpressions.scala | 6 +++--- sql/gen-sql-functions-docs.py | 12 +++++++++++- .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala index deec1ab51ad98..c226e48c6be5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Between.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.internal.SQLConf // scalastyle:off line.size.limit @ExpressionDescription( - usage = "Usage: input [NOT] BETWEEN lower AND upper - evaluate if `input` is [not] in between `lower` and `upper`", + usage = "input [NOT] _FUNC_ lower AND upper - evaluate if `input` is [not] in between `lower` and `upper`", examples = """ Examples: > SELECT 0.5 _FUNC_ 0.1 AND 1.0; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 00274a16b888b..ddba820414ae4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1293,7 +1293,7 @@ sealed trait BitShiftOperation * @param right number of bits to left shift. */ @ExpressionDescription( - usage = "base << exp - Bitwise left shift.", + usage = "base _FUNC_ exp - Bitwise left shift.", examples = """ Examples: > SELECT shiftleft(2, 1); @@ -1322,7 +1322,7 @@ case class ShiftLeft(left: Expression, right: Expression) extends BitShiftOperat * @param right number of bits to right shift. */ @ExpressionDescription( - usage = "base >> expr - Bitwise (signed) right shift.", + usage = "base _FUNC_ expr - Bitwise (signed) right shift.", examples = """ Examples: > SELECT shiftright(4, 1); @@ -1350,7 +1350,7 @@ case class ShiftRight(left: Expression, right: Expression) extends BitShiftOpera * @param right the number of bits to right shift. */ @ExpressionDescription( - usage = "base >>> expr - Bitwise unsigned right shift.", + usage = "base _FUNC_ expr - Bitwise unsigned right shift.", examples = """ Examples: > SELECT shiftrightunsigned(4, 1); diff --git a/sql/gen-sql-functions-docs.py b/sql/gen-sql-functions-docs.py index bb813cffb0128..4be9966747d1f 100644 --- a/sql/gen-sql-functions-docs.py +++ b/sql/gen-sql-functions-docs.py @@ -39,6 +39,10 @@ } +def _print_red(text): + print('\033[31m' + text + '\033[0m') + + def _list_grouped_function_infos(jvm): """ Returns a list of function information grouped by each group value via JVM. @@ -126,7 +130,13 @@ def _make_pretty_usage(infos): func_name = "\\" + func_name elif (info.name == "when"): func_name = "CASE WHEN" - usages = iter(re.split(r"(.*%s.*) - " % func_name, info.usage.strip())[1:]) + expr_usages = re.split(r"(.*%s.*) - " % func_name, info.usage.strip()) + if len(expr_usages) <= 1: + _print_red("\nThe `usage` of %s is not standardized, please correct it. " + "Refer to: `AesDecrypt`" % (func_name)) + os._exit(-1) + usages = iter(expr_usages[1:]) + for (sig, description) in zip(usages, usages): result.append(" ") result.append(" %s" % sig) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 594c097de2c7d..14051034a588e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -246,7 +246,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi checkKeywordsExist(sql("describe function `between`"), "Function: between", - "Usage: input [NOT] BETWEEN lower AND upper - " + + "input [NOT] between lower AND upper - " + "evaluate if `input` is [not] in between `lower` and `upper`") checkKeywordsExist(sql("describe function `case`"), From 3d8c078ddefe3bb74fc78ffc9391a067156c8499 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 20 Sep 2024 08:44:14 -0700 Subject: [PATCH 035/250] [SPARK-49704][BUILD] Upgrade `commons-io` to 2.17.0 ### What changes were proposed in this pull request? This PR aims to upgrade `commons-io` from `2.16.1` to `2.17.0`. ### Why are the changes needed? The full release notes: https://commons.apache.org/proper/commons-io/changes-report.html#a2.17.0 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48154 from panbingkun/SPARK-49704. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 9871cc0bca04f..419625f48fa11 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -44,7 +44,7 @@ commons-compiler/3.1.9//commons-compiler-3.1.9.jar commons-compress/1.27.1//commons-compress-1.27.1.jar commons-crypto/1.1.0//commons-crypto-1.1.0.jar commons-dbcp/1.4//commons-dbcp-1.4.jar -commons-io/2.16.1//commons-io-2.16.1.jar +commons-io/2.17.0//commons-io-2.17.0.jar commons-lang/2.6//commons-lang-2.6.jar commons-lang3/3.17.0//commons-lang3-3.17.0.jar commons-math3/3.6.1//commons-math3-3.6.1.jar diff --git a/pom.xml b/pom.xml index ddabc82d2ad13..b7c87beec0f92 100644 --- a/pom.xml +++ b/pom.xml @@ -187,7 +187,7 @@ 3.0.3 1.17.1 1.27.1 - 2.16.1 + 2.17.0 2.6 From 22a7edce0a7c70d6c1a5dcf995c6c723f0c3352b Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 20 Sep 2024 08:53:52 -0700 Subject: [PATCH 036/250] [SPARK-49531][PYTHON][CONNECT] Support line plot with plotly backend ### What changes were proposed in this pull request? Support line plot with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations, such as line plots, by leveraging libraries like Plotly. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. ```python >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> sdf = spark.createDataFrame(data, columns) >>> sdf.show() +--------+-------+---------+ |category|int_val|float_val| +--------+-------+---------+ | A| 10| 1.5| | B| 30| 2.5| | C| 20| 3.5| +--------+-------+---------+ >>> f = sdf.plot(kind="line", x="category", y="int_val") >>> f.show() # see below >>> g = sdf.plot.line(x="category", y=["int_val", "float_val"]) >>> g.show() # see below ``` `f.show()`: ![newplot](https://github.com/user-attachments/assets/ebd50bbc-0dd1-437f-ae0c-0b4de8f3c722) `g.show()`: ![newplot (1)](https://github.com/user-attachments/assets/46d28840-a147-428f-8d88-d424aa76ad06) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48139 from xinrong-meng/plot_line_w_dep. Authored-by: Xinrong Meng Signed-off-by: Dongjoon Hyun --- .github/workflows/build_python_connect.yml | 2 +- dev/requirements.txt | 2 +- dev/sparktestsupport/modules.py | 4 + .../docs/source/getting_started/install.rst | 1 + python/packaging/classic/setup.py | 1 + python/packaging/connect/setup.py | 2 + python/pyspark/errors/error-conditions.json | 5 + python/pyspark/sql/classic/dataframe.py | 9 ++ python/pyspark/sql/connect/dataframe.py | 8 ++ python/pyspark/sql/dataframe.py | 28 ++++ python/pyspark/sql/plot/__init__.py | 21 +++ python/pyspark/sql/plot/core.py | 135 ++++++++++++++++++ python/pyspark/sql/plot/plotly.py | 30 ++++ .../tests/connect/test_parity_frame_plot.py | 36 +++++ .../connect/test_parity_frame_plot_plotly.py | 36 +++++ python/pyspark/sql/tests/plot/__init__.py | 16 +++ .../pyspark/sql/tests/plot/test_frame_plot.py | 80 +++++++++++ .../sql/tests/plot/test_frame_plot_plotly.py | 64 +++++++++ python/pyspark/sql/utils.py | 17 +++ python/pyspark/testing/sqlutils.py | 7 + .../apache/spark/sql/internal/SQLConf.scala | 27 ++++ 21 files changed, 529 insertions(+), 2 deletions(-) create mode 100644 python/pyspark/sql/plot/__init__.py create mode 100644 python/pyspark/sql/plot/core.py create mode 100644 python/pyspark/sql/plot/plotly.py create mode 100644 python/pyspark/sql/tests/connect/test_parity_frame_plot.py create mode 100644 python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py create mode 100644 python/pyspark/sql/tests/plot/__init__.py create mode 100644 python/pyspark/sql/tests/plot/test_frame_plot.py create mode 100644 python/pyspark/sql/tests/plot/test_frame_plot_plotly.py diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml index 3ac1a0117e41b..f668d813ef26e 100644 --- a/.github/workflows/build_python_connect.yml +++ b/.github/workflows/build_python_connect.yml @@ -71,7 +71,7 @@ jobs: python packaging/connect/setup.py sdist cd dist pip install pyspark*connect-*.tar.gz - pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting + pip install 'six==1.16.0' 'pandas<=2.2.2' scipy 'plotly>=4.8' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' torch torchvision torcheval deepspeed unittest-xml-reporting 'plotly>=4.8' - name: Run tests env: SPARK_TESTING: 1 diff --git a/dev/requirements.txt b/dev/requirements.txt index 5486c98ab8f8f..cafc73405aaa8 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -7,7 +7,7 @@ pyarrow>=10.0.0 six==1.16.0 pandas>=2.0.0 scipy -plotly +plotly>=4.8 mlflow>=2.3.1 scikit-learn matplotlib diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 34fbb8450d544..b9a4bed715f67 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -548,6 +548,8 @@ def __hash__(self): "pyspark.sql.tests.test_udtf", "pyspark.sql.tests.test_utils", "pyspark.sql.tests.test_resources", + "pyspark.sql.tests.plot.test_frame_plot", + "pyspark.sql.tests.plot.test_frame_plot_plotly", ], ) @@ -1051,6 +1053,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_parity_python_datasource", "pyspark.sql.tests.connect.test_parity_python_streaming_datasource", + "pyspark.sql.tests.connect.test_parity_frame_plot", + "pyspark.sql.tests.connect.test_parity_frame_plot_plotly", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_artifact_localcluster", diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index 549656bea103e..88c0a8c26cc94 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -183,6 +183,7 @@ Package Supported version Note Additional libraries that enhance functionality but are not included in the installation packages: - **memory-profiler**: Used for PySpark UDF memory profiling, ``spark.profile.show(...)`` and ``spark.sql.pyspark.udf.profiler``. +- **plotly**: Used for PySpark plotting, ``DataFrame.plot``. Note that PySpark requires Java 17 or later with ``JAVA_HOME`` properly set and refer to |downloading|_. diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 79b74483f00dd..17cca326d0241 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -288,6 +288,7 @@ def run(self): "pyspark.sql.connect.streaming.worker", "pyspark.sql.functions", "pyspark.sql.pandas", + "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.sql.worker", diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py index ab166c79747df..6ae16e9a9ad3a 100755 --- a/python/packaging/connect/setup.py +++ b/python/packaging/connect/setup.py @@ -77,6 +77,7 @@ "pyspark.sql.tests.connect.client", "pyspark.sql.tests.connect.shell", "pyspark.sql.tests.pandas", + "pyspark.sql.tests.plot", "pyspark.sql.tests.streaming", "pyspark.ml.tests.connect", "pyspark.pandas.tests", @@ -161,6 +162,7 @@ "pyspark.sql.connect.streaming.worker", "pyspark.sql.functions", "pyspark.sql.pandas", + "pyspark.sql.plot", "pyspark.sql.protobuf", "pyspark.sql.streaming", "pyspark.sql.worker", diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 4061d024a83cd..92aeb15e21d1b 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1088,6 +1088,11 @@ "Function `` should use only POSITIONAL or POSITIONAL OR KEYWORD arguments." ] }, + "UNSUPPORTED_PLOT_BACKEND": { + "message": [ + "`` is not supported, it should be one of the values from " + ] + }, "UNSUPPORTED_SIGNATURE": { "message": [ "Unsupported signature: ." diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 91b9591625904..a2778cbc32c4c 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -73,6 +73,11 @@ from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin +try: + from pyspark.sql.plot import PySparkPlotAccessor +except ImportError: + PySparkPlotAccessor = None # type: ignore + if TYPE_CHECKING: from py4j.java_gateway import JavaObject import pyarrow as pa @@ -1862,6 +1867,10 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: messageParameters={"member": "queryExecution"}, ) + @property + def plot(self) -> PySparkPlotAccessor: + return PySparkPlotAccessor(self) + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 768abd655d497..59d79decf6690 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -86,6 +86,10 @@ from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema from pyspark.sql.pandas.functions import _validate_pandas_udf # type: ignore[attr-defined] +try: + from pyspark.sql.plot import PySparkPlotAccessor +except ImportError: + PySparkPlotAccessor = None # type: ignore if TYPE_CHECKING: from pyspark.sql.connect._typing import ( @@ -2239,6 +2243,10 @@ def rdd(self) -> "RDD[Row]": def executionInfo(self) -> Optional["ExecutionInfo"]: return self._execution_info + @property + def plot(self) -> PySparkPlotAccessor: + return PySparkPlotAccessor(self) + class DataFrameNaFunctions(ParentDataFrameNaFunctions): def __init__(self, df: ParentDataFrame): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index ef35b73332572..2179a844b1e5e 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -43,6 +43,7 @@ from pyspark.sql.types import StructType, Row from pyspark.sql.utils import dispatch_df_method + if TYPE_CHECKING: from py4j.java_gateway import JavaObject import pyarrow as pa @@ -65,6 +66,7 @@ ArrowMapIterFunction, DataFrameLike as PandasDataFrameLike, ) + from pyspark.sql.plot import PySparkPlotAccessor from pyspark.sql.metrics import ExecutionInfo @@ -6394,6 +6396,32 @@ def executionInfo(self) -> Optional["ExecutionInfo"]: """ ... + @property + def plot(self) -> "PySparkPlotAccessor": + """ + Returns a :class:`PySparkPlotAccessor` for plotting functions. + + .. versionadded:: 4.0.0 + + Returns + ------- + :class:`PySparkPlotAccessor` + + Notes + ----- + This API is experimental. + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> type(df.plot) + + >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + ... + class DataFrameNaFunctions: """Functionality for working with missing data in :class:`DataFrame`. diff --git a/python/pyspark/sql/plot/__init__.py b/python/pyspark/sql/plot/__init__.py new file mode 100644 index 0000000000000..6da07061b2a09 --- /dev/null +++ b/python/pyspark/sql/plot/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This package includes the plotting APIs for PySpark DataFrame. +""" +from pyspark.sql.plot.core import * # noqa: F403, F401 diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py new file mode 100644 index 0000000000000..392ef73b38845 --- /dev/null +++ b/python/pyspark/sql/plot/core.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, TYPE_CHECKING, Optional, Union +from types import ModuleType +from pyspark.errors import PySparkRuntimeError, PySparkValueError +from pyspark.sql.utils import require_minimum_plotly_version + + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + import pandas as pd + from plotly.graph_objs import Figure + + +class PySparkTopNPlotBase: + def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame": + from pyspark.sql import SparkSession + + session = SparkSession.getActiveSession() + if session is None: + raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) + + max_rows = int( + session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + ) + pdf = sdf.limit(max_rows + 1).toPandas() + + self.partial = False + if len(pdf) > max_rows: + self.partial = True + pdf = pdf.iloc[:max_rows] + + return pdf + + +class PySparkSampledPlotBase: + def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": + from pyspark.sql import SparkSession + + session = SparkSession.getActiveSession() + if session is None: + raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) + + sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio") + max_rows = int( + session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] + ) + + if sample_ratio is None: + fraction = 1 / (sdf.count() / max_rows) + fraction = min(1.0, fraction) + else: + fraction = float(sample_ratio) + + sampled_sdf = sdf.sample(fraction=fraction) + pdf = sampled_sdf.toPandas() + + return pdf + + +class PySparkPlotAccessor: + plot_data_map = { + "line": PySparkSampledPlotBase().get_sampled, + } + _backends = {} # type: ignore[var-annotated] + + def __init__(self, data: "DataFrame"): + self.data = data + + def __call__( + self, kind: str = "line", backend: Optional[str] = None, **kwargs: Any + ) -> "Figure": + plot_backend = PySparkPlotAccessor._get_plot_backend(backend) + + return plot_backend.plot_pyspark(self.data, kind=kind, **kwargs) + + @staticmethod + def _get_plot_backend(backend: Optional[str] = None) -> ModuleType: + backend = backend or "plotly" + + if backend in PySparkPlotAccessor._backends: + return PySparkPlotAccessor._backends[backend] + + if backend == "plotly": + require_minimum_plotly_version() + else: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND", + messageParameters={"backend": backend, "supported_backends": ", ".join(["plotly"])}, + ) + from pyspark.sql.plot import plotly as module + + return module + + def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Plot DataFrame as lines. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.line(x="category", y="int_val") # doctest: +SKIP + >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + return self(kind="line", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py new file mode 100644 index 0000000000000..5efc19476057f --- /dev/null +++ b/python/pyspark/sql/plot/plotly.py @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import TYPE_CHECKING, Any + +from pyspark.sql.plot import PySparkPlotAccessor + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + from plotly.graph_objs import Figure + + +def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": + import plotly + + return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py new file mode 100644 index 0000000000000..c69e438bf7eb0 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.sql.tests.plot.test_frame_plot import DataFramePlotTestsMixin + + +class FramePlotParityTests(DataFramePlotTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_frame_plot import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py new file mode 100644 index 0000000000000..78508fe533379 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_frame_plot_plotly.py @@ -0,0 +1,36 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.sql.tests.plot.test_frame_plot_plotly import DataFramePlotPlotlyTestsMixin + + +class FramePlotPlotlyParityTests(DataFramePlotPlotlyTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_frame_plot_plotly import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/plot/__init__.py b/python/pyspark/sql/tests/plot/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/sql/tests/plot/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py new file mode 100644 index 0000000000000..f753b5ab3db72 --- /dev/null +++ b/python/pyspark/sql/tests/plot/test_frame_plot.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +from pyspark.errors import PySparkValueError +from pyspark.sql import Row +from pyspark.sql.plot import PySparkSampledPlotBase, PySparkTopNPlotBase +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message + + +@unittest.skipIf(not have_plotly, plotly_requirement_message) +class DataFramePlotTestsMixin: + def test_backend(self): + accessor = self.spark.range(2).plot + backend = accessor._get_plot_backend() + self.assertEqual(backend.__name__, "pyspark.sql.plot.plotly") + + with self.assertRaises(PySparkValueError) as pe: + accessor._get_plot_backend("matplotlib") + + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND", + messageParameters={"backend": "matplotlib", "supported_backends": "plotly"}, + ) + + def test_topn_max_rows(self): + try: + self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000") + sdf = self.spark.range(2500) + pdf = PySparkTopNPlotBase().get_top_n(sdf) + self.assertEqual(len(pdf), 1000) + finally: + self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows") + + def test_sampled_plot_with_ratio(self): + try: + self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", "0.5") + data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)] + sdf = self.spark.createDataFrame(data) + pdf = PySparkSampledPlotBase().get_sampled(sdf) + self.assertEqual(round(len(pdf) / 2500, 1), 0.5) + finally: + self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio") + + def test_sampled_plot_with_max_rows(self): + data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)] + sdf = self.spark.createDataFrame(data) + pdf = PySparkSampledPlotBase().get_sampled(sdf) + self.assertEqual(round(len(pdf) / 2000, 1), 0.5) + + +class DataFramePlotTests(DataFramePlotTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.plot.test_frame_plot import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py new file mode 100644 index 0000000000000..72a3ed267d192 --- /dev/null +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -0,0 +1,64 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import pyspark.sql.plot # noqa: F401 +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message + + +@unittest.skipIf(not have_plotly, plotly_requirement_message) +class DataFramePlotPlotlyTestsMixin: + @property + def sdf(self): + data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + columns = ["category", "int_val", "float_val"] + return self.spark.createDataFrame(data, columns) + + def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""): + self.assertEqual(fig_data["mode"], "lines") + self.assertEqual(fig_data["type"], "scatter") + self.assertEqual(fig_data["xaxis"], "x") + self.assertEqual(list(fig_data["x"]), expected_x) + self.assertEqual(fig_data["yaxis"], "y") + self.assertEqual(list(fig_data["y"]), expected_y) + self.assertEqual(fig_data["name"], expected_name) + + def test_line_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="line", x="category", y="int_val") + self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + + # multiple columns as vertical axis + fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"]) + self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + +class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.plot.test_frame_plot_plotly import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 11b91612419a3..5d9ec92cbc830 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -41,6 +41,7 @@ PythonException, UnknownException, SparkUpgradeException, + PySparkImportError, PySparkNotImplementedError, PySparkRuntimeError, ) @@ -115,6 +116,22 @@ def require_test_compiled() -> None: ) +def require_minimum_plotly_version() -> None: + """Raise ImportError if plotly is not installed""" + minimum_plotly_version = "4.8" + + try: + import plotly # noqa: F401 + except ImportError as error: + raise PySparkImportError( + errorClass="PACKAGE_NOT_INSTALLED", + messageParameters={ + "package_name": "plotly", + "minimum_version": str(minimum_plotly_version), + }, + ) from error + + class ForeachBatchFunction: """ This is the Python implementation of Java interface 'ForeachBatchFunction'. This wraps diff --git a/python/pyspark/testing/sqlutils.py b/python/pyspark/testing/sqlutils.py index 9f07c44c084cf..00ad40e68bd7c 100644 --- a/python/pyspark/testing/sqlutils.py +++ b/python/pyspark/testing/sqlutils.py @@ -48,6 +48,13 @@ except Exception as e: test_not_compiled_message = str(e) +plotly_requirement_message = None +try: + import plotly +except ImportError as e: + plotly_requirement_message = str(e) +have_plotly = plotly_requirement_message is None + from pyspark.sql import SparkSession from pyspark.sql.types import ArrayType, DoubleType, UserDefinedType, Row from pyspark.testing.utils import ReusedPySparkTestCase, PySparkErrorTestUtils diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2eaafde52228b..6c3e9bac1cfe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3169,6 +3169,29 @@ object SQLConf { .version("4.0.0") .fallbackConf(Python.PYTHON_WORKER_FAULTHANLDER_ENABLED) + val PYSPARK_PLOT_MAX_ROWS = + buildConf("spark.sql.pyspark.plotting.max_rows") + .doc( + "The visual limit on top-n-based plots. If set to 1000, the first 1000 data points " + + "will be used for plotting.") + .version("4.0.0") + .intConf + .createWithDefault(1000) + + val PYSPARK_PLOT_SAMPLE_RATIO = + buildConf("spark.sql.pyspark.plotting.sample_ratio") + .doc( + "The proportion of data that will be plotted for sample-based plots. It is determined " + + "based on spark.sql.pyspark.plotting.max_rows if not explicitly set." + ) + .version("4.0.0") + .doubleConf + .checkValue( + ratio => ratio >= 0.0 && ratio <= 1.0, + "The value should be between 0.0 and 1.0 inclusive." + ) + .createOptional + val ARROW_SPARKR_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.sparkr.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " + @@ -5873,6 +5896,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pythonUDFWorkerFaulthandlerEnabled: Boolean = getConf(PYTHON_UDF_WORKER_FAULTHANLDER_ENABLED) + def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS) + + def pysparkPlotSampleRatio: Option[Double] = getConf(PYSPARK_PLOT_SAMPLE_RATIO) + def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED) def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED) From f3785fadec3089fa60d85fa3c98ae9c6ada807a4 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Fri, 20 Sep 2024 19:12:05 +0200 Subject: [PATCH 037/250] [SPARK-49737][SQL] Disable bucketing on collated columns in complex types ### What changes were proposed in this pull request? To disable bucketing on collated string types in complex types (structs, arrays and maps). ### Why are the changes needed? #45260 introduces the logic to disabled bucketing for collated columns, but forgot to address complex types which have collated strings inside. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48186 from stefankandic/fixBucketing. Authored-by: Stefan Kandic Signed-off-by: Max Gekk --- .../datasources/BucketingUtils.scala | 8 +++---- .../org/apache/spark/sql/CollationSuite.scala | 23 ++++++++++++++----- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index 4fa1e0c1f2c58..fd47feef25d57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.util.SchemaUtils object BucketingUtils { // The file name of bucketed data should have 3 parts: @@ -53,10 +54,7 @@ object BucketingUtils { bucketIdGenerator(mutableInternalRow).getInt(0) } - def canBucketOn(dataType: DataType): Boolean = dataType match { - case st: StringType => st.supportsBinaryOrdering - case other => true - } + def canBucketOn(dataType: DataType): Boolean = !SchemaUtils.hasNonUTF8BinaryCollation(dataType) def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 73fd897e91f53..632b9305feb57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -162,9 +162,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { withTable(tableName) { sql( s""" - |CREATE TABLE $tableName - |(id INT, c1 STRING COLLATE UNICODE, c2 string) - |USING parquet + |CREATE TABLE $tableName ( + | id INT, + | c1 STRING COLLATE UNICODE, + | c2 STRING, + | struct_col STRUCT, + | array_col ARRAY, + | map_col MAP + |) USING parquet |CLUSTERED BY (${bucketColumns.mkString(",")}) |INTO 4 BUCKETS""".stripMargin ) @@ -175,14 +180,20 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { createTable("c2") createTable("id", "c2") - Seq(Seq("c1"), Seq("c1", "id"), Seq("c1", "c2")).foreach { bucketColumns => + val failBucketingColumns = Seq( + Seq("c1"), Seq("c1", "id"), Seq("c1", "c2"), + Seq("struct_col"), Seq("array_col"), Seq("map_col") + ) + + failBucketingColumns.foreach { bucketColumns => checkError( exception = intercept[AnalysisException] { createTable(bucketColumns: _*) }, condition = "INVALID_BUCKET_COLUMN_DATA_TYPE", - parameters = Map("type" -> "\"STRING COLLATE UNICODE\"") - ); + parameters = Map("type" -> ".*STRING COLLATE UNICODE.*"), + matchPVals = true + ) } } From f76a9b1135e748649bdb9a2104360f0dc533cc1f Mon Sep 17 00:00:00 2001 From: viktorluc-db Date: Fri, 20 Sep 2024 22:47:30 +0200 Subject: [PATCH 038/250] [SPARK-49738][SQL] Endswith bug fix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Bugfix in "endswith" string predicate. Also fixed the same type of the bug in `CollationAwareUTF8String.java` in method `lowercaseMatchLengthFrom`. ### Why are the changes needed? Expression `select endswith('İo' collate utf8_lcase, 'İo' collate utf8_lcase)` returns `false` but should return `true`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added tests in CollationSupportSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48187 from viktorluc-db/matchBugFix. Authored-by: viktorluc-db Signed-off-by: Max Gekk --- .../spark/sql/catalyst/util/CollationAwareUTF8String.java | 4 ++-- .../org/apache/spark/unsafe/types/CollationSupportSuite.java | 4 ++++ .../src/test/resources/sql-tests/results/collations.sql.out | 4 ++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 5ed3048fb72b3..fb610a5d96f17 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -109,7 +109,7 @@ private static int lowercaseMatchLengthFrom( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; @@ -211,7 +211,7 @@ private static int lowercaseMatchLengthUntil( } // Compare the characters in the target and pattern strings. int matchLength = 0, codePointBuffer = -1, targetCodePoint, patternCodePoint; - while (targetIterator.hasNext() && patternIterator.hasNext()) { + while ((targetIterator.hasNext() || codePointBuffer != -1) && patternIterator.hasNext()) { if (codePointBuffer != -1) { targetCodePoint = codePointBuffer; codePointBuffer = -1; diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 5719303a0dce8..a445cde52ad57 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -629,6 +629,8 @@ public void testStartsWith() throws SparkException { assertStartsWith("İonic", "Io", "UTF8_LCASE", false); assertStartsWith("İonic", "i\u0307o", "UTF8_LCASE", true); assertStartsWith("İonic", "İo", "UTF8_LCASE", true); + assertStartsWith("oİ", "oİ", "UTF8_LCASE", true); + assertStartsWith("oİ", "oi̇", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertStartsWith("σ", "σ", "UTF8_BINARY", true); assertStartsWith("σ", "ς", "UTF8_BINARY", false); @@ -880,6 +882,8 @@ public void testEndsWith() throws SparkException { assertEndsWith("the İo", "Io", "UTF8_LCASE", false); assertEndsWith("the İo", "i\u0307o", "UTF8_LCASE", true); assertEndsWith("the İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "İo", "UTF8_LCASE", true); + assertEndsWith("İo", "i̇o", "UTF8_LCASE", true); // Conditional case mapping (e.g. Greek sigmas). assertEndsWith("σ", "σ", "UTF8_BINARY", true); assertEndsWith("σ", "ς", "UTF8_BINARY", false); diff --git a/sql/core/src/test/resources/sql-tests/results/collations.sql.out b/sql/core/src/test/resources/sql-tests/results/collations.sql.out index 5999bf20f6884..9d29a46e5a0ef 100644 --- a/sql/core/src/test/resources/sql-tests/results/collations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/collations.sql.out @@ -2213,8 +2213,8 @@ struct Date: Fri, 20 Sep 2024 15:34:17 -0700 Subject: [PATCH 039/250] [SPARK-49557][SQL] Add SQL pipe syntax for the WHERE operator ### What changes were proposed in this pull request? This PR adds SQL pipe syntax support for the WHERE operator. For example: ``` CREATE TABLE t(x INT, y STRING) USING CSV; INSERT INTO t VALUES (0, 'abc'), (1, 'def'); CREATE TABLE other(a INT, b INT) USING JSON; INSERT INTO other VALUES (1, 1), (1, 2), (2, 4); TABLE t |> WHERE x + LENGTH(y) < 4; 0 abc TABLE t |> WHERE (SELECT ANY_VALUE(a) FROM other WHERE x = a LIMIT 1) = 1 1 def TABLE t |> WHERE SUM(x) = 1 Error: aggregate functions are not allowed in the pipe operator |> WHERE clause ``` ### Why are the changes needed? The SQL pipe operator syntax will let users compose queries in a more flexible fashion. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds a few unit test cases, but mostly relies on golden file test coverage. I did this to make sure the answers are correct as this feature is implemented and also so we can look at the analyzer output plans to ensure they look right as well. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48091 from dtenedor/pipe-where. Authored-by: Daniel Tenedorio Signed-off-by: Gengliang Wang --- .../sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../sql/catalyst/parser/AstBuilder.scala | 15 +- .../analyzer-results/pipe-operators.sql.out | 272 ++++++++++++++++++ .../sql-tests/inputs/pipe-operators.sql | 94 +++++- .../sql-tests/results/pipe-operators.sql.out | 268 +++++++++++++++++ .../sql/execution/SparkSqlParserSuite.scala | 12 +- 6 files changed, 658 insertions(+), 4 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index e591a43b84d1a..094f7f5315b80 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1492,6 +1492,7 @@ version operatorPipeRightSide : selectClause + | whereClause ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 52529bb4b789b..674005caaf1b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5876,7 +5876,20 @@ class AstBuilder extends DataTypeAstBuilder windowClause = null, relation = left, isPipeOperatorSelect = true) - }.get + }.getOrElse(Option(ctx.whereClause).map { c => + // Add a table subquery boundary between the new filter and the input plan if one does not + // already exist. This helps the analyzer behave as if we had added the WHERE clause after a + // table subquery containing the input plan. + val withSubqueryAlias = left match { + case s: SubqueryAlias => + s + case u: UnresolvedRelation => + u + case _ => + SubqueryAlias(SubqueryAlias.generateSubqueryName(), left) + } + withWhereClause(c, withSubqueryAlias) + }.get) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index ab0635fef048b..c44ce153a2f41 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -255,6 +255,55 @@ Distinct +- Relation spark_catalog.default.t[x#x,y#x] csv +-- !query +table t +|> select * +-- !query analysis +Project [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select * except (y) +-- !query analysis +Project [x#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query analysis +Repartition 3, true ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query analysis +Repartition 3, true ++- Distinct + +- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query analysis +Repartition 3, true ++- Project [x#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + -- !query table t |> select sum(x) as result @@ -297,6 +346,229 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query analysis +Filter true ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +-- !query analysis +Filter ((x#x + length(y#x)) < 4) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query analysis +Filter ((x#x + length(y#x)) < 3) ++- SubqueryAlias __auto_generated_subquery_name + +- Filter ((x#x + length(y#x)) < 4) + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Aggregate [x#x], [x#x, sum(length(y#x)) AS sum_len#xL] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query analysis +Filter (x#x = 1) ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query analysis +Filter (col#x.i1 = 1) ++- SubqueryAlias __auto_generated_subquery_name + +- Project [col#x] + +- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query analysis +Filter (col#x.i1 = 2) ++- SubqueryAlias spark_catalog.default.st + +- Relation spark_catalog.default.st[x#x,col#x] parquet + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query analysis +Filter exists#x [x#x] +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Project [a#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query analysis +Filter (scalar-subquery#x [x#x] = 1) +: +- GlobalLimit 1 +: +- LocalLimit 1 +: +- Aggregate [any_value(a#x, false) AS any_value(a)#x] +: +- Filter (outer(x#x) = a#x) +: +- SubqueryAlias spark_catalog.default.other +: +- Relation spark_catalog.default.other[a#x,b#x] json ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> where sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 7d0966e7f2095..49a72137ee047 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -12,7 +12,7 @@ drop table if exists st; create table st(x int, col struct) using parquet; insert into st values (1, (2, 3)); --- Selection operators: positive tests. +-- SELECT operators: positive tests. --------------------------------------- -- Selecting a constant. @@ -85,7 +85,24 @@ table t table t |> select distinct x, y; --- Selection operators: negative tests. +-- SELECT * is supported. +table t +|> select *; + +table t +|> select * except (y); + +-- Hints are supported. +table t +|> select /*+ repartition(3) */ *; + +table t +|> select /*+ repartition(3) */ distinct x; + +table t +|> select /*+ repartition(3) */ all x; + +-- SELECT operators: negative tests. --------------------------------------- -- Aggregate functions are not allowed in the pipe operator SELECT list. @@ -95,6 +112,79 @@ table t table t |> select y, length(y) + sum(x) as result; +-- WHERE operators: positive tests. +----------------------------------- + +-- Filtering with a constant predicate. +table t +|> where true; + +-- Filtering with a predicate based on attributes from the input relation. +table t +|> where x + length(y) < 4; + +-- Two consecutive filters are allowed. +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3; + +-- It is possible to use the WHERE operator instead of the HAVING clause when processing the result +-- of aggregations. For example, this WHERE operator is equivalent to the normal SQL "HAVING x = 1". +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1; + +-- Filtering by referring to the table or table subquery alias. +table t +|> where t.x = 1; + +table t +|> where spark_catalog.default.t.x = 1; + +-- Filtering using struct fields. +(select col from st) +|> where col.i1 = 1; + +table st +|> where st.col.i1 = 2; + +-- Expression subqueries in the WHERE clause. +table t +|> where exists (select a from other where x = a limit 1); + +-- Aggregations are allowed within expression subqueries in the pipe operator WHERE clause as long +-- no aggregate functions exist in the top-level expression predicate. +table t +|> where (select any_value(a) from other where x = a limit 1) = 1; + +-- WHERE operators: negative tests. +----------------------------------- + +-- Aggregate functions are not allowed in the top-level WHERE predicate. +-- (Note: to implement this behavior, perform the aggregation first separately and then add a +-- pipe-operator WHERE clause referring to the result of aggregate expression(s) therein). +table t +|> where sum(x) = 1; + +table t +|> where y = 'abc' or length(y) + sum(x) = 1; + +-- Window functions are not allowed in the WHERE clause (pipe operators or otherwise). +table t +|> where first_value(x) over (partition by y) = 1; + +select * from t where first_value(x) over (partition by y) = 1; + +-- Pipe operators may only refer to attributes produced as output from the directly-preceding +-- pipe operator, not from earlier ones. +table t +|> select x, length(y) as z +|> where x + length(y) < 4; + +-- If the WHERE clause wants to filter rows produced by an aggregation, it is not valid to try to +-- refer to the aggregate functions directly; it is necessary to use aliases instead. +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3; + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 7e0b7912105c2..38436b0941034 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -238,6 +238,56 @@ struct 1 def +-- !query +table t +|> select * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select * except (y) +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ * +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> select /*+ repartition(3) */ distinct x +-- !query schema +struct +-- !query output +0 +1 + + +-- !query +table t +|> select /*+ repartition(3) */ all x +-- !query schema +struct +-- !query output +0 +1 + + -- !query table t |> select sum(x) as result @@ -284,6 +334,224 @@ org.apache.spark.sql.AnalysisException } +-- !query +table t +|> where true +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> where x + length(y) < 4 +-- !query schema +struct +-- !query output +0 abc + + +-- !query +table t +|> where x + length(y) < 4 +|> where x + length(y) < 3 +-- !query schema +struct +-- !query output + + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where x = 1 +-- !query schema +struct +-- !query output +1 3 + + +-- !query +table t +|> where t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where spark_catalog.default.t.x = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +(select col from st) +|> where col.i1 = 1 +-- !query schema +struct> +-- !query output + + + +-- !query +table st +|> where st.col.i1 = 2 +-- !query schema +struct> +-- !query output +1 {"i1":2,"i2":3} + + +-- !query +table t +|> where exists (select a from other where x = a limit 1) +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where (select any_value(a) from other where x = a limit 1) = 1 +-- !query schema +struct +-- !query output +1 def + + +-- !query +table t +|> where sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"(sum(x) = 1)\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 27, + "fragment" : "table t\n|> where sum(x) = 1" + } ] +} + + +-- !query +table t +|> where y = 'abc' or length(y) + sum(x) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_WHERE_CONDITION", + "sqlState" : "42903", + "messageParameters" : { + "condition" : "\"((y = abc) OR ((length(y) + sum(x)) = 1))\"", + "expressionList" : "sum(spark_catalog.default.t.x)" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 52, + "fragment" : "table t\n|> where y = 'abc' or length(y) + sum(x) = 1" + } ] +} + + +-- !query +table t +|> where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +select * from t where first_value(x) over (partition by y) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_1034", + "messageParameters" : { + "clauseName" : "WHERE" + } +} + + +-- !query +table t +|> select x, length(y) as z +|> where x + length(y) < 4 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `z`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 57, + "stopIndex" : 57, + "fragment" : "y" + } ] +} + + +-- !query +(select x, sum(length(y)) as sum_len from t group by x) +|> where sum(length(y)) = 3 +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`y`", + "proposal" : "`x`, `sum_len`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 77, + "stopIndex" : 77, + "fragment" : "y" + } ] +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a80444feb68ae..ab949c5a21e44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} +import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -895,6 +895,16 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { checkPipeSelect("TABLE t |> SELECT 1 AS X") checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") + // Basic WHERE operators. + def checkPipeWhere(query: String): Unit = { + val plan: LogicalPlan = parser.parsePlan(query) + assert(plan.containsPattern(FILTER)) + assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) + } + checkPipeWhere("TABLE t |> WHERE X = 1") + checkPipeWhere("TABLE t |> SELECT X, LENGTH(Y) AS Z |> WHERE X + LENGTH(Y) < 4") + checkPipeWhere("TABLE t |> WHERE X = 1 AND Y = 2 |> WHERE X + Y = 3") + checkPipeWhere("VALUES (0), (1) tab(col) |> WHERE col < 1") } } } From 70bd606cc865c3d27808eacad85fcf878c23e3a1 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 21 Sep 2024 11:50:30 +0900 Subject: [PATCH 040/250] [SPARK-49641][DOCS] Include `table_funcs` and `variant_funcs` in the built-in function list doc ### What changes were proposed in this pull request? The pr aims to include `table_funcs` and `variant_funcs` in the built-in function list doc. ### Why are the changes needed? I found that some functions were not involved in our docs, such as `sql_keywords()`, `variant_explode`, etc. Let's include them to improve the user experience for end-users. ### Does this PR introduce _any_ user-facing change? Yes, only for sql api docs. ### How was this patch tested? - Pass GA - Manually check. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48106 from panbingkun/SPARK-49641. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- docs/sql-ref-functions-builtin.md | 10 ++++++++++ .../plans/logical/basicLogicalOperators.scala | 15 +++++++++++---- .../spark/sql/api/python/PythonSQLUtils.scala | 7 +++++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 5 ++++- sql/gen-sql-functions-docs.py | 1 + 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/docs/sql-ref-functions-builtin.md b/docs/sql-ref-functions-builtin.md index c5f4e44dec0d9..b6572609a34b8 100644 --- a/docs/sql-ref-functions-builtin.md +++ b/docs/sql-ref-functions-builtin.md @@ -116,3 +116,13 @@ license: | {% include_api_gen generated-generator-funcs-table.html %} #### Examples {% include_api_gen generated-generator-funcs-examples.html %} + +### Table Functions +{% include_api_gen generated-table-funcs-table.html %} +#### Examples +{% include_api_gen generated-table-funcs-examples.html %} + +### Variant Functions +{% include_api_gen generated-variant-funcs-table.html %} +#### Examples +{% include_api_gen generated-variant-funcs-examples.html %} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 926027df4c74b..90af6333b2e0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -992,12 +992,18 @@ object Range { castAndEval[Int](expression, IntegerType, paramIndex, paramName) } +// scalastyle:off line.size.limit @ExpressionDescription( usage = """ - _FUNC_(start: long, end: long, step: long, numSlices: integer) - _FUNC_(start: long, end: long, step: long) - _FUNC_(start: long, end: long) - _FUNC_(end: long)""", + _FUNC_(start[, end[, step[, numSlices]]]) / _FUNC_(end) - Returns a table of values within a specified range. + """, + arguments = """ + Arguments: + * start - An optional BIGINT literal defaulted to 0, marking the first value generated. + * end - A BIGINT literal marking endpoint (exclusive) of the number generation. + * step - An optional BIGINT literal defaulted to 1, specifying the increment used when generating values. + * numParts - An optional INTEGER literal specifying how the production of rows is spread across partitions. + """, examples = """ Examples: > SELECT * FROM _FUNC_(1); @@ -1023,6 +1029,7 @@ object Range { """, since = "2.0.0", group = "table_funcs") +// scalastyle:on line.size.limit case class Range( start: Long, end: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 93082740cca64..bc270e6ac64ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -69,7 +69,10 @@ private[sql] object PythonSQLUtils extends Logging { // This is needed when generating SQL documentation for built-in functions. def listBuiltinFunctionInfos(): Array[ExpressionInfo] = { - FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray + (FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)) ++ + TableFunctionRegistry.functionSet.flatMap( + f => TableFunctionRegistry.builtin.lookupFunction(f))). + groupBy(_.getName).map(v => v._2.head).toArray } private def listAllSQLConfigs(): Seq[(String, String, String, String)] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index ce88f7dc475d6..8176d02dbd02d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -111,10 +111,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("SPARK-34678: describe functions for table-valued functions") { + sql("describe function range").show(false) checkKeywordsExist(sql("describe function range"), "Function: range", "Class: org.apache.spark.sql.catalyst.plans.logical.Range", - "range(end: long)" + "range(start[, end[, step[, numSlices]]])", + "range(end)", + "Returns a table of values within a specified range." ) } diff --git a/sql/gen-sql-functions-docs.py b/sql/gen-sql-functions-docs.py index 4be9966747d1f..a1facbaaf7e3b 100644 --- a/sql/gen-sql-functions-docs.py +++ b/sql/gen-sql-functions-docs.py @@ -36,6 +36,7 @@ "bitwise_funcs", "conversion_funcs", "csv_funcs", "xml_funcs", "lambda_funcs", "collection_funcs", "url_funcs", "hash_funcs", "struct_funcs", + "table_funcs", "variant_funcs" } From f235bab24761d8049e3d74411c19ddf3e3b5a697 Mon Sep 17 00:00:00 2001 From: Harsh Motwani Date: Sat, 21 Sep 2024 12:27:05 +0900 Subject: [PATCH 041/250] [SPARK-49451][FOLLOW-UP] Add support for duplicate keys in from_json(_, 'variant') ### What changes were proposed in this pull request? This PR adds support for duplicate key support in the `from_json(_, 'variant')` query pattern. Duplicate key support [has been introduced](https://github.com/apache/spark/pull/47920) in `parse_json`, json scans and the `from_json` expressions with nested schemas but this code path was not updated. ### Why are the changes needed? This change makes the behavior of `from_json(_, 'variant')` consistent with every other variant construction expression. ### Does this PR introduce _any_ user-facing change? It potentially allows users to use the `from_json(, 'variant')` expression on json inputs with duplicate keys depending on a config. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48177 from harshmotw-db/harshmotw-db/master. Authored-by: Harsh Motwani Signed-off-by: Hyukjin Kwon --- .../expressions/jsonExpressions.scala | 12 +++++-- .../function_from_json.explain | 2 +- .../function_from_json_orphaned.explain | 2 +- ...unction_from_json_with_json_schema.explain | 2 +- .../analyzer-results/ansi/date.sql.out | 2 +- .../analyzer-results/ansi/interval.sql.out | 6 ++-- .../ansi/parse-schema-string.sql.out | 4 +-- .../analyzer-results/ansi/timestamp.sql.out | 2 +- .../sql-tests/analyzer-results/date.sql.out | 2 +- .../analyzer-results/datetime-legacy.sql.out | 4 +-- .../analyzer-results/interval.sql.out | 6 ++-- .../analyzer-results/json-functions.sql.out | 34 +++++++++---------- .../parse-schema-string.sql.out | 4 +-- .../sql-session-variables.sql.out | 2 +- .../subexp-elimination.sql.out | 10 +++--- .../analyzer-results/timestamp.sql.out | 2 +- .../timestampNTZ/timestamp-ansi.sql.out | 2 +- .../timestampNTZ/timestamp.sql.out | 2 +- .../native/stringCastAndExpressions.sql.out | 2 +- .../spark/sql/VariantEndToEndSuite.scala | 32 +++++++++++++++++ 20 files changed, 87 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 574a61cf9c903..2037eb22fede6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -632,7 +632,8 @@ case class JsonToStructs( schema: DataType, options: Map[String, String], child: Expression, - timeZoneId: Option[String] = None) + timeZoneId: Option[String] = None, + variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback @@ -719,7 +720,8 @@ case class JsonToStructs( override def nullSafeEval(json: Any): Any = nullableSchema match { case _: VariantType => - VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String]) + VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String], + allowDuplicateKeys = variantAllowDuplicateKeys) case _ => converter(parser.parse(json.asInstanceOf[UTF8String])) } @@ -737,6 +739,12 @@ case class JsonToStructs( copy(child = newChild) } +object JsonToStructs { + def unapply( + j: JsonToStructs): Option[(DataType, Map[String, String], Expression, Option[String])] = + Some((j.schema, j.options, j.child, j.timeZoneId)) +} + /** * Converts a [[StructType]], [[ArrayType]] or [[MapType]] to a JSON output string. */ diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain index 1219f11d4696e..8d1d122d156ff 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles)) AS from_json(g)#0] +Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain index 1219f11d4696e..8d1d122d156ff 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_orphaned.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles)) AS from_json(g)#0] +Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain index 1219f11d4696e..8d1d122d156ff 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_from_json_with_json_schema.explain @@ -1,2 +1,2 @@ -Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles)) AS from_json(g)#0] +Project [from_json(StructField(id,LongType,true), StructField(a,IntegerType,true), StructField(b,DoubleType,true), g#0, Some(America/Los_Angeles), false) AS from_json(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out index fd927b99c6456..0e4d2d4e99e26 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/date.sql.out @@ -736,7 +736,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out index 472c9b1df064a..b0d128c4cab69 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/interval.sql.out @@ -2108,7 +2108,7 @@ SELECT to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month') -- !query analysis -Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +- OneRowRelation @@ -2119,7 +2119,7 @@ SELECT to_json(map('a', interval 100 day 130 minute)), from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute') -- !query analysis -Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +- OneRowRelation @@ -2130,7 +2130,7 @@ SELECT to_json(map('a', interval 32 year 10 month)), from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month') -- !query analysis -Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out index 45fc3bd03a782..ae8e47ed3665c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/parse-schema-string.sql.out @@ -16,12 +16,12 @@ Project [from_csv(StructField(cube,IntegerType,true), 1, Some(America/Los_Angele -- !query select from_json('{"create":1}', 'create INT') -- !query analysis -Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles)) AS from_json({"create":1})#x] +Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles), false) AS from_json({"create":1})#x] +- OneRowRelation -- !query select from_json('{"cube":1}', 'cube INT') -- !query analysis -Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles)) AS from_json({"cube":1})#x] +Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles), false) AS from_json({"cube":1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out index bf34490d657e3..560974d28c545 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/ansi/timestamp.sql.out @@ -730,7 +730,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out index 48137e06467e8..88c7d7b4e7d72 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/date.sql.out @@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out index 1e49f4df8267a..4221db822d024 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/datetime-legacy.sql.out @@ -811,7 +811,7 @@ Project [to_date(26/October/2015, Some(dd/MMMMM/yyyy), Some(America/Los_Angeles) -- !query select from_json('{"d":"26/October/2015"}', 'd Date', map('dateFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"d":"26/October/2015"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,dd/MMMMM/yyyy), {"d":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"d":"26/October/2015"})#x] +- OneRowRelation @@ -1833,7 +1833,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out index 3db38d482b26d..efa149509751d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out @@ -2108,7 +2108,7 @@ SELECT to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), from_csv(to_csv(named_struct('a', interval 32 year, 'b', interval 10 month)), 'a interval year, b interval month') -- !query analysis -Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +Project [from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false) AS from_json({"a":"1 days"})#x, from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None) AS from_csv(1, 1)#x, to_json(from_json(StructField(a,CalendarIntervalType,true), {"a":"1 days"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1 days"}))#x, to_csv(from_csv(StructField(a,IntegerType,true), StructField(b,YearMonthIntervalType(0,0),true), 1, 1, Some(America/Los_Angeles), None), Some(America/Los_Angeles)) AS to_csv(from_csv(1, 1))#x, to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)) AS to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH))#x, from_csv(StructField(a,YearMonthIntervalType(0,0),true), StructField(b,YearMonthIntervalType(1,1),true), to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), None) AS from_csv(to_csv(named_struct(a, INTERVAL '32' YEAR, b, INTERVAL '10' MONTH)))#x] +- OneRowRelation @@ -2119,7 +2119,7 @@ SELECT to_json(map('a', interval 100 day 130 minute)), from_json(to_json(map('a', interval 100 day 130 minute)), 'a interval day to minute') -- !query analysis -Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +Project [from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,DayTimeIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE))#x, from_json(StructField(a,DayTimeIntervalType(0,2),true), to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '100 02:10' DAY TO MINUTE)))#x] +- OneRowRelation @@ -2130,7 +2130,7 @@ SELECT to_json(map('a', interval 32 year 10 month)), from_json(to_json(map('a', interval 32 year 10 month)), 'a interval year to month') -- !query analysis -Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles)) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +Project [from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false) AS from_json({"a":"1"})#x, to_json(from_json(StructField(a,YearMonthIntervalType(0,0),true), {"a":"1"}, Some(America/Los_Angeles), false), Some(America/Los_Angeles)) AS to_json(from_json({"a":"1"}))#x, to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)) AS to_json(map(a, INTERVAL '32-10' YEAR TO MONTH))#x, from_json(StructField(a,YearMonthIntervalType(0,1),true), to_json(map(a, INTERVAL '32-10' YEAR TO MONTH), Some(America/Los_Angeles)), Some(America/Los_Angeles), false) AS from_json(to_json(map(a, INTERVAL '32-10' YEAR TO MONTH)))#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out index 0d7c6b2056231..fef9d0c5b6250 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/json-functions.sql.out @@ -118,14 +118,14 @@ org.apache.spark.sql.AnalysisException -- !query select from_json('{"a":1}', 'a INT') -- !query analysis -Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles)) AS from_json({"a":1})#x] +Project [from_json(StructField(a,IntegerType,true), {"a":1}, Some(America/Los_Angeles), false) AS from_json({"a":1})#x] +- OneRowRelation -- !query select from_json('{"time":"26/08/2015"}', 'time Timestamp', map('timestampFormat', 'dd/MM/yyyy')) -- !query analysis -Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles)) AS from_json({"time":"26/08/2015"})#x] +Project [from_json(StructField(time,TimestampType,true), (timestampFormat,dd/MM/yyyy), {"time":"26/08/2015"}, Some(America/Los_Angeles), false) AS from_json({"time":"26/08/2015"})#x] +- OneRowRelation @@ -279,14 +279,14 @@ DropTempViewCommand jsonTable -- !query select from_json('{"a":1, "b":2}', 'map') -- !query analysis -Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles)) AS entries#x] +Project [from_json(MapType(StringType,IntegerType,true), {"a":1, "b":2}, Some(America/Los_Angeles), false) AS entries#x] +- OneRowRelation -- !query select from_json('{"a":1, "b":"2"}', 'struct') -- !query analysis -Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles)) AS from_json({"a":1, "b":"2"})#x] +Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), {"a":1, "b":"2"}, Some(America/Los_Angeles), false) AS from_json({"a":1, "b":"2"})#x] +- OneRowRelation @@ -300,70 +300,70 @@ Project [schema_of_json({"c1":0, "c2":[1]}) AS schema_of_json({"c1":0, "c2":[1]} -- !query select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) -- !query analysis -Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles)) AS from_json({"c1":[1, 2, 3]})#x] +Project [from_json(StructField(c1,ArrayType(LongType,true),true), {"c1":[1, 2, 3]}, Some(America/Los_Angeles), false) AS from_json({"c1":[1, 2, 3]})#x] +- OneRowRelation -- !query select from_json('[1, 2, 3]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles)) AS from_json([1, 2, 3])#x] +Project [from_json(ArrayType(IntegerType,true), [1, 2, 3], Some(America/Los_Angeles), false) AS from_json([1, 2, 3])#x] +- OneRowRelation -- !query select from_json('[1, "2", 3]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles)) AS from_json([1, "2", 3])#x] +Project [from_json(ArrayType(IntegerType,true), [1, "2", 3], Some(America/Los_Angeles), false) AS from_json([1, "2", 3])#x] +- OneRowRelation -- !query select from_json('[1, 2, null]', 'array') -- !query analysis -Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles)) AS from_json([1, 2, null])#x] +Project [from_json(ArrayType(IntegerType,true), [1, 2, null], Some(America/Los_Angeles), false) AS from_json([1, 2, null])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, {"a":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles)) AS from_json([{"a": 1}, {"a":2}])#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [{"a": 1}, {"a":2}], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, {"a":2}])#x] +- OneRowRelation -- !query select from_json('{"a": 1}', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), {"a": 1}, Some(America/Los_Angeles), false) AS from_json({"a": 1})#x] +- OneRowRelation -- !query select from_json('[null, {"a":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles)) AS from_json([null, {"a":2}])#x] +Project [from_json(ArrayType(StructType(StructField(a,IntegerType,true)),true), [null, {"a":2}], Some(America/Los_Angeles), false) AS from_json([null, {"a":2}])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, {"b":2}]', 'array>') -- !query analysis -Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, {"b":2}], Some(America/Los_Angeles)) AS from_json([{"a": 1}, {"b":2}])#x] +Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, {"b":2}], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, {"b":2}])#x] +- OneRowRelation -- !query select from_json('[{"a": 1}, 2]', 'array>') -- !query analysis -Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, 2], Some(America/Los_Angeles)) AS from_json([{"a": 1}, 2])#x] +Project [from_json(ArrayType(MapType(StringType,IntegerType,true),true), [{"a": 1}, 2], Some(America/Los_Angeles), false) AS from_json([{"a": 1}, 2])#x] +- OneRowRelation -- !query select from_json('{"d": "2012-12-15", "t": "2012-12-15 15:15:15"}', 'd date, t timestamp') -- !query analysis -Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), {"d": "2012-12-15", "t": "2012-12-15 15:15:15"}, Some(America/Los_Angeles)) AS from_json({"d": "2012-12-15", "t": "2012-12-15 15:15:15"})#x] +Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), {"d": "2012-12-15", "t": "2012-12-15 15:15:15"}, Some(America/Los_Angeles), false) AS from_json({"d": "2012-12-15", "t": "2012-12-15 15:15:15"})#x] +- OneRowRelation @@ -373,7 +373,7 @@ select from_json( 'd date, t timestamp', map('dateFormat', 'MM/dd yyyy', 'timestampFormat', 'MM/dd yyyy HH:mm:ss')) -- !query analysis -Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), (dateFormat,MM/dd yyyy), (timestampFormat,MM/dd yyyy HH:mm:ss), {"d": "12/15 2012", "t": "12/15 2012 15:15:15"}, Some(America/Los_Angeles)) AS from_json({"d": "12/15 2012", "t": "12/15 2012 15:15:15"})#x] +Project [from_json(StructField(d,DateType,true), StructField(t,TimestampType,true), (dateFormat,MM/dd yyyy), (timestampFormat,MM/dd yyyy HH:mm:ss), {"d": "12/15 2012", "t": "12/15 2012 15:15:15"}, Some(America/Los_Angeles), false) AS from_json({"d": "12/15 2012", "t": "12/15 2012 15:15:15"})#x] +- OneRowRelation @@ -383,7 +383,7 @@ select from_json( 'd date', map('dateFormat', 'MM-dd')) -- !query analysis -Project [from_json(StructField(d,DateType,true), (dateFormat,MM-dd), {"d": "02-29"}, Some(America/Los_Angeles)) AS from_json({"d": "02-29"})#x] +Project [from_json(StructField(d,DateType,true), (dateFormat,MM-dd), {"d": "02-29"}, Some(America/Los_Angeles), false) AS from_json({"d": "02-29"})#x] +- OneRowRelation @@ -393,7 +393,7 @@ select from_json( 't timestamp', map('timestampFormat', 'MM-dd')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,MM-dd), {"t": "02-29"}, Some(America/Los_Angeles)) AS from_json({"t": "02-29"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,MM-dd), {"t": "02-29"}, Some(America/Los_Angeles), false) AS from_json({"t": "02-29"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out index 45fc3bd03a782..ae8e47ed3665c 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/parse-schema-string.sql.out @@ -16,12 +16,12 @@ Project [from_csv(StructField(cube,IntegerType,true), 1, Some(America/Los_Angele -- !query select from_json('{"create":1}', 'create INT') -- !query analysis -Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles)) AS from_json({"create":1})#x] +Project [from_json(StructField(create,IntegerType,true), {"create":1}, Some(America/Los_Angeles), false) AS from_json({"create":1})#x] +- OneRowRelation -- !query select from_json('{"cube":1}', 'cube INT') -- !query analysis -Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles)) AS from_json({"cube":1})#x] +Project [from_json(StructField(cube,IntegerType,true), {"cube":1}, Some(America/Los_Angeles), false) AS from_json({"cube":1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out index a4e40f08b4463..02e7c39ae83fd 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -2147,7 +2147,7 @@ CreateVariable defaultvalueexpression(cast(a INT as string), 'a INT'), true -- !query SELECT from_json('{"a": 1}', var1) -- !query analysis -Project [from_json(StructField(a,IntegerType,true), {"a": 1}, Some(America/Los_Angeles)) AS from_json({"a": 1})#x] +Project [from_json(StructField(a,IntegerType,true), {"a": 1}, Some(America/Los_Angeles), false) AS from_json({"a": 1})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out index 94073f2751b3e..754b05bfa6fed 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subexp-elimination.sql.out @@ -15,7 +15,7 @@ AS testData(a, b), false, true, LocalTempView, UNSUPPORTED, true -- !query SELECT from_json(a, 'struct').a, from_json(a, 'struct').b, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].b FROM testData -- !query analysis -Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a AS from_json(a).a#x, from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b AS from_json(a).b#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a AS from_json(b)[0].a#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b AS from_json(b)[0].b#x] +Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a AS from_json(a).a#x, from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b AS from_json(a).b#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a AS from_json(b)[0].a#x, from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b AS from_json(b)[0].b#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -27,7 +27,7 @@ Project [from_json(StructField(a,IntegerType,true), StructField(b,StringType,tru -- !query SELECT if(from_json(a, 'struct').a > 1, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].a + 1) FROM testData -- !query analysis -Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 1)) from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a else (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].a + 1) AS (IF((from_json(a).a > 1), from_json(b)[0].a, (from_json(b)[0].a + 1)))#x] +Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 1)) from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a else (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].a + 1) AS (IF((from_json(a).a > 1), from_json(b)[0].a, (from_json(b)[0].a + 1)))#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -39,7 +39,7 @@ Project [if ((from_json(StructField(a,IntegerType,true), StructField(b,StringTyp -- !query SELECT if(isnull(from_json(a, 'struct').a), from_json(b, 'array>')[0].b + 1, from_json(b, 'array>')[0].b) FROM testData -- !query analysis -Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a)) (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 1) else from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b AS (IF((from_json(a).a IS NULL), (from_json(b)[0].b + 1), from_json(b)[0].b))#x] +Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a)) (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 1) else from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b AS (IF((from_json(a).a IS NULL), (from_json(b)[0].b + 1), from_json(b)[0].b))#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -51,7 +51,7 @@ Project [if (isnull(from_json(StructField(a,IntegerType,true), StructField(b,Str -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(a, 'struct').b when from_json(a, 'struct').a > 4 then from_json(a, 'struct').b + 1 else from_json(a, 'struct').b + 2 end FROM testData -- !query analysis -Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 5) THEN from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 4) THEN cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b as double) + cast(1 as double)) as string) ELSE cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).b as double) + cast(2 as double)) as string) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(a).b WHEN (from_json(a).a > 4) THEN (from_json(a).b + 1) ELSE (from_json(a).b + 2) END#x] +Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 5) THEN from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 4) THEN cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b as double) + cast(1 as double)) as string) ELSE cast((cast(from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).b as double) + cast(2 as double)) as string) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(a).b WHEN (from_json(a).a > 4) THEN (from_json(a).b + 1) ELSE (from_json(a).b + 2) END#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] @@ -63,7 +63,7 @@ Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,Str -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(b, 'array>')[0].b when from_json(a, 'struct').a > 4 then from_json(b, 'array>')[0].b + 1 else from_json(b, 'array>')[0].b + 2 end FROM testData -- !query analysis -Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 5) THEN from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles)).a > 4) THEN (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 1) ELSE (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles))[0].b + 2) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(b)[0].b WHEN (from_json(a).a > 4) THEN (from_json(b)[0].b + 1) ELSE (from_json(b)[0].b + 2) END#x] +Project [CASE WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 5) THEN from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b WHEN (from_json(StructField(a,IntegerType,true), StructField(b,StringType,true), a#x, Some(America/Los_Angeles), false).a > 4) THEN (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 1) ELSE (from_json(ArrayType(StructType(StructField(a,IntegerType,true),StructField(b,IntegerType,true)),true), b#x, Some(America/Los_Angeles), false)[0].b + 2) END AS CASE WHEN (from_json(a).a > 5) THEN from_json(b)[0].b WHEN (from_json(a).a > 4) THEN (from_json(b)[0].b + 1) ELSE (from_json(b)[0].b + 2) END#x] +- SubqueryAlias testdata +- View (`testData`, [a#x, b#x]) +- Project [cast(a#x as string) AS a#x, cast(b#x as string) AS b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out index 6ca35b8b141dc..dcfd783b648f8 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestamp.sql.out @@ -802,7 +802,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out index e50c860270563..ec227afc87fe1 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp-ansi.sql.out @@ -745,7 +745,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out index 098abfb3852cf..7475f837250d5 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timestampNTZ/timestamp.sql.out @@ -805,7 +805,7 @@ Project [unix_timestamp(22 05 2020 Friday, dd MM yyyy EEEEE, Some(America/Los_An -- !query select from_json('{"t":"26/October/2015"}', 't Timestamp', map('timestampFormat', 'dd/MMMMM/yyyy')) -- !query analysis -Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles)) AS from_json({"t":"26/October/2015"})#x] +Project [from_json(StructField(t,TimestampNTZType,true), (timestampFormat,dd/MMMMM/yyyy), {"t":"26/October/2015"}, Some(America/Los_Angeles), false) AS from_json({"t":"26/October/2015"})#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out index 009e91f7ffacf..22e60d0606382 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/typeCoercion/native/stringCastAndExpressions.sql.out @@ -370,7 +370,7 @@ Project [c0#x] -- !query select from_json(a, 'a INT') from t -- !query analysis -Project [from_json(StructField(a,IntegerType,true), a#x, Some(America/Los_Angeles)) AS from_json(a)#x] +Project [from_json(StructField(a,IntegerType,true), a#x, Some(America/Los_Angeles), false) AS from_json(a)#x] +- SubqueryAlias t +- View (`t`, [a#x]) +- Project [cast(a#x as string) AS a#x] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 3224baf42f3e5..19d4ac23709b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql +import org.apache.spark.SparkThrowable import org.apache.spark.sql.QueryTest.sameRows import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -28,6 +29,7 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.types.variant.VariantBuilder +import org.apache.spark.types.variant.VariantUtil._ import org.apache.spark.unsafe.types.VariantVal class VariantEndToEndSuite extends QueryTest with SharedSparkSession { @@ -37,8 +39,10 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { def check(input: String, output: String = null): Unit = { val df = Seq(input).toDF("v") val variantDF = df.select(to_json(parse_json(col("v")))) + val variantDF2 = df.select(to_json(from_json(col("v"), VariantType))) val expected = if (output != null) output else input checkAnswer(variantDF, Seq(Row(expected))) + checkAnswer(variantDF2, Seq(Row(expected))) } check("null") @@ -339,4 +343,32 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { } } } + + test("from_json(_, 'variant') with duplicate keys") { + val json: String = """{"a": 1, "b": 2, "c": "3", "a": 4}""" + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "true") { + val df = Seq(json).toDF("j") + .selectExpr("from_json(j,'variant')") + val actual = df.collect().head(0).asInstanceOf[VariantVal] + val expectedValue: Array[Byte] = Array(objectHeader(false, 1, 1), + /* size */ 3, + /* id list */ 0, 1, 2, + /* offset list */ 4, 0, 2, 6, + /* field data */ primitiveHeader(INT1), 2, shortStrHeader(1), '3', + primitiveHeader(INT1), 4) + val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c') + assert(actual === new VariantVal(expectedValue, expectedMetadata)) + } + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { + val df = Seq(json).toDF("j") + .selectExpr("from_json(j,'variant')") + checkError( + exception = intercept[SparkThrowable] { + df.collect() + }, + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") + ) + } + } } From fc8b94544163cd1988053a3f8eb8b4770fbbb55b Mon Sep 17 00:00:00 2001 From: Ziqi Liu Date: Sat, 21 Sep 2024 12:37:57 +0900 Subject: [PATCH 042/250] [SPARK-49460][SQL] Followup: fix potential NPE risk ### What changes were proposed in this pull request? Fixed potential NPE risk in `EmptyRelationExec.logical` ### Why are the changes needed? This is a follow up for https://github.com/apache/spark/pull/47931, I've checked other callsites of `EmptyRelationExec.logical`, which we can not assure it's driver-only. So we should fix those potential risks. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Existing UT ### Was this patch authored or co-authored using generative AI tooling? NO Closes #48191 from liuzqt/SPARK-49460. Authored-by: Ziqi Liu Signed-off-by: Hyukjin Kwon --- .../spark/sql/execution/EmptyRelationExec.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala index 8a544de7567e8..a0c3d7b51c2c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/EmptyRelationExec.scala @@ -71,13 +71,15 @@ case class EmptyRelationExec(@transient logical: LogicalPlan) extends LeafExecNo maxFields, printNodeId, indent) - lastChildren.add(true) - logical.generateTreeString( - depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent) - lastChildren.remove(lastChildren.size() - 1) + Option(logical).foreach { _ => + lastChildren.add(true) + logical.generateTreeString( + depth + 1, lastChildren, append, verbose, "", false, maxFields, printNodeId, indent) + lastChildren.remove(lastChildren.size() - 1) + } } override def doCanonicalize(): SparkPlan = { - this.copy(logical = LocalRelation(logical.output).canonicalized) + this.copy(logical = LocalRelation(output).canonicalized) } } From 0b05b1aa72ced85b49c7230a493bd3200bcc786a Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sat, 21 Sep 2024 09:20:23 +0200 Subject: [PATCH 043/250] [SPARK-48782][SQL][TESTS][FOLLOW-UP] Enable ANSI for malformed input test in ProcedureSuite ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/47943 that enables ANSI for malformed input test in ProcedureSuite. ### Why are the changes needed? The specific test fails with ANSI mode disabled https://github.com/apache/spark/actions/runs/10951615244/job/30408963913 ``` - malformed input to implicit cast *** FAILED *** (4 milliseconds) Expected exception org.apache.spark.SparkNumberFormatException to be thrown, but no exception was thrown (ProcedureSuite.scala:264) org.scalatest.exceptions.TestFailedException: at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:472) at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:471) at org.scalatest.funsuite.AnyFunSuite.newAssertionFailedException(AnyFunSuite.scala:1564) ... ``` The test depends on `sum`'s failure so this PR simply enables ANSI mode for that specific test. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually ran with ANSI mode off. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48193 from HyukjinKwon/SPARK-48782-followup. Authored-by: Hyukjin Kwon Signed-off-by: Max Gekk --- .../spark/sql/connector/ProcedureSuite.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala index e39a1b7ea340a..c8faf5a874f5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala @@ -258,18 +258,20 @@ class ProcedureSuite extends QueryTest with SharedSparkSession with BeforeAndAft } test("malformed input to implicit cast") { - catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) - val call = "CALL cat.ns.sum('A', 2)" - checkError( - exception = intercept[SparkNumberFormatException]( - sql(call) - ), - condition = "CAST_INVALID_INPUT", - parameters = Map( - "expression" -> toSQLValue("A"), - "sourceType" -> toSQLType("STRING"), - "targetType" -> toSQLType("INT")), - context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + withSQLConf(SQLConf.ANSI_ENABLED.key -> true.toString) { + catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum) + val call = "CALL cat.ns.sum('A', 2)" + checkError( + exception = intercept[SparkNumberFormatException]( + sql(call) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> toSQLValue("A"), + "sourceType" -> toSQLType("STRING"), + "targetType" -> toSQLType("INT")), + context = ExpectedContext(fragment = call, start = 0, stop = call.length - 1)) + } } test("required parameters after optional") { From bbbc05cbf971e931a1defc54b9924060dcdf55ca Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sat, 21 Sep 2024 19:39:27 +0800 Subject: [PATCH 044/250] [SPARK-49495][DOCS] Document and Feature Preview on the master branch via Live GitHub Pages Updates ### What changes were proposed in this pull request? This pull request introduces functionalities that enable 'Document and Feature Preview on the master branch via Live GitHub Pages Updates'. ### Why are the changes needed? retore 8861f0f9af3f397921ba1204cf4f76f4e20680bb 376382711e200aa978008b25630cc54271fd419b 58d73fe8e7cbff9878539d31430f819eff9fc7a1 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? https://github.com/yaooqinn/spark/actions/runs/10952355999 ### Was this patch authored or co-authored using generative AI tooling? no Closes #48175 from yaooqinn/SPARK-49495. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .asf.yaml | 2 + .github/workflows/pages.yml | 92 +++++++++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 .github/workflows/pages.yml diff --git a/.asf.yaml b/.asf.yaml index 22042b355b2fa..3935a525ff3c4 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,6 +31,8 @@ github: merge: false squash: true rebase: true + ghp_branch: master + ghp_path: /docs notifications: pullrequests: reviews@spark.apache.org diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml new file mode 100644 index 0000000000000..b3f1cad8d947f --- /dev/null +++ b/.github/workflows/pages.yml @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: GitHub Pages deployment + +on: + push: + branches: + - master + +concurrency: + group: 'docs preview' + cancel-in-progress: false + +jobs: + docs: + name: Build and deploy documentation + runs-on: ubuntu-latest + permissions: + id-token: write + pages: write + environment: + name: github-pages # https://github.com/actions/deploy-pages/issues/271 + env: + SPARK_TESTING: 1 # Reduce some noise in the logs + RELEASE_VERSION: 'In-Progress' + steps: + - name: Checkout Spark repository + uses: actions/checkout@v4 + with: + repository: apache/spark + ref: 'master' + - name: Install Java 17 + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: 17 + - name: Install Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + architecture: x64 + cache: 'pip' + - name: Install Python dependencies + run: pip install --upgrade -r dev/requirements.txt + - name: Install Ruby for documentation generation + uses: ruby/setup-ruby@v1 + with: + ruby-version: '3.3' + bundler-cache: true + - name: Install Pandoc + run: | + sudo apt-get update -y + sudo apt-get install pandoc + - name: Install dependencies for documentation generation + run: | + cd docs + gem install bundler -v 2.4.22 -n /usr/local/bin + bundle install --retry=100 + - name: Run documentation build + run: | + sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp3" "s/'facetFilters':.*$/'facetFilters': [\"version:$RELEASE_VERSION\"]/g" docs/_config.yml + sed -i".tmp4" 's/__version__: str = .*$/__version__: str = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py + cd docs + SKIP_RDOC=1 bundle exec jekyll build + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: 'docs/_site' + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 From 19906468d145a52a0f039e49fa54c558767805b2 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sat, 21 Sep 2024 20:44:08 +0800 Subject: [PATCH 045/250] Revert "[SPARK-49495][DOCS] Document and Feature Preview on the master branch via Live GitHub Pages Updates" This reverts commit bbbc05cbf971e931a1defc54b9924060dcdf55ca. --- .asf.yaml | 2 - .github/workflows/pages.yml | 92 ------------------------------------- 2 files changed, 94 deletions(-) delete mode 100644 .github/workflows/pages.yml diff --git a/.asf.yaml b/.asf.yaml index 3935a525ff3c4..22042b355b2fa 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,8 +31,6 @@ github: merge: false squash: true rebase: true - ghp_branch: master - ghp_path: /docs notifications: pullrequests: reviews@spark.apache.org diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml deleted file mode 100644 index b3f1cad8d947f..0000000000000 --- a/.github/workflows/pages.yml +++ /dev/null @@ -1,92 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -name: GitHub Pages deployment - -on: - push: - branches: - - master - -concurrency: - group: 'docs preview' - cancel-in-progress: false - -jobs: - docs: - name: Build and deploy documentation - runs-on: ubuntu-latest - permissions: - id-token: write - pages: write - environment: - name: github-pages # https://github.com/actions/deploy-pages/issues/271 - env: - SPARK_TESTING: 1 # Reduce some noise in the logs - RELEASE_VERSION: 'In-Progress' - steps: - - name: Checkout Spark repository - uses: actions/checkout@v4 - with: - repository: apache/spark - ref: 'master' - - name: Install Java 17 - uses: actions/setup-java@v4 - with: - distribution: zulu - java-version: 17 - - name: Install Python 3.9 - uses: actions/setup-python@v5 - with: - python-version: '3.9' - architecture: x64 - cache: 'pip' - - name: Install Python dependencies - run: pip install --upgrade -r dev/requirements.txt - - name: Install Ruby for documentation generation - uses: ruby/setup-ruby@v1 - with: - ruby-version: '3.3' - bundler-cache: true - - name: Install Pandoc - run: | - sudo apt-get update -y - sudo apt-get install pandoc - - name: Install dependencies for documentation generation - run: | - cd docs - gem install bundler -v 2.4.22 -n /usr/local/bin - bundle install --retry=100 - - name: Run documentation build - run: | - sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml - sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml - sed -i".tmp3" "s/'facetFilters':.*$/'facetFilters': [\"version:$RELEASE_VERSION\"]/g" docs/_config.yml - sed -i".tmp4" 's/__version__: str = .*$/__version__: str = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py - cd docs - SKIP_RDOC=1 bundle exec jekyll build - - name: Setup Pages - uses: actions/configure-pages@v5 - - name: Upload artifact - uses: actions/upload-pages-artifact@v3 - with: - path: 'docs/_site' - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v4 From 4f640e2485d24088345b3f2d894c696ef29e2923 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Sat, 21 Sep 2024 23:18:31 +0800 Subject: [PATCH 046/250] [SPARK-49495][DOCS] Document and Feature Preview on the master branch via Live GitHub Pages Updates ### What changes were proposed in this pull request? This pull request introduces functionalities that enable 'Document and Feature Preview on the master branch via Live GitHub Pages Updates'. ### Why are the changes needed? retore 8861f0f9af3f397921ba1204cf4f76f4e20680bb 376382711e200aa978008b25630cc54271fd419b 58d73fe8e7cbff9878539d31430f819eff9fc7a1 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? https://github.com/yaooqinn/spark/actions/runs/10952355999 ### Was this patch authored or co-authored using generative AI tooling? no Closes #48175 from yaooqinn/SPARK-49495. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .asf.yaml | 2 + .github/workflows/pages.yml | 97 +++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 .github/workflows/pages.yml diff --git a/.asf.yaml b/.asf.yaml index 22042b355b2fa..3935a525ff3c4 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -31,6 +31,8 @@ github: merge: false squash: true rebase: true + ghp_branch: master + ghp_path: /docs notifications: pullrequests: reviews@spark.apache.org diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml new file mode 100644 index 0000000000000..8faeb0557fbfb --- /dev/null +++ b/.github/workflows/pages.yml @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: GitHub Pages deployment + +on: + push: + branches: + - master + +concurrency: + group: 'docs preview' + cancel-in-progress: false + +jobs: + docs: + name: Build and deploy documentation + runs-on: ubuntu-latest + permissions: + id-token: write + pages: write + environment: + name: github-pages # https://github.com/actions/deploy-pages/issues/271 + env: + SPARK_TESTING: 1 # Reduce some noise in the logs + RELEASE_VERSION: 'In-Progress' + steps: + - name: Checkout Spark repository + uses: actions/checkout@v4 + with: + repository: apache/spark + ref: 'master' + - name: Install Java 17 + uses: actions/setup-java@v4 + with: + distribution: zulu + java-version: 17 + - name: Install Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + architecture: x64 + cache: 'pip' + - name: Install Python dependencies + run: | + pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow 'pandas==2.2.2' 'plotly>=4.8' 'docutils<0.18.0' \ + 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ + 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ + 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' + - name: Install Ruby for documentation generation + uses: ruby/setup-ruby@v1 + with: + ruby-version: '3.3' + bundler-cache: true + - name: Install Pandoc + run: | + sudo apt-get update -y + sudo apt-get install pandoc + - name: Install dependencies for documentation generation + run: | + cd docs + gem install bundler -v 2.4.22 -n /usr/local/bin + bundle install --retry=100 + - name: Run documentation build + run: | + sed -i".tmp1" 's/SPARK_VERSION:.*$/SPARK_VERSION: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp2" 's/SPARK_VERSION_SHORT:.*$/SPARK_VERSION_SHORT: '"$RELEASE_VERSION"'/g' docs/_config.yml + sed -i".tmp3" "s/'facetFilters':.*$/'facetFilters': [\"version:$RELEASE_VERSION\"]/g" docs/_config.yml + sed -i".tmp4" 's/__version__: str = .*$/__version__: str = "'"$RELEASE_VERSION"'"/' python/pyspark/version.py + cd docs + SKIP_RDOC=1 bundle exec jekyll build + - name: Setup Pages + uses: actions/configure-pages@v5 + - name: Upload artifact + uses: actions/upload-pages-artifact@v3 + with: + path: 'docs/_site' + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 From b6420969b5df2ba1f542e020c5773d1d107734e9 Mon Sep 17 00:00:00 2001 From: Tim Lee Date: Sun, 22 Sep 2024 14:40:41 +0900 Subject: [PATCH 047/250] [SPARK-49741][DOCS] Add `spark.shuffle.accurateBlockSkewedFactor` to config docs page ### What changes were proposed in this pull request? `spark.shuffle.accurateBlockSkewedFactor` was added in Spark 3.3.0 in https://issues.apache.org/jira/browse/SPARK-36967 and is a useful shuffle configuration to prevent issues where `HighlyCompressedMapStatus` wrongly estimates the shuffle block sizes when the block size distribution is skewed, which can cause the shuffle reducer to fetch too much data and OOM. This PR adds this config to the Spark config docs page to make it discoverable. ### Why are the changes needed? To make this useful config discoverable by users and make them able to resolve shuffle fetch OOM issues themselves. ### Does this PR introduce _any_ user-facing change? Yes, this is a documentation fix. Before this PR there's no `spark.sql.adaptive.skewJoin.skewedPartitionFactor` in the `Shuffle Behavior` section on [the Configurations page](https://spark.apache.org/docs/latest/configuration.html) and now there is. ### How was this patch tested? On the IDE: image Updated: image ### Was this patch authored or co-authored using generative AI tooling? No Closes #48189 from timlee0119/add-accurate-block-skewed-factor-to-doc. Authored-by: Tim Lee Signed-off-by: Hyukjin Kwon --- .../org/apache/spark/internal/config/package.scala | 1 - docs/configuration.md | 13 +++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 47019c04aada2..c5646d2956aeb 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1386,7 +1386,6 @@ package object config { private[spark] val SHUFFLE_ACCURATE_BLOCK_SKEWED_FACTOR = ConfigBuilder("spark.shuffle.accurateBlockSkewedFactor") - .internal() .doc("A shuffle block is considered as skewed and will be accurately recorded in " + "HighlyCompressedMapStatus if its size is larger than this factor multiplying " + "the median shuffle block size or SHUFFLE_ACCURATE_BLOCK_THRESHOLD. It is " + diff --git a/docs/configuration.md b/docs/configuration.md index 73d57b687ca2a..3c83ed92c1280 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1232,6 +1232,19 @@ Apart from these, the following properties are also available, and may be useful 2.2.1 + + spark.shuffle.accurateBlockSkewedFactor + -1.0 + + A shuffle block is considered as skewed and will be accurately recorded in + HighlyCompressedMapStatus if its size is larger than this factor multiplying + the median shuffle block size or spark.shuffle.accurateBlockThreshold. It is + recommended to set this parameter to be the same as + spark.sql.adaptive.skewJoin.skewedPartitionFactor. Set to -1.0 to disable this + feature by default. + + 3.3.0 + spark.shuffle.registration.timeout 5000 From 067f8f188eb22f9abe39eee0d70ad1ef73f4f644 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Sun, 22 Sep 2024 14:41:25 +0900 Subject: [PATCH 048/250] [SPARK-48355][SQL][TESTS][FOLLOWUP] Enable a SQL Scripting test in ANSI and non-ANSI modes ### What changes were proposed in this pull request? In the PR, I propose to enable the test which https://github.com/apache/spark/pull/48115 turned off, and run in the ANSI and non-ANSI modes. ### Why are the changes needed? To make this test stable, and don't depend on the default setting for ANSI mode. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running the modified test locally: ``` $ PYSPARK_PYTHON=python3 build/sbt "sql/testOnly org.apache.spark.sql.scripting.SqlScriptingInterpreterSuite" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48194 from MaxGekk/enable-sqlscript-test-ansi. Authored-by: Max Gekk Signed-off-by: Hyukjin Kwon --- .../SqlScriptingInterpreterSuite.scala | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index bc2adec5be3d5..ac190eb48d1f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.{SparkException, SparkNumberFormatException} import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.exceptions.SqlScriptingException +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession /** @@ -701,8 +702,7 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { verifySqlScriptResult(commands, expected) } - // This is disabled because it fails in non-ANSI mode - ignore("simple case mismatched types") { + test("simple case mismatched types") { val commands = """ |BEGIN @@ -712,18 +712,26 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { | END CASE; |END |""".stripMargin - - checkError( - exception = intercept[SparkNumberFormatException] ( - runSqlScript(commands) - ), - condition = "CAST_INVALID_INPUT", - parameters = Map( - "expression" -> "'one'", - "sourceType" -> "\"STRING\"", - "targetType" -> "\"BIGINT\""), - context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27) - ) + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { + checkError( + exception = intercept[SparkNumberFormatException]( + runSqlScript(commands) + ), + condition = "CAST_INVALID_INPUT", + parameters = Map( + "expression" -> "'one'", + "sourceType" -> "\"STRING\"", + "targetType" -> "\"BIGINT\""), + context = ExpectedContext(fragment = "\"one\"", start = 23, stop = 27)) + } + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + checkError( + exception = intercept[SqlScriptingException]( + runSqlScript(commands) + ), + condition = "BOOLEAN_STATEMENT_WITH_EMPTY_ROW", + parameters = Map("invalidStatement" -> "\"ONE\"")) + } } test("simple case compare with null") { From 719b57a32e0f36e7c425137014df2b83b7c4b029 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Sun, 22 Sep 2024 14:29:24 -0700 Subject: [PATCH 049/250] [SPARK-49752][YARN] Remove workaround for YARN-3350 ### What changes were proposed in this pull request? Remove the logic of forcibly setting the log level to WARN for `org.apache.hadoop.yarn.util.RackResolver`. ### Why are the changes needed? The removed code was introduced in SPARK-5393 as a workaround for YARN-3350, which is already fixed on the YARN 2.8.0/3.0.0. ### Does this PR introduce _any_ user-facing change? Yes, previously, the log level of RackResolver is hardcoded as WARN even if the user explicitly sets it to DEBUG. ### How was this patch tested? Review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48201 from pan3793/SPARK-49752. Authored-by: Cheng Pan Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/yarn/SparkRackResolver.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala index 618f0dc8a4daa..d6e814f5c30a5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/SparkRackResolver.scala @@ -25,9 +25,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.CommonConfigurationKeysPublic import org.apache.hadoop.net._ import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.yarn.util.RackResolver -import org.apache.logging.log4j.{Level, LogManager} -import org.apache.logging.log4j.core.Logger import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.NODE_LOCATION @@ -39,12 +36,6 @@ import org.apache.spark.internal.LogKeys.NODE_LOCATION */ private[spark] class SparkRackResolver(conf: Configuration) extends Logging { - // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. - val logger = LogManager.getLogger(classOf[RackResolver]) - if (logger.getLevel != Level.WARN) { - logger.asInstanceOf[Logger].setLevel(Level.WARN) - } - private val dnsToSwitchMapping: DNSToSwitchMapping = { val dnsToSwitchMappingClass = conf.getClass(CommonConfigurationKeysPublic.NET_TOPOLOGY_NODE_SWITCH_MAPPING_IMPL_KEY, From 0eeb61fb64e0c499610c7b9a84f9e41e923251e8 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 23 Sep 2024 10:46:08 +0800 Subject: [PATCH 050/250] [SPARK-49734][PYTHON] Add `seed` argument for function `shuffle` ### What changes were proposed in this pull request? 1, Add `seed` argument for function `shuffle`; 2, Rewrite and enable the doctest by specify the seed and control the partitioning; ### Why are the changes needed? feature parity, seed is support in SQL side ### Does this PR introduce _any_ user-facing change? yes, new argument ### How was this patch tested? updated doctest ### Was this patch authored or co-authored using generative AI tooling? no Closes #48184 from zhengruifeng/py_func_shuffle. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- .../pyspark/sql/connect/functions/builtin.py | 10 +-- python/pyspark/sql/functions/builtin.py | 69 ++++++++++--------- .../org/apache/spark/sql/functions.scala | 13 +++- 3 files changed, 53 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 7fed175cbc8ea..2a39bc6bfddda 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -65,7 +65,6 @@ from pyspark.sql.types import ( _from_numpy_type, DataType, - LongType, StructType, ArrayType, StringType, @@ -2206,12 +2205,9 @@ def schema_of_xml(xml: Union[str, Column], options: Optional[Mapping[str, str]] schema_of_xml.__doc__ = pysparkfuncs.schema_of_xml.__doc__ -def shuffle(col: "ColumnOrName") -> Column: - return _invoke_function( - "shuffle", - _to_col(col), - LiteralExpression(random.randint(0, sys.maxsize), LongType()), - ) +def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column: + _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed) + return _invoke_function("shuffle", _to_col(col), _seed) shuffle.__doc__ = pysparkfuncs.shuffle.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 5f8d1c21a24f1..2d5dbb5946050 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -17723,7 +17723,7 @@ def array_sort( @_try_remote_functions -def shuffle(col: "ColumnOrName") -> Column: +def shuffle(col: "ColumnOrName", seed: Optional[Union[Column, int]] = None) -> Column: """ Array function: Generates a random permutation of the given array. @@ -17736,6 +17736,10 @@ def shuffle(col: "ColumnOrName") -> Column: ---------- col : :class:`~pyspark.sql.Column` or str The name of the column or expression to be shuffled. + seed : :class:`~pyspark.sql.Column` or int, optional + Seed value for the random generator. + + .. versionadded:: 4.0.0 Returns ------- @@ -17752,48 +17756,51 @@ def shuffle(col: "ColumnOrName") -> Column: Example 1: Shuffling a simple array >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 20, 3, 5],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +-------------+ - |shuffle(data)| - +-------------+ - |[1, 3, 20, 5]| - +-------------+ + >>> df = spark.sql("SELECT ARRAY(1, 20, 3, 5) AS data") + >>> df.select("*", sf.shuffle(df.data, sf.lit(123))).show() + +-------------+-------------+ + | data|shuffle(data)| + +-------------+-------------+ + |[1, 20, 3, 5]|[5, 1, 20, 3]| + +-------------+-------------+ Example 2: Shuffling an array with null values >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 20, None, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +----------------+ - | shuffle(data)| - +----------------+ - |[20, 3, NULL, 1]| - +----------------+ + >>> df = spark.sql("SELECT ARRAY(1, 20, NULL, 5) AS data") + >>> df.select("*", sf.shuffle(sf.col("data"), 234)).show() + +----------------+----------------+ + | data| shuffle(data)| + +----------------+----------------+ + |[1, 20, NULL, 5]|[NULL, 5, 20, 1]| + +----------------+----------------+ Example 3: Shuffling an array with duplicate values >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([([1, 2, 2, 3, 3, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +------------------+ - | shuffle(data)| - +------------------+ - |[3, 2, 1, 3, 2, 3]| - +------------------+ + >>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data") + >>> df.select("*", sf.shuffle("data", 345)).show() + +------------------+------------------+ + | data| shuffle(data)| + +------------------+------------------+ + |[1, 2, 2, 3, 3, 3]|[2, 3, 3, 1, 2, 3]| + +------------------+------------------+ - Example 4: Shuffling an array with different types of elements + Example 4: Shuffling an array with random seed >>> import pyspark.sql.functions as sf - >>> df = spark.createDataFrame([(['a', 'b', 'c', 1, 2, 3],)], ['data']) - >>> df.select(sf.shuffle(df.data)).show() # doctest: +SKIP - +------------------+ - | shuffle(data)| - +------------------+ - |[1, c, 2, a, b, 3]| - +------------------+ + >>> df = spark.sql("SELECT ARRAY(1, 2, 2, 3, 3, 3) AS data") + >>> df.select("*", sf.shuffle("data")).show() # doctest: +SKIP + +------------------+------------------+ + | data| shuffle(data)| + +------------------+------------------+ + |[1, 2, 2, 3, 3, 3]|[3, 3, 2, 3, 2, 1]| + +------------------+------------------+ """ - return _invoke_function_over_columns("shuffle", col) + if seed is not None: + return _invoke_function_over_columns("shuffle", col, lit(seed)) + else: + return _invoke_function_over_columns("shuffle", col) @_try_remote_functions diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 0662b8f2b271f..d9bceabe88f8f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -7252,7 +7252,18 @@ object functions { * @group array_funcs * @since 2.4.0 */ - def shuffle(e: Column): Column = Column.fn("shuffle", e, lit(SparkClassUtils.random.nextLong)) + def shuffle(e: Column): Column = shuffle(e, lit(SparkClassUtils.random.nextLong)) + + /** + * Returns a random permutation of the given array. + * + * @note + * The function is non-deterministic. + * + * @group array_funcs + * @since 4.0.0 + */ + def shuffle(e: Column, seed: Column): Column = Column.fn("shuffle", e, seed) /** * Returns a reversed string or an array with reverse order of elements. From 3c81f076ab9c72514cfc8372edd16e6da7c151d6 Mon Sep 17 00:00:00 2001 From: Andrey Gubichev Date: Mon, 23 Sep 2024 10:58:13 +0800 Subject: [PATCH 051/250] [SPARK-49653][SQL] Single join for correlated scalar subqueries ### What changes were proposed in this pull request? Single join is a left outer join that checks that there is at most 1 build row for every probe row. This PR adds single join implementation to support correlated scalar subqueries where the optimizer can't guarantee that 1 row is coming from them, e.g.: select *, (select t1.x from t1 where t1.y >= t_outer.y) from t_outer. -- this subquery is going to be rewritten as a single join that makes sure there is at most 1 matching build row for every probe row. It will issue a spark runtime error otherwise. Design doc: https://docs.google.com/document/d/1NTsvtBTB9XvvyRvH62QzWIZuw4hXktALUG1fBP7ha1Q/edit The optimizer introduces a single join in cases that were previously returning incorrect results (or were unsupported). Only hash-based implementation is supported, the optimizer makes sure we don't plan a single join as a sort-merge join. ### Why are the changes needed? Expands our subquery coverage. ### Does this PR introduce _any_ user-facing change? Yes, previously unsupported scalar subqueries should now work. ### How was this patch tested? Unit tests for the single join operator. Query tests for the subqueries. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48145 from agubichev/single_join. Authored-by: Andrey Gubichev Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 36 ++-- .../sql/catalyst/expressions/subquery.scala | 22 +- .../sql/catalyst/optimizer/Optimizer.scala | 9 +- .../sql/catalyst/optimizer/expressions.scala | 4 +- .../spark/sql/catalyst/optimizer/joins.scala | 8 +- .../sql/catalyst/optimizer/subquery.scala | 50 ++++- .../spark/sql/catalyst/plans/joinTypes.scala | 4 + .../plans/logical/basicLogicalOperators.scala | 10 +- .../sql/errors/QueryExecutionErrors.scala | 6 + .../apache/spark/sql/internal/SQLConf.scala | 9 + .../spark/sql/execution/SparkStrategies.scala | 11 +- .../adaptive/PlanAdaptiveSubqueries.scala | 2 +- .../joins/BroadcastNestedLoopJoinExec.scala | 44 +++- .../spark/sql/execution/joins/HashJoin.scala | 29 ++- .../sql/execution/joins/ShuffledJoin.scala | 6 +- .../scalar-subquery-group-by.sql.out | 111 ++++++++-- .../scalar-subquery-predicate.sql.out | 18 ++ .../scalar-subquery-group-by.sql | 11 +- .../scalar-subquery-predicate.sql | 3 + .../scalar-subquery-group-by.sql.out | 83 +++++-- .../scalar-subquery-predicate.sql.out | 8 + .../spark/sql/LateralColumnAliasSuite.scala | 11 +- .../org/apache/spark/sql/SubquerySuite.scala | 44 ++-- .../sql/execution/joins/SingleJoinSuite.scala | 204 ++++++++++++++++++ 25 files changed, 613 insertions(+), 132 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9e5b1d1254c87..b2e9115dd512f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2716,7 +2716,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor */ private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): LogicalPlan = { plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), ruleId) { - case s @ ScalarSubquery(sub, _, exprId, _, _, _) if !sub.resolved => + case s @ ScalarSubquery(sub, _, exprId, _, _, _, _) if !sub.resolved => resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId)) case e @ Exists(sub, _, exprId, _, _) if !sub.resolved => resolveSubQuery(e, outer)(Exists(_, _, exprId)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 5a9d5cd87ecc7..b600f455f16ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -952,19 +952,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB messageParameters = Map.empty) } - // SPARK-18504/SPARK-18814: Block cases where GROUP BY columns - // are not part of the correlated columns. - - // Collect the inner query expressions that are guaranteed to have a single value for each - // outer row. See comment on getCorrelatedEquivalentInnerExpressions. - val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query) - // Grouping expressions, except outer refs and constant expressions - grouping by an - // outer ref or a constant is always ok - val groupByExprs = - ExpressionSet(agg.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] && - x.references.nonEmpty)) - val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs - + val nonEquivalentGroupByExprs = nonEquivalentGroupbyCols(query, agg) val invalidCols = if (!SQLConf.get.getConf( SQLConf.LEGACY_SCALAR_SUBQUERY_ALLOW_GROUP_BY_NON_EQUALITY_CORRELATED_PREDICATE)) { nonEquivalentGroupByExprs @@ -1044,7 +1032,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB checkOuterReference(plan, expr) expr match { - case ScalarSubquery(query, outerAttrs, _, _, _, _) => + case ScalarSubquery(query, outerAttrs, _, _, _, _, _) => // Scalar subquery must return one column as output. if (query.output.size != 1) { throw QueryCompilationErrors.subqueryReturnMoreThanOneColumn(query.output.size, @@ -1052,15 +1040,17 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } if (outerAttrs.nonEmpty) { - cleanQueryInScalarSubquery(query) match { - case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a) - case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a) - case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok - case other => - expr.failAnalysis( - errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - messageParameters = Map.empty) + if (!SQLConf.get.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN)) { + cleanQueryInScalarSubquery(query) match { + case a: Aggregate => checkAggregateInScalarSubquery(outerAttrs, query, a) + case Filter(_, a: Aggregate) => checkAggregateInScalarSubquery(outerAttrs, query, a) + case p: LogicalPlan if p.maxRows.exists(_ <= 1) => // Ok + case other => + expr.failAnalysis( + errorClass = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", + messageParameters = Map.empty) + } } // Only certain operators are allowed to host subquery expression containing diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 174d32c73fc01..0c8253659dd56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -358,6 +358,20 @@ object SubExprUtils extends PredicateHelper { case _ => ExpressionSet().empty } } + + // Returns grouping expressions of 'aggNode' of a scalar subquery that do not have equivalent + // columns in the outer query (bound by equality predicates like 'col = outer(c)'). + // We use it to analyze whether a scalar subquery is guaranteed to return at most 1 row. + def nonEquivalentGroupbyCols(query: LogicalPlan, aggNode: Aggregate): ExpressionSet = { + val correlatedEquivalentExprs = getCorrelatedEquivalentInnerExpressions(query) + // Grouping expressions, except outer refs and constant expressions - grouping by an + // outer ref or a constant is always ok + val groupByExprs = + ExpressionSet(aggNode.groupingExpressions.filter(x => !x.isInstanceOf[OuterReference] && + x.references.nonEmpty)) + val nonEquivalentGroupByExprs = groupByExprs -- correlatedEquivalentExprs + nonEquivalentGroupByExprs + } } /** @@ -371,6 +385,11 @@ object SubExprUtils extends PredicateHelper { * case the subquery yields no row at all on empty input to the GROUP BY, which evaluates to NULL. * It is set in PullupCorrelatedPredicates to true/false, before it is set its value is None. * See constructLeftJoins in RewriteCorrelatedScalarSubquery for more details. + * + * 'needSingleJoin' is set to true if we can't guarantee that the correlated scalar subquery + * returns at most 1 row. For such subqueries we use a modification of an outer join called + * LeftSingle join. This value is set in PullupCorrelatedPredicates and used in + * RewriteCorrelatedScalarSubquery. */ case class ScalarSubquery( plan: LogicalPlan, @@ -378,7 +397,8 @@ case class ScalarSubquery( exprId: ExprId = NamedExpression.newExprId, joinCond: Seq[Expression] = Seq.empty, hint: Option[HintInfo] = None, - mayHaveCountBug: Option[Boolean] = None) + mayHaveCountBug: Option[Boolean] = None, + needSingleJoin: Option[Boolean] = None) extends SubqueryExpression(plan, outerAttrs, exprId, joinCond, hint) with Unevaluable { override def dataType: DataType = { if (!plan.schema.fields.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 8e14537c6a5b4..7fc12f7d1fc16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -338,7 +338,7 @@ abstract class Optimizer(catalogManager: CatalogManager) case d: DynamicPruningSubquery => d case s @ ScalarSubquery( PhysicalOperation(projections, predicates, a @ Aggregate(group, _, child)), - _, _, _, _, mayHaveCountBug) + _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => // This is a subquery with an aggregate that may suffer from a COUNT bug. @@ -1988,7 +1988,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { } private def canPushThrough(joinType: JoinType): Boolean = joinType match { - case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftAnti | ExistenceJoin(_) => true + case _: InnerLike | LeftSemi | RightOuter | LeftOuter | LeftSingle | + LeftAnti | ExistenceJoin(_) => true case _ => false } @@ -2028,7 +2029,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) - case LeftOuter | LeftExistence(_) => + case LeftOuter | LeftSingle | LeftExistence(_) => // push down the left side only `where` condition val newLeft = leftFilterConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -2074,6 +2075,8 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) Join(newLeft, newRight, joinType, newJoinCond, hint) + // Do not move join predicates of a single join. + case LeftSingle => j case other => throw SparkException.internalError(s"Unexpected join type: $other") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 3cdde622d51f7..1601d798283c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -90,7 +90,7 @@ object ConstantFolding extends Rule[LogicalPlan] { } // Don't replace ScalarSubquery if its plan is an aggregate that may suffer from a COUNT bug. - case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug) + case s @ ScalarSubquery(_, _, _, _, _, mayHaveCountBug, _) if conf.getConf(SQLConf.DECORRELATE_SUBQUERY_PREVENT_CONSTANT_FOLDING_FOR_COUNT_BUG) && mayHaveCountBug.nonEmpty && mayHaveCountBug.get => s @@ -1007,7 +1007,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { replaceFoldable(j.withNewChildren(newChildren).asInstanceOf[Join], foldableMap) val missDerivedAttrsSet: AttributeSet = AttributeSet(newJoin.joinType match { case _: InnerLike | LeftExistence(_) => Nil - case LeftOuter => newJoin.right.output + case LeftOuter | LeftSingle => newJoin.right.output case RightOuter => newJoin.left.output case FullOuter => newJoin.left.output ++ newJoin.right.output case _ => Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 9fc4873c248b5..6802adaa2ea24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -339,8 +339,8 @@ trait JoinSelectionHelper extends Logging { ) } - def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint): Option[BuildSide] = { - if (hintToNotBroadcastAndReplicateLeft(hint)) { + def getBroadcastNestedLoopJoinBuildSide(hint: JoinHint, joinType: JoinType): Option[BuildSide] = { + if (hintToNotBroadcastAndReplicateLeft(hint) || joinType == LeftSingle) { Some(BuildRight) } else if (hintToNotBroadcastAndReplicateRight(hint)) { Some(BuildLeft) @@ -375,7 +375,7 @@ trait JoinSelectionHelper extends Logging { def canBuildBroadcastRight(joinType: JoinType): Boolean = { joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } } @@ -389,7 +389,7 @@ trait JoinSelectionHelper extends Logging { def canBuildShuffledHashJoinRight(joinType: JoinType): Boolean = { joinType match { - case _: InnerLike | LeftOuter | FullOuter | RightOuter | + case _: InnerLike | LeftOuter | LeftSingle | FullOuter | RightOuter | LeftSemi | LeftAnti | _: ExistenceJoin => true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 1239a5dde1302..d9795cf338279 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -456,6 +456,31 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper (newPlan, newCond) } + // Returns true if 'query' is guaranteed to return at most 1 row. + private def guaranteedToReturnOneRow(query: LogicalPlan): Boolean = { + if (query.maxRows.exists(_ <= 1)) { + return true + } + val aggNode = query match { + case havingPart@Filter(_, aggPart: Aggregate) => Some(aggPart) + case aggPart: Aggregate => Some(aggPart) + // LIMIT 1 is handled above, this is for all other types of LIMITs + case Limit(_, aggPart: Aggregate) => Some(aggPart) + case Project(_, aggPart: Aggregate) => Some(aggPart) + case _: LogicalPlan => None + } + if (!aggNode.isDefined) { + return false + } + val aggregates = aggNode.get.expressions.flatMap(_.collect { + case a: AggregateExpression => a + }) + if (aggregates.isEmpty) { + return false + } + nonEquivalentGroupbyCols(query, aggNode.get).isEmpty + } + private def rewriteSubQueries(plan: LogicalPlan): LogicalPlan = { /** * This function is used as a aid to enforce idempotency of pullUpCorrelatedPredicate rule. @@ -481,7 +506,8 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper } plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) { - case ScalarSubquery(sub, children, exprId, conditions, hint, mayHaveCountBugOld) + case ScalarSubquery(sub, children, exprId, conditions, hint, + mayHaveCountBugOld, needSingleJoinOld) if children.nonEmpty => def mayHaveCountBugAgg(a: Aggregate): Boolean = { @@ -527,8 +553,13 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper val (topPart, havingNode, aggNode) = splitSubquery(sub) (aggNode.isDefined && aggNode.get.groupingExpressions.isEmpty) } + val needSingleJoin = if (needSingleJoinOld.isDefined) { + needSingleJoinOld.get + } else { + conf.getConf(SQLConf.SCALAR_SUBQUERY_USE_SINGLE_JOIN) && !guaranteedToReturnOneRow(sub) + } ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, conditions), - hint, Some(mayHaveCountBug)) + hint, Some(mayHaveCountBug), Some(needSingleJoin)) case Exists(sub, children, exprId, conditions, hint) if children.nonEmpty => val (newPlan, newCond) = if (SQLConf.get.decorrelateInnerQueryEnabledForExistsIn) { decorrelate(sub, plan, handleCountBug = true) @@ -786,7 +817,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = { val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]() val newChild = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug)) => + case (currentChild, ScalarSubquery(sub, _, _, conditions, subHint, mayHaveCountBug, + needSingleJoin)) => val query = DecorrelateInnerQuery.rewriteDomainJoins(currentChild, sub, conditions) val origOutput = query.output.head // The subquery appears on the right side of the join, hence add its hint to the right @@ -794,9 +826,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe val joinHint = JoinHint(None, subHint) val resultWithZeroTups = evalSubqueryOnZeroTups(query) + val joinType = needSingleJoin match { + case Some(true) => LeftSingle + case _ => LeftOuter + } lazy val planWithoutCountBug = Project( currentChild.output :+ origOutput, - Join(currentChild, query, LeftOuter, conditions.reduceOption(And), joinHint)) + Join(currentChild, query, joinType, conditions.reduceOption(And), joinHint)) if (Utils.isTesting) { assert(mayHaveCountBug.isDefined) @@ -845,7 +881,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ subqueryResultExpr, Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And), joinHint)) + joinType, conditions.reduceOption(And), joinHint)) } else { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. @@ -877,7 +913,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] with AliasHelpe currentChild.output :+ caseExpr, Join(currentChild, Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And), joinHint)) + joinType, conditions.reduceOption(And), joinHint)) } } } @@ -1028,7 +1064,7 @@ object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] { case p: LogicalPlan => p.transformExpressionsUpWithPruning( _.containsPattern(SCALAR_SUBQUERY)) { - case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _) + case s @ ScalarSubquery(OneRowSubquery(p @ Project(_, _: OneRowRelation)), _, _, _, _, _, _) if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty => assert(p.projectList.size == 1) stripOuterReferences(p.projectList).head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index d9da255eccc9d..41bba99673a2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -95,6 +95,10 @@ case object LeftAnti extends JoinType { override def sql: String = "LEFT ANTI" } +case object LeftSingle extends JoinType { + override def sql: String = "LEFT SINGLE" +} + case class ExistenceJoin(exists: Attribute) extends JoinType { override def sql: String = { // This join type is only used in the end of optimizer and physical plans, we will not diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 90af6333b2e0b..7c549a32aca0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -559,12 +559,12 @@ case class Join( override def maxRows: Option[Long] = { joinType match { - case Inner | Cross | FullOuter | LeftOuter | RightOuter + case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle if left.maxRows.isDefined && right.maxRows.isDefined => val leftMaxRows = BigInt(left.maxRows.get) val rightMaxRows = BigInt(right.maxRows.get) val minRows = joinType match { - case LeftOuter => leftMaxRows + case LeftOuter | LeftSingle => leftMaxRows case RightOuter => rightMaxRows case FullOuter => leftMaxRows + rightMaxRows case _ => BigInt(0) @@ -590,7 +590,7 @@ case class Join( left.output :+ j.exists case LeftExistence(_) => left.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -627,7 +627,7 @@ case class Join( left.constraints.union(right.constraints) case LeftExistence(_) => left.constraints - case LeftOuter => + case LeftOuter | LeftSingle => left.constraints case RightOuter => right.constraints @@ -659,7 +659,7 @@ case class Join( var patterns = Seq(JOIN) joinType match { case _: InnerLike => patterns = patterns :+ INNER_LIKE_JOIN - case LeftOuter | FullOuter | RightOuter => patterns = patterns :+ OUTER_JOIN + case LeftOuter | FullOuter | RightOuter | LeftSingle => patterns = patterns :+ OUTER_JOIN case LeftSemiOrAnti(_) => patterns = patterns :+ LEFT_SEMI_OR_ANTI_JOIN case NaturalJoin(_) | UsingJoin(_, _) => patterns = patterns :+ NATURAL_LIKE_JOIN case _ => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 4bc071155012b..4a23e9766fc5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2477,6 +2477,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE summary = getSummary(context)) } + def scalarSubqueryReturnsMultipleRows(): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + messageParameters = Map.empty) + } + def comparatorReturnsNull(firstValue: String, secondValue: String): Throwable = { new SparkException( errorClass = "COMPARATOR_RETURNS_NULL", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6c3e9bac1cfe5..4d0930212b373 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -5090,6 +5090,15 @@ object SQLConf { .booleanConf .createWithDefault(true) + val SCALAR_SUBQUERY_USE_SINGLE_JOIN = + buildConf("spark.sql.optimizer.scalarSubqueryUseSingleJoin") + .internal() + .doc("When set to true, use LEFT_SINGLE join for correlated scalar subqueries where " + + "optimizer can't prove that only 1 row will be returned") + .version("4.0.0") + .booleanConf + .createWithDefault(true) + val ALLOW_SUBQUERY_EXPRESSIONS_IN_LAMBDAS_AND_HIGHER_ORDER_FUNCTIONS = buildConf("spark.sql.analyzer.allowSubqueryExpressionsInLambdasOrHigherOrderFunctions") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index aee735e48fc5c..53c335c1eced6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -269,8 +269,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + def canMerge(joinType: JoinType): Boolean = joinType match { + case LeftSingle => false + case _ => true + } + def createSortMergeJoin() = { - if (RowOrdering.isOrderable(leftKeys)) { + if (canMerge(joinType) && RowOrdering.isOrderable(leftKeys)) { Some(Seq(joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, nonEquiCond, planLater(left), planLater(right)))) } else { @@ -297,7 +302,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This join could be very slow or OOM // Build the smaller side unless the join requires a particular build side // (e.g. NO_BROADCAST_AND_REPLICATION hint) - val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint) + val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint, joinType) val buildSide = requiredBuildSide.getOrElse(getSmallerSide(left, right)) Seq(joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, j.condition)) @@ -390,7 +395,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This join could be very slow or OOM // Build the desired side unless the join requires a particular build side // (e.g. NO_BROADCAST_AND_REPLICATION hint) - val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint) + val requiredBuildSide = getBroadcastNestedLoopJoinBuildSide(hint, joinType) val buildSide = requiredBuildSide.getOrElse(desiredBuildSide) Seq(joins.BroadcastNestedLoopJoinExec( planLater(left), planLater(right), buildSide, joinType, condition)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala index df4d895867586..5f2638655c37c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala @@ -30,7 +30,7 @@ case class PlanAdaptiveSubqueries( def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressionsWithPruning( _.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) { - case expressions.ScalarSubquery(_, _, exprId, _, _, _) => + case expressions.ScalarSubquery(_, _, exprId, _, _, _, _) => val subquery = SubqueryExec.createForScalarSubquery( s"subquery#${exprId.id}", subqueryMap(exprId.id)) execution.ScalarSubquery(subquery, exprId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala index 6dd41aca3a5e1..a7292ee1f8fa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoinExec.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ArrayImplicits._ @@ -63,13 +64,15 @@ case class BroadcastNestedLoopJoinExec( override def outputPartitioning: Partitioning = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => streamed.outputPartitioning + (LeftSingle, BuildRight) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => + streamed.outputPartitioning case _ => super.outputPartitioning } override def outputOrdering: Seq[SortOrder] = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => streamed.outputOrdering + (LeftSingle, BuildRight) | (LeftSemi, BuildRight) | (LeftAnti, BuildRight) => + streamed.outputOrdering case _ => Nil } @@ -87,7 +90,7 @@ case class BroadcastNestedLoopJoinExec( joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -135,8 +138,14 @@ case class BroadcastNestedLoopJoinExec( * * LeftOuter with BuildRight * RightOuter with BuildLeft + * LeftSingle with BuildRight + * + * For the (LeftSingle, BuildRight) case we pass 'singleJoin' flag that + * makes sure there is at most 1 matching build row per every probe tuple. */ - private def outerJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + private def outerJoin( + relation: Broadcast[Array[InternalRow]], + singleJoin: Boolean = false): RDD[InternalRow] = { streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value val joinedRow = new JoinedRow @@ -167,6 +176,9 @@ case class BroadcastNestedLoopJoinExec( resultRow = joinedRow(streamRow, buildRows(nextIndex)) nextIndex += 1 if (boundCondition(resultRow)) { + if (foundMatch && singleJoin) { + throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + } foundMatch = true return true } @@ -382,12 +394,18 @@ case class BroadcastNestedLoopJoinExec( innerJoin(broadcastedRelation) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) + case (LeftSingle, BuildRight) => + outerJoin(broadcastedRelation, singleJoin = true) case (LeftSemi, _) => leftExistenceJoin(broadcastedRelation, exists = true) case (LeftAnti, _) => leftExistenceJoin(broadcastedRelation, exists = false) case (_: ExistenceJoin, _) => existenceJoin(broadcastedRelation) + case (LeftSingle, BuildLeft) => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not use the left side as build when " + + s"executing a LeftSingle join") case _ => /** * LeftOuter with BuildLeft @@ -410,7 +428,7 @@ case class BroadcastNestedLoopJoinExec( override def supportCodegen: Boolean = (joinType, buildSide) match { case (_: InnerLike, _) | (LeftOuter, BuildRight) | (RightOuter, BuildLeft) | - (LeftSemi | LeftAnti, BuildRight) => true + (LeftSemi | LeftAnti, BuildRight) | (LeftSingle, BuildRight) => true case _ => false } @@ -428,6 +446,7 @@ case class BroadcastNestedLoopJoinExec( (joinType, buildSide) match { case (_: InnerLike, _) => codegenInner(ctx, input) case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => codegenOuter(ctx, input) + case (LeftSingle, BuildRight) => codegenOuter(ctx, input) case (LeftSemi, BuildRight) => codegenLeftExistence(ctx, input, exists = true) case (LeftAnti, BuildRight) => codegenLeftExistence(ctx, input, exists = false) case _ => @@ -473,7 +492,9 @@ case class BroadcastNestedLoopJoinExec( """.stripMargin } - private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def codegenOuter( + ctx: CodegenContext, + input: Seq[ExprCode]): String = { val (buildRowArray, buildRowArrayTerm) = prepareBroadcast(ctx) val (buildRow, checkCondition, _) = getJoinCondition(ctx, input, streamed, broadcast) val buildVars = genOneSideJoinVars(ctx, buildRow, broadcast, setDefaultValue = true) @@ -494,12 +515,23 @@ case class BroadcastNestedLoopJoinExec( |${consume(ctx, resultVars)} """.stripMargin } else { + // For LeftSingle joins, generate the check on the number of matches. + val evaluateSingleCheck = if (joinType == LeftSingle) { + s""" + |if ($foundMatch) { + | throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + |} + |""".stripMargin + } else { + "" + } s""" |boolean $foundMatch = false; |for (int $arrayIndex = 0; $arrayIndex < $buildRowArrayTerm.length; $arrayIndex++) { | UnsafeRow $buildRow = (UnsafeRow) $buildRowArrayTerm[$arrayIndex]; | boolean $shouldOutputRow = false; | $checkCondition { + | $evaluateSingleCheck | $shouldOutputRow = true; | $foundMatch = true; | } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5d59a48d544a0..ce7d48babc91e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType} @@ -52,7 +53,7 @@ trait HashJoin extends JoinCodegenSupport { joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output @@ -75,7 +76,7 @@ trait HashJoin extends JoinCodegenSupport { } case BuildRight => joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => left.outputPartitioning case x => throw new IllegalArgumentException( @@ -93,7 +94,7 @@ trait HashJoin extends JoinCodegenSupport { } case BuildRight => joinType match { - case _: InnerLike | LeftOuter | LeftSemi | LeftAnti | _: ExistenceJoin => + case _: InnerLike | LeftOuter | LeftSingle | LeftSemi | LeftAnti | _: ExistenceJoin => left.outputOrdering case x => throw new IllegalArgumentException( @@ -191,7 +192,8 @@ trait HashJoin extends JoinCodegenSupport { private def outerJoin( streamedIter: Iterator[InternalRow], - hashedRelation: HashedRelation): Iterator[InternalRow] = { + hashedRelation: HashedRelation, + singleJoin: Boolean = false): Iterator[InternalRow] = { val joinedRow = new JoinedRow() val keyGenerator = streamSideKeyGenerator() val nullRow = new GenericInternalRow(buildPlan.output.length) @@ -218,6 +220,9 @@ trait HashJoin extends JoinCodegenSupport { while (buildIter != null && buildIter.hasNext) { val nextBuildRow = buildIter.next() if (boundCondition(joinedRow.withRight(nextBuildRow))) { + if (found && singleJoin) { + throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + } found = true return true } @@ -329,6 +334,8 @@ trait HashJoin extends JoinCodegenSupport { innerJoin(streamedIter, hashed) case LeftOuter | RightOuter => outerJoin(streamedIter, hashed) + case LeftSingle => + outerJoin(streamedIter, hashed, singleJoin = true) case LeftSemi => semiJoin(streamedIter, hashed) case LeftAnti => @@ -354,7 +361,7 @@ trait HashJoin extends JoinCodegenSupport { override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { joinType match { case _: InnerLike => codegenInner(ctx, input) - case LeftOuter | RightOuter => codegenOuter(ctx, input) + case LeftOuter | RightOuter | LeftSingle => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) case LeftAnti => codegenAnti(ctx, input) case _: ExistenceJoin => codegenExistence(ctx, input) @@ -492,6 +499,17 @@ trait HashJoin extends JoinCodegenSupport { val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") + // For LeftSingle joins generate the check on the number of build rows that match every + // probe row. Return an error for >1 matches. + val evaluateSingleCheck = if (joinType == LeftSingle) { + s""" + |if ($found) { + | throw QueryExecutionErrors.scalarSubqueryReturnsMultipleRows(); + |} + |""".stripMargin + } else { + "" + } s""" |// generate join key for stream side @@ -505,6 +523,7 @@ trait HashJoin extends JoinCodegenSupport { | (UnsafeRow) $matches.next() : null; | ${checkCondition.trim} | if ($conditionPassed) { + | $evaluateSingleCheck | $found = true; | $numOutput.add(1); | ${consume(ctx, resultVars)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala index 7c4628c8576c5..60e5a7769a503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledJoin.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, RightOuter} +import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, FullOuter, InnerLike, LeftExistence, LeftOuter, LeftSingle, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning, PartitioningCollection, UnknownPartitioning, UnspecifiedDistribution} /** @@ -47,7 +47,7 @@ trait ShuffledJoin extends JoinCodegenSupport { override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - case LeftOuter => left.outputPartitioning + case LeftOuter | LeftSingle => left.outputPartitioning case RightOuter => right.outputPartitioning case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) case LeftExistence(_) => left.outputPartitioning @@ -60,7 +60,7 @@ trait ShuffledJoin extends JoinCodegenSupport { joinType match { case _: InnerLike => left.output ++ right.output - case LeftOuter => + case LeftOuter | LeftSingle => left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index bea91e09b0053..01de7beda551d 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -142,6 +142,12 @@ Project [x1#x, x2#x, scalar-subquery#x [x1#x && x2#x] AS scalarsubquery(x1, x2)# +- LocalRelation [col1#x, col2#x] +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(false)) + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis @@ -202,24 +208,83 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(true)) + + +-- !query +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 +-- !query analysis +Project [x1#x, x2#x] ++- Filter (scalar-subquery#x [x1#x] = cast(1 as bigint)) + : +- Aggregate [y1#x], [count(1) AS count(1)#xL] + : +- Filter (y1#x > outer(x1#x)) + : +- SubqueryAlias y + : +- View (`y`, [y1#x, y2#x]) + : +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] + : +- LocalRelation [col1#x, col2#x] + +- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL] +: +- Aggregate [y1#x], [count(1) AS count(1)#xL] +: +- Filter ((y1#x + y2#x) = outer(x1#x)) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query analysis +Project [x1#x, x2#x, scalar-subquery#x [x1#x && x1#x] AS scalarsubquery(x1, x1)#xL] +: +- Aggregate [y2#x], [count(1) AS count(1)#xL] +: +- Filter ((outer(x1#x) = y1#x) AND ((y2#x + 10) = (outer(x1#x) + 1))) +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query analysis -org.apache.spark.sql.catalyst.ExtendedAnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 106, - "fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)" - } ] -} +Project [x1#x, x2#x, scalar-subquery#x [x1#x] AS scalarsubquery(x1)#xL] +: +- Aggregate [y1#x], [count(1) AS count(1)#xL] +: +- SubqueryAlias sub +: +- Union false, false +: :- Project [y1#x, y2#x] +: : +- Filter (y1#x = outer(x1#x)) +: : +- SubqueryAlias y +: : +- View (`y`, [y1#x, y2#x]) +: : +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: : +- LocalRelation [col1#x, col2#x] +: +- Project [y1#x, y2#x] +: +- SubqueryAlias y +: +- View (`y`, [y1#x, y2#x]) +: +- Project [cast(col1#x as int) AS y1#x, cast(col2#x as int) AS y2#x] +: +- LocalRelation [col1#x, col2#x] ++- SubqueryAlias x + +- View (`x`, [x1#x, x2#x]) + +- Project [cast(col1#x as int) AS x1#x, cast(col2#x as int) AS x2#x] + +- LocalRelation [col1#x, col2#x] -- !query @@ -227,17 +292,17 @@ select *, (select count(*) from y left join (select * from z where z1 = x1) sub -- !query analysis org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", "sqlState" : "0A000", "messageParameters" : { - "value" : "z1" + "treeNode" : "Filter (z1#x = outer(x1#x))\n+- SubqueryAlias z\n +- View (`z`, [z1#x, z2#x])\n +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x]\n +- LocalRelation [col1#x, col2#x]\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 11, - "stopIndex" : 103, - "fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)" + "startIndex" : 46, + "stopIndex" : 74, + "fragment" : "select * from z where z1 = x1" } ] } @@ -248,6 +313,12 @@ set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = SetCommand (spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate,Some(true)) +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query analysis +SetCommand (spark.sql.optimizer.scalarSubqueryUseSingleJoin,Some(false)) + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index e3ce85fe5d209..4ff0222d6e965 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -1748,3 +1748,21 @@ Project [t1a#x, t1b#x, t1c#x] +- View (`t1`, [t1a#x, t1b#x, t1c#x]) +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] +- LocalRelation [col1#x, col2#x, col3#x] + + +-- !query +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a) +-- !query analysis +Project [t0a#x, t0b#x] ++- Filter (t0a#x = scalar-subquery#x [t0a#x]) + : +- Distinct + : +- Project [t1c#x] + : +- Filter (t1a#x = outer(t0a#x)) + : +- SubqueryAlias t1 + : +- View (`t1`, [t1a#x, t1b#x, t1c#x]) + : +- Project [cast(col1#x as int) AS t1a#x, cast(col2#x as int) AS t1b#x, cast(col3#x as int) AS t1c#x] + : +- LocalRelation [col1#x, col2#x, col3#x] + +- SubqueryAlias t0 + +- View (`t0`, [t0a#x, t0b#x]) + +- Project [cast(col1#x as int) AS t0a#x, cast(col2#x as int) AS t0b#x] + +- LocalRelation [col1#x, col2#x] diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql index db7cdc97614cb..a23083e9e0e4d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-group-by.sql @@ -22,16 +22,25 @@ select *, (select count(*) from y where x1 = y1 and cast(y2 as double) = x1 + 1 select *, (select count(*) from y where y2 + 1 = x1 + x2 group by y2 + 1) from x; --- Illegal queries +-- Illegal queries (single join disabled) +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false; select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x; +-- Same queries, with LeftSingle join +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true; +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x; +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x; + + -- Certain other operators like OUTER JOIN or UNION between the correlating filter and the group-by also can cause the scalar subquery to return multiple values and hence make the query illegal. select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x; select *, (select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1) from x; -- The correlation below the join is unsupported in Spark anyway, but when we do support it this query should still be disallowed. -- Test legacy behavior conf set spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate = true; +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false; select * from x where (select count(*) from y where y1 > x1 group by y1) = 1; reset spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql index 2823888e6e438..81e0c5f98d82b 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-predicate.sql @@ -529,3 +529,6 @@ FROM t1 WHERE (SELECT max(t2c) FROM t2 WHERE t1b = t2b ) between 1 and 2; + + +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a); diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out index 41cba1f43745f..56932edd4e545 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-group-by.sql.out @@ -112,6 +112,14 @@ struct 2 2 NULL +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin false + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema @@ -178,25 +186,56 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = true +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin true + + +-- !query +select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +select *, (select count(*) from y where y1 + y2 = x1 group by y1) from x +-- !query schema +struct<> +-- !query output +org.apache.spark.SparkRuntimeException +{ + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" +} + + +-- !query +select *, (select count(*) from y where x1 = y1 and y2 + 10 = x1 + 1 group by y2) from x +-- !query schema +struct +-- !query output +1 1 NULL +2 2 NULL + + -- !query select *, (select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1) from x -- !query schema struct<> -- !query output -org.apache.spark.sql.catalyst.ExtendedAnalysisException +org.apache.spark.SparkRuntimeException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", - "sqlState" : "0A000", - "messageParameters" : { - "value" : "y1" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 11, - "stopIndex" : 106, - "fragment" : "(select count(*) from (select * from y where y1 = x1 union all select * from y) sub group by y1)" - } ] + "errorClass" : "SCALAR_SUBQUERY_TOO_MANY_ROWS", + "sqlState" : "21000" } @@ -207,17 +246,17 @@ struct<> -- !query output org.apache.spark.sql.catalyst.ExtendedAnalysisException { - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_CORRELATED_COLUMNS_IN_GROUP_BY", + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED", "sqlState" : "0A000", "messageParameters" : { - "value" : "z1" + "treeNode" : "Filter (z1#x = outer(x1#x))\n+- SubqueryAlias z\n +- View (`z`, [z1#x, z2#x])\n +- Project [cast(col1#x as int) AS z1#x, cast(col2#x as int) AS z2#x]\n +- LocalRelation [col1#x, col2#x]\n" }, "queryContext" : [ { "objectType" : "", "objectName" : "", - "startIndex" : 11, - "stopIndex" : 103, - "fragment" : "(select count(*) from y left join (select * from z where z1 = x1) sub on y2 = z2 group by z1)" + "startIndex" : 46, + "stopIndex" : 74, + "fragment" : "select * from z where z1 = x1" } ] } @@ -230,6 +269,14 @@ struct spark.sql.legacy.scalarSubqueryAllowGroupByNonEqualityCorrelatedPredicate true +-- !query +set spark.sql.optimizer.scalarSubqueryUseSingleJoin = false +-- !query schema +struct +-- !query output +spark.sql.optimizer.scalarSubqueryUseSingleJoin false + + -- !query select * from x where (select count(*) from y where y1 > x1 group by y1) = 1 -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out index a02f0c70be6da..2460c2452ea56 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-predicate.sql.out @@ -906,3 +906,11 @@ WHERE (SELECT max(t2c) struct -- !query output + + +-- !query +SELECT * FROM t0 WHERE t0a = (SELECT distinct(t1c) FROM t1 WHERE t1a = t0a) +-- !query schema +struct +-- !query output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index 9afba65183974..a892cd4db02b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import org.scalactic.source.Position import org.scalatest.Tag +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, ExpressionSet} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.Aggregate @@ -554,7 +555,15 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { | FROM (SELECT dept * 2.0 AS id, id + 1 AS id2 FROM $testTable)) > 5 |ORDER BY id |""".stripMargin - withLCAOff { intercept[AnalysisException] { sql(query4) } } + withLCAOff { + val exception = intercept[SparkRuntimeException] { + sql(query4).collect() + } + checkError( + exception, + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + } withLCAOn { val analyzedPlan = sql(query4).queryExecution.analyzed assert(!analyzedPlan.containsPattern(OUTER_REFERENCE)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 23c4d51983bb4..6e160b4407ca8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import scala.collection.mutable.ArrayBuffer +import org.apache.spark.SparkRuntimeException import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LogicalPlan, Project, Sort, Union} @@ -527,43 +528,30 @@ class SubquerySuite extends QueryTest test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not permitted") { withTempView("v") { Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("v") - - val exception = intercept[AnalysisException] { - sql("select (select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2) sum from v t1") + val exception = intercept[SparkRuntimeException] { + sql("select (select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2) sum from v t1"). + collect() } checkError( exception, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "NON_CORRELATED_COLUMNS_IN_GROUP_BY", - parameters = Map("value" -> "c2"), - sqlState = None, - context = ExpectedContext( - fragment = "(select sum(-1) from v t2 where t1.c2 = t2.c1 group by t2.c2)", - start = 7, stop = 67)) } + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + } } test("non-aggregated correlated scalar subquery") { - val exception1 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1") + val exception1 = intercept[SparkRuntimeException] { + sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1").collect() } checkError( exception1, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - parameters = Map.empty, - context = ExpectedContext( - fragment = "(select b from l l2 where l2.a = l1.a)", start = 10, stop = 47)) - val exception2 = intercept[AnalysisException] { - sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1") - } - checkErrorMatchPVals( - exception2, - condition = "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + - "MUST_AGGREGATE_CORRELATED_SCALAR_SUBQUERY", - parameters = Map.empty[String, String], - sqlState = None, - context = ExpectedContext( - fragment = "(select b from l l2 where l2.a = l1.a group by 1)", start = 10, stop = 58)) + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS" + ) + checkAnswer( + sql("select a, (select b from l l2 where l2.a = l1.a group by 1) sum_b from l l1"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, null) :: Row(6, null) :: Nil + ) } test("non-equal correlated scalar subquery") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala new file mode 100644 index 0000000000000..a318769af6871 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SingleJoinSuite.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.SparkRuntimeException +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.BuildRight +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, Project} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class SingleJoinSuite extends SparkPlanTest with SharedSparkSession { + import testImplicits.toRichColumn + + private val EnsureRequirements = new EnsureRequirements() + + private lazy val left = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + // (a > c && a != 6) + + private lazy val right = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(4, 2.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val singleConditionEQ = EqualTo(left.col("a").expr, right.col("c").expr) + + private lazy val nonEqualityCond = And(GreaterThan(left.col("a").expr, right.col("c").expr), + Not(EqualTo(left.col("a").expr, Literal(6)))) + + + + private def testSingleJoin( + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Option[Expression], + expectedAnswer: Seq[Row], + expectError: Boolean = false): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, + Inner, condition, JoinHint.NONE) + ExtractEquiJoinKeys.unapply(join) + } + + def checkSingleJoinError(planFunction: (SparkPlan, SparkPlan) => SparkPlan): Unit = { + val outputPlan = planFunction(leftRows.queryExecution.sparkPlan, + rightRows.queryExecution.sparkPlan) + checkError( + exception = intercept[SparkRuntimeException] { + SparkPlanTest.executePlan(outputPlan, spark.sqlContext) + }, + condition = "SCALAR_SUBQUERY_TOO_MANY_ROWS", + parameters = Map.empty + ) + } + + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply(BroadcastHashJoinExec( + leftKeys, rightKeys, LeftSingle, BuildRight, boundCondition, left, right)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + testWithWholeStageCodegenOnAndOff(s"$testName using ShuffledHashJoin") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + ShuffledHashJoinExec( + leftKeys, rightKeys, LeftSingle, BuildRight, boundCondition, left, right)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastNestedLoopJoin") { _ => + val planFunction = (left: SparkPlan, right: SparkPlan) => + EnsureRequirements.apply( + BroadcastNestedLoopJoinExec(left, right, BuildRight, LeftSingle, condition)) + if (expectError) { + checkSingleJoinError(planFunction) + } else { + checkAnswer2(leftRows, rightRows, planFunction, + expectedAnswer, + sortAnswers = true) + } + } + } + + testSingleJoin( + "test single condition (equal) for a left single join", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(singleConditionEQ), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, 2), + Row(2, 1.0, 2), + Row(3, 3.0, 3), + Row(6, null, 6), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "test single condition (equal) for a left single join -- multiple matches", + left, + Project(Seq(right.col("d").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(EqualTo(left.col("b").expr, right.col("d").expr)), + Seq.empty, true) + + testSingleJoin( + "test non-equality for a left single join", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(nonEqualityCond), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, null), + Row(2, 1.0, null), + Row(3, 3.0, 2), + Row(6, null, null), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "test non-equality for a left single join -- multiple matches", + left, + Project(Seq(right.col("c").expr.asInstanceOf[NamedExpression]), right.logicalPlan), + Some(GreaterThan(left.col("a").expr, right.col("c").expr)), + Seq.empty, expectError = true) + + private lazy val emptyFrame = spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], new StructType().add("c", IntegerType).add("d", DoubleType)) + + testSingleJoin( + "empty inner (right) side", + left, + Project(Seq(emptyFrame.col("c").expr.asInstanceOf[NamedExpression]), emptyFrame.logicalPlan), + Some(GreaterThan(left.col("a").expr, emptyFrame.col("c").expr)), + Seq(Row(1, 2.0, null), + Row(1, 2.0, null), + Row(2, 1.0, null), + Row(2, 1.0, null), + Row(3, 3.0, null), + Row(6, null, null), + Row(null, 5.0, null), + Row(null, null, null))) + + testSingleJoin( + "empty outer (left) side", + Project(Seq(emptyFrame.col("c").expr.asInstanceOf[NamedExpression]), emptyFrame.logicalPlan), + right, + Some(EqualTo(emptyFrame.col("c").expr, right.col("c").expr)), + Seq.empty) +} From d2e8c1cb60e34a1c7e92374c07d682aa5ca79145 Mon Sep 17 00:00:00 2001 From: Julek Sompolski Date: Mon, 23 Sep 2024 12:39:02 +0900 Subject: [PATCH 052/250] [SPARK-48195][CORE] Save and reuse RDD/Broadcast created by SparkPlan ### What changes were proposed in this pull request? Save the RDD created by doExecute, instead of creating a new one in execute each time. Currently, many types of SparkPlans already save the RDD they create. For example, shuffle just save `lazy val inputRDD: RDD[InternalRow] = child.execute()`. It creates inconsistencies when an action (e.g. repeated `df.collect()`) is executed on Dataframe twice: * The SparkPlan will be reused, since the same `df.queryExecution.executedPlan` will be used. * Any not-result stage will be reused, as the shuffle operators will just have their `inputRDD` reused. * However, for result stage, `execute()` will call `doExecute()` again, and the logic of generating the actual execution RDD will be reexecuted for the result stage. This means that for example for the result stage, WSCG code gen will generate and compile new code, create a new RDD out of it. Generation of execution RDDs is also often influenced by config: for example, staying with WSCG, various configs like `spark.sql.codegen.hugeMethodLimit` or `spark.sql.codegen.methodSplitThreshold`. The fact that upon re-execution this will be evaluated anew for the result stage, but not for earlier stages creates inconsistencies in what config changes are visible. By saving the result of `doExecute` and reusing the RDD in `execute` we make sure that work in creating that RDD is not duplicated, and it is more consistent that all RDDs of the plan are reused, same as with the `executedPlan`. Note, that while the results of earlier shuffle stages are also reused, the result stage still does get executed again, as the result of it are not saved and available for Reuse in BlockManager. We also add a `Lazy` utility instead of using `lazy val` to deal with shortcomings of scala lazy val. ### Why are the changes needed? Resolved subtle inconsistencies coming from object reuse vs. recreating objects from scratch. ### Does this PR introduce _any_ user-facing change? Subtle changes caused by the RDD being reused, e.g. when a config change might be picked up. However, it makes things more consistent. Spark 4.0.0 might be a good candidate for making such a change. ### How was this patch tested? Existing SQL execution tests validate that the change in SparkPlan works. Tests were added for the new Lazy utility. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Github Copilot (trivial code completion suggestions) Closes #48037 from juliuszsompolski/SPARK-48195-rdd. Lead-authored-by: Julek Sompolski Co-authored-by: Hyukjin Kwon Co-authored-by: Wenchen Fan Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/util/LazyTry.scala | 70 ++++++++ .../scala/org/apache/spark/util/Utils.scala | 80 ++++++++++ .../org/apache/spark/util/LazyTrySuite.scala | 151 ++++++++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 112 ++++++++++++- .../sql/execution/CollectMetricsExec.scala | 5 + .../spark/sql/execution/SparkPlan.scala | 21 ++- .../columnar/InMemoryTableScanExec.scala | 82 +++++----- .../exchange/ShuffleExchangeExec.scala | 13 +- 8 files changed, 475 insertions(+), 59 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/LazyTry.scala create mode 100644 core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala diff --git a/core/src/main/scala/org/apache/spark/util/LazyTry.scala b/core/src/main/scala/org/apache/spark/util/LazyTry.scala new file mode 100644 index 0000000000000..7edc08672c26b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/LazyTry.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.util.Try + +/** + * Wrapper utility for a lazy val, with two differences compared to scala behavior: + * + * 1. Non-retrying in case of failure. This wrapper stores the exception in a Try, and will re-throw + * it on the access to `get`. + * In scala, when a `lazy val` field initialization throws an exception, the field remains + * uninitialized, and initialization will be re-attempted on the next access. This also can lead + * to performance issues, needlessly computing something towards a failure, and also can lead to + * duplicated side effects. + * + * 2. Resolving locking issues. + * In scala, when a `lazy val` field is initialized, it grabs the synchronized lock on the + * enclosing object instance. This can lead both to performance issues, and deadlocks. + * For example: + * a) Thread 1 entered a synchronized method, grabbing a coarse lock on the parent object. + * b) Thread 2 get spawned off, and tries to initialize a lazy value on the same parent object + * This causes scala to also try to grab a lock on the parent object. + * c) If thread 1 waits for thread 2 to join, a deadlock occurs. + * This wrapper will only grab a lock on the wrapper itself, and not the parent object. + * + * @param initialize The block of code to initialize the lazy value. + * @tparam T type of the lazy value. + */ +private[spark] class LazyTry[T](initialize: => T) extends Serializable { + private lazy val tryT: Try[T] = Utils.doTryWithCallerStacktrace { initialize } + + /** + * Get the lazy value. If the initialization block threw an exception, it will be re-thrown here. + * The exception will be re-thrown with the current caller's stacktrace. + * An exception with stack trace from when the exception was first thrown can be accessed with + * ``` + * ex.getSuppressed.find { e => + * e.getMessage == org.apache.spark.util.Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE + * } + * ``` + */ + def get: T = Utils.getTryWithCallerStacktrace(tryT) +} + +private[spark] object LazyTry { + /** + * Create a new LazyTry instance. + * + * @param initialize The block of code to initialize the lazy value. + * @tparam T type of the lazy value. + * @return a new LazyTry instance. + */ + def apply[T](initialize: => T): LazyTry[T] = new LazyTry(initialize) +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index d8392cd8043de..52213f36a2cd1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1351,6 +1351,86 @@ private[spark] object Utils } } + val TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE = + "Full stacktrace of original doTryWithCallerStacktrace caller" + + val TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE = + "Stacktrace under doTryWithCallerStacktrace" + + /** + * Use Try with stacktrace substitution for the caller retrieving the error. + * + * Normally in case of failure, the exception would have the stacktrace of the caller that + * originally called doTryWithCallerStacktrace. However, we want to replace the part above + * this function with the stacktrace of the caller who calls getTryWithCallerStacktrace. + * So here we save the part of the stacktrace below doTryWithCallerStacktrace, and + * getTryWithCallerStacktrace will stitch it with the new stack trace of the caller. + * The full original stack trace is kept in ex.getSuppressed. + * + * @param f Code block to be wrapped in Try + * @return Try with Success or Failure of the code block. Use with getTryWithCallerStacktrace. + */ + def doTryWithCallerStacktrace[T](f: => T): Try[T] = { + val t = Try { + f + } + t match { + case Failure(ex) => + // Note: we remove the common suffix instead of e.g. finding the call to this function, to + // account for recursive calls with multiple doTryWithCallerStacktrace on the stack trace. + val origStackTrace = ex.getStackTrace + val currentStackTrace = Thread.currentThread().getStackTrace + val commonSuffixLen = origStackTrace.reverse.zip(currentStackTrace.reverse).takeWhile { + case (exElem, currentElem) => exElem == currentElem + }.length + val belowEx = new Exception(TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE) + belowEx.setStackTrace(origStackTrace.dropRight(commonSuffixLen)) + ex.addSuppressed(belowEx) + + // keep the full original stack trace in a suppressed exception. + val fullEx = new Exception(TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE) + fullEx.setStackTrace(origStackTrace) + ex.addSuppressed(fullEx) + case Success(_) => // nothing + } + t + } + + /** + * Retrieve the result of Try that was created by doTryWithCallerStacktrace. + * + * In case of failure, the resulting exception has a stack trace that combines the stack trace + * below the original doTryWithCallerStacktrace which triggered it, with the caller stack trace + * of the current caller of getTryWithCallerStacktrace. + * + * Full stack trace of the original doTryWithCallerStacktrace caller can be retrieved with + * ``` + * ex.getSuppressed.find { e => + * e.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE + * } + * ``` + * + * + * @param t Try from doTryWithCallerStacktrace + * @return Result of the Try or rethrows the failure exception with modified stacktrace. + */ + def getTryWithCallerStacktrace[T](t: Try[T]): T = t match { + case Failure(ex) => + val belowStacktrace = ex.getSuppressed.find { e => + // added in doTryWithCallerStacktrace + e.getMessage == TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE + }.getOrElse { + // If we don't have the expected stacktrace information, just rethrow + throw ex + }.getStackTrace + // We are modifying and throwing the original exception. It would be better if we could + // return a copy, but we can't easily clone it and preserve. If this is accessed from + // multiple threads that then look at the stack trace, this could break. + ex.setStackTrace(belowStacktrace ++ Thread.currentThread().getStackTrace.drop(1)) + throw ex + case Success(s) => s + } + // A regular expression to match classes of the internal Spark API's // that we want to skip when finding the call site of a method. private val SPARK_CORE_CLASS_REGEX = diff --git a/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala b/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala new file mode 100644 index 0000000000000..79c07f8fbfead --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/LazyTrySuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class LazyTrySuite extends SparkFunSuite{ + test("LazyTry should initialize only once") { + var count = 0 + val lazyVal = LazyTry { + count += 1 + count + } + assert(count == 0) + assert(lazyVal.get == 1) + assert(count == 1) + assert(lazyVal.get == 1) + assert(count == 1) + } + + test("LazyTry should re-throw exceptions") { + val lazyVal = LazyTry { + throw new RuntimeException("test") + } + intercept[RuntimeException] { + lazyVal.get + } + intercept[RuntimeException] { + lazyVal.get + } + } + + test("LazyTry should re-throw exceptions with current caller stack-trace") { + val fileName = Thread.currentThread().getStackTrace()(1).getFileName + val lineNo = Thread.currentThread().getStackTrace()(1).getLineNumber + val lazyVal = LazyTry { + throw new RuntimeException("test") + } + + val e1 = intercept[RuntimeException] { + lazyVal.get // lineNo + 6 + } + assert(e1.getStackTrace + .exists(elem => elem.getFileName == fileName && elem.getLineNumber == lineNo + 6)) + + val e2 = intercept[RuntimeException] { + lazyVal.get // lineNo + 12 + } + assert(e2.getStackTrace + .exists(elem => elem.getFileName == fileName && elem.getLineNumber == lineNo + 12)) + } + + test("LazyTry does not lock containing object") { + class LazyContainer() { + @volatile var aSet = 0 + + val a: LazyTry[Int] = LazyTry { + aSet = 1 + aSet + } + + val b: LazyTry[Int] = LazyTry { + val t = new Thread(new Runnable { + override def run(): Unit = { + assert(a.get == 1) + } + }) + t.start() + t.join() + aSet + } + } + val container = new LazyContainer() + // Nothing is lazy initialized yet + assert(container.aSet == 0) + // This will not deadlock, thread t will initialize a, and update aSet + assert(container.b.get == 1) + assert(container.aSet == 1) + } + + // Scala lazy val tests are added to test for potential changes in the semantics of scala lazy val + + test("Scala lazy val initializing multiple times on error") { + class LazyValError() { + var counter = 0 + lazy val a = { + counter += 1 + throw new RuntimeException("test") + } + } + val lazyValError = new LazyValError() + intercept[RuntimeException] { + lazyValError.a + } + assert(lazyValError.counter == 1) + intercept[RuntimeException] { + lazyValError.a + } + assert(lazyValError.counter == 2) + } + + test("Scala lazy val locking containing object and deadlocking") { + // Note: this will change in scala 3, with different lazy vals not deadlocking with each other. + // https://docs.scala-lang.org/scala3/reference/changed-features/lazy-vals-init.html + class LazyValContainer() { + @volatile var aSet = 0 + @volatile var t: Thread = _ + + lazy val a = { + aSet = 1 + aSet + } + + lazy val b = { + t = new Thread(new Runnable { + override def run(): Unit = { + assert(a == 1) + } + }) + t.start() + t.join(1000) + aSet + } + } + val container = new LazyValContainer() + // Nothing is lazy initialized yet + assert(container.aSet == 0) + // This will deadlock, because b will take monitor on LazyValContainer, and then thread t + // will wait on that monitor, not able to initialize a. + // b will therefore see aSet == 0. + assert(container.b == 0) + // However, after b finishes initializing, the monitor will be released, and then thread t + // will finish initializing a, and set aSet to 1. + container.t.join() + assert(container.aSet == 1) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4fe6fcf17f49f..a694e08def89c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -28,7 +28,7 @@ import java.util.concurrent.TimeUnit import java.util.zip.GZIPOutputStream import scala.collection.mutable.ListBuffer -import scala.util.Random +import scala.util.{Random, Try} import com.google.common.io.Files import org.apache.commons.io.IOUtils @@ -1523,6 +1523,116 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { conf.set(SERIALIZER, "org.apache.spark.serializer.JavaSerializer") assert(Utils.isPushBasedShuffleEnabled(conf, isDriver = true) === false) } + + + private def throwException(): String = { + throw new Exception("test") + } + + private def callDoTry(): Try[String] = { + Utils.doTryWithCallerStacktrace { + throwException() + } + } + + private def callGetTry(t: Try[String]): String = { + Utils.getTryWithCallerStacktrace(t) + } + + private def callGetTryAgain(t: Try[String]): String = { + Utils.getTryWithCallerStacktrace(t) + } + + test("doTryWithCallerStacktrace and getTryWithCallerStacktrace") { + val t = callDoTry() + + val e1 = intercept[Exception] { + callGetTry(t) + } + // Uncomment for manual inspection + // e1.printStackTrace() + // Example: + // java.lang.Exception: test + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.Utils$.getTryWithCallerStacktrace(Utils.scala:1639) + // at org.apache.spark.util.UtilsSuite.callGetTry(UtilsSuite.scala:1650) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1661) + // <- callGetTry is seen as calling getTryWithCallerStacktrace + + val st1 = e1.getStackTrace + // throwException should be on the stack trace + assert(st1.exists(_.getMethodName == "throwException")) + // callDoTry shouldn't be on the stack trace, but callGetTry should be. + assert(!st1.exists(_.getMethodName == "callDoTry")) + assert(st1.exists(_.getMethodName == "callGetTry")) + + // The original stack trace with callDoTry should be in the suppressed exceptions. + // Example: + // scalastyle:off line.size.limit + // Suppressed: java.lang.Exception: Full stacktrace of original doTryWithCallerStacktrace caller + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.UtilsSuite.callDoTry(UtilsSuite.scala:1645) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1658) + // ... 56 more + // scalastyle:on line.size.limit + val origSt = e1.getSuppressed.find( + _.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_FULL_STACKTRACE) + assert(origSt.isDefined) + assert(origSt.get.getStackTrace.exists(_.getMethodName == "throwException")) + assert(origSt.get.getStackTrace.exists(_.getMethodName == "callDoTry")) + + // The stack trace under Try should be in the suppressed exceptions. + // Example: + // Suppressed: java.lang.Exception: Stacktrace under doTryWithCallerStacktrace + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala: 1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala: 1645) + // at scala.util.Try$.apply(Try.scala: 213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala: 1586) + val trySt = e1.getSuppressed.find( + _.getMessage == Utils.TRY_WITH_CALLER_STACKTRACE_TRY_STACKTRACE) + assert(trySt.isDefined) + // calls under callDoTry should be present. + assert(trySt.get.getStackTrace.exists(_.getMethodName == "throwException")) + // callDoTry should be removed. + assert(!trySt.get.getStackTrace.exists(_.getMethodName == "callDoTry")) + + val e2 = intercept[Exception] { + callGetTryAgain(t) + } + // Uncomment for manual inspection + // e2.printStackTrace() + // Example: + // java.lang.Exception: test + // at org.apache.spark.util.UtilsSuite.throwException(UtilsSuite.scala:1640) + // at org.apache.spark.util.UtilsSuite.$anonfun$callDoTry$1(UtilsSuite.scala:1645) + // at scala.util.Try$.apply(Try.scala:213) + // at org.apache.spark.util.Utils$.doTryWithCallerStacktrace(Utils.scala:1586) + // at org.apache.spark.util.Utils$.getTryWithCallerStacktrace(Utils.scala:1639) + // at org.apache.spark.util.UtilsSuite.callGetTryAgain(UtilsSuite.scala:1654) + // at org.apache.spark.util.UtilsSuite.$anonfun$new$165(UtilsSuite.scala:1711) + // <- callGetTryAgain is seen as calling getTryWithCallerStacktrace + + val st2 = e2.getStackTrace + // throwException should be on the stack trace + assert(st2.exists(_.getMethodName == "throwException")) + // callDoTry shouldn't be on the stack trace, but callGetTryAgain should be. + assert(!st2.exists(_.getMethodName == "callDoTry")) + assert(st2.exists(_.getMethodName == "callGetTryAgain")) + // callGetTry that we called before shouldn't be on the stack trace. + assert(!st2.exists(_.getMethodName == "callGetTry")) + + // Unfortunately, this utility is not able to clone the exception, but modifies it in place, + // so now e1 is also pointing to "callGetTryAgain" instead of "callGetTry". + val st1Again = e1.getStackTrace + assert(st1Again.exists(_.getMethodName == "callGetTryAgain")) + assert(!st1Again.exists(_.getMethodName == "callGetTry")) + } } private class SimpleExtension diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index dc918e51d0550..2115e21f81d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -60,6 +60,11 @@ case class CollectMetricsExec( override def outputOrdering: Seq[SortOrder] = child.outputOrdering + override def resetMetrics(): Unit = { + accumulator.reset() + super.resetMetrics() + } + override protected def doExecute(): RDD[InternalRow] = { val collector = accumulator collector.reset() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 7bc770a0c9e33..fb3ec3ad41812 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.datasources.WriteFilesSpec import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.spark.util.NextIterator +import org.apache.spark.util.{LazyTry, NextIterator} import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} object SparkPlan { @@ -182,6 +182,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) + @transient + private val executeRDD = LazyTry { + doExecute() + } + /** * Returns the result of this query as an RDD[InternalRow] by delegating to `doExecute` after * preparations. @@ -192,7 +197,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (isCanonicalizedPlan) { throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") } - doExecute() + executeRDD.get + } + + private val executeBroadcastBcast = LazyTry { + doExecuteBroadcast() } /** @@ -205,7 +214,11 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (isCanonicalizedPlan) { throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") } - doExecuteBroadcast() + executeBroadcastBcast.get.asInstanceOf[broadcast.Broadcast[T]] + } + + private val executeColumnarRDD = LazyTry { + doExecuteColumnar() } /** @@ -219,7 +232,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ if (isCanonicalizedPlan) { throw SparkException.internalError("A canonicalized plan is not supposed to be executed.") } - doExecuteColumnar() + executeColumnarRDD.get } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index cfcfd282e5480..cbd60804b27e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -99,48 +99,6 @@ case class InMemoryTableScanExec( relation.cacheBuilder.serializer.supportsColumnarOutput(relation.schema) } - private lazy val columnarInputRDD: RDD[ColumnarBatch] = { - val numOutputRows = longMetric("numOutputRows") - val buffers = filteredCachedBatches() - relation.cacheBuilder.serializer.convertCachedBatchToColumnarBatch( - buffers, - relation.output, - attributes, - conf).map { cb => - numOutputRows += cb.numRows() - cb - } - } - - private lazy val inputRDD: RDD[InternalRow] = { - if (enableAccumulatorsForTest) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - - val numOutputRows = longMetric("numOutputRows") - // Using these variables here to avoid serialization of entire objects (if referenced - // directly) within the map Partitions closure. - val relOutput = relation.output - val serializer = relation.cacheBuilder.serializer - - // update SQL metrics - val withMetrics = - filteredCachedBatches().mapPartitionsInternal { iter => - if (enableAccumulatorsForTest && iter.hasNext) { - readPartitions.add(1) - } - iter.map { batch => - if (enableAccumulatorsForTest) { - readBatches.add(1) - } - numOutputRows += batch.numRows - batch - } - } - serializer.convertCachedBatchToInternalRow(withMetrics, relOutput, attributes, conf) - } - override def output: Seq[Attribute] = attributes private def cachedPlan = relation.cachedPlan match { @@ -191,11 +149,47 @@ case class InMemoryTableScanExec( } protected override def doExecute(): RDD[InternalRow] = { - inputRDD + // Resulting RDD is cached and reused by SparkPlan.executeRDD + if (enableAccumulatorsForTest) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + val numOutputRows = longMetric("numOutputRows") + // Using these variables here to avoid serialization of entire objects (if referenced + // directly) within the map Partitions closure. + val relOutput = relation.output + val serializer = relation.cacheBuilder.serializer + + // update SQL metrics + val withMetrics = + filteredCachedBatches().mapPartitionsInternal { iter => + if (enableAccumulatorsForTest && iter.hasNext) { + readPartitions.add(1) + } + iter.map { batch => + if (enableAccumulatorsForTest) { + readBatches.add(1) + } + numOutputRows += batch.numRows + batch + } + } + serializer.convertCachedBatchToInternalRow(withMetrics, relOutput, attributes, conf) } protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - columnarInputRDD + // Resulting RDD is cached and reused by SparkPlan.executeColumnarRDD + val numOutputRows = longMetric("numOutputRows") + val buffers = filteredCachedBatches() + relation.cacheBuilder.serializer.convertCachedBatchToColumnarBatch( + buffers, + relation.output, + attributes, + conf).map { cb => + numOutputRows += cb.numRows() + cb + } } override def isMaterialized: Boolean = relation.cacheBuilder.isCachedColumnBuffersLoaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 90f00a5035e15..ae11229cd516e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -249,17 +249,10 @@ case class ShuffleExchangeExec( dep } - /** - * Caches the created ShuffleRowRDD so we can reuse that. - */ - private var cachedShuffleRDD: ShuffledRowRDD = null - protected override def doExecute(): RDD[InternalRow] = { - // Returns the same ShuffleRowRDD if this plan is used by multiple plans. - if (cachedShuffleRDD == null) { - cachedShuffleRDD = new ShuffledRowRDD(shuffleDependency, readMetrics) - } - cachedShuffleRDD + // The ShuffleRowRDD will be cached in SparkPlan.executeRDD and reused if this plan is used by + // multiple plans. + new ShuffledRowRDD(shuffleDependency, readMetrics) } override protected def withNewChildInternal(newChild: SparkPlan): ShuffleExchangeExec = From 44ec70f5103fc5674497373ac5c23e8145ae5660 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 23 Sep 2024 18:28:19 +0800 Subject: [PATCH 053/250] [SPARK-49626][PYTHON][CONNECT] Support horizontal and vertical bar plots ### What changes were proposed in this pull request? Support horizontal and vertical bar plots with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. ```python >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] >>> columns = ["category", "int_val", "float_val"] >>> sdf = spark.createDataFrame(data, columns) >>> sdf.show() +--------+-------+---------+ |category|int_val|float_val| +--------+-------+---------+ | A| 10| 1.5| | B| 30| 2.5| | C| 20| 3.5| +--------+-------+---------+ >>> f = sdf.plot(kind="bar", x="category", y=["int_val", "float_val"]) >>> f.show() # see below >>> g = sdf.plot.barh(x=["int_val", "float_val"], y="category") >>> g.show() # see below ``` `f.show()`: ![newplot (4)](https://github.com/user-attachments/assets/0df9ee86-fb48-4796-b6c3-aaf2879217aa) `g.show()`: ![newplot (3)](https://github.com/user-attachments/assets/f39b01c3-66e6-464b-b2e8-badebb39bc67) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48100 from xinrong-meng/plot_bar. Authored-by: Xinrong Meng Signed-off-by: Xinrong Meng --- python/pyspark/sql/plot/core.py | 79 +++++++++++++++++++ .../sql/tests/plot/test_frame_plot_plotly.py | 44 +++++++++-- 2 files changed, 117 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 392ef73b38845..ed22d02370ca6 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -75,6 +75,8 @@ def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": class PySparkPlotAccessor: plot_data_map = { + "bar": PySparkTopNPlotBase().get_top_n, + "barh": PySparkTopNPlotBase().get_top_n, "line": PySparkSampledPlotBase().get_sampled, } _backends = {} # type: ignore[var-annotated] @@ -133,3 +135,80 @@ def line(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": >>> df.plot.line(x="category", y=["int_val", "float_val"]) # doctest: +SKIP """ return self(kind="line", x=x, y=y, **kwargs) + + def bar(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Vertical bar plot. + + A bar plot is a plot that presents categorical data with rectangular bars with lengths + proportional to the values that they represent. A bar plot shows comparisons among + discrete categories. One axis of the plot shows the specific categories being compared, + and the other axis represents a measured value. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. + Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.bar(x="category", y="int_val") # doctest: +SKIP + >>> df.plot.bar(x="category", y=["int_val", "float_val"]) # doctest: +SKIP + """ + return self(kind="bar", x=x, y=y, **kwargs) + + def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": + """ + Make a horizontal bar plot. + + A horizontal bar plot is a plot that presents quantitative data with + rectangular bars with lengths proportional to the values that they + represent. A bar plot shows comparisons among discrete categories. One + axis of the plot shows the specific categories being compared, and the + other axis represents a measured value. + + Parameters + ---------- + x : str or list of str + Name(s) of the column(s) to use for the horizontal axis. + Multiple columns can be plotted. + y : str or list of str + Name(s) of the column(s) to use for the vertical axis. + Multiple columns can be plotted. + **kwargs : optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Notes + ----- + In Plotly and Matplotlib, the interpretation of `x` and `y` for `barh` plots differs. + In Plotly, `x` refers to the values and `y` refers to the categories. + In Matplotlib, `x` refers to the categories and `y` refers to the values. + Ensure correct axis labeling based on the backend used. + + Examples + -------- + >>> data = [("A", 10, 1.5), ("B", 30, 2.5), ("C", 20, 3.5)] + >>> columns = ["category", "int_val", "float_val"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.barh(x="int_val", y="category") # doctest: +SKIP + >>> df.plot.barh( + ... x=["int_val", "float_val"], y="category" + ... ) # doctest: +SKIP + """ + return self(kind="barh", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 72a3ed267d192..1c52c93a23d3a 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -28,9 +28,16 @@ def sdf(self): columns = ["category", "int_val", "float_val"] return self.spark.createDataFrame(data, columns) - def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""): - self.assertEqual(fig_data["mode"], "lines") - self.assertEqual(fig_data["type"], "scatter") + def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""): + if kind == "line": + self.assertEqual(fig_data["mode"], "lines") + self.assertEqual(fig_data["type"], "scatter") + elif kind == "bar": + self.assertEqual(fig_data["type"], "bar") + elif kind == "barh": + self.assertEqual(fig_data["type"], "bar") + self.assertEqual(fig_data["orientation"], "h") + self.assertEqual(fig_data["xaxis"], "x") self.assertEqual(list(fig_data["x"]), expected_x) self.assertEqual(fig_data["yaxis"], "y") @@ -40,12 +47,37 @@ def _check_fig_data(self, fig_data, expected_x, expected_y, expected_name=""): def test_line_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="line", x="category", y="int_val") - self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) # multiple columns as vertical axis fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"]) - self._check_fig_data(fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") - self._check_fig_data(fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data("line", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + def test_bar_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="bar", x="category", y="int_val") + self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + + # multiple columns as vertical axis + fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"]) + self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data("bar", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + def test_barh_plot(self): + # single column as vertical axis + fig = self.sdf.plot(kind="barh", x="category", y="int_val") + self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + + # multiple columns as vertical axis + fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"]) + self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") + self._check_fig_data("barh", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + + # multiple columns as horizontal axis + fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category") + self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val") + self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val") class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): From e1637e3fbe0a7ee6492cfc909ef13fc1fe0534d1 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Mon, 23 Sep 2024 19:51:21 +0800 Subject: [PATCH 054/250] [SPARK-48712][SQL][FOLLOWUP] Check whether input is valid utf-8 string or not before entering fast path ### What changes were proposed in this pull request? Check whether input is valid utf-8 string or not before entering fast path ### Why are the changes needed? Avoid behavior change on a corner case where users provide invalid UTF-8 strings for UTF-8 encoding ### Does this PR introduce _any_ user-facing change? no, this is a followup to avoid potential breaking change ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48203 from yaooqinn/SPARK-48712. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../expressions/stringExpressions.scala | 5 ++--- .../expressions/StringExpressionsSuite.scala | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index da6d786efb4e3..786c3968be0fe 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -3039,10 +3039,9 @@ object Encode { legacyCharsets: Boolean, legacyErrorAction: Boolean): Array[Byte] = { val toCharset = charset.toString - if (input.numBytes == 0 || "UTF-8".equalsIgnoreCase(toCharset)) { - return input.getBytes - } + if ("UTF-8".equalsIgnoreCase(toCharset) && input.isValid) return input.getBytes val encoder = CharsetProvider.newEncoder(toCharset, legacyCharsets, legacyErrorAction) + if (input.numBytes == 0) return input.getBytes try { val bb = encoder.encode(CharBuffer.wrap(input.toString)) JavaUtils.bufferToArray(bb) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 29b878230472d..9b454ba764f92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -26,9 +26,12 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.util.CharsetProvider +import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -2076,4 +2079,22 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) ) } + + test("SPARK-48712: Check whether input is valid utf-8 string or not before entering fast path") { + val str = UTF8String.fromBytes(Array[Byte](-1, -2, -3, -4)) + assert(!str.isValid, "please use a string that is not valid UTF-8 for testing") + val expected = Array[Byte](-17, -65, -67, -17, -65, -67, -17, -65, -67, -17, -65, -67) + val bytes = Encode.encode(str, UTF8String.fromString("UTF-8"), false, false) + assert(bytes === expected) + checkEvaluation(Encode(Literal(str), Literal("UTF-8")), expected) + checkEvaluation(Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-8")), Array.emptyByteArray) + checkErrorInExpression[SparkIllegalArgumentException]( + Encode(Literal(UTF8String.EMPTY_UTF8), Literal("UTF-12345")), + condition = "INVALID_PARAMETER_VALUE.CHARSET", + parameters = Map( + "charset" -> "UTF-12345", + "functionName" -> toSQLId("encode"), + "parameter" -> toSQLId("charset"), + "charsets" -> CharsetProvider.VALID_CHARSETS.mkString(", "))) + } } From fec1562b0ea03ff42d2468ea8ff7cbbc569336d8 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 23 Sep 2024 20:03:14 +0800 Subject: [PATCH 055/250] [SPARK-49755][CONNECT] Remove special casing for avro functions in Connect ### What changes were proposed in this pull request? apply the built-in registered functions ### Why are the changes needed? code simplification ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48209 from zhengruifeng/connect_avro. Authored-by: Ruifeng Zheng Signed-off-by: yangjie01 --- .../expressions/toFromAvroSqlFunctions.scala | 3 ++ .../from_avro_with_options.explain | 2 +- .../from_avro_without_options.explain | 2 +- .../to_avro_with_schema.explain | 2 +- .../to_avro_without_schema.explain | 2 +- sql/connect/server/pom.xml | 2 +- .../connect/planner/SparkConnectPlanner.scala | 47 +------------------ 7 files changed, 9 insertions(+), 51 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala index 58bddafac0882..457f469e0f687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/toFromAvroSqlFunctions.scala @@ -61,6 +61,9 @@ case class FromAvro(child: Expression, jsonFormatSchema: Expression, options: Ex override def second: Expression = jsonFormatSchema override def third: Expression = options + def this(child: Expression, jsonFormatSchema: Expression) = + this(child, jsonFormatSchema, Literal.create(null)) + override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = { copy(child = newFirst, jsonFormatSchema = newSecond, options = newThird) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain index 1ef91ef8c36ac..f08c804d3b88a 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_with_options.explain @@ -1,2 +1,2 @@ -Project [from_avro(bytes#0, {"type": "int", "name": "id"}, (mode,FAILFAST), (compression,zstandard)) AS from_avro(bytes)#0] +Project [from_avro(bytes#0, {"type": "int", "name": "id"}, (mode,FAILFAST), (compression,zstandard)) AS from_avro(bytes, {"type": "int", "name": "id"}, map(mode, FAILFAST, compression, zstandard))#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain index 8fca0b5341694..6fe4a8babc689 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/from_avro_without_options.explain @@ -1,2 +1,2 @@ -Project [from_avro(bytes#0, {"type": "string", "name": "name"}) AS from_avro(bytes)#0] +Project [from_avro(bytes#0, {"type": "string", "name": "name"}) AS from_avro(bytes, {"type": "string", "name": "name"}, NULL)#0] +- LocalRelation , [id#0L, bytes#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain index cd2dc984e3ffa..8ba9248f844c7 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_with_schema.explain @@ -1,2 +1,2 @@ -Project [to_avro(a#0, Some({"type": "int", "name": "id"})) AS to_avro(a)#0] +Project [to_avro(a#0, Some({"type": "int", "name": "id"})) AS to_avro(a, {"type": "int", "name": "id"})#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain index a5371c70ac78a..b2947334945e3 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/to_avro_without_schema.explain @@ -1,2 +1,2 @@ -Project [to_avro(id#0L, None) AS to_avro(id)#0] +Project [to_avro(id#0L, None) AS to_avro(id, NULL)#0] +- LocalRelation , [id#0L, a#0, b#0] diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml index 3350c4261e9da..12e3ed9030437 100644 --- a/sql/connect/server/pom.xml +++ b/sql/connect/server/pom.xml @@ -105,7 +105,7 @@ org.apache.spark spark-avro_${scala.binary.version} ${project.version} - provided + test org.apache.spark diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 33c9edb1cd21a..231e54ff77d29 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -44,7 +44,6 @@ import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID} import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest} import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession} -import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedTranspose} import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder} @@ -1523,8 +1522,7 @@ class SparkConnectPlanner( case proto.Expression.ExprTypeCase.UNRESOLVED_ATTRIBUTE => transformUnresolvedAttribute(exp.getUnresolvedAttribute) case proto.Expression.ExprTypeCase.UNRESOLVED_FUNCTION => - transformUnregisteredFunction(exp.getUnresolvedFunction) - .getOrElse(transformUnresolvedFunction(exp.getUnresolvedFunction)) + transformUnresolvedFunction(exp.getUnresolvedFunction) case proto.Expression.ExprTypeCase.ALIAS => transformAlias(exp.getAlias) case proto.Expression.ExprTypeCase.EXPRESSION_STRING => transformExpressionString(exp.getExpressionString) @@ -1844,49 +1842,6 @@ class SparkConnectPlanner( UnresolvedNamedLambdaVariable(variable.getNamePartsList.asScala.toSeq) } - /** - * For some reason, not all functions are registered in 'FunctionRegistry'. For a unregistered - * function, we can still wrap it under the proto 'UnresolvedFunction', and then resolve it in - * this method. - */ - private def transformUnregisteredFunction( - fun: proto.Expression.UnresolvedFunction): Option[Expression] = { - fun.getFunctionName match { - // Avro-specific functions - case "from_avro" if Seq(2, 3).contains(fun.getArgumentsCount) => - val children = fun.getArgumentsList.asScala.map(transformExpression) - val jsonFormatSchema = extractString(children(1), "jsonFormatSchema") - var options = Map.empty[String, String] - if (fun.getArgumentsCount == 3) { - options = extractMapData(children(2), "Options") - } - Some(AvroDataToCatalyst(children.head, jsonFormatSchema, options)) - - case "to_avro" if Seq(1, 2).contains(fun.getArgumentsCount) => - val children = fun.getArgumentsList.asScala.map(transformExpression) - var jsonFormatSchema = Option.empty[String] - if (fun.getArgumentsCount == 2) { - jsonFormatSchema = Some(extractString(children(1), "jsonFormatSchema")) - } - Some(CatalystDataToAvro(children.head, jsonFormatSchema)) - - case _ => None - } - } - - private def extractString(expr: Expression, field: String): String = expr match { - case Literal(s, StringType) if s != null => s.toString - case other => throw InvalidPlanInput(s"$field should be a literal string, but got $other") - } - - @scala.annotation.tailrec - private def extractMapData(expr: Expression, field: String): Map[String, String] = expr match { - case map: CreateMap => ExprUtils.convertToMapData(map) - case UnresolvedFunction(Seq("map"), args, _, _, _, _, _) => - extractMapData(CreateMap(args), field) - case other => throw InvalidPlanInput(s"$field should be created by map, but got $other") - } - private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { if (alias.getNameCount == 1) { val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) { From 3b5c1d6baeb239c75c182513b3fad37d532d9f9f Mon Sep 17 00:00:00 2001 From: Nemanja Boric Date: Mon, 23 Sep 2024 11:22:07 -0400 Subject: [PATCH 056/250] [SPARK-49747][CONNECT] Migrate connect/ files to structured logging ### What changes were proposed in this pull request? We are moving one missing piece in SparkConnect to MDC-based logging. ### Why are the changes needed? As part of the greater migration to structured logging. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Compilation/existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48195 from nemanja-boric-databricks/mdc-connect. Authored-by: Nemanja Boric Signed-off-by: Herman van Hovell --- .../spark/sql/connect/execution/ExecuteThreadRunner.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index fe43edb5c6218..e75654e2c384f 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -27,7 +27,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.SparkSQLException import org.apache.spark.connect.proto -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} @@ -113,7 +113,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends } catch { // Need to catch throwable instead of NonFatal, because e.g. InterruptedException is fatal. case e: Throwable => - logDebug(s"Exception in execute: $e") + logDebug(log"Exception in execute: ${MDC(LogKeys.EXCEPTION, e)}") // Always cancel all remaining execution after error. executeHolder.sessionHolder.session.sparkContext.cancelJobsWithTag( executeHolder.jobTag, @@ -298,7 +298,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends ProtoUtils.abbreviate(request, maxLevel = 8).toString) } catch { case NonFatal(e) => - logWarning("Fail to extract debug information", e) + logWarning(log"Fail to extract debug information: ${MDC(LogKeys.EXCEPTION, e)}") "UNKNOWN" } } From 1086256a81f16127563cdf9a6d0b7ef1e413f17a Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 23 Sep 2024 19:10:44 -0400 Subject: [PATCH 057/250] [SPARK-49415][CONNECT][SQL] Move SQLImplicits to sql/api ### What changes were proposed in this pull request? This PR largely moves SQLImplicits and DatasetHolder to sql/api. ### Why are the changes needed? We are creating a unified Scala interface for Classic and Connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48151 from hvanhovell/SPARK-49415. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/DatasetHolder.scala | 41 --- .../org/apache/spark/sql/SQLImplicits.scala | 283 +---------------- .../org/apache/spark/sql/SparkSession.scala | 15 +- project/MimaExcludes.scala | 12 + .../org/apache/spark/sql/DatasetHolder.scala | 11 +- .../apache/spark/sql/api/SQLImplicits.scala | 300 ++++++++++++++++++ .../apache/spark/sql/api/SparkSession.scala | 13 + .../org/apache/spark/sql/SQLContext.scala | 4 +- .../org/apache/spark/sql/SQLImplicits.scala | 248 +-------------- .../org/apache/spark/sql/SparkSession.scala | 15 +- .../sql/expressions/scalalang/typed.scala | 5 - .../apache/spark/sql/test/SQLTestData.scala | 2 +- .../apache/spark/sql/test/SQLTestUtils.scala | 2 +- 13 files changed, 348 insertions(+), 603 deletions(-) delete mode 100644 connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala rename sql/{core => api}/src/main/scala/org/apache/spark/sql/DatasetHolder.scala (79%) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala deleted file mode 100644 index 66f591bf1fb99..0000000000000 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql - -/** - * A container for a [[Dataset]], used for implicit conversions in Scala. - * - * To use this, import implicit conversions in SQL: - * {{{ - * val spark: SparkSession = ... - * import spark.implicits._ - * }}} - * - * @since 3.4.0 - */ -case class DatasetHolder[T] private[sql] (private val ds: Dataset[T]) { - - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. - def toDS(): Dataset[T] = ds - - // This is declared with parentheses to prevent the Scala compiler from treating - // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = ds.toDF() - - def toDF(colNames: String*): DataFrame = ds.toDF(colNames: _*) -} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 7799d395d5c6a..4690253da808b 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -16,283 +16,8 @@ */ package org.apache.spark.sql -import scala.collection.Map -import scala.language.implicitConversions -import scala.reflect.classTag -import scala.reflect.runtime.universe.TypeTag - -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ - -/** - * A collection of implicit methods for converting names and Symbols into [[Column]]s, and for - * converting common Scala objects into [[Dataset]]s. - * - * @since 3.4.0 - */ -abstract class SQLImplicits private[sql] (session: SparkSession) extends LowPrioritySQLImplicits { - - /** - * Converts $"col name" into a [[Column]]. - * - * @since 3.4.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 3.4.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - - /** @since 3.4.0 */ - implicit val newIntEncoder: Encoder[Int] = PrimitiveIntEncoder - - /** @since 3.4.0 */ - implicit val newLongEncoder: Encoder[Long] = PrimitiveLongEncoder - - /** @since 3.4.0 */ - implicit val newDoubleEncoder: Encoder[Double] = PrimitiveDoubleEncoder - - /** @since 3.4.0 */ - implicit val newFloatEncoder: Encoder[Float] = PrimitiveFloatEncoder - - /** @since 3.4.0 */ - implicit val newByteEncoder: Encoder[Byte] = PrimitiveByteEncoder - - /** @since 3.4.0 */ - implicit val newShortEncoder: Encoder[Short] = PrimitiveShortEncoder - - /** @since 3.4.0 */ - implicit val newBooleanEncoder: Encoder[Boolean] = PrimitiveBooleanEncoder - - /** @since 3.4.0 */ - implicit val newStringEncoder: Encoder[String] = StringEncoder - - /** @since 3.4.0 */ - implicit val newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = - AgnosticEncoders.DEFAULT_JAVA_DECIMAL_ENCODER - - /** @since 3.4.0 */ - implicit val newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = - AgnosticEncoders.DEFAULT_SCALA_DECIMAL_ENCODER - - /** @since 3.4.0 */ - implicit val newDateEncoder: Encoder[java.sql.Date] = AgnosticEncoders.STRICT_DATE_ENCODER - - /** @since 3.4.0 */ - implicit val newLocalDateEncoder: Encoder[java.time.LocalDate] = - AgnosticEncoders.STRICT_LOCAL_DATE_ENCODER - - /** @since 3.4.0 */ - implicit val newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = - AgnosticEncoders.LocalDateTimeEncoder - - /** @since 3.4.0 */ - implicit val newTimeStampEncoder: Encoder[java.sql.Timestamp] = - AgnosticEncoders.STRICT_TIMESTAMP_ENCODER - - /** @since 3.4.0 */ - implicit val newInstantEncoder: Encoder[java.time.Instant] = - AgnosticEncoders.STRICT_INSTANT_ENCODER - - /** @since 3.4.0 */ - implicit val newDurationEncoder: Encoder[java.time.Duration] = DayTimeIntervalEncoder - - /** @since 3.4.0 */ - implicit val newPeriodEncoder: Encoder[java.time.Period] = YearMonthIntervalEncoder - - /** @since 3.4.0 */ - implicit def newJavaEnumEncoder[A <: java.lang.Enum[_]: TypeTag]: Encoder[A] = { - ScalaReflection.encoderFor[A] - } - - // Boxed primitives - - /** @since 3.4.0 */ - implicit val newBoxedIntEncoder: Encoder[java.lang.Integer] = BoxedIntEncoder - - /** @since 3.4.0 */ - implicit val newBoxedLongEncoder: Encoder[java.lang.Long] = BoxedLongEncoder - - /** @since 3.4.0 */ - implicit val newBoxedDoubleEncoder: Encoder[java.lang.Double] = BoxedDoubleEncoder - - /** @since 3.4.0 */ - implicit val newBoxedFloatEncoder: Encoder[java.lang.Float] = BoxedFloatEncoder - - /** @since 3.4.0 */ - implicit val newBoxedByteEncoder: Encoder[java.lang.Byte] = BoxedByteEncoder - - /** @since 3.4.0 */ - implicit val newBoxedShortEncoder: Encoder[java.lang.Short] = BoxedShortEncoder - - /** @since 3.4.0 */ - implicit val newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = BoxedBooleanEncoder - - // Seqs - private def newSeqEncoder[E](elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Seq[E]] = { - IterableEncoder( - classTag[Seq[E]], - elementEncoder, - elementEncoder.nullable, - elementEncoder.lenientSerialization) - } - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newLongSeqEncoder: Encoder[Seq[Long]] = newSeqEncoder(PrimitiveLongEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newDoubleSeqEncoder: Encoder[Seq[Double]] = newSeqEncoder(PrimitiveDoubleEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newFloatSeqEncoder: Encoder[Seq[Float]] = newSeqEncoder(PrimitiveFloatEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newByteSeqEncoder: Encoder[Seq[Byte]] = newSeqEncoder(PrimitiveByteEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newShortSeqEncoder: Encoder[Seq[Short]] = newSeqEncoder(PrimitiveShortEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = newSeqEncoder(PrimitiveBooleanEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder) - - /** - * @since 3.4.0 - * @deprecated - * use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] = - newSeqEncoder(ScalaReflection.encoderFor[A]) - - /** @since 3.4.0 */ - implicit def newSequenceEncoder[T <: Seq[_]: TypeTag]: Encoder[T] = - ScalaReflection.encoderFor[T] - - // Maps - /** @since 3.4.0 */ - implicit def newMapEncoder[T <: Map[_, _]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] - - /** - * Notice that we serialize `Set` to Catalyst array. The set property is only kept when - * manipulating the domain objects. The serialization format doesn't keep the set property. When - * we have a Catalyst array which contains duplicated elements and convert it to - * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. - * - * @since 3.4.0 - */ - implicit def newSetEncoder[T <: Set[_]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] - - // Arrays - private def newArrayEncoder[E]( - elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Array[E]] = { - ArrayEncoder(elementEncoder, elementEncoder.nullable) - } - - /** @since 3.4.0 */ - implicit val newIntArrayEncoder: Encoder[Array[Int]] = newArrayEncoder(PrimitiveIntEncoder) - - /** @since 3.4.0 */ - implicit val newLongArrayEncoder: Encoder[Array[Long]] = newArrayEncoder(PrimitiveLongEncoder) - - /** @since 3.4.0 */ - implicit val newDoubleArrayEncoder: Encoder[Array[Double]] = - newArrayEncoder(PrimitiveDoubleEncoder) - - /** @since 3.4.0 */ - implicit val newFloatArrayEncoder: Encoder[Array[Float]] = newArrayEncoder( - PrimitiveFloatEncoder) - - /** @since 3.4.0 */ - implicit val newByteArrayEncoder: Encoder[Array[Byte]] = BinaryEncoder - - /** @since 3.4.0 */ - implicit val newShortArrayEncoder: Encoder[Array[Short]] = newArrayEncoder( - PrimitiveShortEncoder) - - /** @since 3.4.0 */ - implicit val newBooleanArrayEncoder: Encoder[Array[Boolean]] = - newArrayEncoder(PrimitiveBooleanEncoder) - - /** @since 3.4.0 */ - implicit val newStringArrayEncoder: Encoder[Array[String]] = newArrayEncoder(StringEncoder) - - /** @since 3.4.0 */ - implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] = { - newArrayEncoder(ScalaReflection.encoderFor[A]) - } - - /** - * Creates a [[Dataset]] from a local Seq. - * @since 3.4.0 - */ - implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T] = { - DatasetHolder(session.createDataset(s)) - } -} - -/** - * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. Conflicting - * implicits are placed here to disambiguate resolution. - * - * Reasons for including specific implicits: newProductEncoder - to disambiguate for `List`s which - * are both `Seq` and `Product` - */ -trait LowPrioritySQLImplicits { - - /** @since 3.4.0 */ - implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = - ScalaReflection.encoderFor[T] +/** @inheritdoc */ +abstract class SQLImplicits private[sql] (override val session: SparkSession) + extends api.SQLImplicits { + type DS[U] = Dataset[U] } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 04f8eeb5c6d46..0663f0186888e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -252,19 +252,8 @@ class SparkSession private[sql] ( lazy val udf: UDFRegistration = new UDFRegistration(this) // scalastyle:off - // Disable style checker so "implicits" object can start with lowercase i - /** - * (Scala-specific) Implicit methods available in Scala for converting common names and Symbols - * into [[Column]]s, and for converting common Scala objects into DataFrame`s. - * - * {{{ - * val sparkSession = SparkSession.builder.getOrCreate() - * import sparkSession.implicits._ - * }}} - * - * @since 3.4.0 - */ - object implicits extends SQLImplicits(this) with Serializable + /** @inheritdoc */ + object implicits extends SQLImplicits(this) // scalastyle:on /** @inheritdoc */ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ece4504395f12..972438d0757a7 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -183,6 +183,18 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminatedEvent"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryStatus"), + + // SPARK-49415: Shared SQLImplicits. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DatasetHolder"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DatasetHolder$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LowPrioritySQLImplicits"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SQLContext$implicits$"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SQLImplicits"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLImplicits.StringToColumn"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.this"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLImplicits$StringToColumn"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession$implicits$"), + ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.SQLImplicits.session"), ) // Default exclude rules diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala b/sql/api/src/main/scala/org/apache/spark/sql/DatasetHolder.scala similarity index 79% rename from sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala rename to sql/api/src/main/scala/org/apache/spark/sql/DatasetHolder.scala index 1c4ffefb897ea..dd7e8e81a088c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DatasetHolder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/DatasetHolder.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql import org.apache.spark.annotation.Stable +import org.apache.spark.sql.api.Dataset /** - * A container for a [[Dataset]], used for implicit conversions in Scala. + * A container for a [[org.apache.spark.sql.api.Dataset]], used for implicit conversions in Scala. * * To use this, import implicit conversions in SQL: * {{{ @@ -31,15 +32,15 @@ import org.apache.spark.annotation.Stable * @since 1.6.0 */ @Stable -case class DatasetHolder[T] private[sql](private val ds: Dataset[T]) { +class DatasetHolder[T, DS[U] <: Dataset[U]](ds: DS[T]) { // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDS("1")` as invoking this toDS and then apply on the returned Dataset. - def toDS(): Dataset[T] = ds + def toDS(): DS[T] = ds // This is declared with parentheses to prevent the Scala compiler from treating // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame. - def toDF(): DataFrame = ds.toDF() + def toDF(): DS[Row] = ds.toDF().asInstanceOf[DS[Row]] - def toDF(colNames: String*): DataFrame = ds.toDF(colNames : _*) + def toDF(colNames: String*): DS[Row] = ds.toDF(colNames: _*).asInstanceOf[DS[Row]] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala new file mode 100644 index 0000000000000..f6b44e168390a --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala @@ -0,0 +1,300 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import scala.collection.Map +import scala.language.implicitConversions +import scala.reflect.classTag +import scala.reflect.runtime.universe.TypeTag + +import _root_.java + +import org.apache.spark.sql.{ColumnName, DatasetHolder, Encoder, Encoders} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, DEFAULT_SCALA_DECIMAL_ENCODER, IterableEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, StringEncoder} + +/** + * A collection of implicit methods for converting common Scala objects into + * [[org.apache.spark.sql.api.Dataset]]s. + * + * @since 1.6.0 + */ +abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable { + type DS[U] <: Dataset[U] + + protected def session: SparkSession + + /** + * Converts $"col name" into a [[org.apache.spark.sql.Column]]. + * + * @since 2.0.0 + */ + implicit class StringToColumn(val sc: StringContext) { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args: _*)) + } + } + + // Primitives + + /** @since 1.6.0 */ + implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt + + /** @since 1.6.0 */ + implicit def newLongEncoder: Encoder[Long] = Encoders.scalaLong + + /** @since 1.6.0 */ + implicit def newDoubleEncoder: Encoder[Double] = Encoders.scalaDouble + + /** @since 1.6.0 */ + implicit def newFloatEncoder: Encoder[Float] = Encoders.scalaFloat + + /** @since 1.6.0 */ + implicit def newByteEncoder: Encoder[Byte] = Encoders.scalaByte + + /** @since 1.6.0 */ + implicit def newShortEncoder: Encoder[Short] = Encoders.scalaShort + + /** @since 1.6.0 */ + implicit def newBooleanEncoder: Encoder[Boolean] = Encoders.scalaBoolean + + /** @since 1.6.0 */ + implicit def newStringEncoder: Encoder[String] = Encoders.STRING + + /** @since 2.2.0 */ + implicit def newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = Encoders.DECIMAL + + /** @since 2.2.0 */ + implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = + DEFAULT_SCALA_DECIMAL_ENCODER + + /** @since 2.2.0 */ + implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE + + /** @since 3.0.0 */ + implicit def newLocalDateEncoder: Encoder[java.time.LocalDate] = Encoders.LOCALDATE + + /** @since 3.4.0 */ + implicit def newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = Encoders.LOCALDATETIME + + /** @since 2.2.0 */ + implicit def newTimeStampEncoder: Encoder[java.sql.Timestamp] = Encoders.TIMESTAMP + + /** @since 3.0.0 */ + implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT + + /** @since 3.2.0 */ + implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION + + /** @since 3.2.0 */ + implicit def newPeriodEncoder: Encoder[java.time.Period] = Encoders.PERIOD + + /** @since 3.2.0 */ + implicit def newJavaEnumEncoder[A <: java.lang.Enum[_]: TypeTag]: Encoder[A] = + ScalaReflection.encoderFor[A] + + // Boxed primitives + + /** @since 2.0.0 */ + implicit def newBoxedIntEncoder: Encoder[java.lang.Integer] = Encoders.INT + + /** @since 2.0.0 */ + implicit def newBoxedLongEncoder: Encoder[java.lang.Long] = Encoders.LONG + + /** @since 2.0.0 */ + implicit def newBoxedDoubleEncoder: Encoder[java.lang.Double] = Encoders.DOUBLE + + /** @since 2.0.0 */ + implicit def newBoxedFloatEncoder: Encoder[java.lang.Float] = Encoders.FLOAT + + /** @since 2.0.0 */ + implicit def newBoxedByteEncoder: Encoder[java.lang.Byte] = Encoders.BYTE + + /** @since 2.0.0 */ + implicit def newBoxedShortEncoder: Encoder[java.lang.Short] = Encoders.SHORT + + /** @since 2.0.0 */ + implicit def newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = Encoders.BOOLEAN + + // Seqs + private def newSeqEncoder[E](elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Seq[E]] = { + IterableEncoder( + classTag[Seq[E]], + elementEncoder, + elementEncoder.nullable, + elementEncoder.lenientSerialization) + } + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newIntSeqEncoder: Encoder[Seq[Int]] = newSeqEncoder(PrimitiveIntEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newLongSeqEncoder: Encoder[Seq[Long]] = newSeqEncoder(PrimitiveLongEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newDoubleSeqEncoder: Encoder[Seq[Double]] = newSeqEncoder(PrimitiveDoubleEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newFloatSeqEncoder: Encoder[Seq[Float]] = newSeqEncoder(PrimitiveFloatEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newByteSeqEncoder: Encoder[Seq[Byte]] = newSeqEncoder(PrimitiveByteEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newShortSeqEncoder: Encoder[Seq[Short]] = newSeqEncoder(PrimitiveShortEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newBooleanSeqEncoder: Encoder[Seq[Boolean]] = newSeqEncoder(PrimitiveBooleanEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + val newStringSeqEncoder: Encoder[Seq[String]] = newSeqEncoder(StringEncoder) + + /** + * @since 1.6.1 + * @deprecated + * use [[newSequenceEncoder]] + */ + @deprecated("Use newSequenceEncoder instead", "2.2.0") + def newProductSeqEncoder[A <: Product: TypeTag]: Encoder[Seq[A]] = + newSeqEncoder(ScalaReflection.encoderFor[A]) + + /** @since 2.2.0 */ + implicit def newSequenceEncoder[T <: Seq[_]: TypeTag]: Encoder[T] = + ScalaReflection.encoderFor[T] + + // Maps + /** @since 2.3.0 */ + implicit def newMapEncoder[T <: Map[_, _]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] + + /** + * Notice that we serialize `Set` to Catalyst array. The set property is only kept when + * manipulating the domain objects. The serialization format doesn't keep the set property. When + * we have a Catalyst array which contains duplicated elements and convert it to + * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. + * + * @since 2.3.0 + */ + implicit def newSetEncoder[T <: Set[_]: TypeTag]: Encoder[T] = ScalaReflection.encoderFor[T] + + // Arrays + private def newArrayEncoder[E]( + elementEncoder: AgnosticEncoder[E]): AgnosticEncoder[Array[E]] = { + ArrayEncoder(elementEncoder, elementEncoder.nullable) + } + + /** @since 1.6.1 */ + implicit val newIntArrayEncoder: Encoder[Array[Int]] = newArrayEncoder(PrimitiveIntEncoder) + + /** @since 1.6.1 */ + implicit val newLongArrayEncoder: Encoder[Array[Long]] = newArrayEncoder(PrimitiveLongEncoder) + + /** @since 1.6.1 */ + implicit val newDoubleArrayEncoder: Encoder[Array[Double]] = + newArrayEncoder(PrimitiveDoubleEncoder) + + /** @since 1.6.1 */ + implicit val newFloatArrayEncoder: Encoder[Array[Float]] = + newArrayEncoder(PrimitiveFloatEncoder) + + /** @since 1.6.1 */ + implicit val newByteArrayEncoder: Encoder[Array[Byte]] = Encoders.BINARY + + /** @since 1.6.1 */ + implicit val newShortArrayEncoder: Encoder[Array[Short]] = + newArrayEncoder(PrimitiveShortEncoder) + + /** @since 1.6.1 */ + implicit val newBooleanArrayEncoder: Encoder[Array[Boolean]] = + newArrayEncoder(PrimitiveBooleanEncoder) + + /** @since 1.6.1 */ + implicit val newStringArrayEncoder: Encoder[Array[String]] = + newArrayEncoder(StringEncoder) + + /** @since 1.6.1 */ + implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] = + newArrayEncoder(ScalaReflection.encoderFor[A]) + + /** + * Creates a [[Dataset]] from a local Seq. + * @since 1.6.0 + */ + implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T, DS] = { + new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]]) + } + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) +} + +/** + * Lower priority implicit methods for converting Scala objects into + * [[org.apache.spark.sql.api.Dataset]]s. Conflicting implicits are placed here to disambiguate + * resolution. + * + * Reasons for including specific implicits: newProductEncoder - to disambiguate for `List`s which + * are both `Seq` and `Product` + */ +trait LowPrioritySQLImplicits { + + /** @since 1.6.0 */ + implicit def newProductEncoder[T <: Product: TypeTag]: Encoder[T] = Encoders.product[T] +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 41d16b16ab1c5..2623db4060ee6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -505,6 +505,19 @@ abstract class SparkSession extends Serializable with Closeable { */ def read: DataFrameReader + /** + * (Scala-specific) Implicit methods available in Scala for converting common Scala objects into + * `DataFrame`s. + * + * {{{ + * val sparkSession = SparkSession.builder.getOrCreate() + * import sparkSession.implicits._ + * }}} + * + * @since 2.0.0 + */ + val implicits: SQLImplicits + /** * Executes some code block and prints to stdout the time taken to execute the block. This is * available in Scala only and is used primarily for interactive testing and debugging. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ffcc0b923f2cb..636899a7acb06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -251,8 +251,8 @@ class SQLContext private[sql](val sparkSession: SparkSession) * @group basic * @since 1.3.0 */ - object implicits extends SQLImplicits with Serializable { - protected override def session: SparkSession = self.sparkSession + object implicits extends SQLImplicits { + override protected def session: SparkSession = sparkSession } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index a657836aafbea..1bc7e3ee98e76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,259 +17,21 @@ package org.apache.spark.sql -import scala.collection.Map import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -/** - * A collection of implicit methods for converting common Scala objects into [[Dataset]]s. - * - * @since 1.6.0 - */ -abstract class SQLImplicits extends LowPrioritySQLImplicits { +/** @inheritdoc */ +abstract class SQLImplicits extends api.SQLImplicits { + type DS[U] = Dataset[U] protected def session: SparkSession - /** - * Converts $"col name" into a [[Column]]. - * - * @since 2.0.0 - */ - implicit class StringToColumn(val sc: StringContext) { - def $(args: Any*): ColumnName = { - new ColumnName(sc.s(args: _*)) - } - } - - // Primitives - - /** @since 1.6.0 */ - implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt - - /** @since 1.6.0 */ - implicit def newLongEncoder: Encoder[Long] = Encoders.scalaLong - - /** @since 1.6.0 */ - implicit def newDoubleEncoder: Encoder[Double] = Encoders.scalaDouble - - /** @since 1.6.0 */ - implicit def newFloatEncoder: Encoder[Float] = Encoders.scalaFloat - - /** @since 1.6.0 */ - implicit def newByteEncoder: Encoder[Byte] = Encoders.scalaByte - - /** @since 1.6.0 */ - implicit def newShortEncoder: Encoder[Short] = Encoders.scalaShort - - /** @since 1.6.0 */ - implicit def newBooleanEncoder: Encoder[Boolean] = Encoders.scalaBoolean - - /** @since 1.6.0 */ - implicit def newStringEncoder: Encoder[String] = Encoders.STRING - - /** @since 2.2.0 */ - implicit def newJavaDecimalEncoder: Encoder[java.math.BigDecimal] = Encoders.DECIMAL - - /** @since 2.2.0 */ - implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE - - /** @since 3.0.0 */ - implicit def newLocalDateEncoder: Encoder[java.time.LocalDate] = Encoders.LOCALDATE - - /** @since 3.4.0 */ - implicit def newLocalDateTimeEncoder: Encoder[java.time.LocalDateTime] = Encoders.LOCALDATETIME - - /** @since 2.2.0 */ - implicit def newTimeStampEncoder: Encoder[java.sql.Timestamp] = Encoders.TIMESTAMP - - /** @since 3.0.0 */ - implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT - - /** @since 3.2.0 */ - implicit def newDurationEncoder: Encoder[java.time.Duration] = Encoders.DURATION - - /** @since 3.2.0 */ - implicit def newPeriodEncoder: Encoder[java.time.Period] = Encoders.PERIOD - - /** @since 3.2.0 */ - implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] = - ExpressionEncoder() - - // Boxed primitives - - /** @since 2.0.0 */ - implicit def newBoxedIntEncoder: Encoder[java.lang.Integer] = Encoders.INT - - /** @since 2.0.0 */ - implicit def newBoxedLongEncoder: Encoder[java.lang.Long] = Encoders.LONG - - /** @since 2.0.0 */ - implicit def newBoxedDoubleEncoder: Encoder[java.lang.Double] = Encoders.DOUBLE - - /** @since 2.0.0 */ - implicit def newBoxedFloatEncoder: Encoder[java.lang.Float] = Encoders.FLOAT - - /** @since 2.0.0 */ - implicit def newBoxedByteEncoder: Encoder[java.lang.Byte] = Encoders.BYTE - - /** @since 2.0.0 */ - implicit def newBoxedShortEncoder: Encoder[java.lang.Short] = Encoders.SHORT - - /** @since 2.0.0 */ - implicit def newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = Encoders.BOOLEAN - - // Seqs - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder() - - /** - * @since 1.6.1 - * @deprecated use [[newSequenceEncoder]] - */ - @deprecated("Use newSequenceEncoder instead", "2.2.0") - def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder() - - /** @since 2.2.0 */ - implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder() - - // Maps - /** @since 2.3.0 */ - implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() - - /** - * Notice that we serialize `Set` to Catalyst array. The set property is only kept when - * manipulating the domain objects. The serialization format doesn't keep the set property. - * When we have a Catalyst array which contains duplicated elements and convert it to - * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. - * - * @since 2.3.0 - */ - implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder() - - // Arrays - - /** @since 1.6.1 */ - implicit def newIntArrayEncoder: Encoder[Array[Int]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newLongArrayEncoder: Encoder[Array[Long]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newDoubleArrayEncoder: Encoder[Array[Double]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newFloatArrayEncoder: Encoder[Array[Float]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newByteArrayEncoder: Encoder[Array[Byte]] = Encoders.BINARY - - /** @since 1.6.1 */ - implicit def newShortArrayEncoder: Encoder[Array[Short]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newBooleanArrayEncoder: Encoder[Array[Boolean]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newStringArrayEncoder: Encoder[Array[String]] = ExpressionEncoder() - - /** @since 1.6.1 */ - implicit def newProductArrayEncoder[A <: Product : TypeTag]: Encoder[Array[A]] = - ExpressionEncoder() - /** * Creates a [[Dataset]] from an RDD. * * @since 1.6.0 */ - implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { - DatasetHolder(session.createDataset(rdd)) - } - - /** - * Creates a [[Dataset]] from a local Seq. - * @since 1.6.0 - */ - implicit def localSeqToDatasetHolder[T : Encoder](s: Seq[T]): DatasetHolder[T] = { - DatasetHolder(session.createDataset(s)) - } - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) - -} - -/** - * Lower priority implicit methods for converting Scala objects into [[Dataset]]s. - * Conflicting implicits are placed here to disambiguate resolution. - * - * Reasons for including specific implicits: - * newProductEncoder - to disambiguate for `List`s which are both `Seq` and `Product` - */ -trait LowPrioritySQLImplicits { - /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] - + implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T, Dataset] = + new DatasetHolder(session.createDataset(rdd)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 137dbaed9f00a..938df206b9792 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -752,19 +752,8 @@ class SparkSession private( // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i - /** - * (Scala-specific) Implicit methods available in Scala for converting - * common Scala objects into `DataFrame`s. - * - * {{{ - * val sparkSession = SparkSession.builder.getOrCreate() - * import sparkSession.implicits._ - * }}} - * - * @since 2.0.0 - */ - object implicits extends SQLImplicits with Serializable { - protected override def session: SparkSession = SparkSession.this + object implicits extends SQLImplicits { + override protected def session: SparkSession = self } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 6277f8b459248..8d17edd42442e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -39,11 +39,6 @@ object typed { // For example, avg in the Scala version returns Scala primitive Double, whose bytecode // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double. - // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders. - private val implicits = new SQLImplicits { - override protected def session: SparkSession = null - } - /** * Average aggregate function. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index d7c00b68828c4..90432dea3a017 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -35,7 +35,7 @@ private[sql] trait SQLTestData { self => // Helper object to import SQL implicits without a concrete SparkSession private object internalImplicits extends SQLImplicits { - protected override def session: SparkSession = self.spark + override protected def session: SparkSession = self.spark } import internalImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 54d6840eb5775..fe5a0f8ee257a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -240,7 +240,7 @@ private[sql] trait SQLTestUtilsBase * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def session: SparkSession = self.spark + override protected def session: SparkSession = self.spark implicit def toRichColumn(c: Column): SparkSession#RichColumn = session.RichColumn(c) } From 94d288e08f2b9b98c2e74a8dcced86b163c1637a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 24 Sep 2024 08:51:07 +0900 Subject: [PATCH 058/250] [MINOR][PYTHON][DOCS] Fix the docstring of `to_timestamp` ### What changes were proposed in this pull request? Fix the docstring of `to_timestamp` ### Why are the changes needed? `try_to_timestamp` is used in the examples of `to_timestamp` ### Does this PR introduce _any_ user-facing change? doc changes ### How was this patch tested? updated doctests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48207 from zhengruifeng/py_doc_nit_tots. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/functions/builtin.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 2d5dbb5946050..2688f9daa23a4 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -9091,15 +9091,19 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: :class:`~pyspark.sql.Column` timestamp value as :class:`pyspark.sql.types.TimestampType` type. + See Also + -------- + :meth:`pyspark.sql.functions.try_to_timestamp` + Examples -------- Example 1: Convert string to a timestamp >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(sf.try_to_timestamp(df.t).alias('dt')).show() + >>> df.select(sf.to_timestamp(df.t)).show() +-------------------+ - | dt| + | to_timestamp(t)| +-------------------+ |1997-02-28 10:30:00| +-------------------+ @@ -9108,12 +9112,12 @@ def to_timestamp(col: "ColumnOrName", format: Optional[str] = None) -> Column: >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t']) - >>> df.select(sf.try_to_timestamp(df.t, sf.lit('yyyy-MM-dd HH:mm:ss')).alias('dt')).show() - +-------------------+ - | dt| - +-------------------+ - |1997-02-28 10:30:00| - +-------------------+ + >>> df.select(sf.to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss')).show() + +------------------------------------+ + |to_timestamp(t, yyyy-MM-dd HH:mm:ss)| + +------------------------------------+ + | 1997-02-28 10:30:00| + +------------------------------------+ """ from pyspark.sql.classic.column import _to_java_column @@ -9139,6 +9143,10 @@ def try_to_timestamp(col: "ColumnOrName", format: Optional["ColumnOrName"] = Non format: str, optional format to use to convert timestamp values. + See Also + -------- + :meth:`pyspark.sql.functions.to_timestamp` + Examples -------- Example 1: Convert string to a timestamp From 742265ebb742f9520ca06717be57c6aa2e594191 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 23 Sep 2024 22:51:54 -0400 Subject: [PATCH 059/250] [SPARK-49429][CONNECT][SQL] Add Shared DataStreamWriter interface ### What changes were proposed in this pull request? This PR adds a shared DataStreamWriter to sql. ### Why are the changes needed? We are creating a unified Scala interface for sql. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48212 from hvanhovell/SPARK-49429. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../scala/org/apache/spark/sql/Dataset.scala | 7 +- .../sql/streaming/DataStreamWriter.scala | 252 +++---------- .../spark/sql/api/DataStreamWriter.scala | 193 ++++++++++ .../org/apache/spark/sql/api/Dataset.scala | 8 + .../scala/org/apache/spark/sql/Dataset.scala | 7 +- .../sql/streaming/DataStreamWriter.scala | 343 +++++------------- .../sql/streaming/StreamingQueryManager.scala | 6 +- 7 files changed, 340 insertions(+), 476 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamWriter.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index accfff9f2b073..d2877ccaf06c9 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1035,12 +1035,7 @@ class Dataset[T] private[sql] ( new MergeIntoWriterImpl[T](table, this, condition) } - /** - * Interface for saving the content of the streaming Dataset out into external storage. - * - * @group basic - * @since 3.5.0 - */ + /** @inheritdoc */ def writeStream: DataStreamWriter[T] = { new DataStreamWriter[T](this) } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index c8c714047788b..9fcc31e562682 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -29,9 +29,8 @@ import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.WriteStreamOperationStart -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, UdfUtils} +import org.apache.spark.sql.{api, Dataset, ForeachWriter} +import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket} import org.apache.spark.sql.execution.streaming.AvailableNowTrigger import org.apache.spark.sql.execution.streaming.ContinuousTrigger import org.apache.spark.sql.execution.streaming.OneTimeTrigger @@ -47,63 +46,23 @@ import org.apache.spark.util.SparkSerDeUtils * @since 3.5.0 */ @Evolving -final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { +final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends api.DataStreamWriter[T] { + override type DS[U] = Dataset[U] - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.

    • - * `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be written - * to the sink.
    • `OutputMode.Complete()`: all the rows in the streaming - * DataFrame/Dataset will be written to the sink every time there are some updates.
    • - * `OutputMode.Update()`: only the rows that were updated in the streaming DataFrame/Dataset - * will be written to the sink every time there are some updates. If the query doesn't contain - * aggregations, it will be equivalent to `OutputMode.Append()` mode.
    - * - * @since 3.5.0 - */ - def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: OutputMode): this.type = { sinkBuilder.setOutputMode(outputMode.toString.toLowerCase(Locale.ROOT)) this } - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
    • - * `append`: only the new rows in the streaming DataFrame/Dataset will be written to the - * sink.
    • `complete`: all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time there are some updates.
    • `update`: only the rows that were - * updated in the streaming DataFrame/Dataset will be written to the sink every time there are - * some updates. If the query doesn't contain aggregations, it will be equivalent to `append` - * mode.
    - * - * @since 3.5.0 - */ - def outputMode(outputMode: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: String): this.type = { sinkBuilder.setOutputMode(outputMode) this } - /** - * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will - * run the query as fast as possible. - * - * Scala Example: - * {{{ - * df.writeStream.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.writeStream().trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 3.5.0 - */ - def trigger(trigger: Trigger): DataStreamWriter[T] = { + /** @inheritdoc */ + def trigger(trigger: Trigger): this.type = { trigger match { case ProcessingTimeTrigger(intervalMs) => sinkBuilder.setProcessingTimeInterval(s"$intervalMs milliseconds") @@ -117,123 +76,54 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. This name - * must be unique among all the currently active queries in the associated SQLContext. - * - * @since 3.5.0 - */ - def queryName(queryName: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def queryName(queryName: String): this.type = { sinkBuilder.setQueryName(queryName) this } - /** - * Specifies the underlying output data source. - * - * @since 3.5.0 - */ - def format(source: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def format(source: String): this.type = { sinkBuilder.setFormat(source) this } - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. As an example, when we - * partition a dataset by year and then month, the directory layout would look like: - * - *
    • year=2016/month=01/
    • year=2016/month=02/
    - * - * Partitioning is one of the most widely used techniques to optimize physical data layout. It - * provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number of - * distinct values in each column should typically be less than tens of thousands. - * - * @since 3.5.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter[T] = { + def partitionBy(colNames: String*): this.type = { sinkBuilder.clearPartitioningColumnNames() sinkBuilder.addAllPartitioningColumnNames(colNames.asJava) this } - /** - * Clusters the output by the given columns. If specified, the output is laid out such that - * records with similar values on the clustering column are grouped together in the same file. - * - * Clustering improves query efficiency by allowing queries with predicates on the clustering - * columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high - * cardinality columns. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def clusterBy(colNames: String*): DataStreamWriter[T] = { + def clusterBy(colNames: String*): this.type = { sinkBuilder.clearClusteringColumnNames() sinkBuilder.addAllClusteringColumnNames(colNames.asJava) this } - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { sinkBuilder.putOptions(key, value) this } - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) - - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.options(options.asJava) this } - /** - * Adds output options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: java.util.Map[String, String]): this.type = { sinkBuilder.putAllOptions(options) this } - /** - * Sets the output of the streaming query to be processed using the provided writer object. - * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and - * semantics. - * @since 3.5.0 - */ - def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { + /** @inheritdoc */ + def foreach(writer: ForeachWriter[T]): this.type = { val serialized = SparkSerDeUtils.serialize(ForeachWriterPacket(writer, ds.agnosticEncoder)) val scalaWriterBuilder = proto.ScalarScalaUDF .newBuilder() @@ -242,21 +132,9 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * :: Experimental :: - * - * (Scala-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The - * batchId can be used to deduplicate and transactionally write the output (that is, the - * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the - * same for the same batchId (assuming all operations are deterministic in the query). - * - * @since 3.5.0 - */ + /** @inheritdoc */ @Evolving - def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { + def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { val serializedFn = SparkSerDeUtils.serialize(function) sinkBuilder.getForeachBatchBuilder.getScalaFunctionBuilder .setPayload(ByteString.copyFrom(serializedFn)) @@ -265,48 +143,13 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { this } - /** - * :: Experimental :: - * - * (Java-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The - * batchId can be used to deduplicate and transactionally write the output (that is, the - * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the - * same for the same batchId (assuming all operations are deterministic in the query). - * - * @since 3.5.0 - */ - @Evolving - def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { - foreachBatch(UdfUtils.foreachBatchFuncToScalaFunc(function)) - } - - /** - * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def start(path: String): StreamingQuery = { sinkBuilder.setPath(path) start() } - /** - * Starts the execution of the streaming query, which will continually output results to the - * given path as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. Throws a `TimeoutException` if the following conditions are met: - * - Another run of the same streaming query, that is a streaming query sharing the same - * checkpoint location, is already active on the same Spark Driver - * - The SQL configuration `spark.sql.streaming.stopActiveRunOnRestart` is enabled - * - The active run cannot be stopped within the timeout controlled by the SQL configuration - * `spark.sql.streaming.stopTimeout` - * - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[TimeoutException] def start(): StreamingQuery = { val startCmd = Command @@ -323,22 +166,7 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { RemoteStreamingQuery.fromStartCommandResponse(ds.sparkSession, resp) } - /** - * Starts the execution of the streaming query, which will continually output results to the - * given table as new data arrives. The returned [[StreamingQuery]] object can be used to - * interact with the stream. - * - * For v1 table, partitioning columns provided by `partitionBy` will be respected no matter the - * table exists or not. A new table will be created if the table not exists. - * - * For v2 table, `partitionBy` will be ignored if the table already exists. `partitionBy` will - * be respected only if the v2 table does not exist. Besides, the v2 table created by this API - * lacks some functionalities (e.g., customized properties, options, and serde info). If you - * need them, please create the v2 table manually before the execution to avoid creating a table - * with incomplete information. - * - * @since 3.5.0 - */ + /** @inheritdoc */ @Evolving @throws[TimeoutException] def toTable(tableName: String): StreamingQuery = { @@ -346,6 +174,24 @@ final class DataStreamWriter[T] private[sql] (ds: Dataset[T]) extends Logging { start() } + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant Overrides + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + @Evolving + override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = + super.foreachBatch(function) + private val sinkBuilder = WriteStreamOperationStart .newBuilder() .setInput(ds.plan.getRoot) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamWriter.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamWriter.scala new file mode 100644 index 0000000000000..7762708e9520c --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamWriter.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import _root_.java +import _root_.java.util.concurrent.TimeoutException + +import org.apache.spark.annotation.Evolving +import org.apache.spark.api.java.function.VoidFunction2 +import org.apache.spark.sql.{ForeachWriter, WriteConfigMethods} +import org.apache.spark.sql.streaming.{OutputMode, Trigger} + +/** + * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, + * key-value stores, etc). Use `Dataset.writeStream` to access this. + * + * @since 2.0.0 + */ +@Evolving +abstract class DataStreamWriter[T] extends WriteConfigMethods[DataStreamWriter[T]] { + type DS[U] <: Dataset[U] + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
    • + * `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be written + * to the sink.
    • `OutputMode.Complete()`: all the rows in the streaming + * DataFrame/Dataset will be written to the sink every time there are some updates.
    • + * `OutputMode.Update()`: only the rows that were updated in the streaming DataFrame/Dataset + * will be written to the sink every time there are some updates. If the query doesn't contain + * aggregations, it will be equivalent to `OutputMode.Append()` mode.
    + * + * @since 2.0.0 + */ + def outputMode(outputMode: OutputMode): this.type + + /** + * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
    • + * `append`: only the new rows in the streaming DataFrame/Dataset will be written to the + * sink.
    • `complete`: all the rows in the streaming DataFrame/Dataset will be written + * to the sink every time there are some updates.
    • `update`: only the rows that were + * updated in the streaming DataFrame/Dataset will be written to the sink every time there are + * some updates. If the query doesn't contain aggregations, it will be equivalent to `append` + * mode.
    + * + * @since 2.0.0 + */ + def outputMode(outputMode: String): this.type + + /** + * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will + * run the query as fast as possible. + * + * Scala Example: + * {{{ + * df.writeStream.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * df.writeStream.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * df.writeStream().trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.0.0 + */ + def trigger(trigger: Trigger): this.type + + /** + * Specifies the name of the [[org.apache.spark.sql.api.StreamingQuery]] that can be started + * with `start()`. This name must be unique among all the currently active queries in the + * associated SparkSession. + * + * @since 2.0.0 + */ + def queryName(queryName: String): this.type + + /** + * Sets the output of the streaming query to be processed using the provided writer object. + * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and + * semantics. + * + * @since 2.0.0 + */ + def foreach(writer: ForeachWriter[T]): this.type + + /** + * :: Experimental :: + * + * (Scala-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The + * batchId can be used to deduplicate and transactionally write the output (that is, the + * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the + * same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @Evolving + def foreachBatch(function: (DS[T], Long) => Unit): this.type + + /** + * :: Experimental :: + * + * (Java-specific) Sets the output of the streaming query to be processed using the provided + * function. This is supported only in the micro-batch execution modes (that is, when the + * trigger is not continuous). In every micro-batch, the provided function will be called in + * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. The + * batchId can be used to deduplicate and transactionally write the output (that is, the + * provided Dataset) to external systems. The output Dataset is guaranteed to be exactly the + * same for the same batchId (assuming all operations are deterministic in the query). + * + * @since 2.4.0 + */ + @Evolving + def foreachBatch(function: VoidFunction2[DS[T], java.lang.Long]): this.type = { + foreachBatch((batchDs: DS[T], batchId: Long) => function.call(batchDs, batchId)) + } + + /** + * Starts the execution of the streaming query, which will continually output results to the + * given path as new data arrives. The returned [[org.apache.spark.sql.api.StreamingQuery]] + * object can be used to interact with the stream. + * + * @since 2.0.0 + */ + def start(path: String): StreamingQuery + + /** + * Starts the execution of the streaming query, which will continually output results to the + * given path as new data arrives. The returned [[org.apache.spark.sql.api.StreamingQuery]] + * object can be used to interact with the stream. Throws a `TimeoutException` if the following + * conditions are met: + * - Another run of the same streaming query, that is a streaming query sharing the same + * checkpoint location, is already active on the same Spark Driver + * - The SQL configuration `spark.sql.streaming.stopActiveRunOnRestart` is enabled + * - The active run cannot be stopped within the timeout controlled by the SQL configuration + * `spark.sql.streaming.stopTimeout` + * + * @since 2.0.0 + */ + @throws[TimeoutException] + def start(): StreamingQuery + + /** + * Starts the execution of the streaming query, which will continually output results to the + * given table as new data arrives. The returned [[org.apache.spark.sql.api.StreamingQuery]] + * object can be used to interact with the stream. + * + * For v1 table, partitioning columns provided by `partitionBy` will be respected no matter the + * table exists or not. A new table will be created if the table not exists. + * + * For v2 table, `partitionBy` will be ignored if the table already exists. `partitionBy` will + * be respected only if the v2 table does not exist. Besides, the v2 table created by this API + * lacks some functionalities (e.g., customized properties, options, and serde info). If you + * need them, please create the v2 table manually before the execution to avoid creating a table + * with incomplete information. + * + * @since 3.1.0 + */ + @Evolving + @throws[TimeoutException] + def toTable(tableName: String): StreamingQuery + + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant Overrides + /////////////////////////////////////////////////////////////////////////////////////// + override def option(key: String, value: Boolean): this.type = + super.option(key, value).asInstanceOf[this.type] + override def option(key: String, value: Long): this.type = + super.option(key, value).asInstanceOf[this.type] + override def option(key: String, value: Double): this.type = + super.option(key, value).asInstanceOf[this.type] +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 6eef034aa5157..06a6148a7c188 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -3017,6 +3017,14 @@ abstract class Dataset[T] extends Serializable { */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] + /** + * Interface for saving the content of the streaming Dataset out into external storage. + * + * @group basic + * @since 2.0.0 + */ + def writeStream: DataStreamWriter[T] + /** * Create a write configuration builder for v2 sources. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef628ca612b49..80ec70a7864c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1618,12 +1618,7 @@ class Dataset[T] private[sql]( new MergeIntoWriterImpl[T](table, this, condition) } - /** - * Interface for saving the content of the streaming Dataset out into external storage. - * - * @group basic - * @since 2.0.0 - */ + /** @inheritdoc */ def writeStream: DataStreamWriter[T] = { if (!isStreaming) { logicalPlan.failAnalysis( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ab4d350c1e68c..b0233d2c51b75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -55,253 +55,101 @@ import org.apache.spark.util.Utils * @since 2.0.0 */ @Evolving -final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { - import DataStreamWriter._ +final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends api.DataStreamWriter[T] { + type DS[U] = Dataset[U] - private val df = ds.toDF() - - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - *
      - *
    • `OutputMode.Append()`: only the new rows in the streaming DataFrame/Dataset will be - * written to the sink.
    • - *
    • `OutputMode.Complete()`: all the rows in the streaming DataFrame/Dataset will be written - * to the sink every time there are some updates.
    • - *
    • `OutputMode.Update()`: only the rows that were updated in the streaming - * DataFrame/Dataset will be written to the sink every time there are some updates. - * If the query doesn't contain aggregations, it will be equivalent to - * `OutputMode.Append()` mode.
    • - *
    - * - * @since 2.0.0 - */ - def outputMode(outputMode: OutputMode): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: OutputMode): this.type = { this.outputMode = outputMode this } - /** - * Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink. - *
      - *
    • `append`: only the new rows in the streaming DataFrame/Dataset will be written to - * the sink.
    • - *
    • `complete`: all the rows in the streaming DataFrame/Dataset will be written to the sink - * every time there are some updates.
    • - *
    • `update`: only the rows that were updated in the streaming DataFrame/Dataset will - * be written to the sink every time there are some updates. If the query doesn't - * contain aggregations, it will be equivalent to `append` mode.
    • - *
    - * - * @since 2.0.0 - */ - def outputMode(outputMode: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def outputMode(outputMode: String): this.type = { this.outputMode = InternalOutputModes(outputMode) this } - /** - * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run - * the query as fast as possible. - * - * Scala Example: - * {{{ - * df.writeStream.trigger(ProcessingTime("10 seconds")) - * - * import scala.concurrent.duration._ - * df.writeStream.trigger(ProcessingTime(10.seconds)) - * }}} - * - * Java Example: - * {{{ - * df.writeStream().trigger(ProcessingTime.create("10 seconds")) - * - * import java.util.concurrent.TimeUnit - * df.writeStream().trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) - * }}} - * - * @since 2.0.0 - */ - def trigger(trigger: Trigger): DataStreamWriter[T] = { + /** @inheritdoc */ + def trigger(trigger: Trigger): this.type = { this.trigger = trigger this } - /** - * Specifies the name of the [[StreamingQuery]] that can be started with `start()`. - * This name must be unique among all the currently active queries in the associated SQLContext. - * - * @since 2.0.0 - */ - def queryName(queryName: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def queryName(queryName: String): this.type = { this.extraOptions += ("queryName" -> queryName) this } - /** - * Specifies the underlying output data source. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def format(source: String): this.type = { this.source = source this } - /** - * Partitions the output by the given columns on the file system. If specified, the output is - * laid out on the file system similar to Hive's partitioning scheme. As an example, when we - * partition a dataset by year and then month, the directory layout would look like: - * - *
      - *
    • year=2016/month=01/
    • - *
    • year=2016/month=02/
    • - *
    - * - * Partitioning is one of the most widely used techniques to optimize physical data layout. - * It provides a coarse-grained index for skipping unnecessary data reads when queries have - * predicates on the partitioned columns. In order for partitioning to work well, the number - * of distinct values in each column should typically be less than tens of thousands. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def partitionBy(colNames: String*): DataStreamWriter[T] = { + def partitionBy(colNames: String*): this.type = { this.partitioningColumns = Option(colNames) validatePartitioningAndClustering() this } - /** - * Clusters the output by the given columns. If specified, the output is laid out such that - * records with similar values on the clustering column are grouped together in the same file. - * - * Clustering improves query efficiency by allowing queries with predicates on the clustering - * columns to skip unnecessary data. Unlike partitioning, clustering can be used on very high - * cardinality columns. - * - * @since 4.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs - def clusterBy(colNames: String*): DataStreamWriter[T] = { + def clusterBy(colNames: String*): this.type = { this.clusteringColumns = Option(colNames) validatePartitioningAndClustering() this } - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamWriter[T] = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { this.extraOptions += (key -> value) this } - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Boolean): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Long): DataStreamWriter[T] = option(key, value.toString) - - /** - * Adds an output option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Double): DataStreamWriter[T] = option(key, value.toString) - - /** - * (Scala-specific) Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.extraOptions ++= options this } - /** - * Adds output options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamWriter[T] = { + /** @inheritdoc */ + def options(options: java.util.Map[String, String]): this.type = { this.options(options.asScala) this } - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with - * the stream. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def start(path: String): StreamingQuery = { - if (!df.sparkSession.sessionState.conf.legacyPathOptionBehavior && + if (!ds.sparkSession.sessionState.conf.legacyPathOptionBehavior && extraOptions.contains("path")) { throw QueryCompilationErrors.setPathOptionAndCallWithPathParameterError("start") } startInternal(Some(path)) } - /** - * Starts the execution of the streaming query, which will continually output results to the given - * path as new data arrives. The returned [[StreamingQuery]] object can be used to interact with - * the stream. Throws a `TimeoutException` if the following conditions are met: - * - Another run of the same streaming query, that is a streaming query - * sharing the same checkpoint location, is already active on the same - * Spark Driver - * - The SQL configuration `spark.sql.streaming.stopActiveRunOnRestart` - * is enabled - * - The active run cannot be stopped within the timeout controlled by - * the SQL configuration `spark.sql.streaming.stopTimeout` - * - * @since 2.0.0 - */ + /** @inheritdoc */ @throws[TimeoutException] def start(): StreamingQuery = startInternal(None) - /** - * Starts the execution of the streaming query, which will continually output results to the given - * table as new data arrives. The returned [[StreamingQuery]] object can be used to interact with - * the stream. - * - * For v1 table, partitioning columns provided by `partitionBy` will be respected no matter the - * table exists or not. A new table will be created if the table not exists. - * - * For v2 table, `partitionBy` will be ignored if the table already exists. `partitionBy` will be - * respected only if the v2 table does not exist. Besides, the v2 table created by this API lacks - * some functionalities (e.g., customized properties, options, and serde info). If you need them, - * please create the v2 table manually before the execution to avoid creating a table with - * incomplete information. - * - * @since 3.1.0 - */ + /** @inheritdoc */ @Evolving @throws[TimeoutException] def toTable(tableName: String): StreamingQuery = { - this.tableName = tableName - import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier + import ds.sparkSession.sessionState.analyzer.CatalogAndIdentifier import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - val parser = df.sparkSession.sessionState.sqlParser + val parser = ds.sparkSession.sessionState.sqlParser val originalMultipartIdentifier = parser.parseMultipartIdentifier(tableName) val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier // Currently we don't create a logical streaming writer node in logical plan, so cannot rely // on analyzer to resolve it. Directly lookup only for temp view to provide clearer message. // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. - if (df.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) { + if (ds.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) { throw QueryCompilationErrors.tempViewNotSupportStreamingWriteError(tableName) } @@ -327,14 +175,14 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { extraOptions.get("path"), None, None, - false) + external = false) val cmd = CreateTable( UnresolvedIdentifier(originalMultipartIdentifier), - df.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), + ds.schema.asNullable.map(ColumnDefinition.fromV1Column(_, parser)), partitioningOrClusteringTransform, tableSpec, ignoreIfExists = false) - Dataset.ofRows(df.sparkSession, cmd) + Dataset.ofRows(ds.sparkSession, cmd) } val tableInstance = catalog.asTableCatalog.loadTable(identifier) @@ -371,34 +219,34 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { throw QueryCompilationErrors.cannotOperateOnHiveDataSourceFilesError("write") } - if (source == SOURCE_NAME_MEMORY) { - assertNotPartitioned(SOURCE_NAME_MEMORY) + if (source == DataStreamWriter.SOURCE_NAME_MEMORY) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_MEMORY) if (extraOptions.get("queryName").isEmpty) { throw QueryCompilationErrors.queryNameNotSpecifiedForMemorySinkError() } val sink = new MemorySink() - val resultDf = Dataset.ofRows(df.sparkSession, - MemoryPlan(sink, DataTypeUtils.toAttributes(df.schema))) + val resultDf = Dataset.ofRows(ds.sparkSession, + MemoryPlan(sink, DataTypeUtils.toAttributes(ds.schema))) val recoverFromCheckpoint = outputMode == OutputMode.Complete() val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromCheckpoint, catalogTable = catalogTable) resultDf.createOrReplaceTempView(query.name) query - } else if (source == SOURCE_NAME_FOREACH) { - assertNotPartitioned(SOURCE_NAME_FOREACH) + } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH) val sink = ForeachWriterTable[Any](foreachWriter, foreachWriterEncoder) startQuery(sink, extraOptions, catalogTable = catalogTable) - } else if (source == SOURCE_NAME_FOREACH_BATCH) { - assertNotPartitioned(SOURCE_NAME_FOREACH_BATCH) + } else if (source == DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) { + assertNotPartitioned(DataStreamWriter.SOURCE_NAME_FOREACH_BATCH) if (trigger.isInstanceOf[ContinuousTrigger]) { throw QueryCompilationErrors.sourceNotSupportedWithContinuousTriggerError(source) } val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc) startQuery(sink, extraOptions, catalogTable = catalogTable) } else { - val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val cls = DataSource.lookupDataSource(source, ds.sparkSession.sessionState.conf) val disabledSources = - Utils.stringToSeq(df.sparkSession.sessionState.conf.disabledV2StreamingWriters) + Utils.stringToSeq(ds.sparkSession.sessionState.conf.disabledV2StreamingWriters) val useV1Source = disabledSources.contains(cls.getCanonicalName) || // file source v2 does not support streaming yet. classOf[FileDataSourceV2].isAssignableFrom(cls) @@ -412,7 +260,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - source = provider, conf = df.sparkSession.sessionState.conf) + source = provider, conf = ds.sparkSession.sessionState.conf) val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) @@ -420,7 +268,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { // to `getTable`. This is for avoiding schema inference, which can be very expensive. // If the query schema is not compatible with the existing data, the behavior is undefined. val outputSchema = if (provider.supportsExternalMetadata()) { - Some(df.schema) + Some(ds.schema) } else { None } @@ -450,12 +298,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { recoverFromCheckpoint: Boolean = true, catalogAndIdent: Option[(TableCatalog, Identifier)] = None, catalogTable: Option[CatalogTable] = None): StreamingQuery = { - val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY.contains(source) + val useTempCheckpointLocation = DataStreamWriter.SOURCES_ALLOW_ONE_TIME_QUERY.contains(source) - df.sparkSession.sessionState.streamingQueryManager.startQuery( + ds.sparkSession.sessionState.streamingQueryManager.startQuery( newOptions.get("queryName"), newOptions.get("checkpointLocation"), - df, + ds, newOptions.originalMap, sink, outputMode, @@ -480,26 +328,21 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { case None => optionsWithoutClusteringKey } val ds = DataSource( - df.sparkSession, + this.ds.sparkSession, className = source, options = optionsWithClusteringColumns, partitionColumns = normalizedParCols.getOrElse(Nil)) ds.createSink(outputMode) } - /** - * Sets the output of the streaming query to be processed using the provided writer object. - * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and - * semantics. - * @since 2.0.0 - */ - def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = { + /** @inheritdoc */ + def foreach(writer: ForeachWriter[T]): this.type = { foreachImplementation(writer.asInstanceOf[ForeachWriter[Any]]) } private[sql] def foreachImplementation(writer: ForeachWriter[Any], - encoder: Option[ExpressionEncoder[Any]] = None): DataStreamWriter[T] = { - this.source = SOURCE_NAME_FOREACH + encoder: Option[ExpressionEncoder[Any]] = None): this.type = { + this.source = DataStreamWriter.SOURCE_NAME_FOREACH this.foreachWriter = if (writer != null) { ds.sparkSession.sparkContext.clean(writer) } else { @@ -509,47 +352,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { this } - /** - * :: Experimental :: - * - * (Scala-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. - * The batchId can be used to deduplicate and transactionally write the output - * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed - * to be exactly the same for the same batchId (assuming all operations are deterministic - * in the query). - * - * @since 2.4.0 - */ + /** @inheritdoc */ @Evolving - def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = { - this.source = SOURCE_NAME_FOREACH_BATCH + def foreachBatch(function: (Dataset[T], Long) => Unit): this.type = { + this.source = DataStreamWriter.SOURCE_NAME_FOREACH_BATCH if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null") this.foreachBatchWriter = function this } - /** - * :: Experimental :: - * - * (Java-specific) Sets the output of the streaming query to be processed using the provided - * function. This is supported only in the micro-batch execution modes (that is, when the - * trigger is not continuous). In every micro-batch, the provided function will be called in - * every micro-batch with (i) the output rows as a Dataset and (ii) the batch identifier. - * The batchId can be used to deduplicate and transactionally write the output - * (that is, the provided Dataset) to external systems. The output Dataset is guaranteed - * to be exactly the same for the same batchId (assuming all operations are deterministic - * in the query). - * - * @since 2.4.0 - */ - @Evolving - def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): DataStreamWriter[T] = { - foreachBatch((batchDs: Dataset[T], batchId: Long) => function.call(batchDs, batchId)) - } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => cols.map(normalize(_, "Partition")) } @@ -564,8 +375,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { * need to care about case sensitivity afterwards. */ private def normalize(columnName: String, columnType: String): String = { - val validColumnNames = df.logicalPlan.output.map(_.name) - validColumnNames.find(df.sparkSession.sessionState.analyzer.resolver(_, columnName)) + val validColumnNames = ds.logicalPlan.output.map(_.name) + validColumnNames.find(ds.sparkSession.sessionState.analyzer.resolver(_, columnName)) .getOrElse(throw QueryCompilationErrors.columnNotFoundInExistingColumnsError( columnType, columnName, validColumnNames)) } @@ -584,12 +395,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } /////////////////////////////////////////////////////////////////////////////////////// - // Builder pattern config options + // Covariant Overrides /////////////////////////////////////////////////////////////////////////////////////// - private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + @Evolving + override def foreachBatch(function: VoidFunction2[Dataset[T], java.lang.Long]): this.type = + super.foreachBatch(function) + + /////////////////////////////////////////////////////////////////////////////////////// + // Builder pattern config options + /////////////////////////////////////////////////////////////////////////////////////// - private var tableName: String = null + private var source: String = ds.sparkSession.sessionState.conf.defaultDataSourceName private var outputMode: OutputMode = OutputMode.Append @@ -597,12 +424,12 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { private var extraOptions = CaseInsensitiveMap[String](Map.empty) - private var foreachWriter: ForeachWriter[Any] = null + private var foreachWriter: ForeachWriter[Any] = _ private var foreachWriterEncoder: ExpressionEncoder[Any] = ds.exprEnc.asInstanceOf[ExpressionEncoder[Any]] - private var foreachBatchWriter: (Dataset[T], Long) => Unit = null + private var foreachBatchWriter: (Dataset[T], Long) => Unit = _ private var partitioningColumns: Option[Seq[String]] = None @@ -610,14 +437,14 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } object DataStreamWriter { - val SOURCE_NAME_MEMORY = "memory" - val SOURCE_NAME_FOREACH = "foreach" - val SOURCE_NAME_FOREACH_BATCH = "foreachBatch" - val SOURCE_NAME_CONSOLE = "console" - val SOURCE_NAME_TABLE = "table" - val SOURCE_NAME_NOOP = "noop" + val SOURCE_NAME_MEMORY: String = "memory" + val SOURCE_NAME_FOREACH: String = "foreach" + val SOURCE_NAME_FOREACH_BATCH: String = "foreachBatch" + val SOURCE_NAME_CONSOLE: String = "console" + val SOURCE_NAME_TABLE: String = "table" + val SOURCE_NAME_NOOP: String = "noop" // these writer sources are also used for one-time query, hence allow temp checkpoint location - val SOURCES_ALLOW_ONE_TIME_QUERY = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH, + val SOURCES_ALLOW_ONE_TIME_QUERY: Seq[String] = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH, SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 3ab6d02f6b515..9d6fd2e28dea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CLASS_NAME, QUERY_ID, RUN_ID} -import org.apache.spark.sql.{DataFrame, SparkSession} +import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.streaming.{WriteToStream, WriteToStreamStatement} import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog} @@ -241,7 +241,7 @@ class StreamingQueryManager private[sql] ( private def createQuery( userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], - df: DataFrame, + df: Dataset[_], extraOptions: Map[String, String], sink: Table, outputMode: OutputMode, @@ -322,7 +322,7 @@ class StreamingQueryManager private[sql] ( private[sql] def startQuery( userSpecifiedName: Option[String], userSpecifiedCheckpointLocation: Option[String], - df: DataFrame, + df: Dataset[_], extraOptions: Map[String, String], sink: Table, outputMode: OutputMode, From 35e5d290deee9cf2a913571407e2257217e0e9e2 Mon Sep 17 00:00:00 2001 From: Chris Nauroth Date: Mon, 23 Sep 2024 21:35:32 -0700 Subject: [PATCH 060/250] [SPARK-49760][YARN] Correct handling of `SPARK_USER` env variable override in app master ### What changes were proposed in this pull request? This patch corrects handling of a user-supplied `SPARK_USER` environment variable in the YARN app master. Currently, the user-supplied value gets appended to the default, like a classpath entry. The patch fixes it by using only the user-supplied value. ### Why are the changes needed? Overriding the `SPARK_USER` environment variable in the YARN app master with configuration property `spark.yarn.appMasterEnv.SPARK_USER` currently results in an incorrect value. `Client#setupLaunchEnv` first sets a default in the environment map using the Hadoop user. After that, `YarnSparkHadoopUtil.addPathToEnvironment` sees the existing value in the map and interprets the user-supplied value as needing to be appended like a classpath entry. The end result is the Hadoop user appended with the classpath delimiter and user-supplied value, e.g. `cnauroth:overrideuser`. ### Does this PR introduce _any_ user-facing change? Yes, the app master now uses the user-supplied `SPARK_USER` if specified. (The default is still the Hadoop user.) ### How was this patch tested? * Existing unit tests pass. * Added new unit tests covering default and overridden `SPARK_USER` for the app master. The override test fails without this patch, and then passes after the patch is applied. * Manually tested in a live YARN cluster as shown below. Manual testing used the `DFSReadWriteTest` job with overrides of `SPARK_USER`: ``` spark-submit \ --deploy-mode cluster \ --files all-lines.txt \ --class org.apache.spark.examples.DFSReadWriteTest \ --conf spark.yarn.appMasterEnv.SPARK_USER=sparkuser_appMaster \ --conf spark.driverEnv.SPARK_USER=sparkuser_driver \ --conf spark.executorEnv.SPARK_USER=sparkuser_executor \ /usr/lib/spark/examples/jars/spark-examples.jar \ all-lines.txt /tmp/DFSReadWriteTest ``` Before the patch, we can see the app master's `SPARK_USER` mishandled by looking at the `_SUCCESS` file in HDFS: ``` hdfs dfs -ls -R /tmp/DFSReadWriteTest drwxr-xr-x - cnauroth:sparkuser_appMaster hadoop 0 2024-09-20 23:35 /tmp/DFSReadWriteTest/dfs_read_write_test -rw-r--r-- 1 cnauroth:sparkuser_appMaster hadoop 0 2024-09-20 23:35 /tmp/DFSReadWriteTest/dfs_read_write_test/_SUCCESS -rw-r--r-- 1 sparkuser_executor hadoop 2295080 2024-09-20 23:35 /tmp/DFSReadWriteTest/dfs_read_write_test/part-00000 -rw-r--r-- 1 sparkuser_executor hadoop 2288718 2024-09-20 23:35 /tmp/DFSReadWriteTest/dfs_read_write_test/part-00001 ``` After the patch, we can see it working correctly: ``` hdfs dfs -ls -R /tmp/DFSReadWriteTest drwxr-xr-x - sparkuser_appMaster hadoop 0 2024-09-23 17:13 /tmp/DFSReadWriteTest/dfs_read_write_test -rw-r--r-- 1 sparkuser_appMaster hadoop 0 2024-09-23 17:13 /tmp/DFSReadWriteTest/dfs_read_write_test/_SUCCESS -rw-r--r-- 1 sparkuser_executor hadoop 2295080 2024-09-23 17:13 /tmp/DFSReadWriteTest/dfs_read_write_test/part-00000 -rw-r--r-- 1 sparkuser_executor hadoop 2288718 2024-09-23 17:13 /tmp/DFSReadWriteTest/dfs_read_write_test/part-00001 ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48214 from cnauroth/SPARK-49760. Authored-by: Chris Nauroth Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/yarn/Client.scala | 7 +++++-- .../apache/spark/deploy/yarn/ClientSuite.scala | 16 ++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index b2c4d97bc7b07..8b621e82afe28 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -960,14 +960,13 @@ private[spark] class Client( /** * Set up the environment for launching our ApplicationMaster container. */ - private def setupLaunchEnv( + private[yarn] def setupLaunchEnv( stagingDirPath: Path, pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() populateClasspath(args, hadoopConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_STAGING_DIR") = stagingDirPath.toString - env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() env("SPARK_PREFER_IPV6") = Utils.preferIPv6.toString // Pick up any environment variables for the AM provided through spark.yarn.appMasterEnv.* @@ -977,6 +976,10 @@ private[spark] class Client( .map { case (k, v) => (k.substring(amEnvPrefix.length), v) } .foreach { case (k, v) => YarnSparkHadoopUtil.addPathToEnvironment(env, k, v) } + if (!env.contains("SPARK_USER")) { + env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() + } + // If pyFiles contains any .py files, we need to add LOCALIZED_PYTHON_DIR to the PYTHONPATH // of the container processes too. Add all non-.py files directly to PYTHONPATH. // diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 78e84690900e1..93d6cc474d20f 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -29,6 +29,7 @@ import scala.jdk.CollectionConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.hadoop.mapreduce.MRJobConfig +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords.{GetNewApplicationResponse, SubmitApplicationRequest} import org.apache.hadoop.yarn.api.records._ @@ -739,6 +740,21 @@ class ClientSuite extends SparkFunSuite } } + test("SPARK-49760: default app master SPARK_USER") { + val sparkConf = new SparkConf() + val client = createClient(sparkConf) + val env = client.setupLaunchEnv(new Path("/staging/dir/path"), Seq()) + env("SPARK_USER") should be (UserGroupInformation.getCurrentUser().getShortUserName()) + } + + test("SPARK-49760: override app master SPARK_USER") { + val sparkConf = new SparkConf() + .set("spark.yarn.appMasterEnv.SPARK_USER", "overrideuser") + val client = createClient(sparkConf) + val env = client.setupLaunchEnv(new Path("/staging/dir/path"), Seq()) + env("SPARK_USER") should be ("overrideuser") + } + private val matching = Seq( ("files URI match test1", "file:///file1", "file:///file2"), ("files URI match test2", "file:///c:file1", "file://c:file2"), From 64ea50e87c70aea6b22a66ec1a0c98ae29a5dd81 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 24 Sep 2024 13:36:22 +0900 Subject: [PATCH 061/250] [SPARK-49607][PYTHON] Update the sampling approach for sampled based plots ### What changes were proposed in this pull request? 1, Update the sampling approach for sampled based plots 2, Eliminate "spark.sql.pyspark.plotting.sample_ratio" config ### Why are the changes needed? 1, to be consistent with the PS plotting; 2, the "spark.sql.pyspark.plotting.sample_ratio" config is not friendly to large scale data: the plotting backend cannot render large number of data points efficiently, and it is hard for users to set an appropriate sample ratio; ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48218 from zhengruifeng/py_plot_sampling. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/plot/core.py | 36 ++++++++++++++----- .../pyspark/sql/tests/plot/test_frame_plot.py | 14 +------- .../apache/spark/sql/internal/SQLConf.scala | 16 --------- 3 files changed, 28 insertions(+), 38 deletions(-) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index ed22d02370ca6..eb00b8a04f977 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -50,27 +50,45 @@ def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame": class PySparkSampledPlotBase: def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": - from pyspark.sql import SparkSession + from pyspark.sql import SparkSession, Observation, functions as F session = SparkSession.getActiveSession() if session is None: raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict()) - sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio") max_rows = int( session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type] ) - if sample_ratio is None: - fraction = 1 / (sdf.count() / max_rows) - fraction = min(1.0, fraction) - else: - fraction = float(sample_ratio) + observation = Observation("pyspark plotting") - sampled_sdf = sdf.sample(fraction=fraction) + rand_col_name = "__pyspark_plotting_sampled_plot_base_rand__" + id_col_name = "__pyspark_plotting_sampled_plot_base_id__" + + sampled_sdf = ( + sdf.observe(observation, F.count(F.lit(1)).alias("count")) + .select( + "*", + F.rand().alias(rand_col_name), + F.monotonically_increasing_id().alias(id_col_name), + ) + .sort(rand_col_name) + .limit(max_rows + 1) + .coalesce(1) + .sortWithinPartitions(id_col_name) + .drop(rand_col_name, id_col_name) + ) pdf = sampled_sdf.toPandas() - return pdf + if len(pdf) > max_rows: + try: + self.fraction = float(max_rows) / observation.get["count"] + except Exception: + pass + return pdf[:max_rows] + else: + self.fraction = 1.0 + return pdf class PySparkPlotAccessor: diff --git a/python/pyspark/sql/tests/plot/test_frame_plot.py b/python/pyspark/sql/tests/plot/test_frame_plot.py index f753b5ab3db72..2a6971e896292 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot.py @@ -39,23 +39,11 @@ def test_backend(self): ) def test_topn_max_rows(self): - try: + with self.sql_conf({"spark.sql.pyspark.plotting.max_rows": "1000"}): self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000") sdf = self.spark.range(2500) pdf = PySparkTopNPlotBase().get_top_n(sdf) self.assertEqual(len(pdf), 1000) - finally: - self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows") - - def test_sampled_plot_with_ratio(self): - try: - self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", "0.5") - data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)] - sdf = self.spark.createDataFrame(data) - pdf = PySparkSampledPlotBase().get_sampled(sdf) - self.assertEqual(round(len(pdf) / 2500, 1), 0.5) - finally: - self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio") def test_sampled_plot_with_max_rows(self): data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4d0930212b373..9d51afd064d10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3178,20 +3178,6 @@ object SQLConf { .intConf .createWithDefault(1000) - val PYSPARK_PLOT_SAMPLE_RATIO = - buildConf("spark.sql.pyspark.plotting.sample_ratio") - .doc( - "The proportion of data that will be plotted for sample-based plots. It is determined " + - "based on spark.sql.pyspark.plotting.max_rows if not explicitly set." - ) - .version("4.0.0") - .doubleConf - .checkValue( - ratio => ratio >= 0.0 && ratio <= 1.0, - "The value should be between 0.0 and 1.0 inclusive." - ) - .createOptional - val ARROW_SPARKR_EXECUTION_ENABLED = buildConf("spark.sql.execution.arrow.sparkr.enabled") .doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " + @@ -5907,8 +5893,6 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS) - def pysparkPlotSampleRatio: Option[Double] = getConf(PYSPARK_PLOT_SAMPLE_RATIO) - def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED) def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED) From 438a6e7782ece23492928cfbb2d01e14104dfd9a Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 23 Sep 2024 21:39:27 -0700 Subject: [PATCH 062/250] [SPARK-49753][BUILD] Upgrade ZSTD-JNI to 1.5.6-6 ### What changes were proposed in this pull request? The pr aims to upgrade `zstd-jni` from `1.5.6-5` to `1.5.6-6`. ### Why are the changes needed? The new version allow including compression level when training a dictionary: https://github.com/luben/zstd-jni/commit/3ca26eed6c84fb09c382854ead527188e643e206#diff-bd5c0f62db7cb85cac88c7b6cfad1c0e5e2f433ba45097761654829627b7a31c All changes in the new version are as follows: - https://github.com/luben/zstd-jni/compare/v1.5.6-5...v1.5.6-6 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48204 from LuciferYang/zstd-jni-1.5.6-6. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- .../ZStandardBenchmark-jdk21-results.txt | 56 +++++++++---------- .../benchmarks/ZStandardBenchmark-results.txt | 56 +++++++++---------- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 4 files changed, 58 insertions(+), 58 deletions(-) diff --git a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt index b3bffea826e5f..f6bd681451d5e 100644 --- a/core/benchmarks/ZStandardBenchmark-jdk21-results.txt +++ b/core/benchmarks/ZStandardBenchmark-jdk21-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 657 670 14 0.0 65699.2 1.0X -Compression 10000 times at level 2 without buffer pool 697 697 1 0.0 69673.4 0.9X -Compression 10000 times at level 3 without buffer pool 799 802 3 0.0 79855.2 0.8X -Compression 10000 times at level 1 with buffer pool 593 595 1 0.0 59326.9 1.1X -Compression 10000 times at level 2 with buffer pool 622 624 3 0.0 62194.1 1.1X -Compression 10000 times at level 3 with buffer pool 732 733 1 0.0 73178.6 0.9X +Compression 10000 times at level 1 without buffer pool 659 676 16 0.0 65860.7 1.0X +Compression 10000 times at level 2 without buffer pool 721 723 2 0.0 72135.5 0.9X +Compression 10000 times at level 3 without buffer pool 815 816 1 0.0 81500.6 0.8X +Compression 10000 times at level 1 with buffer pool 608 609 0 0.0 60846.6 1.1X +Compression 10000 times at level 2 with buffer pool 645 647 3 0.0 64476.3 1.0X +Compression 10000 times at level 3 with buffer pool 746 746 1 0.0 74584.0 0.9X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 813 820 11 0.0 81273.2 1.0X -Decompression 10000 times from level 2 without buffer pool 810 813 3 0.0 80986.2 1.0X -Decompression 10000 times from level 3 without buffer pool 812 813 2 0.0 81183.1 1.0X -Decompression 10000 times from level 1 with buffer pool 746 747 2 0.0 74568.7 1.1X -Decompression 10000 times from level 2 with buffer pool 744 746 2 0.0 74414.5 1.1X -Decompression 10000 times from level 3 with buffer pool 745 746 1 0.0 74538.6 1.1X +Decompression 10000 times from level 1 without buffer pool 828 829 1 0.0 82822.6 1.0X +Decompression 10000 times from level 2 without buffer pool 829 829 1 0.0 82900.7 1.0X +Decompression 10000 times from level 3 without buffer pool 828 833 8 0.0 82784.4 1.0X +Decompression 10000 times from level 1 with buffer pool 758 760 2 0.0 75756.5 1.1X +Decompression 10000 times from level 2 with buffer pool 758 758 1 0.0 75772.3 1.1X +Decompression 10000 times from level 3 with buffer pool 759 759 0 0.0 75852.7 1.1X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 48 49 1 0.0 374256.1 1.0X -Parallel Compression with 1 workers 34 36 3 0.0 267557.3 1.4X -Parallel Compression with 2 workers 34 38 2 0.0 263684.3 1.4X -Parallel Compression with 4 workers 37 39 2 0.0 289956.1 1.3X -Parallel Compression with 8 workers 39 41 1 0.0 306975.2 1.2X -Parallel Compression with 16 workers 44 45 1 0.0 340992.0 1.1X +Parallel Compression with 0 workers 58 59 1 0.0 452489.9 1.0X +Parallel Compression with 1 workers 42 45 4 0.0 330066.0 1.4X +Parallel Compression with 2 workers 40 42 1 0.0 312560.3 1.4X +Parallel Compression with 4 workers 40 42 2 0.0 308802.7 1.5X +Parallel Compression with 8 workers 41 45 3 0.0 321331.3 1.4X +Parallel Compression with 16 workers 44 45 1 0.0 343311.5 1.3X -OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 21.0.4+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 156 158 1 0.0 1220760.5 1.0X -Parallel Compression with 1 workers 191 192 2 0.0 1495168.2 0.8X -Parallel Compression with 2 workers 111 117 5 0.0 864459.9 1.4X -Parallel Compression with 4 workers 106 109 2 0.0 831025.5 1.5X -Parallel Compression with 8 workers 112 115 2 0.0 875732.7 1.4X -Parallel Compression with 16 workers 110 114 2 0.0 858160.9 1.4X +Parallel Compression with 0 workers 158 160 2 0.0 1234257.6 1.0X +Parallel Compression with 1 workers 193 194 1 0.0 1507686.4 0.8X +Parallel Compression with 2 workers 113 127 11 0.0 881068.0 1.4X +Parallel Compression with 4 workers 109 111 2 0.0 849241.3 1.5X +Parallel Compression with 8 workers 111 115 3 0.0 869455.2 1.4X +Parallel Compression with 16 workers 113 116 2 0.0 881832.5 1.4X diff --git a/core/benchmarks/ZStandardBenchmark-results.txt b/core/benchmarks/ZStandardBenchmark-results.txt index b230f825fecac..136f0333590cc 100644 --- a/core/benchmarks/ZStandardBenchmark-results.txt +++ b/core/benchmarks/ZStandardBenchmark-results.txt @@ -2,48 +2,48 @@ Benchmark ZStandardCompressionCodec ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------- -Compression 10000 times at level 1 without buffer pool 638 638 0 0.0 63765.0 1.0X -Compression 10000 times at level 2 without buffer pool 675 676 1 0.0 67529.4 0.9X -Compression 10000 times at level 3 without buffer pool 775 783 11 0.0 77531.6 0.8X -Compression 10000 times at level 1 with buffer pool 572 573 1 0.0 57223.2 1.1X -Compression 10000 times at level 2 with buffer pool 603 605 1 0.0 60323.7 1.1X -Compression 10000 times at level 3 with buffer pool 720 727 6 0.0 71980.9 0.9X +Compression 10000 times at level 1 without buffer pool 257 259 2 0.0 25704.2 1.0X +Compression 10000 times at level 2 without buffer pool 674 676 2 0.0 67396.3 0.4X +Compression 10000 times at level 3 without buffer pool 775 787 11 0.0 77497.9 0.3X +Compression 10000 times at level 1 with buffer pool 573 574 0 0.0 57347.3 0.4X +Compression 10000 times at level 2 with buffer pool 602 603 2 0.0 60162.8 0.4X +Compression 10000 times at level 3 with buffer pool 722 725 3 0.0 72247.3 0.4X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Benchmark ZStandardCompressionCodec: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------------ -Decompression 10000 times from level 1 without buffer pool 584 585 1 0.0 58381.0 1.0X -Decompression 10000 times from level 2 without buffer pool 585 585 0 0.0 58465.9 1.0X -Decompression 10000 times from level 3 without buffer pool 585 586 1 0.0 58499.5 1.0X -Decompression 10000 times from level 1 with buffer pool 534 534 0 0.0 53375.7 1.1X -Decompression 10000 times from level 2 with buffer pool 533 533 0 0.0 53312.3 1.1X -Decompression 10000 times from level 3 with buffer pool 533 533 1 0.0 53255.1 1.1X +Decompression 10000 times from level 1 without buffer pool 176 177 1 0.1 17641.2 1.0X +Decompression 10000 times from level 2 without buffer pool 176 178 1 0.1 17628.9 1.0X +Decompression 10000 times from level 3 without buffer pool 175 176 0 0.1 17506.1 1.0X +Decompression 10000 times from level 1 with buffer pool 151 152 1 0.1 15051.5 1.2X +Decompression 10000 times from level 2 with buffer pool 150 151 1 0.1 14998.0 1.2X +Decompression 10000 times from level 3 with buffer pool 150 151 0 0.1 15019.4 1.2X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 3: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 46 48 1 0.0 360483.5 1.0X -Parallel Compression with 1 workers 34 36 2 0.0 265816.1 1.4X -Parallel Compression with 2 workers 33 36 2 0.0 254525.8 1.4X -Parallel Compression with 4 workers 34 37 1 0.0 266270.8 1.4X -Parallel Compression with 8 workers 37 39 1 0.0 289289.2 1.2X -Parallel Compression with 16 workers 41 43 1 0.0 320243.3 1.1X +Parallel Compression with 0 workers 57 57 0 0.0 444425.2 1.0X +Parallel Compression with 1 workers 42 44 3 0.0 325107.6 1.4X +Parallel Compression with 2 workers 38 39 2 0.0 294840.0 1.5X +Parallel Compression with 4 workers 36 37 1 0.0 282143.1 1.6X +Parallel Compression with 8 workers 39 40 1 0.0 303793.6 1.5X +Parallel Compression with 16 workers 41 43 1 0.0 324165.5 1.4X -OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.5.0-1025-azure +OpenJDK 64-Bit Server VM 17.0.12+7-LTS on Linux 6.8.0-1014-azure AMD EPYC 7763 64-Core Processor Parallel Compression at level 9: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Parallel Compression with 0 workers 154 156 2 0.0 1205934.0 1.0X -Parallel Compression with 1 workers 191 194 4 0.0 1495729.9 0.8X -Parallel Compression with 2 workers 110 114 5 0.0 859158.9 1.4X -Parallel Compression with 4 workers 105 108 3 0.0 822932.2 1.5X -Parallel Compression with 8 workers 109 113 2 0.0 851560.0 1.4X -Parallel Compression with 16 workers 111 115 2 0.0 870695.9 1.4X +Parallel Compression with 0 workers 156 158 1 0.0 1220298.8 1.0X +Parallel Compression with 1 workers 188 189 1 0.0 1467911.4 0.8X +Parallel Compression with 2 workers 111 118 7 0.0 866985.2 1.4X +Parallel Compression with 4 workers 106 109 2 0.0 827592.1 1.5X +Parallel Compression with 8 workers 114 116 2 0.0 888419.5 1.4X +Parallel Compression with 16 workers 111 115 2 0.0 868463.5 1.4X diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 419625f48fa11..88526995293f5 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -280,4 +280,4 @@ xz/1.10//xz-1.10.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar zookeeper-jute/3.9.2//zookeeper-jute-3.9.2.jar zookeeper/3.9.2//zookeeper-3.9.2.jar -zstd-jni/1.5.6-5//zstd-jni-1.5.6-5.jar +zstd-jni/1.5.6-6//zstd-jni-1.5.6-6.jar diff --git a/pom.xml b/pom.xml index b7c87beec0f92..131e754da8157 100644 --- a/pom.xml +++ b/pom.xml @@ -835,7 +835,7 @@ com.github.luben zstd-jni - 1.5.6-5 + 1.5.6-6 com.clearspring.analytics From 6bdd151d57759d73870f20780fc54ab2aa250409 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 24 Sep 2024 15:40:38 +0800 Subject: [PATCH 063/250] [SPARK-49694][PYTHON][CONNECT] Support scatter plots ### What changes were proposed in this pull request? Support scatter plots with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. Scatter plots are supported as shown below. ```py >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] >>> columns = ["length", "width", "species"] >>> sdf = spark.createDataFrame(data, columns) >>> fig = sdf.plot(kind="scatter", x="length", y="width") # or fig = sdf.plot.scatter(x="length", y="width") >>> fig.show() ``` ![newplot (6)](https://github.com/user-attachments/assets/deef452b-74d1-4f6d-b1ae-60722f3c2b17) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48219 from xinrong-meng/plot_scatter. Authored-by: Xinrong Meng Signed-off-by: Xinrong Meng --- python/pyspark/sql/plot/core.py | 34 +++++++++++++++++++ .../sql/tests/plot/test_frame_plot_plotly.py | 19 +++++++++++ 2 files changed, 53 insertions(+) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index eb00b8a04f977..0a3a0101e1898 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -96,6 +96,7 @@ class PySparkPlotAccessor: "bar": PySparkTopNPlotBase().get_top_n, "barh": PySparkTopNPlotBase().get_top_n, "line": PySparkSampledPlotBase().get_sampled, + "scatter": PySparkSampledPlotBase().get_sampled, } _backends = {} # type: ignore[var-annotated] @@ -230,3 +231,36 @@ def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": ... ) # doctest: +SKIP """ return self(kind="barh", x=x, y=y, **kwargs) + + def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure": + """ + Create a scatter plot with varying marker point size and color. + + The coordinates of each point are defined by two dataframe columns and + filled circles are used to represent each point. This kind of plot is + useful to see complex correlations between two variables. Points could + be for instance natural 2D coordinates like longitude and latitude in + a map or, in general, any pair of metrics that can be plotted against + each other. + + Parameters + ---------- + x : str + Name of column to use as horizontal coordinates for each point. + y : str or list of str + Name of column to use as vertical coordinates for each point. + **kwargs: Optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + >>> columns = ['length', 'width', 'species'] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.scatter(x='length', y='width') # doctest: +SKIP + """ + return self(kind="scatter", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 1c52c93a23d3a..ccfe1a75424e0 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -28,6 +28,12 @@ def sdf(self): columns = ["category", "int_val", "float_val"] return self.spark.createDataFrame(data, columns) + @property + def sdf2(self): + data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + columns = ["length", "width", "species"] + return self.spark.createDataFrame(data, columns) + def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""): if kind == "line": self.assertEqual(fig_data["mode"], "lines") @@ -37,6 +43,9 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name= elif kind == "barh": self.assertEqual(fig_data["type"], "bar") self.assertEqual(fig_data["orientation"], "h") + elif kind == "scatter": + self.assertEqual(fig_data["type"], "scatter") + self.assertEqual(fig_data["orientation"], "v") self.assertEqual(fig_data["xaxis"], "x") self.assertEqual(list(fig_data["x"]), expected_x) @@ -79,6 +88,16 @@ def test_barh_plot(self): self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val") self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val") + def test_scatter_plot(self): + fig = self.sdf2.plot(kind="scatter", x="length", y="width") + self._check_fig_data( + "scatter", fig["data"][0], [5.1, 4.9, 7.0, 6.4, 5.9], [3.5, 3.0, 3.2, 3.2, 3.0] + ) + fig = self.sdf2.plot.scatter(x="width", y="length") + self._check_fig_data( + "scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9] + ) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass From 982028ea7fc61d7aa84756aa46860ebb49bfe9d1 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Tue, 24 Sep 2024 17:21:59 +0900 Subject: [PATCH 064/250] [SPARK-49609][PYTHON][CONNECT] Add API compatibility check between Classic and Connect ### What changes were proposed in this pull request? This PR proposes to add API compatibility check between Classic and Connect. This PR also includes updating both APIs to the same signature. ### Why are the changes needed? APIs supported on both Spark Connect and Spark Classic should guarantee the same signature, such as argument and return types. For example, test would fail when the signature of API is mismatched: ``` Signature mismatch in Column method 'dropFields' Classic: (self, *fieldNames: str) -> pyspark.sql.column.Column Connect: (self, *fieldNames: 'ColumnOrName') -> pyspark.sql.column.Column pyspark.sql.column.Column> != pyspark.sql.column.Column> Expected : pyspark.sql.column.Column> Actual : pyspark.sql.column.Column> ``` ### Does this PR introduce _any_ user-facing change? No, it is a test to prevent future API behavior inconsistencies between Classic and Connect. ### How was this patch tested? Added UTs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48085 from itholic/SPARK-49609. Authored-by: Haejoon Lee Signed-off-by: Haejoon Lee --- dev/sparktestsupport/modules.py | 1 + python/pyspark/sql/classic/dataframe.py | 6 +- python/pyspark/sql/connect/dataframe.py | 26 ++- python/pyspark/sql/session.py | 3 +- .../sql/tests/test_connect_compatibility.py | 188 ++++++++++++++++++ 5 files changed, 209 insertions(+), 15 deletions(-) create mode 100644 python/pyspark/sql/tests/test_connect_compatibility.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index b9a4bed715f67..eda6b063350e5 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -550,6 +550,7 @@ def __hash__(self): "pyspark.sql.tests.test_resources", "pyspark.sql.tests.plot.test_frame_plot", "pyspark.sql.tests.plot.test_frame_plot_plotly", + "pyspark.sql.tests.test_connect_compatibility", ], ) diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index a2778cbc32c4c..23484fcf0051f 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -1068,7 +1068,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame: jdf = self._jdf.selectExpr(self._jseq(expr)) return DataFrame(jdf, self.sparkSession) - def filter(self, condition: "ColumnOrName") -> ParentDataFrame: + def filter(self, condition: Union[Column, str]) -> ParentDataFrame: if isinstance(condition, str): jdf = self._jdf.filter(condition) elif isinstance(condition, Column): @@ -1809,10 +1809,10 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ign def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame: return self.dropDuplicates(subset) - def writeTo(self, table: str) -> DataFrameWriterV2: + def writeTo(self, table: str) -> "DataFrameWriterV2": return DataFrameWriterV2(self, table) - def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter: + def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter": return MergeIntoWriter(self, table, condition) def pandas_api( diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 59d79decf6690..cb37af8868aad 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -535,7 +535,7 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData": ... - def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData: + def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] @@ -570,7 +570,7 @@ def rollup(self, *cols: "ColumnOrName") -> "GroupedData": def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData": ... - def rollup(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc] + def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc] _cols: List[Column] = [] for c in cols: if isinstance(c, Column): @@ -731,8 +731,8 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column: session=self._session, ) - def limit(self, n: int) -> ParentDataFrame: - res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session) + def limit(self, num: int) -> ParentDataFrame: + res = DataFrame(plan.Limit(child=self._plan, limit=num), session=self._session) res._cached_schema = self._cached_schema return res @@ -931,7 +931,11 @@ def _show_string( )._to_table() return table[0][0].as_py() - def withColumns(self, colsMap: Dict[str, Column]) -> ParentDataFrame: + def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame: + # Below code is to help enable kwargs in future. + assert len(colsMap) == 1 + colsMap = colsMap[0] # type: ignore[assignment] + if not isinstance(colsMap, dict): raise PySparkTypeError( errorClass="NOT_DICT", @@ -1256,7 +1260,7 @@ def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame: res._cached_schema = self._merge_cached_schema(other) return res - def where(self, condition: Union[Column, str]) -> ParentDataFrame: + def where(self, condition: "ColumnOrName") -> ParentDataFrame: if not isinstance(condition, (str, Column)): raise PySparkTypeError( errorClass="NOT_COLUMN_OR_STR", @@ -2193,7 +2197,7 @@ def cb(ei: "ExecutionInfo") -> None: return DataFrameWriterV2(self._plan, self._session, table, cb) - def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter: + def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter": def cb(ei: "ExecutionInfo") -> None: self._execution_info = ei @@ -2201,10 +2205,10 @@ def cb(ei: "ExecutionInfo") -> None: self._plan, self._session, table, condition, cb # type: ignore[arg-type] ) - def offset(self, n: int) -> ParentDataFrame: - return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session) + def offset(self, num: int) -> ParentDataFrame: + return DataFrame(plan.Offset(child=self._plan, offset=num), session=self._session) - def checkpoint(self, eager: bool = True) -> "DataFrame": + def checkpoint(self, eager: bool = True) -> ParentDataFrame: cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager) _, properties, self._execution_info = self._session.client.execute_command( cmd.command(self._session.client) @@ -2214,7 +2218,7 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) return checkpointed - def localCheckpoint(self, eager: bool = True) -> "DataFrame": + def localCheckpoint(self, eager: bool = True) -> ParentDataFrame: cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) _, properties, self._execution_info = self._session.client.execute_command( cmd.command(self._session.client) diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index b513d8d4111b9..96344efba2d2a 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -77,6 +77,7 @@ from pyspark.sql.udf import UDFRegistration from pyspark.sql.udtf import UDTFRegistration from pyspark.sql.datasource import DataSourceRegistration + from pyspark.sql.dataframe import DataFrame as ParentDataFrame # Running MyPy type checks will always require pandas and # other dependencies so importing here is fine. @@ -1641,7 +1642,7 @@ def prepare(obj: Any) -> Any: def sql( self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None, **kwargs: Any - ) -> DataFrame: + ) -> "ParentDataFrame": """Returns a :class:`DataFrame` representing the result of the given query. When ``kwargs`` is specified, this method formats the given string by using the Python standard formatter. The method binds named parameters to SQL literals or diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py new file mode 100644 index 0000000000000..ca1f828ef4d78 --- /dev/null +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -0,0 +1,188 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import inspect + +from pyspark.testing.sqlutils import ReusedSQLTestCase +from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame +from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame +from pyspark.sql.classic.column import Column as ClassicColumn +from pyspark.sql.connect.column import Column as ConnectColumn +from pyspark.sql.session import SparkSession as ClassicSparkSession +from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + + +class ConnectCompatibilityTestsMixin: + def get_public_methods(self, cls): + """Get public methods of a class.""" + return { + name: method + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction) + if not name.startswith("_") + } + + def get_public_properties(self, cls): + """Get public properties of a class.""" + return { + name: member + for name, member in inspect.getmembers(cls) + if isinstance(member, property) and not name.startswith("_") + } + + def test_signature_comparison_between_classic_and_connect(self): + def compare_method_signatures(classic_cls, connect_cls, cls_name): + """Compare method signatures between classic and connect classes.""" + classic_methods = self.get_public_methods(classic_cls) + connect_methods = self.get_public_methods(connect_cls) + + common_methods = set(classic_methods.keys()) & set(connect_methods.keys()) + + for method in common_methods: + classic_signature = inspect.signature(classic_methods[method]) + connect_signature = inspect.signature(connect_methods[method]) + + # createDataFrame cannot be the same since RDD is not supported from Spark Connect + if not method == "createDataFrame": + self.assertEqual( + classic_signature, + connect_signature, + f"Signature mismatch in {cls_name} method '{method}'\n" + f"Classic: {classic_signature}\n" + f"Connect: {connect_signature}", + ) + + # DataFrame API signature comparison + compare_method_signatures(ClassicDataFrame, ConnectDataFrame, "DataFrame") + + # Column API signature comparison + compare_method_signatures(ClassicColumn, ConnectColumn, "Column") + + # SparkSession API signature comparison + compare_method_signatures(ClassicSparkSession, ConnectSparkSession, "SparkSession") + + def test_property_comparison_between_classic_and_connect(self): + def compare_property_lists(classic_cls, connect_cls, cls_name, expected_missing_properties): + """Compare properties between classic and connect classes.""" + classic_properties = self.get_public_properties(classic_cls) + connect_properties = self.get_public_properties(connect_cls) + + # Identify missing properties + classic_only_properties = set(classic_properties.keys()) - set( + connect_properties.keys() + ) + + # Compare the actual missing properties with the expected ones + self.assertEqual( + classic_only_properties, + expected_missing_properties, + f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}", + ) + + # Expected missing properties for DataFrame + expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"} + + # DataFrame properties comparison + compare_property_lists( + ClassicDataFrame, + ConnectDataFrame, + "DataFrame", + expected_missing_properties_for_dataframe, + ) + + # Expected missing properties for Column (if any, replace with actual values) + expected_missing_properties_for_column = set() + + # Column properties comparison + compare_property_lists( + ClassicColumn, ConnectColumn, "Column", expected_missing_properties_for_column + ) + + # Expected missing properties for SparkSession + expected_missing_properties_for_spark_session = {"sparkContext", "version"} + + # SparkSession properties comparison + compare_property_lists( + ClassicSparkSession, + ConnectSparkSession, + "SparkSession", + expected_missing_properties_for_spark_session, + ) + + def test_missing_methods(self): + def check_missing_methods(classic_cls, connect_cls, cls_name, expected_missing_methods): + """Check for expected missing methods between classic and connect classes.""" + classic_methods = self.get_public_methods(classic_cls) + connect_methods = self.get_public_methods(connect_cls) + + # Identify missing methods + classic_only_methods = set(classic_methods.keys()) - set(connect_methods.keys()) + + # Compare the actual missing methods with the expected ones + self.assertEqual( + classic_only_methods, + expected_missing_methods, + f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}", + ) + + # Expected missing methods for DataFrame + expected_missing_methods_for_dataframe = { + "inputFiles", + "isLocal", + "semanticHash", + "isEmpty", + } + + # DataFrame missing method check + check_missing_methods( + ClassicDataFrame, ConnectDataFrame, "DataFrame", expected_missing_methods_for_dataframe + ) + + # Expected missing methods for Column (if any, replace with actual values) + expected_missing_methods_for_column = set() + + # Column missing method check + check_missing_methods( + ClassicColumn, ConnectColumn, "Column", expected_missing_methods_for_column + ) + + # Expected missing methods for SparkSession (if any, replace with actual values) + expected_missing_methods_for_spark_session = {"newSession"} + + # SparkSession missing method check + check_missing_methods( + ClassicSparkSession, + ConnectSparkSession, + "SparkSession", + expected_missing_methods_for_spark_session, + ) + + +class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.test_connect_compatibility import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From 73d6bd7c35b599690d40efe306eea0774f272ba8 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Tue, 24 Sep 2024 19:06:12 +0900 Subject: [PATCH 065/250] [SPARK-49630][SS] Add flatten option to process collection types with state data source reader ### What changes were proposed in this pull request? Add flatten option to process collection types with state data source reader ### Why are the changes needed? Changes are needed to process entries row-by-row in case we don't have enough memory to fit these collections inside a single row ### Does this PR introduce _any_ user-facing change? Yes Users can provide the following query option: ``` val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, ) .option(StateSourceOptions.STATE_VAR_NAME, ) .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, ) .load() ``` ### How was this patch tested? Added unit tests ``` [info] Run completed in 1 minute, 10 seconds. [info] Total number of tests run: 12 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 12, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48110 from anishshri-db/task/SPARK-49630. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../v2/state/StateDataSource.scala | 19 +- .../v2/state/StatePartitionReader.scala | 55 +--- .../v2/state/utils/SchemaUtil.scala | 264 ++++++++++++------ .../v2/state/StateDataSourceReadSuite.scala | 19 ++ ...ateDataSourceTransformWithStateSuite.scala | 88 +++++- .../streaming/TransformWithStateSuite.scala | 8 +- 6 files changed, 320 insertions(+), 133 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 50b90641d309b..429464ea5438d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -303,13 +303,15 @@ case class StateSourceOptions( readChangeFeed: Boolean, fromSnapshotOptions: Option[FromSnapshotOptions], readChangeFeedOptions: Option[ReadChangeFeedOptions], - stateVarName: Option[String]) { + stateVarName: Option[String], + flattenCollectionTypes: Boolean) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) override def toString: String = { var desc = s"StateSourceOptions(checkpointLocation=$resolvedCpLocation, batchId=$batchId, " + s"operatorId=$operatorId, storeName=$storeName, joinSide=$joinSide, " + - s"stateVarName=${stateVarName.getOrElse("None")}" + s"stateVarName=${stateVarName.getOrElse("None")}, +" + + s"flattenCollectionTypes=$flattenCollectionTypes" if (fromSnapshotOptions.isDefined) { desc += s", snapshotStartBatchId=${fromSnapshotOptions.get.snapshotStartBatchId}" desc += s", snapshotPartitionId=${fromSnapshotOptions.get.snapshotPartitionId}" @@ -334,6 +336,7 @@ object StateSourceOptions extends DataSourceOptions { val CHANGE_START_BATCH_ID = newOption("changeStartBatchId") val CHANGE_END_BATCH_ID = newOption("changeEndBatchId") val STATE_VAR_NAME = newOption("stateVarName") + val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes") object JoinSideValues extends Enumeration { type JoinSideValues = Value @@ -374,6 +377,15 @@ object StateSourceOptions extends DataSourceOptions { val stateVarName = Option(options.get(STATE_VAR_NAME)) .map(_.trim) + val flattenCollectionTypes = try { + Option(options.get(FLATTEN_COLLECTION_TYPES)) + .map(_.toBoolean).getOrElse(true) + } catch { + case _: IllegalArgumentException => + throw StateDataSourceErrors.invalidOptionValue(FLATTEN_COLLECTION_TYPES, + "Boolean value is expected") + } + val joinSide = try { Option(options.get(JOIN_SIDE)) .map(JoinSideValues.withName).getOrElse(JoinSideValues.none) @@ -477,7 +489,8 @@ object StateSourceOptions extends DataSourceOptions { StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, - readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName) + readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName, + flattenCollectionTypes) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index 24166a46bbd39..ae12b18c1f627 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.state import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} -import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} @@ -75,9 +74,11 @@ abstract class StatePartitionReaderBase( StructType(Array(StructField("__dummy__", NullType))) protected val keySchema = { - if (!SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { + if (SchemaUtil.checkVariableType(stateVariableInfoOpt, StateVariableType.MapState)) { + SchemaUtil.getCompositeKeySchema(schema, partition.sourceOptions) + } else { SchemaUtil.getSchemaAsDataType(schema, "key").asInstanceOf[StructType] - } else SchemaUtil.getCompositeKeySchema(schema) + } } protected val valueSchema = if (stateVariableInfoOpt.isDefined) { @@ -98,12 +99,8 @@ abstract class StatePartitionReaderBase( false } - val useMultipleValuesPerKey = if (stateVariableInfoOpt.isDefined && - stateVariableInfoOpt.get.stateVariableType == StateVariableType.ListState) { - true - } else { - false - } + val useMultipleValuesPerKey = SchemaUtil.checkVariableType(stateVariableInfoOpt, + StateVariableType.ListState) val provider = StateStoreProvider.createAndInit( stateStoreProviderId, keySchema, valueSchema, keyStateEncoderSpec, @@ -149,7 +146,7 @@ abstract class StatePartitionReaderBase( /** * An implementation of [[StatePartitionReaderBase]] for the normal mode of State Data - * Source. It reads the the state at a particular batchId. + * Source. It reads the state at a particular batchId. */ class StatePartitionReader( storeConf: StateStoreConf, @@ -181,41 +178,17 @@ class StatePartitionReader( override lazy val iter: Iterator[InternalRow] = { val stateVarName = stateVariableInfoOpt .map(_.stateName).getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) - if (SchemaUtil.isMapStateVariable(stateVariableInfoOpt)) { - SchemaUtil.unifyMapStateRowPair( - store.iterator(stateVarName), keySchema, partition.partition) + + if (stateVariableInfoOpt.isDefined) { + val stateVariableInfo = stateVariableInfoOpt.get + val stateVarType = stateVariableInfo.stateVariableType + SchemaUtil.processStateEntries(stateVarType, stateVarName, store, + keySchema, partition.partition, partition.sourceOptions) } else { store .iterator(stateVarName) .map { pair => - stateVariableInfoOpt match { - case Some(stateVarInfo) => - val stateVarType = stateVarInfo.stateVariableType - - stateVarType match { - case StateVariableType.ValueState => - SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) - - case StateVariableType.ListState => - val key = pair.key - val result = store.valuesIterator(key, stateVarName) - var unsafeRowArr: Seq[UnsafeRow] = Seq.empty - result.foreach { entry => - unsafeRowArr = unsafeRowArr :+ entry.copy() - } - // convert the list of values to array type - val arrData = new GenericArrayData(unsafeRowArr.toArray) - SchemaUtil.unifyStateRowPairWithMultipleValues((pair.key, arrData), - partition.partition) - - case _ => - throw new IllegalStateException( - s"Unsupported state variable type: $stateVarType") - } - - case None => - SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) - } + SchemaUtil.unifyStateRowPair((pair.key, pair.value), partition.partition) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index 88ea06d598e56..dc0d6af951143 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2.state.utils import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import org.apache.spark.sql.AnalysisException @@ -24,9 +25,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.execution.datasources.v2.state.{StateDataSourceErrors, StateSourceOptions} +import org.apache.spark.sql.execution.streaming.{StateVariableType, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StateVariableType._ -import org.apache.spark.sql.execution.streaming.TransformWithStateVariableInfo -import org.apache.spark.sql.execution.streaming.state.{StateStoreColFamilySchema, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{ReadStateStore, StateStoreColFamilySchema, UnsafeRowPair} import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, LongType, MapType, StringType, StructType} import org.apache.spark.util.ArrayImplicits._ @@ -58,7 +59,7 @@ object SchemaUtil { } else if (transformWithStateVariableInfoOpt.isDefined) { require(stateStoreColFamilySchemaOpt.isDefined) generateSchemaForStateVar(transformWithStateVariableInfoOpt.get, - stateStoreColFamilySchemaOpt.get) + stateStoreColFamilySchemaOpt.get, sourceOptions) } else { new StructType() .add("key", keySchema) @@ -101,7 +102,8 @@ object SchemaUtil { def unifyMapStateRowPair( stateRows: Iterator[UnsafeRowPair], compositeKeySchema: StructType, - partitionId: Int): Iterator[InternalRow] = { + partitionId: Int, + stateSourceOptions: StateSourceOptions): Iterator[InternalRow] = { val groupingKeySchema = SchemaUtil.getSchemaAsDataType( compositeKeySchema, "key" ).asInstanceOf[StructType] @@ -130,61 +132,84 @@ object SchemaUtil { row } - // All of the rows with the same grouping key were co-located and were - // grouped together consecutively. - new Iterator[InternalRow] { - var curGroupingKey: UnsafeRow = _ - var curStateRowPair: UnsafeRowPair = _ - val curMap = mutable.Map.empty[Any, Any] - - override def hasNext: Boolean = - stateRows.hasNext || !curMap.isEmpty - - override def next(): InternalRow = { - var foundNewGroupingKey = false - while (stateRows.hasNext && !foundNewGroupingKey) { - curStateRowPair = stateRows.next() - if (curGroupingKey == null) { - // First time in the iterator - // Need to make a copy because we need to keep the - // value across function calls - curGroupingKey = curStateRowPair.key - .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy() - appendKVPairToMap(curMap, curStateRowPair) - } else { - val curPairGroupingKey = - curStateRowPair.key.get(0, groupingKeySchema) - if (curPairGroupingKey == curGroupingKey) { + def createFlattenedRow( + groupingKey: UnsafeRow, + userMapKey: UnsafeRow, + userMapValue: UnsafeRow, + partitionId: Int): GenericInternalRow = { + val row = new GenericInternalRow(4) + row.update(0, groupingKey) + row.update(1, userMapKey) + row.update(2, userMapValue) + row.update(3, partitionId) + row + } + + if (stateSourceOptions.flattenCollectionTypes) { + stateRows + .map { pair => + val groupingKey = pair.key.get(0, groupingKeySchema).asInstanceOf[UnsafeRow] + val userMapKey = pair.key.get(1, userKeySchema).asInstanceOf[UnsafeRow] + val userMapValue = pair.value + createFlattenedRow(groupingKey, userMapKey, userMapValue, partitionId) + } + } else { + // All of the rows with the same grouping key were co-located and were + // grouped together consecutively. + new Iterator[InternalRow] { + var curGroupingKey: UnsafeRow = _ + var curStateRowPair: UnsafeRowPair = _ + val curMap = mutable.Map.empty[Any, Any] + + override def hasNext: Boolean = + stateRows.hasNext || !curMap.isEmpty + + override def next(): InternalRow = { + var foundNewGroupingKey = false + while (stateRows.hasNext && !foundNewGroupingKey) { + curStateRowPair = stateRows.next() + if (curGroupingKey == null) { + // First time in the iterator + // Need to make a copy because we need to keep the + // value across function calls + curGroupingKey = curStateRowPair.key + .get(0, groupingKeySchema).asInstanceOf[UnsafeRow].copy() appendKVPairToMap(curMap, curStateRowPair) } else { - // find a different grouping key, exit loop and return a row - foundNewGroupingKey = true + val curPairGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + if (curPairGroupingKey == curGroupingKey) { + appendKVPairToMap(curMap, curStateRowPair) + } else { + // find a different grouping key, exit loop and return a row + foundNewGroupingKey = true + } } } - } - if (foundNewGroupingKey) { - // found a different grouping key - val row = createDataRow(curGroupingKey, curMap) - // update vars - curGroupingKey = - curStateRowPair.key.get(0, groupingKeySchema) - .asInstanceOf[UnsafeRow].copy() - // empty the map, append current row - curMap.clear() - appendKVPairToMap(curMap, curStateRowPair) - // return map value of previous grouping key - row - } else { - if (curMap.isEmpty) { - throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " + - "user is trying to get element from an exhausted iterator.") - } - else { - // reach the end of the state rows + if (foundNewGroupingKey) { + // found a different grouping key val row = createDataRow(curGroupingKey, curMap) - // clear the map to end the iterator + // update vars + curGroupingKey = + curStateRowPair.key.get(0, groupingKeySchema) + .asInstanceOf[UnsafeRow].copy() + // empty the map, append current row curMap.clear() + appendKVPairToMap(curMap, curStateRowPair) + // return map value of previous grouping key row + } else { + if (curMap.isEmpty) { + throw new NoSuchElementException("Please check if the iterator hasNext(); Likely " + + "user is trying to get element from an exhausted iterator.") + } + else { + // reach the end of the state rows + val row = createDataRow(curGroupingKey, curMap) + // clear the map to end the iterator + curMap.clear() + row + } } } } @@ -200,9 +225,11 @@ object SchemaUtil { "change_type" -> classOf[StringType], "key" -> classOf[StructType], "value" -> classOf[StructType], - "single_value" -> classOf[StructType], + "list_element" -> classOf[StructType], "list_value" -> classOf[ArrayType], "map_value" -> classOf[MapType], + "user_map_key" -> classOf[StructType], + "user_map_value" -> classOf[StructType], "partition_id" -> classOf[IntegerType]) val expectedFieldNames = if (sourceOptions.readChangeFeed) { @@ -213,13 +240,21 @@ object SchemaUtil { stateVarType match { case ValueState => - Seq("key", "single_value", "partition_id") + Seq("key", "value", "partition_id") case ListState => - Seq("key", "list_value", "partition_id") + if (sourceOptions.flattenCollectionTypes) { + Seq("key", "list_element", "partition_id") + } else { + Seq("key", "list_value", "partition_id") + } case MapState => - Seq("key", "map_value", "partition_id") + if (sourceOptions.flattenCollectionTypes) { + Seq("key", "user_map_key", "user_map_value", "partition_id") + } else { + Seq("key", "map_value", "partition_id") + } case _ => throw StateDataSourceErrors @@ -241,21 +276,29 @@ object SchemaUtil { private def generateSchemaForStateVar( stateVarInfo: TransformWithStateVariableInfo, - stateStoreColFamilySchema: StateStoreColFamilySchema): StructType = { + stateStoreColFamilySchema: StateStoreColFamilySchema, + stateSourceOptions: StateSourceOptions): StructType = { val stateVarType = stateVarInfo.stateVariableType stateVarType match { case ValueState => new StructType() .add("key", stateStoreColFamilySchema.keySchema) - .add("single_value", stateStoreColFamilySchema.valueSchema) + .add("value", stateStoreColFamilySchema.valueSchema) .add("partition_id", IntegerType) case ListState => - new StructType() - .add("key", stateStoreColFamilySchema.keySchema) - .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema)) - .add("partition_id", IntegerType) + if (stateSourceOptions.flattenCollectionTypes) { + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("list_element", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", stateStoreColFamilySchema.keySchema) + .add("list_value", ArrayType(stateStoreColFamilySchema.valueSchema)) + .add("partition_id", IntegerType) + } case MapState => val groupingKeySchema = SchemaUtil.getSchemaAsDataType( @@ -266,43 +309,47 @@ object SchemaUtil { valueType = stateStoreColFamilySchema.valueSchema ) - new StructType() - .add("key", groupingKeySchema) - .add("map_value", valueMapSchema) - .add("partition_id", IntegerType) + if (stateSourceOptions.flattenCollectionTypes) { + new StructType() + .add("key", groupingKeySchema) + .add("user_map_key", userKeySchema) + .add("user_map_value", stateStoreColFamilySchema.valueSchema) + .add("partition_id", IntegerType) + } else { + new StructType() + .add("key", groupingKeySchema) + .add("map_value", valueMapSchema) + .add("partition_id", IntegerType) + } case _ => throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType") } } - /** - * Helper functions for map state data source reader. - * - * Map state variables are stored in RocksDB state store has the schema of - * `TransformWithStateKeyValueRowSchemaUtils.getCompositeKeySchema()`; - * But for state store reader, we need to return in format of: - * "key": groupingKey, "map_value": Map(userKey -> value). - * - * The following functions help to translate between two schema. - */ - def isMapStateVariable( - stateVariableInfoOpt: Option[TransformWithStateVariableInfo]): Boolean = { + def checkVariableType( + stateVariableInfoOpt: Option[TransformWithStateVariableInfo], + varType: StateVariableType): Boolean = { stateVariableInfoOpt.isDefined && - stateVariableInfoOpt.get.stateVariableType == MapState + stateVariableInfoOpt.get.stateVariableType == varType } /** * Given key-value schema generated from `generateSchemaForStateVar()`, * returns the compositeKey schema that key is stored in the state store */ - def getCompositeKeySchema(schema: StructType): StructType = { + def getCompositeKeySchema( + schema: StructType, + stateSourceOptions: StateSourceOptions): StructType = { val groupingKeySchema = SchemaUtil.getSchemaAsDataType( schema, "key").asInstanceOf[StructType] val userKeySchema = try { - Option( - SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType] + if (stateSourceOptions.flattenCollectionTypes) { + Option(SchemaUtil.getSchemaAsDataType(schema, "user_map_key").asInstanceOf[StructType]) + } else { + Option(SchemaUtil.getSchemaAsDataType(schema, "map_value").asInstanceOf[MapType] .keyType.asInstanceOf[StructType]) + } } catch { case NonFatal(e) => throw StateDataSourceErrors.internalError(s"No such field named as 'map_value' " + @@ -312,4 +359,57 @@ object SchemaUtil { .add("key", groupingKeySchema) .add("userKey", userKeySchema.get) } + + def processStateEntries( + stateVarType: StateVariableType, + stateVarName: String, + store: ReadStateStore, + compositeKeySchema: StructType, + partitionId: Int, + stateSourceOptions: StateSourceOptions): Iterator[InternalRow] = { + stateVarType match { + case StateVariableType.ValueState => + store + .iterator(stateVarName) + .map { pair => + unifyStateRowPair((pair.key, pair.value), partitionId) + } + + case StateVariableType.ListState => + if (stateSourceOptions.flattenCollectionTypes) { + store + .iterator(stateVarName) + .flatMap { pair => + val key = pair.key + val result = store.valuesIterator(key, stateVarName) + result.map { entry => + SchemaUtil.unifyStateRowPair((key, entry), partitionId) + } + } + } else { + store + .iterator(stateVarName) + .map { pair => + val key = pair.key + val result = store.valuesIterator(key, stateVarName) + val unsafeRowArr = ArrayBuffer[UnsafeRow]() + result.foreach { entry => + unsafeRowArr += entry.copy() + } + // convert the list of values to array type + val arrData = new GenericArrayData(unsafeRowArr.toArray) + // convert the list of values to a single row + SchemaUtil.unifyStateRowPairWithMultipleValues((key, arrData), partitionId) + } + } + + case StateVariableType.MapState => + unifyMapStateRowPair(store.iterator(stateVarName), + compositeKeySchema, partitionId, stateSourceOptions) + + case _ => + throw new IllegalStateException( + s"Unsupported state variable type: $stateVarType") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index af07707569500..8707facc4c126 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -287,6 +287,25 @@ class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase { matchPVals = true) } } + + test("ERROR: trying to specify non boolean value for " + + "flattenCollectionTypes") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) + + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, "test") + .load(tempDir.getAbsolutePath) + } + checkError(exc, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", Some("42616"), + Map("optionName" -> StateSourceOptions.FLATTEN_COLLECTION_TYPES, + "message" -> ".*"), + matchPVals = true) + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 61091fde35e79..69df86fd5f746 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -159,7 +159,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest val resultDf = stateReaderDf.selectExpr( "key.value AS groupingKey", - "single_value.id AS valueId", "single_value.name AS valueName", + "value.id AS valueId", "value.name AS valueName", "partition_id") checkAnswer(resultDf, @@ -222,7 +222,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest .load() val resultDf = stateReaderDf.selectExpr( - "key.value", "single_value.value", "single_value.ttlExpirationMs", "partition_id") + "key.value", "value.value", "value.ttlExpirationMs", "partition_id") var count = 0L resultDf.collect().foreach { row => @@ -235,7 +235,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest val answerDf = stateReaderDf.selectExpr( "key.value AS groupingKey", - "single_value.value.value AS valueId", "partition_id") + "value.value.value AS valueId", "partition_id") checkAnswer(answerDf, Seq(Row("a", 1L, 0), Row("b", 1L, 1))) @@ -290,10 +290,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "groupsList") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val listStateDf = stateReaderDf @@ -307,6 +309,19 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest checkAnswer(listStateDf, Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), Row("session2", "group1"), Row("session3", "group7"))) + + val flattenedReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsList") + .load() + + val resultDf = flattenedReaderDf.selectExpr( + "key.value AS groupingKey", + "list_element.value AS valueList") + checkAnswer(resultDf, + Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), + Row("session2", "group1"), Row("session3", "group7"))) } } } @@ -338,10 +353,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val listStateDf = stateReaderDf @@ -368,6 +385,31 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest checkAnswer(valuesDf, Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), Row("session2", "group1"), Row("session3", "group7"))) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "groupsListWithTTL") + .load() + + val flattenedResultDf = flattenedStateReaderDf + .selectExpr("list_element.ttlExpirationMs AS ttlExpirationMs") + var flattenedCount = 0L + flattenedResultDf.collect().foreach { row => + flattenedCount = flattenedCount + 1 + assert(row.getLong(0) > 0) + } + + // verify that 5 state rows are present + assert(flattenedCount === 5) + + val outputDf = flattenedStateReaderDf + .selectExpr("key.value AS groupingKey", + "list_element.value.value AS groupId") + + checkAnswer(outputDf, + Seq(Row("session1", "group1"), Row("session1", "group2"), Row("session1", "group4"), + Row("session2", "group1"), Row("session3", "group7"))) } } } @@ -397,10 +439,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "sessionState") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val resultDf = stateReaderDf.selectExpr( @@ -413,6 +457,24 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest Row("k2", Map(Row("v2") -> Row("3")))) ) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "sessionState") + .load() + + val outputDf = flattenedStateReaderDf + .selectExpr("key.value AS groupingKey", + "user_map_key.value AS mapKey", + "user_map_value.value AS mapValue") + + checkAnswer(outputDf, + Seq( + Row("k1", "v1", "10"), + Row("k1", "v2", "5"), + Row("k2", "v2", "3")) + ) } } } @@ -463,10 +525,12 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest StopStream ) + // Verify that the state can be read in flattened/non-flattened modes val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "mapState") + .option(StateSourceOptions.FLATTEN_COLLECTION_TYPES, false) .load() val resultDf = stateReaderDf.selectExpr( @@ -478,6 +542,24 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest Map(Row("key2") -> Row(Row(2), 61000L), Row("key1") -> Row(Row(1), 61000L)))) ) + + val flattenedStateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "mapState") + .load() + + val outputDf = flattenedStateReaderDf + .selectExpr("key.value AS groupingKey", + "user_map_key.value AS mapKey", + "user_map_value.value.value AS mapValue", + "user_map_value.ttlExpirationMs AS ttlTimestamp") + + checkAnswer(outputDf, + Seq( + Row("k1", "key1", 1, 61000L), + Row("k1", "key2", 2, 61000L)) + ) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d0e255bb30499..0c02fbf97820b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -1623,7 +1623,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val batch1AnsDf = batch1Df.selectExpr( "key.value AS groupingKey", - "single_value.value AS valueId") + "value.value AS valueId") checkAnswer(batch1AnsDf, Seq(Row("a", 2L))) @@ -1636,7 +1636,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val batch3AnsDf = batch3Df.selectExpr( "key.value AS groupingKey", - "single_value.value AS valueId") + "value.value AS valueId") checkAnswer(batch3AnsDf, Seq(Row("a", 1L))) } } @@ -1731,7 +1731,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val countStateAnsDf = countStateDf.selectExpr( "key.value AS groupingKey", - "single_value.value AS valueId") + "value.value AS valueId") checkAnswer(countStateAnsDf, Seq(Row("a", 5L))) val mostRecentDf = spark.read @@ -1743,7 +1743,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest val mostRecentAnsDf = mostRecentDf.selectExpr( "key.value AS groupingKey", - "single_value.value") + "value.value") checkAnswer(mostRecentAnsDf, Seq(Row("a", "str1"))) } } From dedf5aa91827f32736ce5dae2eb123ba4e244c3b Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Tue, 24 Sep 2024 07:40:58 -0700 Subject: [PATCH 066/250] [SPARK-49750][DOC] Mention delegation token support in K8s mode ### What changes were proposed in this pull request? Update docs to mention delegation token support in K8s mode. ### Why are the changes needed? The delegation token support in K8s mode has been implemented since 3.0.0 via SPARK-23257. ### Does this PR introduce _any_ user-facing change? Yes, docs are updated. ### How was this patch tested? Review. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48199 from pan3793/SPARK-49750. Authored-by: Cheng Pan Signed-off-by: Dongjoon Hyun --- docs/security.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/security.md b/docs/security.md index b97abfeacf240..c7d3fd5f8c36f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -947,7 +947,7 @@ mechanism (see `java.util.ServiceLoader`). Implementations of `org.apache.spark.security.HadoopDelegationTokenProvider` can be made available to Spark by listing their names in the corresponding file in the jar's `META-INF/services` directory. -Delegation token support is currently only supported in YARN mode. Consult the +Delegation token support is currently only supported in YARN and Kubernetes mode. Consult the deployment-specific page for more information. The following options provides finer-grained control for this feature: From 55d0233d19cc52bee91a9619057d9b6f33165a0a Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Tue, 24 Sep 2024 07:48:23 -0700 Subject: [PATCH 067/250] [SPARK-49713][PYTHON][FOLLOWUP] Make function `count_min_sketch` accept long seed ### What changes were proposed in this pull request? Make function `count_min_sketch` accept long seed ### Why are the changes needed? existing implementation only accepts int seed, which is inconsistent with other `ExpressionWithRandomSeed`: ```py In [3]: >>> from pyspark.sql import functions as sf ...: >>> spark.range(100).select( ...: ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.6, 1111111111111111111)) ...: ... ).show(truncate=False) ... AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "count_min_sketch(id, 1.5, 0.6, 1111111111111111111)" due to data type mismatch: The 4th parameter requires the "INT" type, however "1111111111111111111" has the type "BIGINT". SQLSTATE: 42K09; 'Aggregate [unresolvedalias('hex(count_min_sketch(id#64L, 1.5, 0.6, 1111111111111111111, 0, 0)))] +- Range (0, 100, step=1, splits=Some(12)) ... ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added doctest ### Was this patch authored or co-authored using generative AI tooling? no Closes #48223 from zhengruifeng/count_min_sk_long_seed. Authored-by: Ruifeng Zheng Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/connect/functions/builtin.py | 3 +-- python/pyspark/sql/functions/builtin.py | 14 +++++++++++++- .../scala/org/apache/spark/sql/functions.scala | 2 +- .../expressions/aggregate/CountMinSketchAgg.scala | 8 ++++++-- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 2a39bc6bfddda..6953230f5b42e 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -70,7 +70,6 @@ StringType, ) from pyspark.sql.utils import enum_to_value as _enum_to_value -from pyspark.util import JVM_INT_MAX # The implementation of pandas_udf is embedded in pyspark.sql.function.pandas_udf # for code reuse. @@ -1130,7 +1129,7 @@ def count_min_sketch( confidence: Union[Column, float], seed: Optional[Union[Column, int]] = None, ) -> Column: - _seed = lit(random.randint(0, JVM_INT_MAX)) if seed is None else lit(seed) + _seed = lit(random.randint(0, sys.maxsize)) if seed is None else lit(seed) return _invoke_function_over_columns("count_min_sketch", col, lit(eps), lit(confidence), _seed) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 2688f9daa23a4..09a286fe7c94e 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -6080,7 +6080,19 @@ def count_min_sketch( |0000000100000000000000640000000100000002000000005D96391C00000000000000320000000000000032| +----------------------------------------------------------------------------------------+ - Example 3: Using a random seed + Example 3: Using a long seed + + >>> from pyspark.sql import functions as sf + >>> spark.range(100).select( + ... sf.hex(sf.count_min_sketch("id", sf.lit(1.5), 0.2, 1111111111111111111)) + ... ).show(truncate=False) + +----------------------------------------------------------------------------------------+ + |hex(count_min_sketch(id, 1.5, 0.2, 1111111111111111111)) | + +----------------------------------------------------------------------------------------+ + |00000001000000000000006400000001000000020000000044078BA100000000000000320000000000000032| + +----------------------------------------------------------------------------------------+ + + Example 4: Using a random seed >>> from pyspark.sql import functions as sf >>> spark.range(100).select( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index d9bceabe88f8f..ab69789c75f50 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -399,7 +399,7 @@ object functions { * @since 4.0.0 */ def count_min_sketch(e: Column, eps: Column, confidence: Column): Column = - count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextInt)) + count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextLong)) private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala index c26c4a9bdfea3..f0a27677628dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CountMinSketchAgg.scala @@ -63,7 +63,10 @@ case class CountMinSketchAgg( // Mark as lazy so that they are not evaluated during tree transformation. private lazy val eps: Double = epsExpression.eval().asInstanceOf[Double] private lazy val confidence: Double = confidenceExpression.eval().asInstanceOf[Double] - private lazy val seed: Int = seedExpression.eval().asInstanceOf[Int] + private lazy val seed: Int = seedExpression.eval() match { + case i: Int => i + case l: Long => l.toInt + } override def checkInputDataTypes(): TypeCheckResult = { val defaultCheck = super.checkInputDataTypes() @@ -168,7 +171,8 @@ case class CountMinSketchAgg( copy(inputAggBufferOffset = newInputAggBufferOffset) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, IntegerType) + Seq(TypeCollection(IntegralType, StringType, BinaryType), DoubleType, DoubleType, + TypeCollection(IntegerType, LongType)) } override def nullable: Boolean = false From afe8bf945e1ad72fcb0ec4ec35b169e54169f5f1 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 25 Sep 2024 08:53:09 +0900 Subject: [PATCH 068/250] [SPARK-49771][PYTHON] Improve Pandas Scalar Iter UDF error when output rows exceed input rows ### What changes were proposed in this pull request? This PR changes the `assert` error into a user-facing PySpark error when the pandas_iter UDF has more output rows than input rows. ### Why are the changes needed? To make the error message more user-friendly. After the PR, the error will be `pyspark.errors.exceptions.base.PySparkRuntimeError: [PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS] The Pandas SCALAR_ITER UDF outputs more rows than input rows.` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48231 from allisonwang-db/spark-49771-pd-iter-err. Authored-by: allisonwang-db Signed-off-by: Hyukjin Kwon --- python/pyspark/errors/error-conditions.json | 5 +++++ python/pyspark/worker.py | 9 +++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 92aeb15e21d1b..115ad658e32f5 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -802,6 +802,11 @@ " >= must be installed; however, it was not found." ] }, + "PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS" : { + "message": [ + "The Pandas SCALAR_ITER UDF outputs more rows than input rows." + ] + }, "PIPE_FUNCTION_EXITED": { "message": [ "Pipe function `` exited with error code ." diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b8263769c28a9..eedf5d1fd5996 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -1565,14 +1565,15 @@ def map_batch(batch): num_output_rows = 0 for result_batch, result_type in result_iter: num_output_rows += len(result_batch) - # This assert is for Scalar Iterator UDF to fail fast. + # This check is for Scalar Iterator UDF to fail fast. # The length of the entire input can only be explicitly known # by consuming the input iterator in user side. Therefore, # it's very unlikely the output length is higher than # input length. - assert ( - is_map_pandas_iter or is_map_arrow_iter or num_output_rows <= num_input_rows - ), "Pandas SCALAR_ITER UDF outputted more rows than input rows." + if is_scalar_iter and num_output_rows > num_input_rows: + raise PySparkRuntimeError( + errorClass="PANDAS_UDF_OUTPUT_EXCEEDS_INPUT_ROWS", messageParameters={} + ) yield (result_batch, result_type) if is_scalar_iter: From 0a7b98532fd2cf3a251aa258886c1e78779e9594 Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Wed, 25 Sep 2024 08:57:55 +0900 Subject: [PATCH 069/250] [SPARK-49585][CONNECT] Replace executions map in SessionHolder with operationID set ### What changes were proposed in this pull request? SessionHolder has no reason to store ExecuteHolder directly as SparkConnectExecutionManager has a global map of ExecuteHolder. This PR replaces the map in SessionHolder with a set of operation IDs which is only used when interrupting executions within the session. ### Why are the changes needed? Save memory, and simplify the code by making SparkConnectExecutionManager the single source of ExecuteHolders. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48071 from changgyoopark-db/SPARK-49585. Authored-by: Changgyoo Park Signed-off-by: Hyukjin Kwon --- .../sql/connect/service/SessionHolder.scala | 62 +++++++++---------- .../SparkConnectExecutionManager.scala | 4 +- .../SparkConnectReattachExecuteHandler.scala | 33 +++++----- .../SparkConnectReleaseExecuteHandler.scala | 4 +- .../planner/SparkConnectServiceSuite.scala | 6 +- .../SparkConnectSessionHolderSuite.scala | 9 +++ 6 files changed, 67 insertions(+), 51 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index e56d66da3050d..5dced7acfb0d2 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.connect.service import java.nio.file.Path import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} -import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} @@ -40,6 +39,7 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper +import org.apache.spark.sql.connect.service.ExecuteKey import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock, Utils} @@ -91,8 +91,8 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio // Setting it to -1 indicated forever. @volatile private var customInactiveTimeoutMs: Option[Long] = None - private val executions: ConcurrentMap[String, ExecuteHolder] = - new ConcurrentHashMap[String, ExecuteHolder]() + private val operationIds: ConcurrentMap[String, Boolean] = + new ConcurrentHashMap[String, Boolean]() // The cache that maps an error id to a throwable. The throwable in cache is independent to // each other. @@ -138,12 +138,11 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } /** - * Add ExecuteHolder to this session. + * Add an operation ID to this session. * - * Called only by SparkConnectExecutionManager under executionsLock. + * Called only by SparkConnectExecutionManager when a new execution is started. */ - @GuardedBy("SparkConnectService.executionManager.executionsLock") - private[service] def addExecuteHolder(executeHolder: ExecuteHolder): Unit = { + private[service] def addOperationId(operationId: String): Unit = { if (closedTimeMs.isDefined) { // Do not accept new executions if the session is closing. throw new SparkSQLException( @@ -151,26 +150,20 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio messageParameters = Map("handle" -> sessionId)) } - val oldExecute = executions.putIfAbsent(executeHolder.operationId, executeHolder) - if (oldExecute != null) { - // the existence of this should alrady be checked by SparkConnectExecutionManager - throw new IllegalStateException( - s"ExecuteHolder with opId=${executeHolder.operationId} already exists!") + val alreadyExists = operationIds.putIfAbsent(operationId, true) + if (alreadyExists) { + // The existence of it should have been checked by SparkConnectExecutionManager. + throw new IllegalStateException(s"ExecuteHolder with opId=${operationId} already exists!") } } /** - * Remove ExecuteHolder from this session. + * Remove an operation ID from this session. * - * Called only by SparkConnectExecutionManager under executionsLock. + * Called only by SparkConnectExecutionManager when an execution is ended. */ - @GuardedBy("SparkConnectService.executionManager.executionsLock") - private[service] def removeExecuteHolder(operationId: String): Unit = { - executions.remove(operationId) - } - - private[connect] def executeHolder(operationId: String): Option[ExecuteHolder] = { - Option(executions.get(operationId)) + private[service] def removeOperationId(operationId: String): Unit = { + operationIds.remove(operationId) } /** @@ -182,9 +175,12 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio val interruptedIds = new mutable.ArrayBuffer[String]() val operationsIds = SparkConnectService.streamingSessionManager.cleanupRunningQueries(this, blocking = false) - executions.asScala.values.foreach { execute => - if (execute.interrupt()) { - interruptedIds += execute.operationId + operationIds.asScala.foreach { case (operationId, _) => + val executeKey = ExecuteKey(userId, sessionId, operationId) + SparkConnectService.executionManager.getExecuteHolder(executeKey).foreach { executeHolder => + if (executeHolder.interrupt()) { + interruptedIds += operationId + } } } interruptedIds.toSeq ++ operationsIds @@ -199,10 +195,13 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio val interruptedIds = new mutable.ArrayBuffer[String]() val queries = SparkConnectService.streamingSessionManager.getTaggedQuery(tag, session) queries.foreach(q => Future(q.query.stop())(ExecutionContext.global)) - executions.asScala.values.foreach { execute => - if (execute.sparkSessionTags.contains(tag)) { - if (execute.interrupt()) { - interruptedIds += execute.operationId + operationIds.asScala.foreach { case (operationId, _) => + val executeKey = ExecuteKey(userId, sessionId, operationId) + SparkConnectService.executionManager.getExecuteHolder(executeKey).foreach { executeHolder => + if (executeHolder.sparkSessionTags.contains(tag)) { + if (executeHolder.interrupt()) { + interruptedIds += operationId + } } } } @@ -216,9 +215,10 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[service] def interruptOperation(operationId: String): Seq[String] = { val interruptedIds = new mutable.ArrayBuffer[String]() - Option(executions.get(operationId)).foreach { execute => - if (execute.interrupt()) { - interruptedIds += execute.operationId + val executeKey = ExecuteKey(userId, sessionId, operationId) + SparkConnectService.executionManager.getExecuteHolder(executeKey).foreach { executeHolder => + if (executeHolder.interrupt()) { + interruptedIds += operationId } } interruptedIds.toSeq diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index d66964b8d34bd..d9eb5438c3886 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -114,7 +114,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { new ExecuteHolder(executeKey, request, sessionHolder) }) - sessionHolder.addExecuteHolder(executeHolder) + sessionHolder.addOperationId(executeHolder.operationId) logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.") @@ -142,7 +142,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // Remove the execution from the map *after* putting it in abandonedTombstones. executions.remove(key) - executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId) + executeHolder.sessionHolder.removeOperationId(executeHolder.operationId) updateLastExecutionTime() diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala index 534937f84eaee..a2696311bd843 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReattachExecuteHandler.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkSQLException import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.execution.ExecuteGrpcResponseSender +import org.apache.spark.sql.connect.service.ExecuteKey class SparkConnectReattachExecuteHandler( responseObserver: StreamObserver[proto.ExecutePlanResponse]) @@ -38,22 +39,24 @@ class SparkConnectReattachExecuteHandler( SessionKey(v.getUserContext.getUserId, v.getSessionId), previousSessionId) - val executeHolder = sessionHolder.executeHolder(v.getOperationId).getOrElse { - if (SparkConnectService.executionManager - .getAbandonedTombstone( - ExecuteKey(v.getUserContext.getUserId, v.getSessionId, v.getOperationId)) - .isDefined) { - logDebug(s"Reattach operation abandoned: ${v.getOperationId}") - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", - messageParameters = Map("handle" -> v.getOperationId)) - } else { - logDebug(s"Reattach operation not found: ${v.getOperationId}") - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_NOT_FOUND", - messageParameters = Map("handle" -> v.getOperationId)) + val executeKey = ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, v.getOperationId) + val executeHolder = + SparkConnectService.executionManager.getExecuteHolder(executeKey).getOrElse { + if (SparkConnectService.executionManager + .getAbandonedTombstone( + ExecuteKey(v.getUserContext.getUserId, v.getSessionId, v.getOperationId)) + .isDefined) { + logDebug(s"Reattach operation abandoned: ${v.getOperationId}") + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", + messageParameters = Map("handle" -> v.getOperationId)) + } else { + logDebug(s"Reattach operation not found: ${v.getOperationId}") + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_NOT_FOUND", + messageParameters = Map("handle" -> v.getOperationId)) + } } - } if (!executeHolder.reattachable) { logWarning(s"Reattach to not reattachable operation.") throw new SparkSQLException( diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala index a2dbf3b2eec9f..6beba13d55156 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectReleaseExecuteHandler.scala @@ -22,6 +22,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.SparkSQLException import org.apache.spark.connect.proto import org.apache.spark.internal.Logging +import org.apache.spark.sql.connect.service.ExecuteKey class SparkConnectReleaseExecuteHandler( responseObserver: StreamObserver[proto.ReleaseExecuteResponse]) @@ -42,8 +43,9 @@ class SparkConnectReleaseExecuteHandler( // ReleaseExecute arrived after it was abandoned and timed out. // An asynchronous ReleastUntil operation may also arrive after ReleaseAll. // Because of that, make it noop and not fail if the ExecuteHolder is no longer there. + val executeKey = ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, v.getOperationId) val executeHolderOption = - sessionHolder.executeHolder(v.getOperationId).foreach { executeHolder => + SparkConnectService.executionManager.getExecuteHolder(executeKey).foreach { executeHolder => if (!executeHolder.reattachable) { throw new SparkSQLException( errorClass = "INVALID_CURSOR.NOT_REATTACHABLE", diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 62146f19328a8..d6d137e6d91aa 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -48,7 +48,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry -import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteStatus, SessionStatus, SparkConnectAnalyzeHandler, SparkConnectService, SparkListenerConnectOperationStarted} +import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteKey, ExecuteStatus, SessionStatus, SparkConnectAnalyzeHandler, SparkConnectService, SparkListenerConnectOperationStarted} import org.apache.spark.sql.connector.catalog.InMemoryPartitionTableCatalog import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.sql.test.SharedSparkSession @@ -926,7 +926,9 @@ class SparkConnectServiceSuite semaphoreStarted.release() val sessionHolder = SparkConnectService.getOrCreateIsolatedSession(e.userId, e.sessionId, None) - executeHolder = sessionHolder.executeHolder(e.operationId) + val executeKey = + ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, e.operationId) + executeHolder = SparkConnectService.executionManager.getExecuteHolder(executeKey) case _ => } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index ed2f60afb0096..21f84291a2f07 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -413,4 +413,13 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { planner.transformRelation(query, cachePlan = true) assertPlanCache(sessionHolder, Some(Set())) } + + test("Test duplicate operation IDs") { + val sessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark) + sessionHolder.addOperationId("DUMMY") + val ex = intercept[IllegalStateException] { + sessionHolder.addOperationId("DUMMY") + } + assert(ex.getMessage.contains("already exists")) + } } From 29ed2729492a7af3445b436cf589883e56dd9aee Mon Sep 17 00:00:00 2001 From: Changgyoo Park Date: Wed, 25 Sep 2024 08:58:33 +0900 Subject: [PATCH 070/250] [SPARK-49688][CONNECT] Fix a data race between interrupt and execute plan ### What changes were proposed in this pull request? Get rid of the complicated "promise"-based completion callback mechanism, and introduce a lock-free state machine. The gist is, - A thread can only be interrupted when it is in a certain state: started. - A successful interruption means the interrupted thread must call the completion callback. - Interruption after completion or before starting is prohibited without relying on a mutex. ### Why are the changes needed? Execution can be interrupted before started, thus causing the "closed" message to be omitted. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? SparkConnectServiceSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48208 from changgyoopark-db/SPARK-49688. Authored-by: Changgyoo Park Signed-off-by: Hyukjin Kwon --- .../execution/ExecuteThreadRunner.scala | 224 ++++++++++-------- .../sql/connect/service/ExecuteHolder.scala | 35 +-- .../service/SparkConnectServiceE2ESuite.scala | 2 +- 3 files changed, 151 insertions(+), 110 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala index e75654e2c384f..61be2bc4eb994 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.connect.execution -import scala.concurrent.{ExecutionContext, Promise} +import java.util.concurrent.atomic.AtomicInteger + import scala.jdk.CollectionConverters._ -import scala.util.Try import scala.util.control.NonFatal import com.google.protobuf.Message @@ -32,7 +32,7 @@ import org.apache.spark.sql.connect.common.ProtoUtils import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, SparkConnectService} import org.apache.spark.sql.connect.utils.ErrorUtils -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.Utils /** * This class launches the actual execution in an execution thread. The execution pushes the @@ -40,68 +40,70 @@ import org.apache.spark.util.{ThreadUtils, Utils} */ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends Logging { - private val promise: Promise[Unit] = Promise[Unit]() + /** The thread state. */ + private val state: AtomicInteger = new AtomicInteger(ThreadState.notStarted) // The newly created thread will inherit all InheritableThreadLocals used by Spark, // e.g. SparkContext.localProperties. If considering implementing a thread-pool, // forwarding of thread locals needs to be taken into account. - private val executionThread: ExecutionThread = new ExecutionThread(promise) - - private var started: Boolean = false - - private var interrupted: Boolean = false - - private var completed: Boolean = false - - private val lock = new Object - - /** Launches the execution in a background thread, returns immediately. */ - private[connect] def start(): Unit = { - lock.synchronized { - assert(!started) - // Do not start if already interrupted. - if (!interrupted) { - executionThread.start() - started = true - } - } - } + private val executionThread: ExecutionThread = new ExecutionThread() /** - * Register a callback that gets executed after completion/interruption of the execution thread. + * Launches the execution in a background thread, returns immediately. This method is expected + * to be invoked only once for an ExecuteHolder. */ - private[connect] def processOnCompletion(callback: Try[Unit] => Unit): Unit = { - promise.future.onComplete(callback)(ExecuteThreadRunner.namedExecutionContext) + private[connect] def start(): Unit = { + val currentState = state.getAcquire() + if (currentState == ThreadState.notStarted) { + executionThread.start() + } else { + // This assertion does not hold if it is called more than once. + assert(currentState == ThreadState.interrupted) + } } /** - * Interrupt the executing thread. + * Interrupts the execution thread if the execution has been interrupted by this method call. + * * @return - * true if it was not interrupted before, false if it was already interrupted or completed. + * true if the thread is running and interrupted. */ private[connect] def interrupt(): Boolean = { - lock.synchronized { - if (!started && !interrupted) { - // execution thread hasn't started yet, and will not be started. - // handle the interrupted error here directly. - interrupted = true - ErrorUtils.handleError( - "execute", - executeHolder.responseObserver, - executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId, - Some(executeHolder.eventsManager), - interrupted)(new SparkSQLException("OPERATION_CANCELED", Map.empty)) - true - } else if (!interrupted && !completed) { - // checking completed prevents sending interrupt onError after onCompleted - interrupted = true - executionThread.interrupt() - true + var currentState = state.getAcquire() + while (currentState == ThreadState.notStarted || currentState == ThreadState.started) { + val newState = if (currentState == ThreadState.notStarted) { + ThreadState.interrupted } else { - false + ThreadState.startedInterrupted + } + + val prevState = state.compareAndExchangeRelease(currentState, newState) + if (prevState == currentState) { + if (prevState == ThreadState.notStarted) { + // The execution thread has not been started, or will immediately return because the state + // transition happens at the beginning of executeInternal. + try { + ErrorUtils.handleError( + "execute", + executeHolder.responseObserver, + executeHolder.sessionHolder.userId, + executeHolder.sessionHolder.sessionId, + Some(executeHolder.eventsManager), + true)(new SparkSQLException("OPERATION_CANCELED", Map.empty)) + } finally { + executeHolder.cleanup() + } + } else { + // Interrupt execution. + executionThread.interrupt() + } + return true } + currentState = prevState } + + // Already interrupted, completed, or not started. + false } private def execute(): Unit = { @@ -118,15 +120,8 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends executeHolder.sessionHolder.session.sparkContext.cancelJobsWithTag( executeHolder.jobTag, s"A job with the same tag ${executeHolder.jobTag} has failed.") - // Rely on an internal interrupted flag, because Thread.interrupted() could be cleared, - // and different exceptions like InterruptedException, ClosedByInterruptException etc. - // could be thrown. - if (interrupted) { - throw new SparkSQLException("OPERATION_CANCELED", Map.empty) - } else { - // Rethrown the original error. - throw e - } + // Rethrow the original error. + throw e } finally { executeHolder.sessionHolder.session.sparkContext.removeJobTag(executeHolder.jobTag) SparkConnectService.executionListener.foreach(_.removeJobTag(executeHolder.jobTag)) @@ -139,23 +134,50 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends } } } catch { - ErrorUtils.handleError( - "execute", - executeHolder.responseObserver, - executeHolder.sessionHolder.userId, - executeHolder.sessionHolder.sessionId, - Some(executeHolder.eventsManager), - interrupted) + case e: Throwable if state.getAcquire() != ThreadState.startedInterrupted => + ErrorUtils.handleError( + "execute", + executeHolder.responseObserver, + executeHolder.sessionHolder.userId, + executeHolder.sessionHolder.sessionId, + Some(executeHolder.eventsManager), + false)(e) + } finally { + // Make sure to transition to completed in order to prevent the thread from being interrupted + // afterwards. + var currentState = state.getAcquire() + while (currentState == ThreadState.started || + currentState == ThreadState.startedInterrupted) { + val interrupted = currentState == ThreadState.startedInterrupted + val prevState = state.compareAndExchangeRelease(currentState, ThreadState.completed) + if (prevState == currentState) { + if (interrupted) { + try { + ErrorUtils.handleError( + "execute", + executeHolder.responseObserver, + executeHolder.sessionHolder.userId, + executeHolder.sessionHolder.sessionId, + Some(executeHolder.eventsManager), + true)(new SparkSQLException("OPERATION_CANCELED", Map.empty)) + } finally { + executeHolder.cleanup() + } + } + return + } + currentState = prevState + } } } // Inner executeInternal is wrapped by execute() for error handling. - private def executeInternal() = { - // synchronized - check if already got interrupted while starting. - lock.synchronized { - if (interrupted) { - throw new InterruptedException() - } + private def executeInternal(): Unit = { + val prevState = state.compareAndExchangeRelease(ThreadState.notStarted, ThreadState.started) + if (prevState != ThreadState.notStarted) { + // Silently return, expecting that the caller would handle the interruption. + assert(prevState == ThreadState.interrupted) + return } // `withSession` ensures that session-specific artifacts (such as JARs and class files) are @@ -226,17 +248,14 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends observedMetrics ++ accumulatedInPython)) } - lock.synchronized { - // Synchronized before sending ResultComplete, and up until completing the result stream - // to prevent a situation in which a client of reattachable execution receives - // ResultComplete, and proceeds to send ReleaseExecute, and that triggers an interrupt - // before it finishes. - - if (interrupted) { - // check if it got interrupted at the very last moment - throw new InterruptedException() - } - completed = true // no longer interruptible + // State transition should be atomic to prevent a situation in which a client of reattachable + // execution receives ResultComplete, and proceeds to send ReleaseExecute, and that triggers + // an interrupt before it finishes. Failing to transition to completed means that the thread + // was interrupted, and that will be checked at the end of the execution. + if (state.compareAndExchangeRelease( + ThreadState.started, + ThreadState.completed) == ThreadState.started) { + // Now, the execution cannot be interrupted. // If the request starts a long running iterator (e.g. StreamingQueryListener needs // a long-running iterator to continuously stream back events, it runs in a separate @@ -311,21 +330,36 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends .build() } - private class ExecutionThread(onCompletionPromise: Promise[Unit]) + private class ExecutionThread() extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") { - override def run(): Unit = { - try { - execute() - onCompletionPromise.success(()) - } catch { - case NonFatal(e) => - onCompletionPromise.failure(e) - } - } + override def run(): Unit = execute() } } -private[connect] object ExecuteThreadRunner { - private implicit val namedExecutionContext: ExecutionContext = ExecutionContext - .fromExecutor(ThreadUtils.newDaemonSingleThreadExecutor("SparkConnectExecuteThreadCallback")) +/** + * Defines possible execution thread states. + * + * The state transitions as follows. + * - notStarted -> interrupted. + * - notStarted -> started -> startedInterrupted -> completed. + * - notStarted -> started -> completed. + * + * The thread can only be interrupted if the thread is in the startedInterrupted state. + */ +private object ThreadState { + + /** The thread has not started: transition to interrupted or started. */ + val notStarted: Int = 0 + + /** Execution was interrupted: terminal state. */ + val interrupted: Int = 1 + + /** The thread has started: transition to startedInterrupted or completed. */ + val started: Int = 2 + + /** The thread has started and execution was interrupted: transition to completed. */ + val startedInterrupted: Int = 3 + + /** Execution was completed: terminal state. */ + val completed: Int = 4 } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index dc349c3e33251..821ddb2c85d58 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.connect.service +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -104,8 +106,8 @@ private[connect] class ExecuteHolder( : mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]] = new mutable.ArrayBuffer[ExecuteGrpcResponseSender[proto.ExecutePlanResponse]]() - /** For testing. Whether the async completion callback is called. */ - @volatile private[connect] var completionCallbackCalled: Boolean = false + /** Indicates whether the cleanup method was called. */ + private[connect] val completionCallbackCalled: AtomicBoolean = new AtomicBoolean(false) /** * Start the execution. The execution is started in a background thread in ExecuteThreadRunner. @@ -227,16 +229,7 @@ private[connect] class ExecuteHolder( def close(): Unit = synchronized { if (closedTimeMs.isEmpty) { // interrupt execution, if still running. - runner.interrupt() - // Do not wait for the execution to finish, clean up resources immediately. - runner.processOnCompletion { _ => - completionCallbackCalled = true - // The execution may not immediately get interrupted, clean up any remaining resources when - // it does. - responseObserver.removeAll() - // post closed to UI - eventsManager.postClosed() - } + val interrupted = runner.interrupt() // interrupt any attached grpcResponseSenders grpcResponseSenders.foreach(_.interrupt()) // if there were still any grpcResponseSenders, register detach time @@ -244,12 +237,26 @@ private[connect] class ExecuteHolder( lastAttachedRpcTimeMs = Some(System.currentTimeMillis()) grpcResponseSenders.clear() } - // remove all cached responses from observer - responseObserver.removeAll() + if (!interrupted) { + cleanup() + } closedTimeMs = Some(System.currentTimeMillis()) } } + /** + * A piece of code that is called only once when this execute holder is closed or the + * interrupted execution thread is terminated. + */ + private[connect] def cleanup(): Unit = { + if (completionCallbackCalled.compareAndSet(false, true)) { + // Remove all cached responses from the observer. + responseObserver.removeAll() + // Post "closed" to UI. + eventsManager.postClosed() + } + } + /** * Spark Connect tags are also added as SparkContext job tags, but to make the tag unique, they * need to be combined with userId and sessionId. diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala index cb0bd8f771ebc..f86298a8b5b98 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala @@ -109,7 +109,7 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest { } // Check the async execute cleanup get called Eventually.eventually(timeout(eventuallyTimeout)) { - assert(executeHolder1.completionCallbackCalled) + assert(executeHolder1.completionCallbackCalled.get()) } } } From 5fb0ff9e10b1df266732466790264fd63f159446 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 24 Sep 2024 21:22:42 -0400 Subject: [PATCH 071/250] [SPARK-49282][CONNECT][SQL] Create a shared SparkSessionBuilder interface ### What changes were proposed in this pull request? This PR adds a shared SparkSessionBuilder interface. It also adds a SparkSessionCompanion interface which is mean should be implemented by all SparkSession companions (a.k.a. `object SparkSession`. This is currently the entry point for session building, in the future we will also add the management of active/default sessions. Finally we add a companion for api.SparkSession. This will bind the implementation that is currently located in `org.apache.spark.sql`. This makes it possible to exclusively work with the interface, instead of selecting an implementation upfront. ### Why are the changes needed? We are creating a shared Scala SQL interface. Building a session is part of this interface. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. I have added tests for the implementation binding. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48229 from hvanhovell/SPARK-49282. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- connector/connect/client/jvm/pom.xml | 7 + .../org/apache/spark/sql/SparkSession.scala | 96 +++------- ...ionBuilderImplementationBindingSuite.scala | 33 ++++ project/MimaExcludes.scala | 46 +++-- .../apache/spark/sql/api/SparkSession.scala | 171 ++++++++++++++++- ...ionBuilderImplementationBindingSuite.scala | 38 ++++ sql/core/pom.xml | 7 + .../org/apache/spark/sql/SparkSession.scala | 179 ++++++------------ ...ionBuilderImplementationBindingSuite.scala | 26 +++ 9 files changed, 388 insertions(+), 215 deletions(-) create mode 100644 connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala create mode 100644 sql/api/src/test/scala/org/apache/spark/sql/api/SparkSessionBuilderImplementationBindingSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index be358f317481e..e117a0a7451cb 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -88,6 +88,13 @@ scalacheck_${scala.binary.version} test + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${project.version} + tests + test + org.apache.spark spark-common-utils_${scala.binary.version} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 0663f0186888e..5313369a2c987 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -509,7 +509,7 @@ class SparkSession private[sql] ( // The minimal builder needed to create a spark session. // TODO: implements all methods mentioned in the scaladoc of [[SparkSession]] -object SparkSession extends Logging { +object SparkSession extends api.SparkSessionCompanion with Logging { private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong private var server: Option[Process] = None @@ -618,15 +618,15 @@ object SparkSession extends Logging { */ def builder(): Builder = new Builder() - class Builder() extends Logging { + class Builder() extends api.SparkSessionBuilder { // Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE // by default, if it exists. The connection string can be overridden using // the remote() function, as it takes precedence over the SPARK_REMOTE environment variable. private val builder = SparkConnectClient.builder().loadFromEnvironment() private var client: SparkConnectClient = _ - private[this] val options = new scala.collection.mutable.HashMap[String, String] - def remote(connectionString: String): Builder = { + /** @inheritdoc */ + def remote(connectionString: String): this.type = { builder.connectionString(connectionString) this } @@ -638,93 +638,45 @@ object SparkSession extends Logging { * * @since 3.5.0 */ - def interceptor(interceptor: ClientInterceptor): Builder = { + def interceptor(interceptor: ClientInterceptor): this.type = { builder.interceptor(interceptor) this } - private[sql] def client(client: SparkConnectClient): Builder = { + private[sql] def client(client: SparkConnectClient): this.type = { this.client = client this } - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: String): Builder = synchronized { - options += key -> value - this - } + /** @inheritdoc */ + override def config(key: String, value: String): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Long): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Long): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Double): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Double): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to the - * Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(key: String, value: Boolean): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Boolean): this.type = super.config(key, value) - /** - * Sets a config a map of options. Options set using this method are automatically propagated - * to the Spark Connect session. Only runtime options are supported. - * - * @since 3.5.0 - */ - def config(map: Map[String, Any]): Builder = synchronized { - map.foreach { kv: (String, Any) => - { - options += kv._1 -> kv._2.toString - } - } - this - } + /** @inheritdoc */ + override def config(map: Map[String, Any]): this.type = super.config(map) - /** - * Sets a config option. Options set using this method are automatically propagated to both - * `SparkConf` and SparkSession's own configuration. - * - * @since 3.5.0 - */ - def config(map: java.util.Map[String, Any]): Builder = synchronized { - config(map.asScala.toMap) - } + /** @inheritdoc */ + override def config(map: java.util.Map[String, Any]): this.type = super.config(map) + /** @inheritdoc */ @deprecated("enableHiveSupport does not work in Spark Connect") - def enableHiveSupport(): Builder = this + override def enableHiveSupport(): this.type = this + /** @inheritdoc */ @deprecated("master does not work in Spark Connect, please use remote instead") - def master(master: String): Builder = this + override def master(master: String): this.type = this + /** @inheritdoc */ @deprecated("appName does not work in Spark Connect") - def appName(name: String): Builder = this + override def appName(name: String): this.type = this private def tryCreateSessionFromClient(): Option[SparkSession] = { if (client != null && client.isSessionValid) { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala new file mode 100644 index 0000000000000..ed930882ac2fd --- /dev/null +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.sql.api.SparkSessionBuilder +import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession} + +/** + * Make sure the api.SparkSessionBuilder binds to Connect implementation. + */ +class SparkSessionBuilderImplementationBindingSuite + extends ConnectFunSuite + with api.SparkSessionBuilderImplementationBindingSuite + with RemoteSparkSession { + override protected def configure(builder: SparkSessionBuilder): builder.type = { + // We need to set this configuration because the port used by the server is random. + builder.remote(s"sc://localhost:$serverPort") + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 972438d0757a7..9a89ebb4797c9 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -125,26 +125,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.Observation$"), - // SPARK-49414: Remove Logging from DataFrameReader. - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.DataFrameReader"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.log"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logInfo"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logDebug"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logTrace"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logWarning"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.logError"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.isTraceEnabled"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeLogIfNecessary$default$2"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.DataFrameReader.initializeForcefully"), - // SPARK-49425: Create a shared DataFrameWriter interface. ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameWriter"), @@ -195,7 +175,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLImplicits$StringToColumn"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession$implicits$"), ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.sql.SQLImplicits.session"), - ) + + // SPARK-49282: Shared SparkSessionBuilder + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession$Builder"), + ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ + loggingExcludes("org.apache.spark.sql.SparkSession#Builder") // Default exclude rules lazy val defaultExcludes = Seq( @@ -236,6 +220,26 @@ object MimaExcludes { } ) + private def loggingExcludes(fqn: String) = { + Seq( + ProblemFilters.exclude[MissingTypesProblem](fqn), + missingMethod(fqn, "logName"), + missingMethod(fqn, "log"), + missingMethod(fqn, "logInfo"), + missingMethod(fqn, "logDebug"), + missingMethod(fqn, "logTrace"), + missingMethod(fqn, "logWarning"), + missingMethod(fqn, "logError"), + missingMethod(fqn, "isTraceEnabled"), + missingMethod(fqn, "initializeLogIfNecessary"), + missingMethod(fqn, "initializeLogIfNecessary$default$2"), + missingMethod(fqn, "initializeForcefully")) + } + + private def missingMethod(names: String*) = { + ProblemFilters.exclude[DirectMissingMethodProblem](names.mkString(".")) + } + def excludes(version: String): Seq[Problem => Boolean] = version match { case v if v.startsWith("4.0") => v40excludes case _ => Seq() diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 2623db4060ee6..2295c153cd51c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -25,9 +25,10 @@ import _root_.java.lang import _root_.java.net.URI import _root_.java.util -import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable} import org.apache.spark.sql.{Encoder, Row, RuntimeConfig} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SparkClassUtils /** * The entry point to programming Spark with the Dataset and DataFrame API. @@ -541,3 +542,171 @@ abstract class SparkSession extends Serializable with Closeable { */ def stop(): Unit = close() } + +object SparkSession extends SparkSessionCompanion { + private[this] val companion: SparkSessionCompanion = { + val cls = SparkClassUtils.classForName("org.apache.spark.sql.SparkSession") + val mirror = scala.reflect.runtime.currentMirror + val module = mirror.classSymbol(cls).companion.asModule + mirror.reflectModule(module).instance.asInstanceOf[SparkSessionCompanion] + } + + /** @inheritdoc */ + override def builder(): SparkSessionBuilder = companion.builder() +} + +/** + * Companion of a [[SparkSession]]. + */ +private[sql] abstract class SparkSessionCompanion { + + /** + * Creates a [[SparkSessionBuilder]] for constructing a [[SparkSession]]. + * + * @since 2.0.0 + */ + def builder(): SparkSessionBuilder +} + +/** + * Builder for [[SparkSession]]. + */ +@Stable +abstract class SparkSessionBuilder { + protected val options = new scala.collection.mutable.HashMap[String, String] + + /** + * Sets a name for the application, which will be shown in the Spark web UI. If no application + * name is set, a randomly generated name will be used. + * + * @since 2.0.0 + */ + def appName(name: String): this.type = config("spark.app.name", name) + + /** + * Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]" to run + * locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster. + * + * @note + * this is only supported in Classic. + * @since 2.0.0 + */ + def master(master: String): this.type = config("spark.master", master) + + /** + * Enables Hive support, including connectivity to a persistent Hive metastore, support for Hive + * serdes, and Hive user-defined functions. + * + * @note + * this is only supported in Classic. + * @since 2.0.0 + */ + def enableHiveSupport(): this.type = config("spark.sql.catalogImplementation", "hive") + + /** + * Sets the Spark Connect remote URL. + * + * @note + * this is only supported in Connect. + * @since 3.5.0 + */ + def remote(connectionString: String): this.type + + /** + * Sets a config option. Options set using this method are automatically propagated to both + * `SparkConf` and SparkSession's own configuration. + * + * @note + * this is only supported in Connect mode. + * @since 2.0.0 + */ + def config(key: String, value: String): this.type = synchronized { + options += key -> value + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to both + * `SparkConf` and SparkSession's own configuration. + * + * @since 2.0.0 + */ + def config(key: String, value: Long): this.type = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to both + * `SparkConf` and SparkSession's own configuration. + * + * @since 2.0.0 + */ + def config(key: String, value: Double): this.type = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to both + * `SparkConf` and SparkSession's own configuration. + * + * @since 2.0.0 + */ + def config(key: String, value: Boolean): this.type = synchronized { + options += key -> value.toString + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to both + * `SparkConf` and SparkSession's own configuration. + * + * @since 3.4.0 + */ + def config(map: Map[String, Any]): this.type = synchronized { + map.foreach { kv: (String, Any) => + { + options += kv._1 -> kv._2.toString + } + } + this + } + + /** + * Sets a config option. Options set using this method are automatically propagated to both + * `SparkConf` and SparkSession's own configuration. + * + * @since 3.4.0 + */ + def config(map: util.Map[String, Any]): this.type = synchronized { + config(map.asScala.toMap) + } + + /** + * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new one based on + * the options set in this builder. + * + * This method first checks whether there is a valid thread-local SparkSession, and if yes, + * return that one. It then checks whether there is a valid global default SparkSession, and if + * yes, return that one. If no valid global default SparkSession exists, the method creates a + * new SparkSession and assigns the newly created SparkSession as the global default. + * + * In case an existing SparkSession is returned, the non-static config options specified in this + * builder will be applied to the existing SparkSession. + * + * @since 2.0.0 + */ + def getOrCreate(): SparkSession + + /** + * Create a new [[SparkSession]]. + * + * This will always return a newly created session. + * + * This method will update the default and/or active session if they are not set. + * + * @since 3.5.0 + */ + def create(): SparkSession +} diff --git a/sql/api/src/test/scala/org/apache/spark/sql/api/SparkSessionBuilderImplementationBindingSuite.scala b/sql/api/src/test/scala/org/apache/spark/sql/api/SparkSessionBuilderImplementationBindingSuite.scala new file mode 100644 index 0000000000000..84b6b85f639a3 --- /dev/null +++ b/sql/api/src/test/scala/org/apache/spark/sql/api/SparkSessionBuilderImplementationBindingSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +// scalastyle:off funsuite +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.functions.sum + +/** + * Test suite for SparkSession implementation binding. + */ +trait SparkSessionBuilderImplementationBindingSuite extends AnyFunSuite with BeforeAndAfterAll { +// scalastyle:on + protected def configure(builder: SparkSessionBuilder): builder.type = builder + + test("range") { + val session = configure(SparkSession.builder()).getOrCreate() + import session.implicits._ + val df = session.range(10).agg(sum("id")).as[Long] + assert(df.head() == 45) + } +} diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 9eb5decb3b515..4352c44a4feda 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -73,6 +73,13 @@ test-jar test + + org.apache.spark + spark-sql-api_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 938df206b9792..fe139d629eb24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -853,129 +853,64 @@ class SparkSession private( @Stable -object SparkSession extends Logging { +object SparkSession extends api.SparkSessionCompanion with Logging { /** * Builder for [[SparkSession]]. */ @Stable - class Builder extends Logging { - - private[this] val options = new scala.collection.mutable.HashMap[String, String] + class Builder extends api.SparkSessionBuilder { private[this] val extensions = new SparkSessionExtensions private[this] var userSuppliedContext: Option[SparkContext] = None - private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { + private[spark] def sparkContext(sparkContext: SparkContext): this.type = synchronized { userSuppliedContext = Option(sparkContext) this } - /** - * Sets a name for the application, which will be shown in the Spark web UI. - * If no application name is set, a randomly generated name will be used. - * - * @since 2.0.0 - */ - def appName(name: String): Builder = config("spark.app.name", name) + /** @inheritdoc */ + override def remote(connectionString: String): this.type = this - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: String): Builder = synchronized { - options += key -> value - this - } + /** @inheritdoc */ + override def appName(name: String): this.type = super.appName(name) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: Long): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: String): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: Double): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Long): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 2.0.0 - */ - def config(key: String, value: Boolean): Builder = synchronized { - options += key -> value.toString - this - } + /** @inheritdoc */ + override def config(key: String, value: Double): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 3.4.0 - */ - def config(map: Map[String, Any]): Builder = synchronized { - map.foreach { - kv: (String, Any) => { - options += kv._1 -> kv._2.toString - } - } - this - } + /** @inheritdoc */ + override def config(key: String, value: Boolean): this.type = super.config(key, value) - /** - * Sets a config option. Options set using this method are automatically propagated to - * both `SparkConf` and SparkSession's own configuration. - * - * @since 3.4.0 - */ - def config(map: java.util.Map[String, Any]): Builder = synchronized { - config(map.asScala.toMap) - } + /** @inheritdoc */ + override def config(map: Map[String, Any]): this.type = super.config(map) + + /** @inheritdoc */ + override def config(map: java.util.Map[String, Any]): this.type = super.config(map) /** * Sets a list of config options based on the given `SparkConf`. * * @since 2.0.0 */ - def config(conf: SparkConf): Builder = synchronized { + def config(conf: SparkConf): this.type = synchronized { conf.getAll.foreach { case (k, v) => options += k -> v } this } - /** - * Sets the Spark master URL to connect to, such as "local" to run locally, "local[4]" to - * run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster. - * - * @since 2.0.0 - */ - def master(master: String): Builder = config("spark.master", master) + /** @inheritdoc */ + override def master(master: String): this.type = super.master(master) - /** - * Enables Hive support, including connectivity to a persistent Hive metastore, support for - * Hive serdes, and Hive user-defined functions. - * - * @since 2.0.0 - */ - def enableHiveSupport(): Builder = synchronized { + /** @inheritdoc */ + override def enableHiveSupport(): this.type = synchronized { if (hiveClassesArePresent) { - config(CATALOG_IMPLEMENTATION.key, "hive") + super.enableHiveSupport() } else { throw new IllegalArgumentException( "Unable to instantiate SparkSession with Hive support because " + @@ -989,27 +924,12 @@ object SparkSession extends Logging { * * @since 2.2.0 */ - def withExtensions(f: SparkSessionExtensions => Unit): Builder = synchronized { + def withExtensions(f: SparkSessionExtensions => Unit): this.type = synchronized { f(extensions) this } - /** - * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new - * one based on the options set in this builder. - * - * This method first checks whether there is a valid thread-local SparkSession, - * and if yes, return that one. It then checks whether there is a valid global - * default SparkSession, and if yes, return that one. If no valid global default - * SparkSession exists, the method creates a new SparkSession and assigns the - * newly created SparkSession as the global default. - * - * In case an existing SparkSession is returned, the non-static config options specified in - * this builder will be applied to the existing SparkSession. - * - * @since 2.0.0 - */ - def getOrCreate(): SparkSession = synchronized { + private def build(forceCreate: Boolean): SparkSession = synchronized { val sparkConf = new SparkConf() options.foreach { case (k, v) => sparkConf.set(k, v) } @@ -1017,20 +937,28 @@ object SparkSession extends Logging { assertOnDriver() } + def clearSessionIfDead(session: SparkSession): SparkSession = { + if ((session ne null) && !session.sparkContext.isStopped) { + session + } else { + null + } + } + // Get the session from current thread's active session. - var session = activeThreadSession.get() - if ((session ne null) && !session.sparkContext.isStopped) { - applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) - return session + val active = clearSessionIfDead(activeThreadSession.get()) + if (!forceCreate && (active ne null)) { + applyModifiableSettings(active, new java.util.HashMap[String, String](options.asJava)) + return active } // Global synchronization so we will only set the default session once. SparkSession.synchronized { // If the current thread does not have an active session, get it from the global session. - session = defaultSession.get() - if ((session ne null) && !session.sparkContext.isStopped) { - applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) - return session + val default = clearSessionIfDead(defaultSession.get()) + if (!forceCreate && (default ne null)) { + applyModifiableSettings(default, new java.util.HashMap[String, String](options.asJava)) + return default } // No active nor global default session. Create a new one. @@ -1047,19 +975,28 @@ object SparkSession extends Logging { loadExtensions(extensions) applyExtensions(sparkContext, extensions) - session = new SparkSession(sparkContext, + val session = new SparkSession(sparkContext, existingSharedState = None, parentSessionState = None, extensions, initialSessionOptions = options.toMap, parentManagedJobTags = Map.empty) - setDefaultSession(session) - setActiveSession(session) + if (default eq null) { + setDefaultSession(session) + } + if (active eq null) { + setActiveSession(session) + } registerContextListener(sparkContext) + session } - - return session } + + /** @inheritdoc */ + def getOrCreate(): SparkSession = build(forceCreate = false) + + /** @inheritdoc */ + def create(): SparkSession = build(forceCreate = true) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala new file mode 100644 index 0000000000000..c4fd16ca5ce59 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBuilderImplementationBindingSuite.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSparkSession + +/** + * Make sure the api.SparkSessionBuilder binds to Classic implementation. + */ +class SparkSessionBuilderImplementationBindingSuite + extends SharedSparkSession + with api.SparkSessionBuilderImplementationBindingSuite From 0c234bb1a68c8f419471182d394145c9d48fb3a5 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 24 Sep 2024 21:24:38 -0400 Subject: [PATCH 072/250] [SPARK-49369][CONNECT][SQL] Add implicit Column conversions ### What changes were proposed in this pull request? This introduces an implicit conversion for the Column companion object that allows a user/developer to create a Column from a catalyst Expression (for Classic) or a proto Expression (Builder) (for Connect). This mostly recreates they had before we refactored the Column API. This comes at the price of adding the an import. ### Why are the changes needed? Improved upgrade experience for Developers and User who create their own Column's from expressions. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? I added it to a couple of places in the code and it works. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48020 from hvanhovell/SPARK-49369. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../sql/connect/ConnectConversions.scala | 40 ++++++++++++++++++- .../scala/org/apache/spark/sql/package.scala | 27 ------------- .../spark/sql/PlanGenerationTestSuite.scala | 1 + .../spark/sql/DataFrameNaFunctions.scala | 17 ++++---- .../scala/org/apache/spark/sql/Dataset.scala | 29 +++++++++----- .../sql/classic/ClassicConversions.scala | 11 ++++- .../sql/internal/RuntimeConfigImpl.scala | 2 +- 7 files changed, 76 insertions(+), 51 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala index 7d81f4ead7857..0344152be86e6 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.connect import scala.language.implicitConversions import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.connect.proto import org.apache.spark.sql._ +import org.apache.spark.sql.internal.ProtoColumnNode /** * Conversions from sql interfaces to the Connect specific implementation. * - * This class is mainly used by the implementation. In the case of connect it should be extremely - * rare that a developer needs these classes. + * This class is mainly used by the implementation. It is also meant to be used by extension + * developers. * * We provide both a trait and an object. The trait is useful in situations where an extension * developer needs to use these conversions in a project covering multiple Spark versions. They @@ -46,6 +48,40 @@ trait ConnectConversions { implicit def castToImpl[K, V]( kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + + /** + * Create a [[Column]] from a [[proto.Expression]] + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(expr: proto.Expression): Column = { + Column(ProtoColumnNode(expr)) + } + + /** + * Create a [[Column]] using a function that manipulates an [[proto.Expression.Builder]]. + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(f: proto.Expression.Builder => Unit): Column = { + val builder = proto.Expression.newBuilder() + f(builder) + column(builder.build()) + } + + /** + * Implicit helper that makes it easy to construct a Column from an Expression or an Expression + * builder. This allows developers to create a Column in the same way as in earlier versions of + * Spark (before 4.0). + */ + @DeveloperApi + implicit class ColumnConstructorExt(val c: Column.type) { + def apply(e: proto.Expression): Column = column(e) + } } object ConnectConversions extends ConnectConversions diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala index 154f2b0405fcd..556b472283a37 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala @@ -17,10 +17,7 @@ package org.apache.spark -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.internal.ProtoColumnNode package object sql { type DataFrame = Dataset[Row] @@ -28,28 +25,4 @@ package object sql { private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = { implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]] } - - /** - * Create a [[Column]] from a [[proto.Expression]] - * - * This method is meant to be used by Connect plugins. We do not guarantee any compatility - * between (minor) versions. - */ - @DeveloperApi - def column(expr: proto.Expression): Column = { - Column(ProtoColumnNode(expr)) - } - - /** - * Creat a [[Column]] using a function that manipulates an [[proto.Expression.Builder]]. - * - * This method is meant to be used by Connect plugins. We do not guarantee any compatility - * between (minor) versions. - */ - @DeveloperApi - def column(f: proto.Expression.Builder => Unit): Column = { - val builder = proto.Expression.newBuilder() - f(builder) - column(builder.build()) - } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 315f80e13eff7..c557b54732797 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index b356751083fc1..53e12f58edd69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.types._ /** @@ -122,7 +121,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) (attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) { replaceCol(attr, replacementMap) } else { - column(attr) + Column(attr) } } df.select(projections : _*) @@ -131,7 +130,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) protected def fillMap(values: Seq[(String, Any)]): DataFrame = { // Error handling val attrToValue = AttributeMap(values.map { case (colName, replaceValue) => - // Check column name exists + // Check Column name exists val attr = df.resolve(colName) match { case a: Attribute => a case _ => throw QueryExecutionErrors.nestedFieldUnsupportedError(colName) @@ -155,7 +154,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) case v: jl.Integer => fillCol[Integer](attr, v) case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue()) case v: String => fillCol[String](attr, v) - }.getOrElse(column(attr)) + }.getOrElse(Column(attr)) } df.select(projections : _*) } @@ -165,7 +164,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) * with `replacement`. */ private def fillCol[T](attr: Attribute, replacement: T): Column = { - fillCol(attr.dataType, attr.name, column(attr), replacement) + fillCol(attr.dataType, attr.name, Column(attr), replacement) } /** @@ -192,7 +191,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) val branches = replacementMap.flatMap { case (source, target) => Seq(Literal(source), buildExpr(target)) }.toSeq - column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) + Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) } private def convertToDouble(v: Any): Double = v match { @@ -219,7 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. val predicate = AtLeastNNonNulls(minNonNulls.getOrElse(cols.size), cols) - df.filter(column(predicate)) + df.filter(Column(predicate)) } private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame = { @@ -255,9 +254,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(_.semanticEquals(col))) { - fillCol(col.dataType, col.name, column(col), value) + fillCol(col.dataType, col.name, Column(col), value) } else { - column(col) + Column(col) } } df.select(projections : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 80ec70a7864c3..18fc5787a1583 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -63,7 +63,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Data import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} -import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ @@ -303,7 +302,7 @@ class Dataset[T] private[sql]( truncate: Int): Seq[Seq[String]] = { val newDf = commandResultOptimized.toDF() val castCols = newDf.logicalPlan.output.map { col => - column(ToPrettyString(col)) + Column(ToPrettyString(col)) } val data = newDf.select(castCols: _*).take(numRows + 1) @@ -505,7 +504,7 @@ class Dataset[T] private[sql]( s"New column names (${colNames.size}): " + colNames.mkString(", ")) val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => - column(oldAttribute).as(newName) + Column(oldAttribute).as(newName) } select(newCols : _*) } @@ -760,18 +759,18 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def col(colName: String): Column = colName match { case "*" => - column(ResolvedStar(queryExecution.analyzed.output)) + Column(ResolvedStar(queryExecution.analyzed.output)) case _ => if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) { colRegex(colName) } else { - column(addDataFrameIdToCol(resolve(colName))) + Column(addDataFrameIdToCol(resolve(colName))) } } /** @inheritdoc */ def metadataColumn(colName: String): Column = - column(queryExecution.analyzed.getMetadataAttributeByName(colName)) + Column(queryExecution.analyzed.getMetadataAttributeByName(colName)) // Attach the dataset id and column position to the column reference, so that we can detect // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. @@ -797,11 +796,11 @@ class Dataset[T] private[sql]( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis colName match { case ParserUtils.escapedIdentifier(columnNameRegex) => - column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) + Column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => - column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) + Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) case _ => - column(addDataFrameIdToCol(resolve(colName))) + Column(addDataFrameIdToCol(resolve(colName))) } } @@ -1194,7 +1193,7 @@ class Dataset[T] private[sql]( resolver(field.name, colName) } match { case Some((colName: String, col: Column)) => col.as(colName) - case _ => column(field) + case _ => Column(field) } } @@ -1264,7 +1263,7 @@ class Dataset[T] private[sql]( val allColumns = queryExecution.analyzed.output val remainingCols = allColumns.filter { attribute => colNames.forall(n => !resolver(attribute.name, n)) - }.map(attribute => column(attribute)) + }.map(attribute => Column(attribute)) if (remainingCols.size == allColumns.size) { toDF() } else { @@ -1975,6 +1974,14 @@ class Dataset[T] private[sql]( // For Python API //////////////////////////////////////////////////////////////////////////// + /** + * It adds a new long column with the name `name` that increases one by one. + * This is for 'distributed-sequence' default index in pandas API on Spark. + */ + private[sql] def withSequenceColumn(name: String) = { + select(Column(DistributedSequenceID()).alias(name), col("*")) + } + /** * Converts a JavaRDD to a PythonRDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala index af91b57a6848b..8c3223fa72f55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala @@ -20,11 +20,13 @@ import scala.language.implicitConversions import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.ExpressionUtils /** * Conversions from sql interfaces to the Classic specific implementation. * - * This class is mainly used by the implementation, but is also meant to be used by extension + * This class is mainly used by the implementation. It is also meant to be used by extension * developers. * * We provide both a trait and an object. The trait is useful in situations where an extension @@ -45,6 +47,13 @@ trait ClassicConversions { implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V]) : KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + + /** + * Helper that makes it easy to construct a Column from an Expression. + */ + implicit class ColumnConstructorExt(val c: Column.type) { + def apply(e: Expression): Column = ExpressionUtils.column(e) + } } object ClassicConversions extends ClassicConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala index ca439cdb89958..f25ca387db299 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -84,7 +84,7 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends sqlConf.contains(key) } - private def requireNonStaticConf(key: String): Unit = { + private[sql] def requireNonStaticConf(key: String): Unit = { if (SQLConf.isStaticConfigKey(key)) { throw QueryCompilationErrors.cannotModifyValueOfStaticConfigError(key) } From 828b1f94734af8a629e80b1ec2d7f25326c69411 Mon Sep 17 00:00:00 2001 From: bogao007 Date: Wed, 25 Sep 2024 11:05:15 +0900 Subject: [PATCH 073/250] [SPARK-49463] Support ListState for TransformWithStateInPandas ### What changes were proposed in this pull request? Support ListState for TransformWithStateInPandas ### Why are the changes needed? Adding new functionality for TransformWithStateInPandas ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? Added new unit tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47933 from bogao007/list-state. Authored-by: bogao007 Signed-off-by: Jungtaek Lim --- python/pyspark/sql/pandas/types.py | 36 + .../pyspark/sql/streaming/StateMessage_pb2.py | 71 +- .../sql/streaming/StateMessage_pb2.pyi | 313 +- .../sql/streaming/list_state_client.py | 187 + .../sql/streaming/stateful_processor.py | 87 +- .../stateful_processor_api_client.py | 45 +- .../sql/streaming/value_state_client.py | 8 +- .../test_pandas_transform_with_state.py | 63 + .../apache/spark/sql/internal/SQLConf.scala | 11 + .../execution/streaming/StateMessage.proto | 27 + .../streaming/state/StateMessage.java | 4942 +++++++++++++++-- ...ansformWithStateInPandasDeserializer.scala | 60 + ...ansformWithStateInPandasPythonRunner.scala | 3 +- ...ransformWithStateInPandasStateServer.scala | 205 +- ...ormWithStateInPandasStateServerSuite.scala | 157 +- 15 files changed, 5570 insertions(+), 645 deletions(-) create mode 100644 python/pyspark/sql/streaming/list_state_client.py create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasDeserializer.scala diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py index 53c72304adfaa..57e46901013fe 100644 --- a/python/pyspark/sql/pandas/types.py +++ b/python/pyspark/sql/pandas/types.py @@ -53,12 +53,17 @@ ) from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError from pyspark.loose_version import LooseVersion +from pyspark.sql.utils import has_numpy + +if has_numpy: + import numpy as np if TYPE_CHECKING: import pandas as pd import pyarrow as pa from pyspark.sql.pandas._typing import SeriesLike as PandasSeriesLike + from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike def to_arrow_type( @@ -1344,3 +1349,34 @@ def _deduplicate_field_names(dt: DataType) -> DataType: ) else: return dt + + +def _to_numpy_type(type: DataType) -> Optional["np.dtype"]: + """Convert Spark data type to NumPy type.""" + import numpy as np + + if type == ByteType(): + return np.dtype("int8") + elif type == ShortType(): + return np.dtype("int16") + elif type == IntegerType(): + return np.dtype("int32") + elif type == LongType(): + return np.dtype("int64") + elif type == FloatType(): + return np.dtype("float32") + elif type == DoubleType(): + return np.dtype("float64") + return None + + +def convert_pandas_using_numpy_type( + df: "PandasDataFrameLike", schema: StructType +) -> "PandasDataFrameLike": + for field in schema.fields: + if isinstance( + field.dataType, (ByteType, ShortType, LongType, FloatType, DoubleType, IntegerType) + ): + np_type = _to_numpy_type(field.dataType) + df[field.name] = df[field.name].astype(np_type) + return df diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.py b/python/pyspark/sql/streaming/StateMessage_pb2.py index a22f004fd3048..e75d0394ea0f5 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.py +++ b/python/pyspark/sql/streaming/StateMessage_pb2.py @@ -16,12 +16,14 @@ # # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE # source: StateMessage.proto +# Protobuf Python Version: 5.27.3 """Generated protocol buffer code.""" -from google.protobuf.internal import builder as _builder from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports) @@ -29,45 +31,54 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"z\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 + b'\n\x12StateMessage.proto\x12.org.apache.spark.sql.execution.streaming.state"\xe9\x02\n\x0cStateRequest\x12\x0f\n\x07version\x18\x01 \x01(\x05\x12\x66\n\x15statefulProcessorCall\x18\x02 \x01(\x0b\x32\x45.org.apache.spark.sql.execution.streaming.state.StatefulProcessorCallH\x00\x12\x64\n\x14stateVariableRequest\x18\x03 \x01(\x0b\x32\x44.org.apache.spark.sql.execution.streaming.state.StateVariableRequestH\x00\x12p\n\x1aimplicitGroupingKeyRequest\x18\x04 \x01(\x0b\x32J.org.apache.spark.sql.execution.streaming.state.ImplicitGroupingKeyRequestH\x00\x42\x08\n\x06method"H\n\rStateResponse\x12\x12\n\nstatusCode\x18\x01 \x01(\x05\x12\x14\n\x0c\x65rrorMessage\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\x0c"\x89\x03\n\x15StatefulProcessorCall\x12X\n\x0esetHandleState\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetHandleStateH\x00\x12Y\n\rgetValueState\x18\x02 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12X\n\x0cgetListState\x18\x03 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x12W\n\x0bgetMapState\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.StateCallCommandH\x00\x42\x08\n\x06method"\xd2\x01\n\x14StateVariableRequest\x12X\n\x0evalueStateCall\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.ValueStateCallH\x00\x12V\n\rlistStateCall\x18\x02 \x01(\x0b\x32=.org.apache.spark.sql.execution.streaming.state.ListStateCallH\x00\x42\x08\n\x06method"\xe0\x01\n\x1aImplicitGroupingKeyRequest\x12X\n\x0esetImplicitKey\x18\x01 \x01(\x0b\x32>.org.apache.spark.sql.execution.streaming.state.SetImplicitKeyH\x00\x12^\n\x11removeImplicitKey\x18\x02 \x01(\x0b\x32\x41.org.apache.spark.sql.execution.streaming.state.RemoveImplicitKeyH\x00\x42\x08\n\x06method"}\n\x10StateCallCommand\x12\x11\n\tstateName\x18\x01 \x01(\t\x12\x0e\n\x06schema\x18\x02 \x01(\t\x12\x46\n\x03ttl\x18\x03 \x01(\x0b\x32\x39.org.apache.spark.sql.execution.streaming.state.TTLConfig"\xe1\x02\n\x0eValueStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12\x42\n\x03get\x18\x03 \x01(\x0b\x32\x33.org.apache.spark.sql.execution.streaming.state.GetH\x00\x12\\\n\x10valueStateUpdate\x18\x04 \x01(\x0b\x32@.org.apache.spark.sql.execution.streaming.state.ValueStateUpdateH\x00\x12\x46\n\x05\x63lear\x18\x05 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x90\x04\n\rListStateCall\x12\x11\n\tstateName\x18\x01 \x01(\t\x12H\n\x06\x65xists\x18\x02 \x01(\x0b\x32\x36.org.apache.spark.sql.execution.streaming.state.ExistsH\x00\x12T\n\x0clistStateGet\x18\x03 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStateGetH\x00\x12T\n\x0clistStatePut\x18\x04 \x01(\x0b\x32<.org.apache.spark.sql.execution.streaming.state.ListStatePutH\x00\x12R\n\x0b\x61ppendValue\x18\x05 \x01(\x0b\x32;.org.apache.spark.sql.execution.streaming.state.AppendValueH\x00\x12P\n\nappendList\x18\x06 \x01(\x0b\x32:.org.apache.spark.sql.execution.streaming.state.AppendListH\x00\x12\x46\n\x05\x63lear\x18\x07 \x01(\x0b\x32\x35.org.apache.spark.sql.execution.streaming.state.ClearH\x00\x42\x08\n\x06method"\x1d\n\x0eSetImplicitKey\x12\x0b\n\x03key\x18\x01 \x01(\x0c"\x13\n\x11RemoveImplicitKey"\x08\n\x06\x45xists"\x05\n\x03Get"!\n\x10ValueStateUpdate\x12\r\n\x05value\x18\x01 \x01(\x0c"\x07\n\x05\x43lear""\n\x0cListStateGet\x12\x12\n\niteratorId\x18\x01 \x01(\t"\x0e\n\x0cListStatePut"\x1c\n\x0b\x41ppendValue\x12\r\n\x05value\x18\x01 \x01(\x0c"\x0c\n\nAppendList"\\\n\x0eSetHandleState\x12J\n\x05state\x18\x01 \x01(\x0e\x32;.org.apache.spark.sql.execution.streaming.state.HandleState"\x1f\n\tTTLConfig\x12\x12\n\ndurationMs\x18\x01 \x01(\x05*K\n\x0bHandleState\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bINITIALIZED\x10\x01\x12\x12\n\x0e\x44\x41TA_PROCESSED\x10\x02\x12\n\n\x06\x43LOSED\x10\x03\x62\x06proto3' # noqa: E501 ) _globals = globals() - _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "StateMessage_pb2", _globals) if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._options = None - _globals["_HANDLESTATE"]._serialized_start = 1978 - _globals["_HANDLESTATE"]._serialized_end = 2053 + DESCRIPTOR._loaded_options = None + _globals["_HANDLESTATE"]._serialized_start = 2694 + _globals["_HANDLESTATE"]._serialized_end = 2769 _globals["_STATEREQUEST"]._serialized_start = 71 _globals["_STATEREQUEST"]._serialized_end = 432 _globals["_STATERESPONSE"]._serialized_start = 434 _globals["_STATERESPONSE"]._serialized_end = 506 _globals["_STATEFULPROCESSORCALL"]._serialized_start = 509 _globals["_STATEFULPROCESSORCALL"]._serialized_end = 902 - _globals["_STATEVARIABLEREQUEST"]._serialized_start = 904 - _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1026 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1029 - _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1253 - _globals["_STATECALLCOMMAND"]._serialized_start = 1255 - _globals["_STATECALLCOMMAND"]._serialized_end = 1380 - _globals["_VALUESTATECALL"]._serialized_start = 1383 - _globals["_VALUESTATECALL"]._serialized_end = 1736 - _globals["_SETIMPLICITKEY"]._serialized_start = 1738 - _globals["_SETIMPLICITKEY"]._serialized_end = 1767 - _globals["_REMOVEIMPLICITKEY"]._serialized_start = 1769 - _globals["_REMOVEIMPLICITKEY"]._serialized_end = 1788 - _globals["_EXISTS"]._serialized_start = 1790 - _globals["_EXISTS"]._serialized_end = 1798 - _globals["_GET"]._serialized_start = 1800 - _globals["_GET"]._serialized_end = 1805 - _globals["_VALUESTATEUPDATE"]._serialized_start = 1807 - _globals["_VALUESTATEUPDATE"]._serialized_end = 1840 - _globals["_CLEAR"]._serialized_start = 1842 - _globals["_CLEAR"]._serialized_end = 1849 - _globals["_SETHANDLESTATE"]._serialized_start = 1851 - _globals["_SETHANDLESTATE"]._serialized_end = 1943 - _globals["_TTLCONFIG"]._serialized_start = 1945 - _globals["_TTLCONFIG"]._serialized_end = 1976 + _globals["_STATEVARIABLEREQUEST"]._serialized_start = 905 + _globals["_STATEVARIABLEREQUEST"]._serialized_end = 1115 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_start = 1118 + _globals["_IMPLICITGROUPINGKEYREQUEST"]._serialized_end = 1342 + _globals["_STATECALLCOMMAND"]._serialized_start = 1344 + _globals["_STATECALLCOMMAND"]._serialized_end = 1469 + _globals["_VALUESTATECALL"]._serialized_start = 1472 + _globals["_VALUESTATECALL"]._serialized_end = 1825 + _globals["_LISTSTATECALL"]._serialized_start = 1828 + _globals["_LISTSTATECALL"]._serialized_end = 2356 + _globals["_SETIMPLICITKEY"]._serialized_start = 2358 + _globals["_SETIMPLICITKEY"]._serialized_end = 2387 + _globals["_REMOVEIMPLICITKEY"]._serialized_start = 2389 + _globals["_REMOVEIMPLICITKEY"]._serialized_end = 2408 + _globals["_EXISTS"]._serialized_start = 2410 + _globals["_EXISTS"]._serialized_end = 2418 + _globals["_GET"]._serialized_start = 2420 + _globals["_GET"]._serialized_end = 2425 + _globals["_VALUESTATEUPDATE"]._serialized_start = 2427 + _globals["_VALUESTATEUPDATE"]._serialized_end = 2460 + _globals["_CLEAR"]._serialized_start = 2462 + _globals["_CLEAR"]._serialized_end = 2469 + _globals["_LISTSTATEGET"]._serialized_start = 2471 + _globals["_LISTSTATEGET"]._serialized_end = 2505 + _globals["_LISTSTATEPUT"]._serialized_start = 2507 + _globals["_LISTSTATEPUT"]._serialized_end = 2521 + _globals["_APPENDVALUE"]._serialized_start = 2523 + _globals["_APPENDVALUE"]._serialized_end = 2551 + _globals["_APPENDLIST"]._serialized_start = 2553 + _globals["_APPENDLIST"]._serialized_end = 2565 + _globals["_SETHANDLESTATE"]._serialized_start = 2567 + _globals["_SETHANDLESTATE"]._serialized_end = 2659 + _globals["_TTLCONFIG"]._serialized_start = 2661 + _globals["_TTLCONFIG"]._serialized_end = 2692 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/streaming/StateMessage_pb2.pyi b/python/pyspark/sql/streaming/StateMessage_pb2.pyi index 1ab48a27c8f87..b1f5f0f7d2a1e 100644 --- a/python/pyspark/sql/streaming/StateMessage_pb2.pyi +++ b/python/pyspark/sql/streaming/StateMessage_pb2.pyi @@ -13,167 +13,238 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# + from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from typing import ClassVar, Mapping, Optional, Union +from typing import ( + ClassVar as _ClassVar, + Mapping as _Mapping, + Optional as _Optional, + Union as _Union, +) -CLOSED: HandleState -CREATED: HandleState -DATA_PROCESSED: HandleState DESCRIPTOR: _descriptor.FileDescriptor -INITIALIZED: HandleState -class Clear(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Exists(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class Get(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class ImplicitGroupingKeyRequest(_message.Message): - __slots__ = ["removeImplicitKey", "setImplicitKey"] - REMOVEIMPLICITKEY_FIELD_NUMBER: ClassVar[int] - SETIMPLICITKEY_FIELD_NUMBER: ClassVar[int] - removeImplicitKey: RemoveImplicitKey - setImplicitKey: SetImplicitKey - def __init__( - self, - setImplicitKey: Optional[Union[SetImplicitKey, Mapping]] = ..., - removeImplicitKey: Optional[Union[RemoveImplicitKey, Mapping]] = ..., - ) -> None: ... - -class RemoveImplicitKey(_message.Message): +class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): __slots__ = () - def __init__(self) -> None: ... - -class SetHandleState(_message.Message): - __slots__ = ["state"] - STATE_FIELD_NUMBER: ClassVar[int] - state: HandleState - def __init__(self, state: Optional[Union[HandleState, str]] = ...) -> None: ... - -class SetImplicitKey(_message.Message): - __slots__ = ["key"] - KEY_FIELD_NUMBER: ClassVar[int] - key: bytes - def __init__(self, key: Optional[bytes] = ...) -> None: ... + CREATED: _ClassVar[HandleState] + INITIALIZED: _ClassVar[HandleState] + DATA_PROCESSED: _ClassVar[HandleState] + CLOSED: _ClassVar[HandleState] -class StateCallCommand(_message.Message): - __slots__ = ["schema", "stateName", "ttl"] - SCHEMA_FIELD_NUMBER: ClassVar[int] - STATENAME_FIELD_NUMBER: ClassVar[int] - TTL_FIELD_NUMBER: ClassVar[int] - schema: str - stateName: str - ttl: TTLConfig - def __init__( - self, - stateName: Optional[str] = ..., - schema: Optional[str] = ..., - ttl: Optional[Union[TTLConfig, Mapping]] = ..., - ) -> None: ... +CREATED: HandleState +INITIALIZED: HandleState +DATA_PROCESSED: HandleState +CLOSED: HandleState class StateRequest(_message.Message): - __slots__ = [ - "implicitGroupingKeyRequest", - "stateVariableRequest", - "statefulProcessorCall", + __slots__ = ( "version", - ] - IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: ClassVar[int] - STATEFULPROCESSORCALL_FIELD_NUMBER: ClassVar[int] - STATEVARIABLEREQUEST_FIELD_NUMBER: ClassVar[int] - VERSION_FIELD_NUMBER: ClassVar[int] - implicitGroupingKeyRequest: ImplicitGroupingKeyRequest - stateVariableRequest: StateVariableRequest - statefulProcessorCall: StatefulProcessorCall + "statefulProcessorCall", + "stateVariableRequest", + "implicitGroupingKeyRequest", + ) + VERSION_FIELD_NUMBER: _ClassVar[int] + STATEFULPROCESSORCALL_FIELD_NUMBER: _ClassVar[int] + STATEVARIABLEREQUEST_FIELD_NUMBER: _ClassVar[int] + IMPLICITGROUPINGKEYREQUEST_FIELD_NUMBER: _ClassVar[int] version: int + statefulProcessorCall: StatefulProcessorCall + stateVariableRequest: StateVariableRequest + implicitGroupingKeyRequest: ImplicitGroupingKeyRequest def __init__( self, - version: Optional[int] = ..., - statefulProcessorCall: Optional[Union[StatefulProcessorCall, Mapping]] = ..., - stateVariableRequest: Optional[Union[StateVariableRequest, Mapping]] = ..., - implicitGroupingKeyRequest: Optional[Union[ImplicitGroupingKeyRequest, Mapping]] = ..., + version: _Optional[int] = ..., + statefulProcessorCall: _Optional[_Union[StatefulProcessorCall, _Mapping]] = ..., + stateVariableRequest: _Optional[_Union[StateVariableRequest, _Mapping]] = ..., + implicitGroupingKeyRequest: _Optional[_Union[ImplicitGroupingKeyRequest, _Mapping]] = ..., ) -> None: ... class StateResponse(_message.Message): - __slots__ = ["errorMessage", "statusCode", "value"] - ERRORMESSAGE_FIELD_NUMBER: ClassVar[int] - STATUSCODE_FIELD_NUMBER: ClassVar[int] - VALUE_FIELD_NUMBER: ClassVar[int] - errorMessage: str + __slots__ = ("statusCode", "errorMessage", "value") + STATUSCODE_FIELD_NUMBER: _ClassVar[int] + ERRORMESSAGE_FIELD_NUMBER: _ClassVar[int] + VALUE_FIELD_NUMBER: _ClassVar[int] statusCode: int + errorMessage: str value: bytes def __init__( self, - statusCode: Optional[int] = ..., - errorMessage: Optional[str] = ..., - value: Optional[bytes] = ..., + statusCode: _Optional[int] = ..., + errorMessage: _Optional[str] = ..., + value: _Optional[bytes] = ..., ) -> None: ... -class StateVariableRequest(_message.Message): - __slots__ = ["valueStateCall"] - VALUESTATECALL_FIELD_NUMBER: ClassVar[int] - valueStateCall: ValueStateCall - def __init__(self, valueStateCall: Optional[Union[ValueStateCall, Mapping]] = ...) -> None: ... - class StatefulProcessorCall(_message.Message): - __slots__ = ["getListState", "getMapState", "getValueState", "setHandleState"] - GETLISTSTATE_FIELD_NUMBER: ClassVar[int] - GETMAPSTATE_FIELD_NUMBER: ClassVar[int] - GETVALUESTATE_FIELD_NUMBER: ClassVar[int] - SETHANDLESTATE_FIELD_NUMBER: ClassVar[int] + __slots__ = ("setHandleState", "getValueState", "getListState", "getMapState") + SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int] + GETVALUESTATE_FIELD_NUMBER: _ClassVar[int] + GETLISTSTATE_FIELD_NUMBER: _ClassVar[int] + GETMAPSTATE_FIELD_NUMBER: _ClassVar[int] + setHandleState: SetHandleState + getValueState: StateCallCommand getListState: StateCallCommand getMapState: StateCallCommand - getValueState: StateCallCommand - setHandleState: SetHandleState def __init__( self, - setHandleState: Optional[Union[SetHandleState, Mapping]] = ..., - getValueState: Optional[Union[StateCallCommand, Mapping]] = ..., - getListState: Optional[Union[StateCallCommand, Mapping]] = ..., - getMapState: Optional[Union[StateCallCommand, Mapping]] = ..., + setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ..., + getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., + getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ..., ) -> None: ... -class TTLConfig(_message.Message): - __slots__ = ["durationMs"] - DURATIONMS_FIELD_NUMBER: ClassVar[int] - durationMs: int - def __init__(self, durationMs: Optional[int] = ...) -> None: ... +class StateVariableRequest(_message.Message): + __slots__ = ("valueStateCall", "listStateCall") + VALUESTATECALL_FIELD_NUMBER: _ClassVar[int] + LISTSTATECALL_FIELD_NUMBER: _ClassVar[int] + valueStateCall: ValueStateCall + listStateCall: ListStateCall + def __init__( + self, + valueStateCall: _Optional[_Union[ValueStateCall, _Mapping]] = ..., + listStateCall: _Optional[_Union[ListStateCall, _Mapping]] = ..., + ) -> None: ... + +class ImplicitGroupingKeyRequest(_message.Message): + __slots__ = ("setImplicitKey", "removeImplicitKey") + SETIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] + REMOVEIMPLICITKEY_FIELD_NUMBER: _ClassVar[int] + setImplicitKey: SetImplicitKey + removeImplicitKey: RemoveImplicitKey + def __init__( + self, + setImplicitKey: _Optional[_Union[SetImplicitKey, _Mapping]] = ..., + removeImplicitKey: _Optional[_Union[RemoveImplicitKey, _Mapping]] = ..., + ) -> None: ... + +class StateCallCommand(_message.Message): + __slots__ = ("stateName", "schema", "ttl") + STATENAME_FIELD_NUMBER: _ClassVar[int] + SCHEMA_FIELD_NUMBER: _ClassVar[int] + TTL_FIELD_NUMBER: _ClassVar[int] + stateName: str + schema: str + ttl: TTLConfig + def __init__( + self, + stateName: _Optional[str] = ..., + schema: _Optional[str] = ..., + ttl: _Optional[_Union[TTLConfig, _Mapping]] = ..., + ) -> None: ... class ValueStateCall(_message.Message): - __slots__ = ["clear", "exists", "get", "stateName", "valueStateUpdate"] - CLEAR_FIELD_NUMBER: ClassVar[int] - EXISTS_FIELD_NUMBER: ClassVar[int] - GET_FIELD_NUMBER: ClassVar[int] - STATENAME_FIELD_NUMBER: ClassVar[int] - VALUESTATEUPDATE_FIELD_NUMBER: ClassVar[int] - clear: Clear + __slots__ = ("stateName", "exists", "get", "valueStateUpdate", "clear") + STATENAME_FIELD_NUMBER: _ClassVar[int] + EXISTS_FIELD_NUMBER: _ClassVar[int] + GET_FIELD_NUMBER: _ClassVar[int] + VALUESTATEUPDATE_FIELD_NUMBER: _ClassVar[int] + CLEAR_FIELD_NUMBER: _ClassVar[int] + stateName: str exists: Exists get: Get - stateName: str valueStateUpdate: ValueStateUpdate + clear: Clear + def __init__( + self, + stateName: _Optional[str] = ..., + exists: _Optional[_Union[Exists, _Mapping]] = ..., + get: _Optional[_Union[Get, _Mapping]] = ..., + valueStateUpdate: _Optional[_Union[ValueStateUpdate, _Mapping]] = ..., + clear: _Optional[_Union[Clear, _Mapping]] = ..., + ) -> None: ... + +class ListStateCall(_message.Message): + __slots__ = ( + "stateName", + "exists", + "listStateGet", + "listStatePut", + "appendValue", + "appendList", + "clear", + ) + STATENAME_FIELD_NUMBER: _ClassVar[int] + EXISTS_FIELD_NUMBER: _ClassVar[int] + LISTSTATEGET_FIELD_NUMBER: _ClassVar[int] + LISTSTATEPUT_FIELD_NUMBER: _ClassVar[int] + APPENDVALUE_FIELD_NUMBER: _ClassVar[int] + APPENDLIST_FIELD_NUMBER: _ClassVar[int] + CLEAR_FIELD_NUMBER: _ClassVar[int] + stateName: str + exists: Exists + listStateGet: ListStateGet + listStatePut: ListStatePut + appendValue: AppendValue + appendList: AppendList + clear: Clear def __init__( self, - stateName: Optional[str] = ..., - exists: Optional[Union[Exists, Mapping]] = ..., - get: Optional[Union[Get, Mapping]] = ..., - valueStateUpdate: Optional[Union[ValueStateUpdate, Mapping]] = ..., - clear: Optional[Union[Clear, Mapping]] = ..., + stateName: _Optional[str] = ..., + exists: _Optional[_Union[Exists, _Mapping]] = ..., + listStateGet: _Optional[_Union[ListStateGet, _Mapping]] = ..., + listStatePut: _Optional[_Union[ListStatePut, _Mapping]] = ..., + appendValue: _Optional[_Union[AppendValue, _Mapping]] = ..., + appendList: _Optional[_Union[AppendList, _Mapping]] = ..., + clear: _Optional[_Union[Clear, _Mapping]] = ..., ) -> None: ... +class SetImplicitKey(_message.Message): + __slots__ = ("key",) + KEY_FIELD_NUMBER: _ClassVar[int] + key: bytes + def __init__(self, key: _Optional[bytes] = ...) -> None: ... + +class RemoveImplicitKey(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Exists(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class Get(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + class ValueStateUpdate(_message.Message): - __slots__ = ["value"] - VALUE_FIELD_NUMBER: ClassVar[int] + __slots__ = ("value",) + VALUE_FIELD_NUMBER: _ClassVar[int] value: bytes - def __init__(self, value: Optional[bytes] = ...) -> None: ... + def __init__(self, value: _Optional[bytes] = ...) -> None: ... -class HandleState(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): +class Clear(_message.Message): __slots__ = () + def __init__(self) -> None: ... + +class ListStateGet(_message.Message): + __slots__ = ("iteratorId",) + ITERATORID_FIELD_NUMBER: _ClassVar[int] + iteratorId: str + def __init__(self, iteratorId: _Optional[str] = ...) -> None: ... + +class ListStatePut(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class AppendValue(_message.Message): + __slots__ = ("value",) + VALUE_FIELD_NUMBER: _ClassVar[int] + value: bytes + def __init__(self, value: _Optional[bytes] = ...) -> None: ... + +class AppendList(_message.Message): + __slots__ = () + def __init__(self) -> None: ... + +class SetHandleState(_message.Message): + __slots__ = ("state",) + STATE_FIELD_NUMBER: _ClassVar[int] + state: HandleState + def __init__(self, state: _Optional[_Union[HandleState, str]] = ...) -> None: ... + +class TTLConfig(_message.Message): + __slots__ = ("durationMs",) + DURATIONMS_FIELD_NUMBER: _ClassVar[int] + durationMs: int + def __init__(self, durationMs: _Optional[int] = ...) -> None: ... diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py new file mode 100644 index 0000000000000..93306eca425eb --- /dev/null +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -0,0 +1,187 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Dict, Iterator, List, Union, cast, Tuple + +from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.types import StructType, TYPE_CHECKING, _parse_datatype_string +from pyspark.errors import PySparkRuntimeError +import uuid + +if TYPE_CHECKING: + from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike + +__all__ = ["ListStateClient"] + + +class ListStateClient: + def __init__(self, stateful_processor_api_client: StatefulProcessorApiClient) -> None: + self._stateful_processor_api_client = stateful_processor_api_client + # A dictionary to store the mapping between list state name and a tuple of pandas DataFrame + # and the index of the last row that was read. + self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {} + + def exists(self, state_name: str) -> bool: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + exists_call = stateMessage.Exists() + list_state_call = stateMessage.ListStateCall(stateName=state_name, exists=exists_call) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + return True + elif status == 2: + # Expect status code is 2 when state variable doesn't have a value. + return False + else: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError( + f"Error checking value state exists: " f"{response_message[1]}" + ) + + def get(self, state_name: str, iterator_id: str) -> Tuple: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if iterator_id in self.pandas_df_dict: + # If the state is already in the dictionary, return the next row. + pandas_df, index = self.pandas_df_dict[iterator_id] + else: + # If the state is not in the dictionary, fetch the state from the server. + get_call = stateMessage.ListStateGet(iteratorId=iterator_id) + list_state_call = stateMessage.ListStateCall( + stateName=state_name, listStateGet=get_call + ) + state_variable_request = stateMessage.StateVariableRequest( + listStateCall=list_state_call + ) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status == 0: + iterator = self._stateful_processor_api_client._read_arrow_state() + batch = next(iterator) + pandas_df = batch.to_pandas() + index = 0 + else: + raise StopIteration() + + new_index = index + 1 + if new_index < len(pandas_df): + # Update the index in the dictionary. + self.pandas_df_dict[iterator_id] = (pandas_df, new_index) + else: + # If the index is at the end of the DataFrame, remove the state from the dictionary. + self.pandas_df_dict.pop(iterator_id, None) + pandas_row = pandas_df.iloc[index] + return tuple(pandas_row) + + def append_value(self, state_name: str, schema: Union[StructType, str], value: Tuple) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + bytes = self._stateful_processor_api_client._serialize_to_bytes(schema, value) + append_value_call = stateMessage.AppendValue(value=bytes) + list_state_call = stateMessage.ListStateCall( + stateName=state_name, appendValue=append_value_call + ) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") + + def append_list( + self, state_name: str, schema: Union[StructType, str], values: List[Tuple] + ) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + append_list_call = stateMessage.AppendList() + list_state_call = stateMessage.ListStateCall( + stateName=state_name, appendList=append_list_call + ) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + + self._stateful_processor_api_client._send_arrow_state(schema, values) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") + + def put(self, state_name: str, schema: Union[StructType, str], values: List[Tuple]) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + put_call = stateMessage.ListStatePut() + list_state_call = stateMessage.ListStateCall(stateName=state_name, listStatePut=put_call) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + + self._stateful_processor_api_client._send_arrow_state(schema, values) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error updating value state: " f"{response_message[1]}") + + def clear(self, state_name: str) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + clear_call = stateMessage.Clear() + list_state_call = stateMessage.ListStateCall(stateName=state_name, clear=clear_call) + state_variable_request = stateMessage.StateVariableRequest(listStateCall=list_state_call) + message = stateMessage.StateRequest(stateVariableRequest=state_variable_request) + + self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) + response_message = self._stateful_processor_api_client._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error clearing value state: " f"{response_message[1]}") + + +class ListStateIterator: + def __init__(self, list_state_client: ListStateClient, state_name: str): + self.list_state_client = list_state_client + self.state_name = state_name + # Generate a unique identifier for the iterator to make sure iterators from the same + # list state do not interfere with each other. + self.iterator_id = str(uuid.uuid4()) + + def __iter__(self) -> Iterator[Tuple]: + return self + + def __next__(self) -> Tuple: + return self.list_state_client.get(self.state_name, self.iterator_id) diff --git a/python/pyspark/sql/streaming/stateful_processor.py b/python/pyspark/sql/streaming/stateful_processor.py index 9045c81e287cd..0011b62132ade 100644 --- a/python/pyspark/sql/streaming/stateful_processor.py +++ b/python/pyspark/sql/streaming/stateful_processor.py @@ -16,12 +16,12 @@ # from abc import ABC, abstractmethod -from typing import Any, TYPE_CHECKING, Iterator, Optional, Union, cast +from typing import Any, List, TYPE_CHECKING, Iterator, Optional, Union, Tuple -from pyspark.sql import Row from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient +from pyspark.sql.streaming.list_state_client import ListStateClient, ListStateIterator from pyspark.sql.streaming.value_state_client import ValueStateClient -from pyspark.sql.types import StructType, _create_row, _parse_datatype_string +from pyspark.sql.types import StructType if TYPE_CHECKING: from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike @@ -50,19 +50,11 @@ def exists(self) -> bool: """ return self._value_state_client.exists(self._state_name) - def get(self) -> Optional[Row]: + def get(self) -> Optional[Tuple]: """ Get the state value if it exists. Returns None if the state variable does not have a value. """ - value = self._value_state_client.get(self._state_name) - if value is None: - return None - schema = self.schema - if isinstance(schema, str): - schema = cast(StructType, _parse_datatype_string(schema)) - # Create the Row using the values and schema fields - row = _create_row(schema.fieldNames(), value) - return row + return self._value_state_client.get(self._state_name) def update(self, new_value: Any) -> None: """ @@ -77,6 +69,58 @@ def clear(self) -> None: self._value_state_client.clear(self._state_name) +class ListState: + """ + Class used for arbitrary stateful operations with transformWithState to capture list value + state. + + .. versionadded:: 4.0.0 + """ + + def __init__( + self, list_state_client: ListStateClient, state_name: str, schema: Union[StructType, str] + ) -> None: + self._list_state_client = list_state_client + self._state_name = state_name + self.schema = schema + + def exists(self) -> bool: + """ + Whether list state exists or not. + """ + return self._list_state_client.exists(self._state_name) + + def get(self) -> Iterator[Tuple]: + """ + Get list state with an iterator. + """ + return ListStateIterator(self._list_state_client, self._state_name) + + def put(self, new_state: List[Tuple]) -> None: + """ + Update the values of the list state. + """ + self._list_state_client.put(self._state_name, self.schema, new_state) + + def append_value(self, new_state: Tuple) -> None: + """ + Append a new value to the list state. + """ + self._list_state_client.append_value(self._state_name, self.schema, new_state) + + def append_list(self, new_state: List[Tuple]) -> None: + """ + Append a list of new values to the list state. + """ + self._list_state_client.append_list(self._state_name, self.schema, new_state) + + def clear(self) -> None: + """ + Remove this state. + """ + self._list_state_client.clear(self._state_name) + + class StatefulProcessorHandle: """ Represents the operation handle provided to the stateful processor used in transformWithState @@ -112,6 +156,23 @@ def getValueState( self.stateful_processor_api_client.get_value_state(state_name, schema, ttl_duration_ms) return ValueState(ValueStateClient(self.stateful_processor_api_client), state_name, schema) + def getListState(self, state_name: str, schema: Union[StructType, str]) -> ListState: + """ + Function to create new or return existing single value state variable of given type. + The user must ensure to call this function only within the `init()` method of the + :class:`StatefulProcessor`. + + Parameters + ---------- + state_name : str + name of the state variable + schema : :class:`pyspark.sql.types.DataType` or str + The schema of the state variable. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + """ + self.stateful_processor_api_client.get_list_state(state_name, schema) + return ListState(ListStateClient(self.stateful_processor_api_client), state_name, schema) + class StatefulProcessor(ABC): """ diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 9703aa17d3474..2a5e55159e766 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -17,10 +17,16 @@ from enum import Enum import os import socket -from typing import Any, Union, Optional, cast, Tuple +from typing import Any, List, Union, Optional, cast, Tuple from pyspark.serializers import write_int, read_int, UTF8Deserializer -from pyspark.sql.types import StructType, _parse_datatype_string, Row +from pyspark.sql.pandas.serializers import ArrowStreamSerializer +from pyspark.sql.types import ( + StructType, + _parse_datatype_string, + Row, +) +from pyspark.sql.pandas.types import convert_pandas_using_numpy_type from pyspark.sql.utils import has_numpy from pyspark.serializers import CPickleSerializer from pyspark.errors import PySparkRuntimeError @@ -46,6 +52,7 @@ def __init__(self, state_server_port: int, key_schema: StructType) -> None: self.handle_state = StatefulProcessorHandleState.CREATED self.utf8_deserializer = UTF8Deserializer() self.pickleSer = CPickleSerializer() + self.serializer = ArrowStreamSerializer() def set_handle_state(self, state: StatefulProcessorHandleState) -> None: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage @@ -124,6 +131,25 @@ def get_value_state( # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") + def get_list_state(self, state_name: str, schema: Union[StructType, str]) -> None: + import pyspark.sql.streaming.StateMessage_pb2 as stateMessage + + if isinstance(schema, str): + schema = cast(StructType, _parse_datatype_string(schema)) + + state_call_command = stateMessage.StateCallCommand() + state_call_command.stateName = state_name + state_call_command.schema = schema.json() + call = stateMessage.StatefulProcessorCall(getListState=state_call_command) + message = stateMessage.StateRequest(statefulProcessorCall=call) + + self._send_proto_message(message.SerializeToString()) + response_message = self._receive_proto_message() + status = response_message[0] + if status != 0: + # TODO(SPARK-49233): Classify user facing errors. + raise PySparkRuntimeError(f"Error initializing value state: " f"{response_message[1]}") + def _send_proto_message(self, message: bytes) -> None: # Writing zero here to indicate message version. This allows us to evolve the message # format or even changing the message protocol in the future. @@ -168,3 +194,18 @@ def _serialize_to_bytes(self, schema: StructType, data: Tuple) -> bytes: def _deserialize_from_bytes(self, value: bytes) -> Any: return self.pickleSer.loads(value) + + def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None: + import pyarrow as pa + import pandas as pd + + column_names = [field.name for field in schema.fields] + pandas_df = convert_pandas_using_numpy_type( + pd.DataFrame(state, columns=column_names), schema + ) + batch = pa.RecordBatch.from_pandas(pandas_df) + self.serializer.dump_stream(iter([batch]), self.sockfile) + self.sockfile.flush() + + def _read_arrow_state(self) -> Any: + return self.serializer.load_stream(self.sockfile) diff --git a/python/pyspark/sql/streaming/value_state_client.py b/python/pyspark/sql/streaming/value_state_client.py index e902f70cb40a5..3fe32bcc5235c 100644 --- a/python/pyspark/sql/streaming/value_state_client.py +++ b/python/pyspark/sql/streaming/value_state_client.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Union, cast, Tuple +from typing import Union, cast, Tuple, Optional from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient from pyspark.sql.types import StructType, _parse_datatype_string @@ -49,7 +49,7 @@ def exists(self, state_name: str) -> bool: f"Error checking value state exists: " f"{response_message[1]}" ) - def get(self, state_name: str) -> Any: + def get(self, state_name: str) -> Optional[Tuple]: import pyspark.sql.streaming.StateMessage_pb2 as stateMessage get_call = stateMessage.Get() @@ -63,8 +63,8 @@ def get(self, state_name: str) -> Any: if status == 0: if len(response_message[2]) == 0: return None - row = self._stateful_processor_api_client._deserialize_from_bytes(response_message[2]) - return row + data = self._stateful_processor_api_client._deserialize_from_bytes(response_message[2]) + return tuple(data) else: # TODO(SPARK-49233): Classify user facing errors. raise PySparkRuntimeError(f"Error getting value state: " f"{response_message[1]}") diff --git a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py index 8ad24704de3a4..99333ae6f5c26 100644 --- a/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py +++ b/python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py @@ -59,6 +59,7 @@ def conf(cls): "spark.sql.streaming.stateStore.providerClass", "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider", ) + cfg.set("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch", "2") return cfg def _prepare_input_data(self, input_path, col1, col2): @@ -211,6 +212,15 @@ def test_transform_with_state_in_pandas_query_restarts(self): Row(id="1", countAsString="2"), } + def test_transform_with_state_in_pandas_list_state(self): + def check_results(batch_df, _): + assert set(batch_df.sort("id").collect()) == { + Row(id="0", countAsString="2"), + Row(id="1", countAsString="2"), + } + + self._test_transform_with_state_in_pandas_basic(ListStateProcessor(), check_results, True) + # test value state with ttl has the same behavior as value state when # state doesn't expire. def test_value_state_ttl_basic(self): @@ -394,6 +404,59 @@ def close(self) -> None: pass +class ListStateProcessor(StatefulProcessor): + # Dict to store the expected results. The key represents the grouping key string, and the value + # is a dictionary of pandas dataframe index -> expected temperature value. Since we set + # maxRecordsPerBatch to 2, we expect the pandas dataframe dictionary to have 2 entries. + dict = {0: 120, 1: 20} + + def init(self, handle: StatefulProcessorHandle) -> None: + state_schema = StructType([StructField("temperature", IntegerType(), True)]) + self.list_state1 = handle.getListState("listState1", state_schema) + self.list_state2 = handle.getListState("listState2", state_schema) + + def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]: + count = 0 + for pdf in rows: + list_state_rows = [(120,), (20,)] + self.list_state1.put(list_state_rows) + self.list_state2.put(list_state_rows) + self.list_state1.append_value((111,)) + self.list_state2.append_value((222,)) + self.list_state1.append_list(list_state_rows) + self.list_state2.append_list(list_state_rows) + pdf_count = pdf.count() + count += pdf_count.get("temperature") + iter1 = self.list_state1.get() + iter2 = self.list_state2.get() + # Mixing the iterator to test it we can resume from the correct point + assert next(iter1)[0] == self.dict[0] + assert next(iter2)[0] == self.dict[0] + assert next(iter1)[0] == self.dict[1] + assert next(iter2)[0] == self.dict[1] + # Get another iterator for list_state1 to test if the 2 iterators (iter1 and iter3) don't + # interfere with each other. + iter3 = self.list_state1.get() + assert next(iter3)[0] == self.dict[0] + assert next(iter3)[0] == self.dict[1] + # the second arrow batch should contain the appended value 111 for list_state1 and + # 222 for list_state2 + assert next(iter1)[0] == 111 + assert next(iter2)[0] == 222 + assert next(iter3)[0] == 111 + # since we put another 2 rows after 111/222, check them here + assert next(iter1)[0] == self.dict[0] + assert next(iter2)[0] == self.dict[0] + assert next(iter3)[0] == self.dict[0] + assert next(iter1)[0] == self.dict[1] + assert next(iter2)[0] == self.dict[1] + assert next(iter3)[0] == self.dict[1] + yield pd.DataFrame({"id": key, "countAsString": str(count)}) + + def close(self) -> None: + pass + + class TransformWithStateInPandasTests(TransformWithStateInPandasTestsMixin, ReusedSQLTestCase): pass diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9d51afd064d10..c9c227a21cfff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3217,6 +3217,14 @@ object SQLConf { .intConf .createWithDefault(10000) + val ARROW_TRANSFORM_WITH_STATE_IN_PANDAS_MAX_RECORDS_PER_BATCH = + buildConf("spark.sql.execution.arrow.transformWithStateInPandas.maxRecordsPerBatch") + .doc("When using TransformWithStateInPandas, limit the maximum number of state records " + + "that can be written to a single ArrowRecordBatch in memory.") + .version("4.0.0") + .intConf + .createWithDefault(10000) + val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = buildConf("spark.sql.execution.arrow.useLargeVarTypes") .doc("When using Apache Arrow, use large variable width vectors for string and binary " + @@ -5899,6 +5907,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def arrowTransformWithStateInPandasMaxRecordsPerBatch: Int = + getConf(ARROW_TRANSFORM_WITH_STATE_IN_PANDAS_MAX_RECORDS_PER_BATCH) + def arrowUseLargeVarTypes: Boolean = getConf(ARROW_EXECUTION_USE_LARGE_VAR_TYPES) def pandasUDFBufferSize: Int = getConf(PANDAS_UDF_BUFFER_SIZE) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto index 1ff90f27e173a..63728216ded1e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto @@ -46,6 +46,7 @@ message StatefulProcessorCall { message StateVariableRequest { oneof method { ValueStateCall valueStateCall = 1; + ListStateCall listStateCall = 2; } } @@ -72,6 +73,18 @@ message ValueStateCall { } } +message ListStateCall { + string stateName = 1; + oneof method { + Exists exists = 2; + ListStateGet listStateGet = 3; + ListStatePut listStatePut = 4; + AppendValue appendValue = 5; + AppendList appendList = 6; + Clear clear = 7; + } +} + message SetImplicitKey { bytes key = 1; } @@ -92,6 +105,20 @@ message ValueStateUpdate { message Clear { } +message ListStateGet { + string iteratorId = 1; +} + +message ListStatePut { +} + +message AppendValue { + bytes value = 1; +} + +message AppendList { +} + enum HandleState { CREATED = 0; INITIALIZED = 1; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java index 4fbb20be05b7b..d6d56dd732775 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java @@ -3462,6 +3462,21 @@ public interface StateVariableRequestOrBuilder extends */ org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCallOrBuilder getValueStateCallOrBuilder(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return Whether the listStateCall field is set. + */ + boolean hasListStateCall(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return The listStateCall. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getListStateCall(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder getListStateCallOrBuilder(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariableRequest.MethodCase getMethodCase(); } /** @@ -3510,6 +3525,7 @@ public enum MethodCase implements com.google.protobuf.Internal.EnumLite, com.google.protobuf.AbstractMessage.InternalOneOfEnum { VALUESTATECALL(1), + LISTSTATECALL(2), METHOD_NOT_SET(0); private final int value; private MethodCase(int value) { @@ -3528,6 +3544,7 @@ public static MethodCase valueOf(int value) { public static MethodCase forNumber(int value) { switch (value) { case 1: return VALUESTATECALL; + case 2: return LISTSTATECALL; case 0: return METHOD_NOT_SET; default: return null; } @@ -3574,6 +3591,37 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall.getDefaultInstance(); } + public static final int LISTSTATECALL_FIELD_NUMBER = 2; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return Whether the listStateCall field is set. + */ + @java.lang.Override + public boolean hasListStateCall() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return The listStateCall. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getListStateCall() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder getListStateCallOrBuilder() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -3591,6 +3639,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (methodCase_ == 1) { output.writeMessage(1, (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall) method_); } + if (methodCase_ == 2) { + output.writeMessage(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_); + } getUnknownFields().writeTo(output); } @@ -3604,6 +3655,10 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeMessageSize(1, (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCall) method_); } + if (methodCase_ == 2) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -3625,6 +3680,10 @@ public boolean equals(final java.lang.Object obj) { if (!getValueStateCall() .equals(other.getValueStateCall())) return false; break; + case 2: + if (!getListStateCall() + .equals(other.getListStateCall())) return false; + break; case 0: default: } @@ -3644,6 +3703,10 @@ public int hashCode() { hash = (37 * hash) + VALUESTATECALL_FIELD_NUMBER; hash = (53 * hash) + getValueStateCall().hashCode(); break; + case 2: + hash = (37 * hash) + LISTSTATECALL_FIELD_NUMBER; + hash = (53 * hash) + getListStateCall().hashCode(); + break; case 0: default: } @@ -3778,6 +3841,9 @@ public Builder clear() { if (valueStateCallBuilder_ != null) { valueStateCallBuilder_.clear(); } + if (listStateCallBuilder_ != null) { + listStateCallBuilder_.clear(); + } methodCase_ = 0; method_ = null; return this; @@ -3813,6 +3879,13 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.StateVariable result.method_ = valueStateCallBuilder_.build(); } } + if (methodCase_ == 2) { + if (listStateCallBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = listStateCallBuilder_.build(); + } + } result.methodCase_ = methodCase_; onBuilt(); return result; @@ -3867,6 +3940,10 @@ public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMes mergeValueStateCall(other.getValueStateCall()); break; } + case LISTSTATECALL: { + mergeListStateCall(other.getListStateCall()); + break; + } case METHOD_NOT_SET: { break; } @@ -3904,6 +3981,13 @@ public Builder mergeFrom( methodCase_ = 1; break; } // case 10 + case 18: { + input.readMessage( + getListStateCallFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 2; + break; + } // case 18 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -4076,6 +4160,148 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal onChanged();; return valueStateCallBuilder_; } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder> listStateCallBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return Whether the listStateCall field is set. + */ + @java.lang.Override + public boolean hasListStateCall() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + * @return The listStateCall. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getListStateCall() { + if (listStateCallBuilder_ == null) { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } else { + if (methodCase_ == 2) { + return listStateCallBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder setListStateCall(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall value) { + if (listStateCallBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + listStateCallBuilder_.setMessage(value); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder setListStateCall( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder builderForValue) { + if (listStateCallBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + listStateCallBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder mergeListStateCall(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall value) { + if (listStateCallBuilder_ == null) { + if (methodCase_ == 2 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 2) { + listStateCallBuilder_.mergeFrom(value); + } else { + listStateCallBuilder_.setMessage(value); + } + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public Builder clearListStateCall() { + if (listStateCallBuilder_ == null) { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + } + listStateCallBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder getListStateCallBuilder() { + return getListStateCallFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder getListStateCallOrBuilder() { + if ((methodCase_ == 2) && (listStateCallBuilder_ != null)) { + return listStateCallBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateCall listStateCall = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder> + getListStateCallFieldBuilder() { + if (listStateCallBuilder_ == null) { + if (!(methodCase_ == 2)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + listStateCallBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 2; + onChanged();; + return listStateCallBuilder_; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -7482,37 +7708,135 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateCal } - public interface SetImplicitKeyOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + public interface ListStateCallOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStateCall) com.google.protobuf.MessageOrBuilder { /** - * bytes key = 1; - * @return The key. + * string stateName = 1; + * @return The stateName. */ - com.google.protobuf.ByteString getKey(); + java.lang.String getStateName(); + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + com.google.protobuf.ByteString + getStateNameBytes(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + boolean hasExists(); + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists(); + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return Whether the listStateGet field is set. + */ + boolean hasListStateGet(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return The listStateGet. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getListStateGet(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder getListStateGetOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return Whether the listStatePut field is set. + */ + boolean hasListStatePut(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return The listStatePut. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getListStatePut(); + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder getListStatePutOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return Whether the appendValue field is set. + */ + boolean hasAppendValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return The appendValue. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getAppendValue(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder getAppendValueOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return Whether the appendList field is set. + */ + boolean hasAppendList(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return The appendList. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getAppendList(); + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder getAppendListOrBuilder(); + + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return Whether the clear field is set. + */ + boolean hasClear(); + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return The clear. + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear(); + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder(); + + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.MethodCase getMethodCase(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateCall} */ - public static final class SetImplicitKey extends + public static final class ListStateCall extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - SetImplicitKeyOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStateCall) + ListStateCallOrBuilder { private static final long serialVersionUID = 0L; - // Use SetImplicitKey.newBuilder() to construct. - private SetImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ListStateCall.newBuilder() to construct. + private ListStateCall(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private SetImplicitKey() { - key_ = com.google.protobuf.ByteString.EMPTY; + private ListStateCall() { + stateName_ = ""; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new SetImplicitKey(); + return new ListStateCall(); } @java.lang.Override @@ -7522,31 +7846,3583 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder.class); } - public static final int KEY_FIELD_NUMBER = 1; - private com.google.protobuf.ByteString key_; - /** - * bytes key = 1; - * @return The key. - */ - @java.lang.Override - public com.google.protobuf.ByteString getKey() { - return key_; - } + private int methodCase_ = 0; + private java.lang.Object method_; + public enum MethodCase + implements com.google.protobuf.Internal.EnumLite, + com.google.protobuf.AbstractMessage.InternalOneOfEnum { + EXISTS(2), + LISTSTATEGET(3), + LISTSTATEPUT(4), + APPENDVALUE(5), + APPENDLIST(6), + CLEAR(7), + METHOD_NOT_SET(0); + private final int value; + private MethodCase(int value) { + this.value = value; + } + /** + * @param value The number of the enum to look for. + * @return The enum associated with the given number. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static MethodCase valueOf(int value) { + return forNumber(value); + } - private byte memoizedIsInitialized = -1; - @java.lang.Override - public final boolean isInitialized() { + public static MethodCase forNumber(int value) { + switch (value) { + case 2: return EXISTS; + case 3: return LISTSTATEGET; + case 4: return LISTSTATEPUT; + case 5: return APPENDVALUE; + case 6: return APPENDLIST; + case 7: return CLEAR; + case 0: return METHOD_NOT_SET; + default: return null; + } + } + public int getNumber() { + return this.value; + } + }; + + public MethodCase + getMethodCase() { + return MethodCase.forNumber( + methodCase_); + } + + public static final int STATENAME_FIELD_NUMBER = 1; + private volatile java.lang.Object stateName_; + /** + * string stateName = 1; + * @return The stateName. + */ + @java.lang.Override + public java.lang.String getStateName() { + java.lang.Object ref = stateName_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + stateName_ = s; + return s; + } + } + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getStateNameBytes() { + java.lang.Object ref = stateName_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + stateName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int EXISTS_FIELD_NUMBER = 2; + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + @java.lang.Override + public boolean hasExists() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder() { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + + public static final int LISTSTATEGET_FIELD_NUMBER = 3; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return Whether the listStateGet field is set. + */ + @java.lang.Override + public boolean hasListStateGet() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return The listStateGet. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getListStateGet() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder getListStateGetOrBuilder() { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + + public static final int LISTSTATEPUT_FIELD_NUMBER = 4; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return Whether the listStatePut field is set. + */ + @java.lang.Override + public boolean hasListStatePut() { + return methodCase_ == 4; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return The listStatePut. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getListStatePut() { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder getListStatePutOrBuilder() { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + + public static final int APPENDVALUE_FIELD_NUMBER = 5; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return Whether the appendValue field is set. + */ + @java.lang.Override + public boolean hasAppendValue() { + return methodCase_ == 5; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return The appendValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getAppendValue() { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder getAppendValueOrBuilder() { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + + public static final int APPENDLIST_FIELD_NUMBER = 6; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return Whether the appendList field is set. + */ + @java.lang.Override + public boolean hasAppendList() { + return methodCase_ == 6; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return The appendList. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getAppendList() { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder getAppendListOrBuilder() { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + + public static final int CLEAR_FIELD_NUMBER = 7; + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return Whether the clear field is set. + */ + @java.lang.Override + public boolean hasClear() { + return methodCase_ == 7; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return The clear. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear() { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder() { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(stateName_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, stateName_); + } + if (methodCase_ == 2) { + output.writeMessage(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_); + } + if (methodCase_ == 3) { + output.writeMessage(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_); + } + if (methodCase_ == 4) { + output.writeMessage(4, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_); + } + if (methodCase_ == 5) { + output.writeMessage(5, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_); + } + if (methodCase_ == 6) { + output.writeMessage(6, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_); + } + if (methodCase_ == 7) { + output.writeMessage(7, (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(stateName_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, stateName_); + } + if (methodCase_ == 2) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_); + } + if (methodCase_ == 3) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_); + } + if (methodCase_ == 4) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(4, (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_); + } + if (methodCase_ == 5) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(5, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_); + } + if (methodCase_ == 6) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(6, (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_); + } + if (methodCase_ == 7) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(7, (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) obj; + + if (!getStateName() + .equals(other.getStateName())) return false; + if (!getMethodCase().equals(other.getMethodCase())) return false; + switch (methodCase_) { + case 2: + if (!getExists() + .equals(other.getExists())) return false; + break; + case 3: + if (!getListStateGet() + .equals(other.getListStateGet())) return false; + break; + case 4: + if (!getListStatePut() + .equals(other.getListStatePut())) return false; + break; + case 5: + if (!getAppendValue() + .equals(other.getAppendValue())) return false; + break; + case 6: + if (!getAppendList() + .equals(other.getAppendList())) return false; + break; + case 7: + if (!getClear() + .equals(other.getClear())) return false; + break; + case 0: + default: + } + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + STATENAME_FIELD_NUMBER; + hash = (53 * hash) + getStateName().hashCode(); + switch (methodCase_) { + case 2: + hash = (37 * hash) + EXISTS_FIELD_NUMBER; + hash = (53 * hash) + getExists().hashCode(); + break; + case 3: + hash = (37 * hash) + LISTSTATEGET_FIELD_NUMBER; + hash = (53 * hash) + getListStateGet().hashCode(); + break; + case 4: + hash = (37 * hash) + LISTSTATEPUT_FIELD_NUMBER; + hash = (53 * hash) + getListStatePut().hashCode(); + break; + case 5: + hash = (37 * hash) + APPENDVALUE_FIELD_NUMBER; + hash = (53 * hash) + getAppendValue().hashCode(); + break; + case 6: + hash = (37 * hash) + APPENDLIST_FIELD_NUMBER; + hash = (53 * hash) + getAppendList().hashCode(); + break; + case 7: + hash = (37 * hash) + CLEAR_FIELD_NUMBER; + hash = (53 * hash) + getClear().hashCode(); + break; + case 0: + default: + } + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateCall} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStateCall) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCallOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + stateName_ = ""; + + if (existsBuilder_ != null) { + existsBuilder_.clear(); + } + if (listStateGetBuilder_ != null) { + listStateGetBuilder_.clear(); + } + if (listStatePutBuilder_ != null) { + listStatePutBuilder_.clear(); + } + if (appendValueBuilder_ != null) { + appendValueBuilder_.clear(); + } + if (appendListBuilder_ != null) { + appendListBuilder_.clear(); + } + if (clearBuilder_ != null) { + clearBuilder_.clear(); + } + methodCase_ = 0; + method_ = null; + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall(this); + result.stateName_ = stateName_; + if (methodCase_ == 2) { + if (existsBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = existsBuilder_.build(); + } + } + if (methodCase_ == 3) { + if (listStateGetBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = listStateGetBuilder_.build(); + } + } + if (methodCase_ == 4) { + if (listStatePutBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = listStatePutBuilder_.build(); + } + } + if (methodCase_ == 5) { + if (appendValueBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = appendValueBuilder_.build(); + } + } + if (methodCase_ == 6) { + if (appendListBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = appendListBuilder_.build(); + } + } + if (methodCase_ == 7) { + if (clearBuilder_ == null) { + result.method_ = method_; + } else { + result.method_ = clearBuilder_.build(); + } + } + result.methodCase_ = methodCase_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall.getDefaultInstance()) return this; + if (!other.getStateName().isEmpty()) { + stateName_ = other.stateName_; + onChanged(); + } + switch (other.getMethodCase()) { + case EXISTS: { + mergeExists(other.getExists()); + break; + } + case LISTSTATEGET: { + mergeListStateGet(other.getListStateGet()); + break; + } + case LISTSTATEPUT: { + mergeListStatePut(other.getListStatePut()); + break; + } + case APPENDVALUE: { + mergeAppendValue(other.getAppendValue()); + break; + } + case APPENDLIST: { + mergeAppendList(other.getAppendList()); + break; + } + case CLEAR: { + mergeClear(other.getClear()); + break; + } + case METHOD_NOT_SET: { + break; + } + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + stateName_ = input.readStringRequireUtf8(); + + break; + } // case 10 + case 18: { + input.readMessage( + getExistsFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 2; + break; + } // case 18 + case 26: { + input.readMessage( + getListStateGetFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 3; + break; + } // case 26 + case 34: { + input.readMessage( + getListStatePutFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 4; + break; + } // case 34 + case 42: { + input.readMessage( + getAppendValueFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 5; + break; + } // case 42 + case 50: { + input.readMessage( + getAppendListFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 6; + break; + } // case 50 + case 58: { + input.readMessage( + getClearFieldBuilder().getBuilder(), + extensionRegistry); + methodCase_ = 7; + break; + } // case 58 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + private int methodCase_ = 0; + private java.lang.Object method_; + public MethodCase + getMethodCase() { + return MethodCase.forNumber( + methodCase_); + } + + public Builder clearMethod() { + methodCase_ = 0; + method_ = null; + onChanged(); + return this; + } + + + private java.lang.Object stateName_ = ""; + /** + * string stateName = 1; + * @return The stateName. + */ + public java.lang.String getStateName() { + java.lang.Object ref = stateName_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + stateName_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string stateName = 1; + * @return The bytes for stateName. + */ + public com.google.protobuf.ByteString + getStateNameBytes() { + java.lang.Object ref = stateName_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + stateName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string stateName = 1; + * @param value The stateName to set. + * @return This builder for chaining. + */ + public Builder setStateName( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + stateName_ = value; + onChanged(); + return this; + } + /** + * string stateName = 1; + * @return This builder for chaining. + */ + public Builder clearStateName() { + + stateName_ = getDefaultInstance().getStateName(); + onChanged(); + return this; + } + /** + * string stateName = 1; + * @param value The bytes for stateName to set. + * @return This builder for chaining. + */ + public Builder setStateNameBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + stateName_ = value; + onChanged(); + return this; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder> existsBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return Whether the exists field is set. + */ + @java.lang.Override + public boolean hasExists() { + return methodCase_ == 2; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + * @return The exists. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getExists() { + if (existsBuilder_ == null) { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } else { + if (methodCase_ == 2) { + return existsBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder setExists(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists value) { + if (existsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + existsBuilder_.setMessage(value); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder setExists( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder builderForValue) { + if (existsBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + existsBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder mergeExists(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists value) { + if (existsBuilder_ == null) { + if (methodCase_ == 2 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 2) { + existsBuilder_.mergeFrom(value); + } else { + existsBuilder_.setMessage(value); + } + } + methodCase_ = 2; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public Builder clearExists() { + if (existsBuilder_ == null) { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 2) { + methodCase_ = 0; + method_ = null; + } + existsBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder getExistsBuilder() { + return getExistsFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder getExistsOrBuilder() { + if ((methodCase_ == 2) && (existsBuilder_ != null)) { + return existsBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 2) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Exists exists = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder> + getExistsFieldBuilder() { + if (existsBuilder_ == null) { + if (!(methodCase_ == 2)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + existsBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 2; + onChanged();; + return existsBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder> listStateGetBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return Whether the listStateGet field is set. + */ + @java.lang.Override + public boolean hasListStateGet() { + return methodCase_ == 3; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + * @return The listStateGet. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getListStateGet() { + if (listStateGetBuilder_ == null) { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } else { + if (methodCase_ == 3) { + return listStateGetBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder setListStateGet(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet value) { + if (listStateGetBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + listStateGetBuilder_.setMessage(value); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder setListStateGet( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder builderForValue) { + if (listStateGetBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + listStateGetBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder mergeListStateGet(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet value) { + if (listStateGetBuilder_ == null) { + if (methodCase_ == 3 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 3) { + listStateGetBuilder_.mergeFrom(value); + } else { + listStateGetBuilder_.setMessage(value); + } + } + methodCase_ = 3; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public Builder clearListStateGet() { + if (listStateGetBuilder_ == null) { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 3) { + methodCase_ = 0; + method_ = null; + } + listStateGetBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder getListStateGetBuilder() { + return getListStateGetFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder getListStateGetOrBuilder() { + if ((methodCase_ == 3) && (listStateGetBuilder_ != null)) { + return listStateGetBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 3) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStateGet listStateGet = 3; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder> + getListStateGetFieldBuilder() { + if (listStateGetBuilder_ == null) { + if (!(methodCase_ == 3)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); + } + listStateGetBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 3; + onChanged();; + return listStateGetBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder> listStatePutBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return Whether the listStatePut field is set. + */ + @java.lang.Override + public boolean hasListStatePut() { + return methodCase_ == 4; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + * @return The listStatePut. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getListStatePut() { + if (listStatePutBuilder_ == null) { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } else { + if (methodCase_ == 4) { + return listStatePutBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder setListStatePut(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut value) { + if (listStatePutBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + listStatePutBuilder_.setMessage(value); + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder setListStatePut( + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder builderForValue) { + if (listStatePutBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + listStatePutBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder mergeListStatePut(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut value) { + if (listStatePutBuilder_ == null) { + if (methodCase_ == 4 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 4) { + listStatePutBuilder_.mergeFrom(value); + } else { + listStatePutBuilder_.setMessage(value); + } + } + methodCase_ = 4; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public Builder clearListStatePut() { + if (listStatePutBuilder_ == null) { + if (methodCase_ == 4) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 4) { + methodCase_ = 0; + method_ = null; + } + listStatePutBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder getListStatePutBuilder() { + return getListStatePutFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder getListStatePutOrBuilder() { + if ((methodCase_ == 4) && (listStatePutBuilder_ != null)) { + return listStatePutBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 4) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.ListStatePut listStatePut = 4; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder> + getListStatePutFieldBuilder() { + if (listStatePutBuilder_ == null) { + if (!(methodCase_ == 4)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); + } + listStatePutBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 4; + onChanged();; + return listStatePutBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder> appendValueBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return Whether the appendValue field is set. + */ + @java.lang.Override + public boolean hasAppendValue() { + return methodCase_ == 5; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + * @return The appendValue. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getAppendValue() { + if (appendValueBuilder_ == null) { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } else { + if (methodCase_ == 5) { + return appendValueBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder setAppendValue(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue value) { + if (appendValueBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + appendValueBuilder_.setMessage(value); + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder setAppendValue( + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder builderForValue) { + if (appendValueBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + appendValueBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder mergeAppendValue(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue value) { + if (appendValueBuilder_ == null) { + if (methodCase_ == 5 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 5) { + appendValueBuilder_.mergeFrom(value); + } else { + appendValueBuilder_.setMessage(value); + } + } + methodCase_ = 5; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public Builder clearAppendValue() { + if (appendValueBuilder_ == null) { + if (methodCase_ == 5) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 5) { + methodCase_ = 0; + method_ = null; + } + appendValueBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder getAppendValueBuilder() { + return getAppendValueFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder getAppendValueOrBuilder() { + if ((methodCase_ == 5) && (appendValueBuilder_ != null)) { + return appendValueBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 5) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendValue appendValue = 5; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder> + getAppendValueFieldBuilder() { + if (appendValueBuilder_ == null) { + if (!(methodCase_ == 5)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); + } + appendValueBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 5; + onChanged();; + return appendValueBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder> appendListBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return Whether the appendList field is set. + */ + @java.lang.Override + public boolean hasAppendList() { + return methodCase_ == 6; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + * @return The appendList. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getAppendList() { + if (appendListBuilder_ == null) { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } else { + if (methodCase_ == 6) { + return appendListBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder setAppendList(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList value) { + if (appendListBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + appendListBuilder_.setMessage(value); + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder setAppendList( + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder builderForValue) { + if (appendListBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + appendListBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder mergeAppendList(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList value) { + if (appendListBuilder_ == null) { + if (methodCase_ == 6 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 6) { + appendListBuilder_.mergeFrom(value); + } else { + appendListBuilder_.setMessage(value); + } + } + methodCase_ = 6; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public Builder clearAppendList() { + if (appendListBuilder_ == null) { + if (methodCase_ == 6) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 6) { + methodCase_ = 0; + method_ = null; + } + appendListBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder getAppendListBuilder() { + return getAppendListFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder getAppendListOrBuilder() { + if ((methodCase_ == 6) && (appendListBuilder_ != null)) { + return appendListBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 6) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.AppendList appendList = 6; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder> + getAppendListFieldBuilder() { + if (appendListBuilder_ == null) { + if (!(methodCase_ == 6)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); + } + appendListBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 6; + onChanged();; + return appendListBuilder_; + } + + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder> clearBuilder_; + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return Whether the clear field is set. + */ + @java.lang.Override + public boolean hasClear() { + return methodCase_ == 7; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + * @return The clear. + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getClear() { + if (clearBuilder_ == null) { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } else { + if (methodCase_ == 7) { + return clearBuilder_.getMessage(); + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder setClear(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear value) { + if (clearBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + method_ = value; + onChanged(); + } else { + clearBuilder_.setMessage(value); + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder setClear( + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder builderForValue) { + if (clearBuilder_ == null) { + method_ = builderForValue.build(); + onChanged(); + } else { + clearBuilder_.setMessage(builderForValue.build()); + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder mergeClear(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear value) { + if (clearBuilder_ == null) { + if (methodCase_ == 7 && + method_ != org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_) + .mergeFrom(value).buildPartial(); + } else { + method_ = value; + } + onChanged(); + } else { + if (methodCase_ == 7) { + clearBuilder_.mergeFrom(value); + } else { + clearBuilder_.setMessage(value); + } + } + methodCase_ = 7; + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public Builder clearClear() { + if (clearBuilder_ == null) { + if (methodCase_ == 7) { + methodCase_ = 0; + method_ = null; + onChanged(); + } + } else { + if (methodCase_ == 7) { + methodCase_ = 0; + method_ = null; + } + clearBuilder_.clear(); + } + return this; + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder getClearBuilder() { + return getClearFieldBuilder().getBuilder(); + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder getClearOrBuilder() { + if ((methodCase_ == 7) && (clearBuilder_ != null)) { + return clearBuilder_.getMessageOrBuilder(); + } else { + if (methodCase_ == 7) { + return (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_; + } + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + } + /** + * .org.apache.spark.sql.execution.streaming.state.Clear clear = 7; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder> + getClearFieldBuilder() { + if (clearBuilder_ == null) { + if (!(methodCase_ == 7)) { + method_ = org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + } + clearBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder, org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder>( + (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) method_, + getParentForChildren(), + isClean()); + method_ = null; + } + methodCase_ = 7; + onChanged();; + return clearBuilder_; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStateCall) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStateCall) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public ListStateCall parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateCall getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface SetImplicitKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + com.google.protobuf.MessageOrBuilder { + + /** + * bytes key = 1; + * @return The key. + */ + com.google.protobuf.ByteString getKey(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + */ + public static final class SetImplicitKey extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + SetImplicitKeyOrBuilder { + private static final long serialVersionUID = 0L; + // Use SetImplicitKey.newBuilder() to construct. + private SetImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private SetImplicitKey() { + key_ = com.google.protobuf.ByteString.EMPTY; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new SetImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + } + + public static final int KEY_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString key_; + /** + * bytes key = 1; + * @return The key. + */ + @java.lang.Override + public com.google.protobuf.ByteString getKey() { + return key_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (!key_.isEmpty()) { + output.writeBytes(1, key_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (!key_.isEmpty()) { + size += com.google.protobuf.CodedOutputStream + .computeBytesSize(1, key_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) obj; + + if (!getKey() + .equals(other.getKey())) return false; + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + KEY_FIELD_NUMBER; + hash = (53 * hash) + getKey().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + key_ = com.google.protobuf.ByteString.EMPTY; + + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this); + result.key_ = key_; + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance()) return this; + if (other.getKey() != com.google.protobuf.ByteString.EMPTY) { + setKey(other.getKey()); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + key_ = input.readBytes(); + + break; + } // case 10 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + + private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; + /** + * bytes key = 1; + * @return The key. + */ + @java.lang.Override + public com.google.protobuf.ByteString getKey() { + return key_; + } + /** + * bytes key = 1; + * @param value The key to set. + * @return This builder for chaining. + */ + public Builder setKey(com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + + key_ = value; + onChanged(); + return this; + } + /** + * bytes key = 1; + * @return This builder for chaining. + */ + public Builder clearKey() { + + key_ = getDefaultInstance().getKey(); + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public SetImplicitKey parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface RemoveImplicitKeyOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + */ + public static final class RemoveImplicitKey extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + RemoveImplicitKeyOrBuilder { + private static final long serialVersionUID = 0L; + // Use RemoveImplicitKey.newBuilder() to construct. + private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private RemoveImplicitKey() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new RemoveImplicitKey(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public RemoveImplicitKey parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ExistsOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Exists) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + */ + public static final class Exists extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Exists) + ExistsOrBuilder { + private static final long serialVersionUID = 0L; + // Use Exists.newBuilder() to construct. + private Exists(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Exists() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Exists(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Exists) + org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Exists) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Exists DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Exists parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface GetOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Get) + com.google.protobuf.MessageOrBuilder { + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + */ + public static final class Get extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Get) + GetOrBuilder { + private static final long serialVersionUID = 0L; + // Use Get.newBuilder() to construct. + private Get(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private Get() { + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new Get(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get)) { + return super.equals(obj); + } + org.apache.spark.sql.execution.streaming.state.StateMessage.Get other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Get) obj; + + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Get prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Get) + org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + } + + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Get.newBuilder() + private Builder() { + + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + + } + @java.lang.Override + public Builder clear() { + super.clear(); + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance(); + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(this); + onBuilt(); + return result; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Get)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Get other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance()) return this; + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get) + } + + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Get) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Get DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(); + } + + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public Get parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + + } + + public interface ValueStateUpdateOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + com.google.protobuf.MessageOrBuilder { + + /** + * bytes value = 1; + * @return The value. + */ + com.google.protobuf.ByteString getValue(); + } + /** + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + */ + public static final class ValueStateUpdate extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + ValueStateUpdateOrBuilder { + private static final long serialVersionUID = 0L; + // Use ValueStateUpdate.newBuilder() to construct. + private ValueStateUpdate(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private ValueStateUpdate() { + value_ = com.google.protobuf.ByteString.EMPTY; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new ValueStateUpdate(); + } + + @java.lang.Override + public final com.google.protobuf.UnknownFieldSet + getUnknownFields() { + return this.unknownFields; + } + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + } + + public static final int VALUE_FIELD_NUMBER = 1; + private com.google.protobuf.ByteString value_; + /** + * bytes value = 1; + * @return The value. + */ + @java.lang.Override + public com.google.protobuf.ByteString getValue() { + return value_; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { byte isInitialized = memoizedIsInitialized; if (isInitialized == 1) return true; if (isInitialized == 0) return false; @@ -7558,8 +11434,8 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { - if (!key_.isEmpty()) { - output.writeBytes(1, key_); + if (!value_.isEmpty()) { + output.writeBytes(1, value_); } getUnknownFields().writeTo(output); } @@ -7570,9 +11446,9 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; - if (!key_.isEmpty()) { + if (!value_.isEmpty()) { size += com.google.protobuf.CodedOutputStream - .computeBytesSize(1, key_); + .computeBytesSize(1, value_); } size += getUnknownFields().getSerializedSize(); memoizedSize = size; @@ -7584,13 +11460,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) obj; - if (!getKey() - .equals(other.getKey())) return false; + if (!getValue() + .equals(other.getValue())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -7602,76 +11478,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); - hash = (37 * hash) + KEY_FIELD_NUMBER; - hash = (53 * hash) + getKey().hashCode(); + hash = (37 * hash) + VALUE_FIELD_NUMBER; + hash = (53 * hash) + getValue().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -7684,7 +11560,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImp public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -7700,26 +11576,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.SetImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKeyOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdateOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.newBuilder() private Builder() { } @@ -7732,7 +11608,7 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); - key_ = com.google.protobuf.ByteString.EMPTY; + value_ = com.google.protobuf.ByteString.EMPTY; return this; } @@ -7740,17 +11616,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -7758,9 +11634,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKe } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(this); - result.key_ = key_; + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this); + result.value_ = value_; onBuilt(); return result; } @@ -7799,18 +11675,18 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey.getDefaultInstance()) return this; - if (other.getKey() != com.google.protobuf.ByteString.EMPTY) { - setKey(other.getKey()); + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance()) return this; + if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { + setValue(other.getValue()); } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); @@ -7839,7 +11715,7 @@ public Builder mergeFrom( done = true; break; case 10: { - key_ = input.readBytes(); + value_ = input.readBytes(); break; } // case 10 @@ -7859,36 +11735,36 @@ public Builder mergeFrom( return this; } - private com.google.protobuf.ByteString key_ = com.google.protobuf.ByteString.EMPTY; + private com.google.protobuf.ByteString value_ = com.google.protobuf.ByteString.EMPTY; /** - * bytes key = 1; - * @return The key. + * bytes value = 1; + * @return The value. */ @java.lang.Override - public com.google.protobuf.ByteString getKey() { - return key_; + public com.google.protobuf.ByteString getValue() { + return value_; } /** - * bytes key = 1; - * @param value The key to set. + * bytes value = 1; + * @param value The value to set. * @return This builder for chaining. */ - public Builder setKey(com.google.protobuf.ByteString value) { + public Builder setValue(com.google.protobuf.ByteString value) { if (value == null) { throw new NullPointerException(); } - key_ = value; + value_ = value; onChanged(); return this; } /** - * bytes key = 1; + * bytes value = 1; * @return This builder for chaining. */ - public Builder clearKey() { + public Builder clearValue() { - key_ = getDefaultInstance().getKey(); + value_ = getDefaultInstance().getValue(); onChanged(); return this; } @@ -7905,23 +11781,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.SetImplicitKey) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public SetImplicitKey parsePartialFrom( + public ValueStateUpdate parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -7940,46 +11816,46 @@ public SetImplicitKey parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.SetImplicitKey getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface RemoveImplicitKeyOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + public interface ClearOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Clear) com.google.protobuf.MessageOrBuilder { } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} */ - public static final class RemoveImplicitKey extends + public static final class Clear extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - RemoveImplicitKeyOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Clear) + ClearOrBuilder { private static final long serialVersionUID = 0L; - // Use RemoveImplicitKey.newBuilder() to construct. - private RemoveImplicitKey(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use Clear.newBuilder() to construct. + private Clear(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private RemoveImplicitKey() { + private Clear() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new RemoveImplicitKey(); + return new Clear(); } @java.lang.Override @@ -7989,15 +11865,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); } private byte memoizedIsInitialized = -1; @@ -8033,10 +11909,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other = (org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) obj; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; @@ -8054,69 +11930,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -8129,7 +12005,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Remove public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -8145,26 +12021,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKeyOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Clear) + org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.class, org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder() private Builder() { } @@ -8183,17 +12059,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -8201,8 +12077,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplici } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey result = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(this); onBuilt(); return result; } @@ -8241,16 +12117,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -8305,23 +12181,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.RemoveImplicitKey) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Clear) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Clear DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public RemoveImplicitKey parsePartialFrom( + public Clear parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -8340,46 +12216,59 @@ public RemoveImplicitKey parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.RemoveImplicitKey getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ExistsOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Exists) + public interface ListStateGetOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStateGet) com.google.protobuf.MessageOrBuilder { + + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + java.lang.String getIteratorId(); + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + com.google.protobuf.ByteString + getIteratorIdBytes(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} */ - public static final class Exists extends + public static final class ListStateGet extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Exists) - ExistsOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) + ListStateGetOrBuilder { private static final long serialVersionUID = 0L; - // Use Exists.newBuilder() to construct. - private Exists(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ListStateGet.newBuilder() to construct. + private ListStateGet(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Exists() { + private ListStateGet() { + iteratorId_ = ""; } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Exists(); + return new ListStateGet(); } @java.lang.Override @@ -8389,15 +12278,53 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); + } + + public static final int ITERATORID_FIELD_NUMBER = 1; + private volatile java.lang.Object iteratorId_; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + @java.lang.Override + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } } private byte memoizedIsInitialized = -1; @@ -8414,6 +12341,9 @@ public final boolean isInitialized() { @java.lang.Override public void writeTo(com.google.protobuf.CodedOutputStream output) throws java.io.IOException { + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 1, iteratorId_); + } getUnknownFields().writeTo(output); } @@ -8423,6 +12353,9 @@ public int getSerializedSize() { if (size != -1) return size; size = 0; + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(iteratorId_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(1, iteratorId_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -8433,11 +12366,13 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) obj; + if (!getIteratorId() + .equals(other.getIteratorId())) return false; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -8449,74 +12384,76 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + ITERATORID_FIELD_NUMBER; + hash = (53 * hash) + getIteratorId().hashCode(); hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -8529,7 +12466,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -8545,26 +12482,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Exists} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStateGet} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Exists) - org.apache.spark.sql.execution.streaming.state.StateMessage.ExistsOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStateGet) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGetOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.newBuilder() private Builder() { } @@ -8577,23 +12514,25 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + iteratorId_ = ""; + return this; } @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -8601,8 +12540,9 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists build( } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Exists result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(this); + result.iteratorId_ = iteratorId_; onBuilt(); return result; } @@ -8641,16 +12581,20 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Exists) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Exists)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Exists other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Exists.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet.getDefaultInstance()) return this; + if (!other.getIteratorId().isEmpty()) { + iteratorId_ = other.iteratorId_; + onChanged(); + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -8677,6 +12621,11 @@ public Builder mergeFrom( case 0: done = true; break; + case 10: { + iteratorId_ = input.readStringRequireUtf8(); + + break; + } // case 10 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -8692,6 +12641,82 @@ public Builder mergeFrom( } // finally return this; } + + private java.lang.Object iteratorId_ = ""; + /** + * string iteratorId = 1; + * @return The iteratorId. + */ + public java.lang.String getIteratorId() { + java.lang.Object ref = iteratorId_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + iteratorId_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string iteratorId = 1; + * @return The bytes for iteratorId. + */ + public com.google.protobuf.ByteString + getIteratorIdBytes() { + java.lang.Object ref = iteratorId_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + iteratorId_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string iteratorId = 1; + * @param value The iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorId( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + iteratorId_ = value; + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @return This builder for chaining. + */ + public Builder clearIteratorId() { + + iteratorId_ = getDefaultInstance().getIteratorId(); + onChanged(); + return this; + } + /** + * string iteratorId = 1; + * @param value The bytes for iteratorId to set. + * @return This builder for chaining. + */ + public Builder setIteratorIdBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + iteratorId_ = value; + onChanged(); + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -8705,23 +12730,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Exists) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Exists) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Exists DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStateGet) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Exists(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Exists parsePartialFrom( + public ListStateGet parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -8740,46 +12765,46 @@ public Exists parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Exists getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStateGet getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface GetOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Get) + public interface ListStatePutOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ListStatePut) com.google.protobuf.MessageOrBuilder { } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} */ - public static final class Get extends + public static final class ListStatePut extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Get) - GetOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) + ListStatePutOrBuilder { private static final long serialVersionUID = 0L; - // Use Get.newBuilder() to construct. - private Get(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use ListStatePut.newBuilder() to construct. + private ListStatePut(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Get() { + private ListStatePut() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Get(); + return new ListStatePut(); } @java.lang.Override @@ -8789,15 +12814,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); } private byte memoizedIsInitialized = -1; @@ -8833,10 +12858,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Get other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Get) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) obj; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; @@ -8854,69 +12879,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -8929,7 +12954,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get pa public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Get prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -8945,26 +12970,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Get} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ListStatePut} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Get) - org.apache.spark.sql.execution.streaming.state.StateMessage.GetOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ListStatePut) + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePutOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Get.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Get.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Get.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.newBuilder() private Builder() { } @@ -8983,17 +13008,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -9001,8 +13026,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Get build() { } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Get result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(this); onBuilt(); return result; } @@ -9041,16 +13066,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Get) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Get)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Get other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Get.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -9105,23 +13130,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Get) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Get) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Get DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ListStatePut) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Get(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Get parsePartialFrom( + public ListStatePut parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -9140,24 +13165,24 @@ public Get parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Get getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.ListStatePut getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ValueStateUpdateOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + public interface AppendValueOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendValue) com.google.protobuf.MessageOrBuilder { /** @@ -9167,18 +13192,18 @@ public interface ValueStateUpdateOrBuilder extends com.google.protobuf.ByteString getValue(); } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} */ - public static final class ValueStateUpdate extends + public static final class AppendValue extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - ValueStateUpdateOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) + AppendValueOrBuilder { private static final long serialVersionUID = 0L; - // Use ValueStateUpdate.newBuilder() to construct. - private ValueStateUpdate(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use AppendValue.newBuilder() to construct. + private AppendValue(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private ValueStateUpdate() { + private AppendValue() { value_ = com.google.protobuf.ByteString.EMPTY; } @@ -9186,7 +13211,7 @@ private ValueStateUpdate() { @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new ValueStateUpdate(); + return new AppendValue(); } @java.lang.Override @@ -9196,15 +13221,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); } public static final int VALUE_FIELD_NUMBER = 1; @@ -9258,10 +13283,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other = (org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) obj; if (!getValue() .equals(other.getValue())) return false; @@ -9283,69 +13308,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -9358,7 +13383,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueS public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -9374,26 +13399,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.ValueStateUpdate} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendValue} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdateOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendValue) + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValueOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.class, org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.newBuilder() private Builder() { } @@ -9414,17 +13439,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -9432,8 +13457,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpd } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate result = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(this); result.value_ = value_; onBuilt(); return result; @@ -9473,16 +13498,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue.getDefaultInstance()) return this; if (other.getValue() != com.google.protobuf.ByteString.EMPTY) { setValue(other.getValue()); } @@ -9579,23 +13604,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.ValueStateUpdate) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendValue) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public ValueStateUpdate parsePartialFrom( + public AppendValue parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -9614,46 +13639,46 @@ public ValueStateUpdate parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.ValueStateUpdate getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendValue getDefaultInstanceForType() { return DEFAULT_INSTANCE; } } - public interface ClearOrBuilder extends - // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.Clear) + public interface AppendListOrBuilder extends + // @@protoc_insertion_point(interface_extends:org.apache.spark.sql.execution.streaming.state.AppendList) com.google.protobuf.MessageOrBuilder { } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} */ - public static final class Clear extends + public static final class AppendList extends com.google.protobuf.GeneratedMessageV3 implements - // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.Clear) - ClearOrBuilder { + // @@protoc_insertion_point(message_implements:org.apache.spark.sql.execution.streaming.state.AppendList) + AppendListOrBuilder { private static final long serialVersionUID = 0L; - // Use Clear.newBuilder() to construct. - private Clear(com.google.protobuf.GeneratedMessageV3.Builder builder) { + // Use AppendList.newBuilder() to construct. + private AppendList(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } - private Clear() { + private AppendList() { } @java.lang.Override @SuppressWarnings({"unused"}) protected java.lang.Object newInstance( UnusedPrivateParameter unused) { - return new Clear(); + return new AppendList(); } @java.lang.Override @@ -9663,15 +13688,15 @@ protected java.lang.Object newInstance( } public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); } private byte memoizedIsInitialized = -1; @@ -9707,10 +13732,10 @@ public boolean equals(final java.lang.Object obj) { if (obj == this) { return true; } - if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)) { + if (!(obj instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)) { return super.equals(obj); } - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other = (org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) obj; + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other = (org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) obj; if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; @@ -9728,69 +13753,69 @@ public int hashCode() { return hash; } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.nio.ByteBuffer data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.nio.ByteBuffer data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.ByteString data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.ByteString data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(byte[] data) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(byte[] data) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( byte[] data, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { return PARSER.parseFrom(data, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom(java.io.InputStream input) + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom(java.io.InputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseDelimitedFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseDelimitedFrom( java.io.InputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseDelimitedWithIOException(PARSER, input, extensionRegistry); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.CodedInputStream input) throws java.io.IOException { return com.google.protobuf.GeneratedMessageV3 .parseWithIOException(PARSER, input); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear parseFrom( + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList parseFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws java.io.IOException { @@ -9803,7 +13828,7 @@ public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear public static Builder newBuilder() { return DEFAULT_INSTANCE.toBuilder(); } - public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear prototype) { + public static Builder newBuilder(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList prototype) { return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); } @java.lang.Override @@ -9819,26 +13844,26 @@ protected Builder newBuilderForType( return builder; } /** - * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.Clear} + * Protobuf type {@code org.apache.spark.sql.execution.streaming.state.AppendList} */ public static final class Builder extends com.google.protobuf.GeneratedMessageV3.Builder implements - // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.Clear) - org.apache.spark.sql.execution.streaming.state.StateMessage.ClearOrBuilder { + // @@protoc_insertion_point(builder_implements:org.apache.spark.sql.execution.streaming.state.AppendList) + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendListOrBuilder { public static final com.google.protobuf.Descriptors.Descriptor getDescriptor() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internalGetFieldAccessorTable() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable .ensureFieldAccessorsInitialized( - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.class, org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.Builder.class); + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.class, org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.Builder.class); } - // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.newBuilder() + // Construct using org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.newBuilder() private Builder() { } @@ -9857,17 +13882,17 @@ public Builder clear() { @java.lang.Override public com.google.protobuf.Descriptors.Descriptor getDescriptorForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor; + return org.apache.spark.sql.execution.streaming.state.StateMessage.internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { - return org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { + return org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance(); } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = buildPartial(); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList build() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = buildPartial(); if (!result.isInitialized()) { throw newUninitializedMessageException(result); } @@ -9875,8 +13900,8 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear build() } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear buildPartial() { - org.apache.spark.sql.execution.streaming.state.StateMessage.Clear result = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(this); + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList buildPartial() { + org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList result = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(this); onBuilt(); return result; } @@ -9915,16 +13940,16 @@ public Builder addRepeatedField( } @java.lang.Override public Builder mergeFrom(com.google.protobuf.Message other) { - if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.Clear) { - return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.Clear)other); + if (other instanceof org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList) { + return mergeFrom((org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList)other); } else { super.mergeFrom(other); return this; } } - public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.Clear other) { - if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.Clear.getDefaultInstance()) return this; + public Builder mergeFrom(org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList other) { + if (other == org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList.getDefaultInstance()) return this; this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -9979,23 +14004,23 @@ public final Builder mergeUnknownFields( } - // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.Clear) + // @@protoc_insertion_point(builder_scope:org.apache.spark.sql.execution.streaming.state.AppendList) } - // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.Clear) - private static final org.apache.spark.sql.execution.streaming.state.StateMessage.Clear DEFAULT_INSTANCE; + // @@protoc_insertion_point(class_scope:org.apache.spark.sql.execution.streaming.state.AppendList) + private static final org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList DEFAULT_INSTANCE; static { - DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.Clear(); + DEFAULT_INSTANCE = new org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList(); } - public static org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstance() { + public static org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstance() { return DEFAULT_INSTANCE; } - private static final com.google.protobuf.Parser - PARSER = new com.google.protobuf.AbstractParser() { + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { @java.lang.Override - public Clear parsePartialFrom( + public AppendList parsePartialFrom( com.google.protobuf.CodedInputStream input, com.google.protobuf.ExtensionRegistryLite extensionRegistry) throws com.google.protobuf.InvalidProtocolBufferException { @@ -10014,17 +14039,17 @@ public Clear parsePartialFrom( } }; - public static com.google.protobuf.Parser parser() { + public static com.google.protobuf.Parser parser() { return PARSER; } @java.lang.Override - public com.google.protobuf.Parser getParserForType() { + public com.google.protobuf.Parser getParserForType() { return PARSER; } @java.lang.Override - public org.apache.spark.sql.execution.streaming.state.StateMessage.Clear getDefaultInstanceForType() { + public org.apache.spark.sql.execution.streaming.state.StateMessage.AppendList getDefaultInstanceForType() { return DEFAULT_INSTANCE; } @@ -11041,6 +15066,11 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor; private static final @@ -11071,6 +15101,26 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get private static final com.google.protobuf.GeneratedMessageV3.FieldAccessorTable internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable; + private static final com.google.protobuf.Descriptors.Descriptor + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor; + private static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable; private static final com.google.protobuf.Descriptors.Descriptor internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor; private static final @@ -11112,36 +15162,54 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get "xecution.streaming.state.StateCallComman" + "dH\000\022W\n\013getMapState\030\004 \001(\0132@.org.apache.sp" + "ark.sql.execution.streaming.state.StateC" + - "allCommandH\000B\010\n\006method\"z\n\024StateVariableR" + - "equest\022X\n\016valueStateCall\030\001 \001(\0132>.org.apa" + - "che.spark.sql.execution.streaming.state." + - "ValueStateCallH\000B\010\n\006method\"\340\001\n\032ImplicitG" + - "roupingKeyRequest\022X\n\016setImplicitKey\030\001 \001(" + - "\0132>.org.apache.spark.sql.execution.strea" + - "ming.state.SetImplicitKeyH\000\022^\n\021removeImp" + - "licitKey\030\002 \001(\0132A.org.apache.spark.sql.ex" + - "ecution.streaming.state.RemoveImplicitKe" + - "yH\000B\010\n\006method\"}\n\020StateCallCommand\022\021\n\tsta" + - "teName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\022F\n\003ttl\030\003 \001(" + - "\01329.org.apache.spark.sql.execution.strea" + - "ming.state.TTLConfig\"\341\002\n\016ValueStateCall\022" + - "\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326.org" + + "allCommandH\000B\010\n\006method\"\322\001\n\024StateVariable" + + "Request\022X\n\016valueStateCall\030\001 \001(\0132>.org.ap" + + "ache.spark.sql.execution.streaming.state" + + ".ValueStateCallH\000\022V\n\rlistStateCall\030\002 \001(\013" + + "2=.org.apache.spark.sql.execution.stream" + + "ing.state.ListStateCallH\000B\010\n\006method\"\340\001\n\032" + + "ImplicitGroupingKeyRequest\022X\n\016setImplici" + + "tKey\030\001 \001(\0132>.org.apache.spark.sql.execut" + + "ion.streaming.state.SetImplicitKeyH\000\022^\n\021" + + "removeImplicitKey\030\002 \001(\0132A.org.apache.spa" + + "rk.sql.execution.streaming.state.RemoveI" + + "mplicitKeyH\000B\010\n\006method\"}\n\020StateCallComma" + + "nd\022\021\n\tstateName\030\001 \001(\t\022\016\n\006schema\030\002 \001(\t\022F\n" + + "\003ttl\030\003 \001(\01329.org.apache.spark.sql.execut" + + "ion.streaming.state.TTLConfig\"\341\002\n\016ValueS" + + "tateCall\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 " + + "\001(\01326.org.apache.spark.sql.execution.str" + + "eaming.state.ExistsH\000\022B\n\003get\030\003 \001(\01323.org" + ".apache.spark.sql.execution.streaming.st" + - "ate.ExistsH\000\022B\n\003get\030\003 \001(\01323.org.apache.s" + - "park.sql.execution.streaming.state.GetH\000" + - "\022\\\n\020valueStateUpdate\030\004 \001(\0132@.org.apache." + - "spark.sql.execution.streaming.state.Valu" + - "eStateUpdateH\000\022F\n\005clear\030\005 \001(\01325.org.apac" + - "he.spark.sql.execution.streaming.state.C" + - "learH\000B\010\n\006method\"\035\n\016SetImplicitKey\022\013\n\003ke" + - "y\030\001 \001(\014\"\023\n\021RemoveImplicitKey\"\010\n\006Exists\"\005" + - "\n\003Get\"!\n\020ValueStateUpdate\022\r\n\005value\030\001 \001(\014" + - "\"\007\n\005Clear\"\\\n\016SetHandleState\022J\n\005state\030\001 \001" + - "(\0162;.org.apache.spark.sql.execution.stre" + - "aming.state.HandleState\"\037\n\tTTLConfig\022\022\n\n" + - "durationMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREAT" + - "ED\020\000\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020" + - "\002\022\n\n\006CLOSED\020\003b\006proto3" + "ate.GetH\000\022\\\n\020valueStateUpdate\030\004 \001(\0132@.or" + + "g.apache.spark.sql.execution.streaming.s" + + "tate.ValueStateUpdateH\000\022F\n\005clear\030\005 \001(\01325" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.ClearH\000B\010\n\006method\"\220\004\n\rListStateC" + + "all\022\021\n\tstateName\030\001 \001(\t\022H\n\006exists\030\002 \001(\01326" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.ExistsH\000\022T\n\014listStateGet\030\003 \001(\0132<" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.ListStateGetH\000\022T\n\014listStatePut\030\004" + + " \001(\0132<.org.apache.spark.sql.execution.st" + + "reaming.state.ListStatePutH\000\022R\n\013appendVa" + + "lue\030\005 \001(\0132;.org.apache.spark.sql.executi" + + "on.streaming.state.AppendValueH\000\022P\n\nappe" + + "ndList\030\006 \001(\0132:.org.apache.spark.sql.exec" + + "ution.streaming.state.AppendListH\000\022F\n\005cl" + + "ear\030\007 \001(\01325.org.apache.spark.sql.executi" + + "on.streaming.state.ClearH\000B\010\n\006method\"\035\n\016" + + "SetImplicitKey\022\013\n\003key\030\001 \001(\014\"\023\n\021RemoveImp" + + "licitKey\"\010\n\006Exists\"\005\n\003Get\"!\n\020ValueStateU" + + "pdate\022\r\n\005value\030\001 \001(\014\"\007\n\005Clear\"\"\n\014ListSta" + + "teGet\022\022\n\niteratorId\030\001 \001(\t\"\016\n\014ListStatePu" + + "t\"\034\n\013AppendValue\022\r\n\005value\030\001 \001(\014\"\014\n\nAppen" + + "dList\"\\\n\016SetHandleState\022J\n\005state\030\001 \001(\0162;" + + ".org.apache.spark.sql.execution.streamin" + + "g.state.HandleState\"\037\n\tTTLConfig\022\022\n\ndura" + + "tionMs\030\001 \001(\005*K\n\013HandleState\022\013\n\007CREATED\020\000" + + "\022\017\n\013INITIALIZED\020\001\022\022\n\016DATA_PROCESSED\020\002\022\n\n" + + "\006CLOSED\020\003b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -11170,7 +15238,7 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_StateVariableRequest_descriptor, - new java.lang.String[] { "ValueStateCall", "Method", }); + new java.lang.String[] { "ValueStateCall", "ListStateCall", "Method", }); internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_descriptor = getDescriptor().getMessageTypes().get(4); internal_static_org_apache_spark_sql_execution_streaming_state_ImplicitGroupingKeyRequest_fieldAccessorTable = new @@ -11189,50 +15257,80 @@ public org.apache.spark.sql.execution.streaming.state.StateMessage.TTLConfig get com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateCall_descriptor, new java.lang.String[] { "StateName", "Exists", "Get", "ValueStateUpdate", "Clear", "Method", }); - internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor = + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor = getDescriptor().getMessageTypes().get(7); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateCall_descriptor, + new java.lang.String[] { "StateName", "Exists", "ListStateGet", "ListStatePut", "AppendValue", "AppendList", "Clear", "Method", }); + internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor = + getDescriptor().getMessageTypes().get(8); internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_SetImplicitKey_descriptor, new java.lang.String[] { "Key", }); internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor = - getDescriptor().getMessageTypes().get(8); + getDescriptor().getMessageTypes().get(9); internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_RemoveImplicitKey_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor = - getDescriptor().getMessageTypes().get(9); + getDescriptor().getMessageTypes().get(10); internal_static_org_apache_spark_sql_execution_streaming_state_Exists_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Exists_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor = - getDescriptor().getMessageTypes().get(10); + getDescriptor().getMessageTypes().get(11); internal_static_org_apache_spark_sql_execution_streaming_state_Get_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Get_descriptor, new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor = - getDescriptor().getMessageTypes().get(11); + getDescriptor().getMessageTypes().get(12); internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_ValueStateUpdate_descriptor, new java.lang.String[] { "Value", }); internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor = - getDescriptor().getMessageTypes().get(12); + getDescriptor().getMessageTypes().get(13); internal_static_org_apache_spark_sql_execution_streaming_state_Clear_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_Clear_descriptor, new java.lang.String[] { }); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor = + getDescriptor().getMessageTypes().get(14); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_ListStateGet_descriptor, + new java.lang.String[] { "IteratorId", }); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor = + getDescriptor().getMessageTypes().get(15); + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_ListStatePut_descriptor, + new java.lang.String[] { }); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor = + getDescriptor().getMessageTypes().get(16); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_AppendValue_descriptor, + new java.lang.String[] { "Value", }); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor = + getDescriptor().getMessageTypes().get(17); + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_org_apache_spark_sql_execution_streaming_state_AppendList_descriptor, + new java.lang.String[] { }); internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor = - getDescriptor().getMessageTypes().get(13); + getDescriptor().getMessageTypes().get(18); internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_SetHandleState_descriptor, new java.lang.String[] { "State", }); internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor = - getDescriptor().getMessageTypes().get(14); + getDescriptor().getMessageTypes().get(19); internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_org_apache_spark_sql_execution_streaming_state_TTLConfig_descriptor, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasDeserializer.scala new file mode 100644 index 0000000000000..82d4978853cb6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasDeserializer.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.DataInputStream + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.arrow.vector.ipc.ArrowStreamReader + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} + +/** + * A helper class to deserialize state Arrow batches from the state socket in + * TransformWithStateInPandas. + */ +class TransformWithStateInPandasDeserializer(deserializer: ExpressionEncoder.Deserializer[Row]) + extends Logging { + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for transformWithStateInPandas state socket", 0, Long.MaxValue) + + /** + * Read Arrow batches from the given stream and deserialize them into rows. + */ + def readArrowBatches(stream: DataInputStream): Seq[Row] = { + val reader = new ArrowStreamReader(stream, allocator) + val root = reader.getVectorSchemaRoot + val vectors = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + val rows = ArrayBuffer[Row]() + while (reader.loadNextBatch()) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + rows.appendAll(batch.rowIterator().asScala.map(r => deserializer(r.copy()))) + } + reader.close(false) + rows.toSeq + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala index 7d0c177d1df8f..b4b516ba9e5a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala @@ -103,7 +103,8 @@ class TransformWithStateInPandasPythonRunner( executionContext.execute( new TransformWithStateInPandasStateServer(stateServerSocket, processorHandle, - groupingKeySchema)) + groupingKeySchema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes, + sqlConf.arrowTransformWithStateInPandasMaxRecordsPerBatch)) context.addTaskCompletionListener[Unit] { _ => logInfo(log"completion listener called") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala index b5ec26b401d28..d293e7a4a5bb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala @@ -24,15 +24,18 @@ import java.time.Duration import scala.collection.mutable import com.google.protobuf.ByteString +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.internal.{Logging, LogKeys, MDC} import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState} -import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} -import org.apache.spark.sql.streaming.{TTLConfig, ValueState} +import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, StatefulProcessorHandleImpl, StatefulProcessorHandleState, StateVariableType} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, ImplicitGroupingKeyRequest, ListStateCall, StatefulProcessorCall, StateRequest, StateResponse, StateVariableRequest, ValueStateCall} +import org.apache.spark.sql.streaming.{ListState, TTLConfig, ValueState} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** * This class is used to handle the state requests from the Python side. It runs on a separate @@ -48,9 +51,16 @@ class TransformWithStateInPandasStateServer( stateServerSocket: ServerSocket, statefulProcessorHandle: StatefulProcessorHandleImpl, groupingKeySchema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean, + arrowTransformWithStateInPandasMaxRecordsPerBatch: Int, outputStreamForTest: DataOutputStream = null, - valueStateMapForTest: mutable.HashMap[String, - (ValueState[Row], StructType, ExpressionEncoder.Deserializer[Row])] = null) + valueStateMapForTest: mutable.HashMap[String, ValueStateInfo] = null, + deserializerForTest: TransformWithStateInPandasDeserializer = null, + arrowStreamWriterForTest: BaseStreamingArrowWriter = null, + listStatesMapForTest : mutable.HashMap[String, ListStateInfo] = null, + listStateIteratorMapForTest: mutable.HashMap[String, Iterator[Row]] = null) extends Runnable with Logging { private val keyRowDeserializer: ExpressionEncoder.Deserializer[Row] = ExpressionEncoder(groupingKeySchema).resolveAndBind().createDeserializer() @@ -60,8 +70,22 @@ class TransformWithStateInPandasStateServer( private val valueStates = if (valueStateMapForTest != null) { valueStateMapForTest } else { - new mutable.HashMap[String, (ValueState[Row], StructType, - ExpressionEncoder.Deserializer[Row])]() + new mutable.HashMap[String, ValueStateInfo]() + } + // A map to store the list state name -> (list state, schema, list state row deserializer, + // list state row serializer) mapping. + private val listStates = if (listStatesMapForTest != null) { + listStatesMapForTest + } else { + new mutable.HashMap[String, ListStateInfo]() + } + // A map to store the iterator id -> iterator mapping. This is to keep track of the + // current iterator position for each list state in a grouping key in case user tries to fetch + // another list state before the current iterator is exhausted. + private var listStateIterators = if (listStateIteratorMapForTest != null) { + listStateIteratorMapForTest + } else { + new mutable.HashMap[String, Iterator[Row]]() } def run(): Unit = { @@ -125,9 +149,13 @@ class TransformWithStateInPandasStateServer( // The key row is serialized as a byte array, we need to convert it back to a Row val keyRow = PythonSQLUtils.toJVMRow(keyBytes, groupingKeySchema, keyRowDeserializer) ImplicitGroupingKeyTracker.setImplicitKey(keyRow) + // Reset the list state iterators for a new grouping key. + listStateIterators = new mutable.HashMap[String, Iterator[Row]]() sendResponse(0) case ImplicitGroupingKeyRequest.MethodCase.REMOVEIMPLICITKEY => ImplicitGroupingKeyTracker.removeImplicitKey() + // Reset the list state iterators for a new grouping key. + listStateIterators = new mutable.HashMap[String, Iterator[Row]]() sendResponse(0) case _ => throw new IllegalArgumentException("Invalid method call") @@ -157,7 +185,12 @@ class TransformWithStateInPandasStateServer( val ttlDurationMs = if (message.getGetValueState.hasTtl) { Some(message.getGetValueState.getTtl.getDurationMs) } else None - initializeValueState(stateName, schema, ttlDurationMs) + initializeStateVariable(stateName, schema, StateVariableType.ValueState, ttlDurationMs) + case StatefulProcessorCall.MethodCase.GETLISTSTATE => + val stateName = message.getGetListState.getStateName + val schema = message.getGetListState.getSchema + // TODO(SPARK-49744): Add ttl support for list state. + initializeStateVariable(stateName, schema, StateVariableType.ListState, None) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -167,6 +200,8 @@ class TransformWithStateInPandasStateServer( message.getMethodCase match { case StateVariableRequest.MethodCase.VALUESTATECALL => handleValueStateRequest(message.getValueStateCall) + case StateVariableRequest.MethodCase.LISTSTATECALL => + handleListStateRequest(message.getListStateCall) case _ => throw new IllegalArgumentException("Invalid method call") } @@ -179,16 +214,17 @@ class TransformWithStateInPandasStateServer( sendResponse(1, s"Value state $stateName is not initialized.") return } + val valueStateInfo = valueStates(stateName) message.getMethodCase match { case ValueStateCall.MethodCase.EXISTS => - if (valueStates(stateName)._1.exists()) { + if (valueStateInfo.valueState.exists()) { sendResponse(0) } else { // Send status code 2 to indicate that the value state doesn't have a value yet. sendResponse(2, s"state $stateName doesn't exist") } case ValueStateCall.MethodCase.GET => - val valueOption = valueStates(stateName)._1.getOption() + val valueOption = valueStateInfo.valueState.getOption() if (valueOption.isDefined) { // Serialize the value row as a byte array val valueBytes = PythonSQLUtils.toPyRow(valueOption.get) @@ -201,13 +237,95 @@ class TransformWithStateInPandasStateServer( } case ValueStateCall.MethodCase.VALUESTATEUPDATE => val byteArray = message.getValueStateUpdate.getValue.toByteArray - val valueStateTuple = valueStates(stateName) // The value row is serialized as a byte array, we need to convert it back to a Row - val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateTuple._2, valueStateTuple._3) - valueStateTuple._1.update(valueRow) + val valueRow = PythonSQLUtils.toJVMRow(byteArray, valueStateInfo.schema, + valueStateInfo.deserializer) + valueStateInfo.valueState.update(valueRow) sendResponse(0) case ValueStateCall.MethodCase.CLEAR => - valueStates(stateName)._1.clear() + valueStateInfo.valueState.clear() + sendResponse(0) + case _ => + throw new IllegalArgumentException("Invalid method call") + } + } + + private[sql] def handleListStateRequest(message: ListStateCall): Unit = { + val stateName = message.getStateName + if (!listStates.contains(stateName)) { + logWarning(log"List state ${MDC(LogKeys.STATE_NAME, stateName)} is not initialized.") + sendResponse(1, s"List state $stateName is not initialized.") + return + } + val listStateInfo = listStates(stateName) + val deserializer = if (deserializerForTest != null) { + deserializerForTest + } else { + new TransformWithStateInPandasDeserializer(listStateInfo.deserializer) + } + message.getMethodCase match { + case ListStateCall.MethodCase.EXISTS => + if (listStateInfo.listState.exists()) { + sendResponse(0) + } else { + // Send status code 2 to indicate that the list state doesn't have a value yet. + sendResponse(2, s"state $stateName doesn't exist") + } + case ListStateCall.MethodCase.LISTSTATEPUT => + val rows = deserializer.readArrowBatches(inputStream) + listStateInfo.listState.put(rows.toArray) + sendResponse(0) + case ListStateCall.MethodCase.LISTSTATEGET => + val iteratorId = message.getListStateGet.getIteratorId + var iteratorOption = listStateIterators.get(iteratorId) + if (iteratorOption.isEmpty) { + iteratorOption = Some(listStateInfo.listState.get()) + listStateIterators.put(iteratorId, iteratorOption.get) + } + if (!iteratorOption.get.hasNext) { + sendResponse(2, s"List state $stateName doesn't contain any value.") + return + } else { + sendResponse(0) + } + outputStream.flush() + val arrowStreamWriter = if (arrowStreamWriterForTest != null) { + arrowStreamWriterForTest + } else { + val arrowSchema = ArrowUtils.toArrowSchema(listStateInfo.schema, timeZoneId, + errorOnDuplicatedFieldNames, largeVarTypes) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for transformWithStateInPandas state socket", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null, outputStream), + arrowTransformWithStateInPandasMaxRecordsPerBatch) + } + val listRowSerializer = listStateInfo.serializer + // Only write a single batch in each GET request. Stops writing row if rowCount reaches + // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case + // when there are multiple state variables, user tries to access a different state variable + // while the current state variable is not exhausted yet. + var rowCount = 0 + while (iteratorOption.get.hasNext && + rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) { + val row = iteratorOption.get.next() + val internalRow = listRowSerializer(row) + arrowStreamWriter.writeRow(internalRow) + rowCount += 1 + } + arrowStreamWriter.finalizeCurrentArrowBatch() + case ListStateCall.MethodCase.APPENDVALUE => + val byteArray = message.getAppendValue.getValue.toByteArray + val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema, + listStateInfo.deserializer) + listStateInfo.listState.appendValue(newRow) + sendResponse(0) + case ListStateCall.MethodCase.APPENDLIST => + val rows = deserializer.readArrowBatches(inputStream) + listStateInfo.listState.appendList(rows.toArray) + sendResponse(0) + case ListStateCall.MethodCase.CLEAR => + listStates(stateName).listState.clear() sendResponse(0) case _ => throw new IllegalArgumentException("Invalid method call") @@ -232,23 +350,54 @@ class TransformWithStateInPandasStateServer( outputStream.write(responseMessageBytes) } - private def initializeValueState( + private def initializeStateVariable( stateName: String, schemaString: String, + stateType: StateVariableType.StateVariableType, ttlDurationMs: Option[Int]): Unit = { - if (!valueStates.contains(stateName)) { - val schema = StructType.fromString(schemaString) - val state = if (ttlDurationMs.isEmpty) { - statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema)) - } else { - statefulProcessorHandle.getValueState( - stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) - } - val valueRowDeserializer = ExpressionEncoder(schema).resolveAndBind().createDeserializer() - valueStates.put(stateName, (state, schema, valueRowDeserializer)) - sendResponse(0) - } else { - sendResponse(1, s"state $stateName already exists") + val schema = StructType.fromString(schemaString) + val expressionEncoder = ExpressionEncoder(schema).resolveAndBind() + stateType match { + case StateVariableType.ValueState => if (!valueStates.contains(stateName)) { + val state = if (ttlDurationMs.isEmpty) { + statefulProcessorHandle.getValueState[Row](stateName, Encoders.row(schema)) + } else { + statefulProcessorHandle.getValueState( + stateName, Encoders.row(schema), TTLConfig(Duration.ofMillis(ttlDurationMs.get))) + } + valueStates.put(stateName, + ValueStateInfo(state, schema, expressionEncoder.createDeserializer())) + sendResponse(0) + } else { + sendResponse(1, s"Value state $stateName already exists") + } + case StateVariableType.ListState => if (!listStates.contains(stateName)) { + // TODO(SPARK-49744): Add ttl support for list state. + listStates.put(stateName, + ListStateInfo(statefulProcessorHandle.getListState[Row](stateName, + Encoders.row(schema)), schema, expressionEncoder.createDeserializer(), + expressionEncoder.createSerializer())) + sendResponse(0) + } else { + sendResponse(1, s"List state $stateName already exists") + } } } } + +/** + * Case class to store the information of a value state. + */ +case class ValueStateInfo( + valueState: ValueState[Row], + schema: StructType, + deserializer: ExpressionEncoder.Deserializer[Row]) + +/** + * Case class to store the information of a list state. + */ +case class ListStateInfo( + listState: ListState[Row], + schema: StructType, + deserializer: ExpressionEncoder.Deserializer[Row], + serializer: ExpressionEncoder.Serializer[Row]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala index 615e1e89f30b8..137e2531f4f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServerSuite.scala @@ -32,32 +32,59 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.execution.streaming.{StatefulProcessorHandleImpl, StatefulProcessorHandleState} import org.apache.spark.sql.execution.streaming.state.StateMessage -import org.apache.spark.sql.execution.streaming.state.StateMessage.{Clear, Exists, Get, HandleState, SetHandleState, StateCallCommand, StatefulProcessorCall, ValueStateCall, ValueStateUpdate} -import org.apache.spark.sql.streaming.{TTLConfig, ValueState} +import org.apache.spark.sql.execution.streaming.state.StateMessage.{AppendList, AppendValue, Clear, Exists, Get, HandleState, ListStateCall, ListStateGet, ListStatePut, SetHandleState, StateCallCommand, StatefulProcessorCall, ValueStateCall, ValueStateUpdate} +import org.apache.spark.sql.streaming.{ListState, TTLConfig, ValueState} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with BeforeAndAfterEach { - val valueStateName = "test" - var statefulProcessorHandle: StatefulProcessorHandleImpl = _ + val stateName = "test" + val iteratorId = "testId" + val serverSocket: ServerSocket = mock(classOf[ServerSocket]) + val groupingKeySchema: StructType = StructType(Seq()) + val stateSchema: StructType = StructType(Array(StructField("value", IntegerType))) + // Below byte array is a serialized row with a single integer value 1. + val byteArray: Array[Byte] = Array(0x80.toByte, 0x05.toByte, 0x95.toByte, 0x05.toByte, + 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, + 'K'.toByte, 0x01.toByte, 0x85.toByte, 0x94.toByte, '.'.toByte + ) + + var statefulProcessorHandle: StatefulProcessorHandleImpl = + mock(classOf[StatefulProcessorHandleImpl]) var outputStream: DataOutputStream = _ var valueState: ValueState[Row] = _ + var listState: ListState[Row] = _ var stateServer: TransformWithStateInPandasStateServer = _ - var valueSchema: StructType = _ - var valueDeserializer: ExpressionEncoder.Deserializer[Row] = _ + var stateDeserializer: ExpressionEncoder.Deserializer[Row] = _ + var stateSerializer: ExpressionEncoder.Serializer[Row] = _ + var transformWithStateInPandasDeserializer: TransformWithStateInPandasDeserializer = _ + var arrowStreamWriter: BaseStreamingArrowWriter = _ + var valueStateMap: mutable.HashMap[String, ValueStateInfo] = mutable.HashMap() + var listStateMap: mutable.HashMap[String, ListStateInfo] = mutable.HashMap() override def beforeEach(): Unit = { - val serverSocket = mock(classOf[ServerSocket]) statefulProcessorHandle = mock(classOf[StatefulProcessorHandleImpl]) - val groupingKeySchema = StructType(Seq()) outputStream = mock(classOf[DataOutputStream]) valueState = mock(classOf[ValueState[Row]]) - valueSchema = StructType(Array(StructField("value", IntegerType))) - valueDeserializer = ExpressionEncoder(valueSchema).resolveAndBind().createDeserializer() - val valueStateMap = mutable.HashMap[String, - (ValueState[Row], StructType, ExpressionEncoder.Deserializer[Row])](valueStateName -> - (valueState, valueSchema, valueDeserializer)) + listState = mock(classOf[ListState[Row]]) + stateDeserializer = ExpressionEncoder(stateSchema).resolveAndBind().createDeserializer() + stateSerializer = ExpressionEncoder(stateSchema).resolveAndBind().createSerializer() + valueStateMap = mutable.HashMap[String, ValueStateInfo](stateName -> + ValueStateInfo(valueState, stateSchema, stateDeserializer)) + listStateMap = mutable.HashMap[String, ListStateInfo](stateName -> + ListStateInfo(listState, stateSchema, stateDeserializer, stateSerializer)) + // Iterator map for list state. Please note that `handleImplicitGroupingKeyRequest` would + // reset the iterator map to empty so be careful to call it if you want to access the iterator + // map later. + val listStateIteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> + Iterator(new GenericRowWithSchema(Array(1), stateSchema))) + transformWithStateInPandasDeserializer = mock(classOf[TransformWithStateInPandasDeserializer]) + arrowStreamWriter = mock(classOf[BaseStreamingArrowWriter]) stateServer = new TransformWithStateInPandasStateServer(serverSocket, - statefulProcessorHandle, groupingKeySchema, outputStream, valueStateMap) + statefulProcessorHandle, groupingKeySchema, "", false, false, 2, + outputStream, valueStateMap, transformWithStateInPandasDeserializer, arrowStreamWriter, + listStateMap, listStateIteratorMap) + when(transformWithStateInPandasDeserializer.readArrowBatches(any)) + .thenReturn(Seq(new GenericRowWithSchema(Array(1), stateSchema))) } test("set handle state") { @@ -92,14 +119,14 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } test("value state exists") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setExists(Exists.newBuilder().build()).build() stateServer.handleValueStateRequest(message) verify(valueState).exists() } test("value state get") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setGet(Get.newBuilder().build()).build() val schema = new StructType().add("value", "int") when(valueState.getOption()).thenReturn(Some(new GenericRowWithSchema(Array(1), schema))) @@ -109,7 +136,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } test("value state get - not exist") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setGet(Get.newBuilder().build()).build() when(valueState.getOption()).thenReturn(None) stateServer.handleValueStateRequest(message) @@ -127,7 +154,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } test("value state clear") { - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setClear(Clear.newBuilder().build()).build() stateServer.handleValueStateRequest(message) verify(valueState).clear() @@ -135,16 +162,98 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo } test("value state update") { - // Below byte array is a serialized row with a single integer value 1. - val byteArray: Array[Byte] = Array(0x80.toByte, 0x05.toByte, 0x95.toByte, 0x05.toByte, - 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, 0x00.toByte, - 'K'.toByte, 0x01.toByte, 0x85.toByte, 0x94.toByte, '.'.toByte - ) val byteString: ByteString = ByteString.copyFrom(byteArray) - val message = ValueStateCall.newBuilder().setStateName(valueStateName) + val message = ValueStateCall.newBuilder().setStateName(stateName) .setValueStateUpdate(ValueStateUpdate.newBuilder().setValue(byteString).build()).build() stateServer.handleValueStateRequest(message) verify(valueState).update(any[Row]) verify(outputStream).writeInt(0) } + + test("list state exists") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setExists(Exists.newBuilder().build()).build() + stateServer.handleListStateRequest(message) + verify(listState).exists() + } + + test("list state get - iterator in map") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() + stateServer.handleListStateRequest(message) + verify(listState, times(0)).get() + verify(arrowStreamWriter).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("list state get - iterator in map with multiple batches") { + val maxRecordsPerBatch = 2 + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() + val iteratorMap = mutable.HashMap[String, Iterator[Row]](iteratorId -> + Iterator(new GenericRowWithSchema(Array(1), stateSchema), + new GenericRowWithSchema(Array(2), stateSchema), + new GenericRowWithSchema(Array(3), stateSchema), + new GenericRowWithSchema(Array(4), stateSchema))) + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, + transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, iteratorMap) + // First call should send 2 records. + stateServer.handleListStateRequest(message) + verify(listState, times(0)).get() + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + // Second call should send the remaining 2 records. + stateServer.handleListStateRequest(message) + verify(listState, times(0)).get() + // Since Mockito's verify counts the total number of calls, the expected number of writeRow call + // should be 2 * maxRecordsPerBatch. + verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch() + } + + test("list state get - iterator not in map") { + val maxRecordsPerBatch = 2 + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() + val iteratorMap: mutable.HashMap[String, Iterator[Row]] = mutable.HashMap() + stateServer = new TransformWithStateInPandasStateServer(serverSocket, + statefulProcessorHandle, groupingKeySchema, "", false, false, + maxRecordsPerBatch, outputStream, valueStateMap, + transformWithStateInPandasDeserializer, arrowStreamWriter, listStateMap, iteratorMap) + when(listState.get()).thenReturn(Iterator(new GenericRowWithSchema(Array(1), stateSchema), + new GenericRowWithSchema(Array(2), stateSchema), + new GenericRowWithSchema(Array(3), stateSchema))) + stateServer.handleListStateRequest(message) + verify(listState).get() + // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still + // having 1 row left in the iterator. + verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) + verify(arrowStreamWriter).finalizeCurrentArrowBatch() + } + + test("list state put") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setListStatePut(ListStatePut.newBuilder().build()).build() + stateServer.handleListStateRequest(message) + verify(transformWithStateInPandasDeserializer).readArrowBatches(any) + verify(listState).put(any) + } + + test("list state append value") { + val byteString: ByteString = ByteString.copyFrom(byteArray) + val message = ListStateCall.newBuilder().setStateName(stateName) + .setAppendValue(AppendValue.newBuilder().setValue(byteString).build()).build() + stateServer.handleListStateRequest(message) + verify(listState).appendValue(any[Row]) + } + + test("list state append list") { + val message = ListStateCall.newBuilder().setStateName(stateName) + .setAppendList(AppendList.newBuilder().build()).build() + stateServer.handleListStateRequest(message) + verify(transformWithStateInPandasDeserializer).readArrowBatches(any) + verify(listState).appendList(any) + } } From a4fb6cbfda228de407e2be83e28c761381576276 Mon Sep 17 00:00:00 2001 From: Nikhil Sheoran <125331115+nikhilsheoran-db@users.noreply.github.com> Date: Wed, 25 Sep 2024 10:29:34 +0800 Subject: [PATCH 074/250] [SPARK-49743][SQL] OptimizeCsvJsonExpr should not change schema fields when pruning GetArrayStructFields ### What changes were proposed in this pull request? - When pruning the schema of the struct in `GetArrayStructFields`, rely on the existing `StructType` to obtain the pruned schema instead of using the accessed field. ### Why are the changes needed? - Fixes a bug in `OptimizeCsvJsonExprs` rule that would have otherwise changed the schema fields of the underlying struct to be extracted. - This would show up as a correctness issue where for a field instead of picking the right values we would have ended up giving null output. ### Does this PR introduce _any_ user-facing change? Yes. The query output would change for the queries of the following type: ``` SELECT from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').a, from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').A FROM range(3) as t ``` Earlier, the result would had been: ``` Array([ArraySeq(0),ArraySeq(null)], [ArraySeq(1),ArraySeq(null)], [ArraySeq(2),ArraySeq(null)]) ``` vs the new result is (verified through spark-shell): ``` Array([ArraySeq(0),ArraySeq(0)], [ArraySeq(1),ArraySeq(1)], [ArraySeq(2),ArraySeq(2)]) ``` ### How was this patch tested? - Added unit tests. - Without this change, the added test would fail as we would have modified the schema from `a` to `A`: ``` - SPARK-49743: prune unnecessary columns from GetArrayStructFields does not change schema *** FAILED *** == FAIL: Plans do not match === !Project [from_json(ArrayType(StructType(StructField(A,IntegerType,true)),true), json#0, Some(America/Los_Angeles)).A AS a#0] Project [from_json(ArrayType(StructType(S tructField(a,IntegerType,true)),true), json#0, Some(America/Los_Angeles)).A AS a#0] +- LocalRelation , [json#0] +- LocalRelation , [json#0] (PlanT est.scala:179) ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48190 from nikhilsheoran-db/SPARK-49743. Authored-by: Nikhil Sheoran <125331115+nikhilsheoran-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../optimizer/OptimizeCsvJsonExprs.scala | 7 ++++--- .../optimizer/OptimizeJsonExprsSuite.scala | 17 +++++++++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 13 +++++++++++++ 3 files changed, 34 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala index 4347137bf68b8..04cc230f99b44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCsvJsonExprs.scala @@ -112,9 +112,10 @@ object OptimizeCsvJsonExprs extends Rule[LogicalPlan] { val prunedSchema = StructType(Array(schema(ordinal))) g.copy(child = j.copy(schema = prunedSchema), ordinal = 0) - case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _) - if schema.elementType.asInstanceOf[StructType].length > 1 && j.options.isEmpty => - val prunedSchema = ArrayType(StructType(Array(g.field)), g.containsNull) + case g @ GetArrayStructFields(j @ JsonToStructs(ArrayType(schema: StructType, _), + _, _, _), _, ordinal, _, _) if schema.length > 1 && j.options.isEmpty => + // Obtain the pruned schema by picking the `ordinal` field of the struct. + val prunedSchema = ArrayType(StructType(Array(schema(ordinal))), g.containsNull) g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala index c185de4c05d88..eed06da609f8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeJsonExprsSuite.scala @@ -307,4 +307,21 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper { comparePlans(optimized, query.analyze) } } + + test("SPARK-49743: prune unnecessary columns from GetArrayStructFields does not change schema") { + val options = Map.empty[String, String] + val schema = ArrayType(StructType.fromDDL("a int, b int"), containsNull = true) + + val field = StructField("A", IntegerType) // Instead of "a", use "A" to test case sensitivity. + val query = testRelation2 + .select(GetArrayStructFields( + JsonToStructs(schema, options, $"json"), field, 0, 2, true).as("a")) + val optimized = Optimizer.execute(query.analyze) + + val prunedSchema = ArrayType(StructType.fromDDL("a int"), containsNull = true) + val expected = testRelation2 + .select(GetArrayStructFields( + JsonToStructs(prunedSchema, options, $"json"), field, 0, 1, true).as("a")).analyze + comparePlans(optimized, expected) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 8176d02dbd02d..e3346684285a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4928,6 +4928,19 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark ) ) } + + test("SPARK-49743: OptimizeCsvJsonExpr does not change schema when pruning struct") { + val df = sql(""" + | SELECT + | from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').a, + | from_json('[{"a": '||id||', "b": '|| (2*id) ||'}]', 'array>').A + | FROM + | range(3) as t + |""".stripMargin) + val expectedAnswer = Seq( + Row(Array(0), Array(0)), Row(Array(1), Array(1)), Row(Array(2), Array(2))) + checkAnswer(df, expectedAnswer) + } } case class Foo(bar: Option[String]) From e2d2ab510632cc1948cb6b4500e9da49036a96bd Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 25 Sep 2024 10:57:44 +0800 Subject: [PATCH 075/250] [SPARK-49552][PYTHON] Add DataFrame API support for new 'randstr' and 'uniform' SQL functions ### What changes were proposed in this pull request? In https://github.com/apache/spark/pull/48004 we added new SQL functions `randstr` and `uniform`. This PR adds DataFrame API support for them. For example, in Scala: ``` sql("create table t(col int not null) using csv") sql("insert into t values (0)") val df = sql("select col from t") df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))) > 5 df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5") > true ``` ### Why are the changes needed? This improves DataFrame parity with the SQL API. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds unit test coverage. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48143 from dtenedor/dataframes-uniform-randstr. Authored-by: Daniel Tenedorio Signed-off-by: Ruifeng Zheng --- .../reference/pyspark.sql/functions.rst | 2 + .../pyspark/sql/connect/functions/builtin.py | 28 +++++ python/pyspark/sql/functions/builtin.py | 92 ++++++++++++++++ python/pyspark/sql/tests/test_functions.py | 21 +++- .../org/apache/spark/sql/functions.scala | 45 ++++++++ .../expressions/randomExpressions.scala | 49 +++++++-- .../spark/sql/DataFrameFunctionsSuite.scala | 104 ++++++++++++++++++ 7 files changed, 331 insertions(+), 10 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst index 4910a5b59273b..6248e71331656 100644 --- a/python/docs/source/reference/pyspark.sql/functions.rst +++ b/python/docs/source/reference/pyspark.sql/functions.rst @@ -148,6 +148,7 @@ Mathematical Functions try_multiply try_subtract unhex + uniform width_bucket @@ -189,6 +190,7 @@ String Functions overlay position printf + randstr regexp_count regexp_extract regexp_extract_all diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 6953230f5b42e..27b12fff3c0ac 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1007,6 +1007,22 @@ def unhex(col: "ColumnOrName") -> Column: unhex.__doc__ = pysparkfuncs.unhex.__doc__ +def uniform( + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, +) -> Column: + if seed is None: + return _invoke_function_over_columns( + "uniform", lit(min), lit(max), lit(random.randint(0, sys.maxsize)) + ) + else: + return _invoke_function_over_columns("uniform", lit(min), lit(max), lit(seed)) + + +uniform.__doc__ = pysparkfuncs.uniform.__doc__ + + def approxCountDistinct(col: "ColumnOrName", rsd: Optional[float] = None) -> Column: warnings.warn("Deprecated in 3.4, use approx_count_distinct instead.", FutureWarning) return approx_count_distinct(col, rsd) @@ -2581,6 +2597,18 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: regexp_like.__doc__ = pysparkfuncs.regexp_like.__doc__ +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: + if seed is None: + return _invoke_function_over_columns( + "randstr", lit(length), lit(random.randint(0, sys.maxsize)) + ) + else: + return _invoke_function_over_columns("randstr", lit(length), lit(seed)) + + +randstr.__doc__ = pysparkfuncs.randstr.__doc__ + + def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_count", str, regexp) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 09a286fe7c94e..4ca39562cb20b 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -11973,6 +11973,47 @@ def regexp_like(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: return _invoke_function_over_columns("regexp_like", str, regexp) +@_try_remote_functions +def randstr(length: Union[Column, int], seed: Optional[Union[Column, int]] = None) -> Column: + """Returns a string of the specified length whose characters are chosen uniformly at random from + the following pool of characters: 0-9, a-z, A-Z. The random seed is optional. The string length + must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + + .. versionadded:: 4.0.0 + + Parameters + ---------- + length : :class:`~pyspark.sql.Column` or int + Number of characters in the string to generate. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random string with the specified length. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(randstr(lit(5), lit(0)).alias('result')) \\ + ... .selectExpr("length(result) > 0").show() + +--------------------+ + |(length(result) > 0)| + +--------------------+ + | true| + +--------------------+ + """ + length = _enum_to_value(length) + length = lit(length) + if seed is None: + return _invoke_function_over_columns("randstr", length) + else: + seed = _enum_to_value(seed) + seed = lit(seed) + return _invoke_function_over_columns("randstr", length, seed) + + @_try_remote_functions def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: r"""Returns a count of the number of times that the Java regex pattern `regexp` is matched @@ -12339,6 +12380,57 @@ def unhex(col: "ColumnOrName") -> Column: return _invoke_function_over_columns("unhex", col) +@_try_remote_functions +def uniform( + min: Union[Column, int, float], + max: Union[Column, int, float], + seed: Optional[Union[Column, int]] = None, +) -> Column: + """Returns a random value with independent and identically distributed (i.i.d.) values with the + specified range of numbers. The random seed is optional. The provided numbers specifying the + minimum and maximum values of the range must be constant. If both of these numbers are integers, + then the result will also be an integer. Otherwise if one or both of these are floating-point + numbers, then the result will also be a floating-point number. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + min : :class:`~pyspark.sql.Column`, int, or float + Minimum value in the range. + max : :class:`~pyspark.sql.Column`, int, or float + Maximum value in the range. + seed : :class:`~pyspark.sql.Column` or int + Optional random number seed to use. + + Returns + ------- + :class:`~pyspark.sql.Column` + The generated random number within the specified range. + + Examples + -------- + >>> spark.createDataFrame([('3',)], ['a']) \\ + ... .select(uniform(lit(0), lit(10), lit(0)).alias('result')) \\ + ... .selectExpr("result < 15").show() + +-------------+ + |(result < 15)| + +-------------+ + | true| + +-------------+ + """ + min = _enum_to_value(min) + min = lit(min) + max = _enum_to_value(max) + max = lit(max) + if seed is None: + return _invoke_function_over_columns("uniform", min, max) + else: + seed = _enum_to_value(seed) + seed = lit(seed) + return _invoke_function_over_columns("uniform", min, max, seed) + + @_try_remote_functions def length(col: "ColumnOrName") -> Column: """Computes the character length of string data or number of bytes of binary data. diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a0ab9bc9c7d40..a51156e895c62 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -29,7 +29,7 @@ from pyspark.sql import Row, Window, functions as F, types from pyspark.sql.avro.functions import from_avro, to_avro from pyspark.sql.column import Column -from pyspark.sql.functions.builtin import nullifzero, zeroifnull +from pyspark.sql.functions.builtin import nullifzero, randstr, uniform, zeroifnull from pyspark.testing.sqlutils import ReusedSQLTestCase, SQLTestUtils from pyspark.testing.utils import have_numpy @@ -1610,6 +1610,25 @@ def test_nullifzero_zeroifnull(self): result = df.select(zeroifnull(df.a).alias("r")).collect() self.assertEqual([Row(r=0), Row(r=1)], result) + def test_randstr_uniform(self): + df = self.spark.createDataFrame([(0,)], ["a"]) + result = df.select(randstr(F.lit(5), F.lit(0)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + # The random seed is optional. + result = df.select(randstr(F.lit(5)).alias("x")).selectExpr("length(x)").collect() + self.assertEqual([Row(5)], result) + + df = self.spark.createDataFrame([(0,)], ["a"]) + result = ( + df.select(uniform(F.lit(10), F.lit(20), F.lit(0)).alias("x")) + .selectExpr("x > 5") + .collect() + ) + self.assertEqual([Row(True)], result) + # The random seed is optional. + result = df.select(uniform(F.lit(10), F.lit(20)).alias("x")).selectExpr("x > 5").collect() + self.assertEqual([Row(True)], result) + class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin): pass diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index ab69789c75f50..93bff22621057 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -1896,6 +1896,26 @@ object functions { */ def randn(): Column = randn(SparkClassUtils.random.nextLong) + /** + * Returns a string of the specified length whose characters are chosen uniformly at random from + * the following pool of characters: 0-9, a-z, A-Z. The string length must be a constant + * two-byte or four-byte integer (SMALLINT or INT, respectively). + * + * @group string_funcs + * @since 4.0.0 + */ + def randstr(length: Column): Column = Column.fn("randstr", length) + + /** + * Returns a string of the specified length whose characters are chosen uniformly at random from + * the following pool of characters: 0-9, a-z, A-Z, with the chosen random seed. The string + * length must be a constant two-byte or four-byte integer (SMALLINT or INT, respectively). + * + * @group string_funcs + * @since 4.0.0 + */ + def randstr(length: Column, seed: Column): Column = Column.fn("randstr", length, seed) + /** * Partition ID. * @@ -3740,6 +3760,31 @@ object functions { */ def stack(cols: Column*): Column = Column.fn("stack", cols: _*) + /** + * Returns a random value with independent and identically distributed (i.i.d.) values with the + * specified range of numbers. The provided numbers specifying the minimum and maximum values of + * the range must be constant. If both of these numbers are integers, then the result will also + * be an integer. Otherwise if one or both of these are floating-point numbers, then the result + * will also be a floating-point number. + * + * @group math_funcs + * @since 4.0.0 + */ + def uniform(min: Column, max: Column): Column = Column.fn("uniform", min, max) + + /** + * Returns a random value with independent and identically distributed (i.i.d.) values with the + * specified range of numbers, with the chosen random seed. The provided numbers specifying the + * minimum and maximum values of the range must be constant. If both of these numbers are + * integers, then the result will also be an integer. Otherwise if one or both of these are + * floating-point numbers, then the result will also be a floating-point number. + * + * @group math_funcs + * @since 4.0.0 + */ + def uniform(min: Column, max: Column, seed: Column): Column = + Column.fn("uniform", min, max, seed) + /** * Returns a random value with independent and identically distributed (i.i.d.) uniformly * distributed values in [0, 1). diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index f329f8346b0de..ada0a73a67958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -206,15 +206,18 @@ object Randn { """, since = "4.0.0", group = "math_funcs") -case class Uniform(min: Expression, max: Expression, seedExpression: Expression) +case class Uniform(min: Expression, max: Expression, seedExpression: Expression, hideSeed: Boolean) extends RuntimeReplaceable with TernaryLike[Expression] with RDG { - def this(min: Expression, max: Expression) = this(min, max, UnresolvedSeed) + def this(min: Expression, max: Expression) = + this(min, max, UnresolvedSeed, hideSeed = true) + def this(min: Expression, max: Expression, seedExpression: Expression) = + this(min, max, seedExpression, hideSeed = false) final override lazy val deterministic: Boolean = false override val nodePatterns: Seq[TreePattern] = Seq(RUNTIME_REPLACEABLE, EXPRESSION_WITH_RANDOM_SEED) - override val dataType: DataType = { + override def dataType: DataType = { val first = min.dataType val second = max.dataType (min.dataType, max.dataType) match { @@ -240,6 +243,10 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) case _ => false } + override def sql: String = { + s"uniform(${min.sql}, ${max.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})" + } + override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess def requiredType = "integer or floating-point" @@ -277,11 +284,11 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) override def third: Expression = seedExpression override def withNewSeed(newSeed: Long): Expression = - Uniform(min, max, Literal(newSeed, LongType)) + Uniform(min, max, Literal(newSeed, LongType), hideSeed) override def withNewChildrenInternal( newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = - Uniform(newFirst, newSecond, newThird) + Uniform(newFirst, newSecond, newThird, hideSeed) override def replacement: Expression = { if (Seq(min, max, seedExpression).exists(_.dataType == NullType)) { @@ -300,6 +307,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) } } +object Uniform { + def apply(min: Expression, max: Expression): Uniform = + Uniform(min, max, UnresolvedSeed, hideSeed = true) + def apply(min: Expression, max: Expression, seedExpression: Expression): Uniform = + Uniform(min, max, seedExpression, hideSeed = false) +} + @ExpressionDescription( usage = """ _FUNC_(length[, seed]) - Returns a string of the specified length whose characters are chosen @@ -315,9 +329,13 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression) """, since = "4.0.0", group = "string_funcs") -case class RandStr(length: Expression, override val seedExpression: Expression) +case class RandStr( + length: Expression, override val seedExpression: Expression, hideSeed: Boolean) extends ExpressionWithRandomSeed with BinaryLike[Expression] with Nondeterministic { - def this(length: Expression) = this(length, UnresolvedSeed) + def this(length: Expression) = + this(length, UnresolvedSeed, hideSeed = true) + def this(length: Expression, seedExpression: Expression) = + this(length, seedExpression, hideSeed = false) override def nullable: Boolean = false override def dataType: DataType = StringType @@ -339,9 +357,14 @@ case class RandStr(length: Expression, override val seedExpression: Expression) rng = new XORShiftRandom(seed + partitionIndex) } - override def withNewSeed(newSeed: Long): Expression = RandStr(length, Literal(newSeed, LongType)) + override def withNewSeed(newSeed: Long): Expression = + RandStr(length, Literal(newSeed, LongType), hideSeed) override def withNewChildrenInternal(newFirst: Expression, newSecond: Expression): Expression = - RandStr(newFirst, newSecond) + RandStr(newFirst, newSecond, hideSeed) + + override def sql: String = { + s"randstr(${length.sql}${if (hideSeed) "" else s", ${seedExpression.sql}"})" + } override def checkInputDataTypes(): TypeCheckResult = { var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess @@ -422,3 +445,11 @@ case class RandStr(length: Expression, override val seedExpression: Expression) isNull = FalseLiteral) } } + +object RandStr { + def apply(length: Expression): RandStr = + RandStr(length, UnresolvedSeed, hideSeed = true) + def apply(length: Expression, seedExpression: Expression): RandStr = + RandStr(length, seedExpression, hideSeed = false) +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0842b92e5d53c..016803635ff60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -411,6 +411,110 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(nvl2(col("b"), col("a"), col("c"))), Seq(Row(null))) } + test("randstr function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(randstr(lit(5), lit(0)).alias("x")).select(length(col("x"))), + Seq(Row(5))) + // The random seed is optional. + checkAnswer( + df.select(randstr(lit(5)).alias("x")).select(length(col("x"))), + Seq(Row(5))) + } + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = randstr(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"randstr(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "INT or SMALLINT"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = randstr(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "length", + "inputType" -> "INT or SMALLINT", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"randstr(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "randstr", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + } + + test("uniform function") { + withTable("t") { + sql("create table t(col int not null) using csv") + sql("insert into t values (0)") + val df = sql("select col from t") + checkAnswer( + df.select(uniform(lit(10), lit(20), lit(0)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) + // The random seed is optional. + checkAnswer( + df.select(uniform(lit(10), lit(20)).alias("x")).selectExpr("x > 5"), + Seq(Row(true))) + } + // Here we exercise some error cases. + val df = Seq((0)).toDF("a") + var expr = uniform(lit(10), lit("a")) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + parameters = Map( + "sqlExpr" -> "\"uniform(10, a)\"", + "paramIndex" -> "second", + "inputSql" -> "\"a\"", + "inputType" -> "\"STRING\"", + "requiredType" -> "integer or floating-point"), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + expr = uniform(col("a"), lit(10)) + checkError( + intercept[AnalysisException](df.select(expr)), + condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", + parameters = Map( + "inputName" -> "min", + "inputType" -> "integer or floating-point", + "inputExpr" -> "\"a\"", + "sqlExpr" -> "\"uniform(a, 10)\""), + context = ExpectedContext( + contextType = QueryContextType.DataFrame, + fragment = "uniform", + objectType = "", + objectName = "", + callSitePattern = "", + startIndex = 0, + stopIndex = 0)) + } + test("zeroifnull function") { withTable("t") { // Here we exercise a non-nullable, non-foldable column. From 9aa11d1ee480498de58f0ebd660535effca8fcc6 Mon Sep 17 00:00:00 2001 From: Livia Zhu Date: Wed, 25 Sep 2024 12:42:01 +0900 Subject: [PATCH 076/250] [SPARK-49772][SS] Remove ColumnFamilyOptions and add configs directly to dbOptions in RocksDB ### What changes were proposed in this pull request? To reduce confusion from having vestigial `columnFamilyOptions` value, removed it and added the options directly to `dbOptions`. Also renamed `dbOptions` to `rocksDbOptions` ### Why are the changes needed? Refactoring to simplify and clarify the RocksDB options. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Updated and ensured that existing unit tests pass. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48232 from liviazhu-db/liviazhu-db/rocksdb-options. Lead-authored-by: Livia Zhu Co-authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../execution/streaming/state/RocksDB.scala | 39 +++++++++---------- .../streaming/state/RocksDBSuite.scala | 4 +- 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index 4a2aac43b3331..f8d0c8722c3f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -115,37 +115,34 @@ class RocksDB( tableFormatConfig.setPinL0FilterAndIndexBlocksInCache(true) } - private[state] val columnFamilyOptions = new ColumnFamilyOptions() + private[state] val rocksDbOptions = new Options() // options to open the RocksDB + + rocksDbOptions.setCreateIfMissing(true) // Set RocksDB options around MemTable memory usage. By default, we let RocksDB // use its internal default values for these settings. if (conf.writeBufferSizeMB > 0L) { - columnFamilyOptions.setWriteBufferSize(conf.writeBufferSizeMB * 1024 * 1024) + rocksDbOptions.setWriteBufferSize(conf.writeBufferSizeMB * 1024 * 1024) } if (conf.maxWriteBufferNumber > 0L) { - columnFamilyOptions.setMaxWriteBufferNumber(conf.maxWriteBufferNumber) + rocksDbOptions.setMaxWriteBufferNumber(conf.maxWriteBufferNumber) } - columnFamilyOptions.setCompressionType(getCompressionType(conf.compression)) - columnFamilyOptions.setMergeOperator(new StringAppendOperator()) - - private val dbOptions = - new Options(new DBOptions(), columnFamilyOptions) // options to open the RocksDB + rocksDbOptions.setCompressionType(getCompressionType(conf.compression)) - dbOptions.setCreateIfMissing(true) - dbOptions.setTableFormatConfig(tableFormatConfig) - dbOptions.setMaxOpenFiles(conf.maxOpenFiles) - dbOptions.setAllowFAllocate(conf.allowFAllocate) - dbOptions.setMergeOperator(new StringAppendOperator()) + rocksDbOptions.setTableFormatConfig(tableFormatConfig) + rocksDbOptions.setMaxOpenFiles(conf.maxOpenFiles) + rocksDbOptions.setAllowFAllocate(conf.allowFAllocate) + rocksDbOptions.setMergeOperator(new StringAppendOperator()) if (conf.boundedMemoryUsage) { - dbOptions.setWriteBufferManager(writeBufferManager) + rocksDbOptions.setWriteBufferManager(writeBufferManager) } private val dbLogger = createLogger() // for forwarding RocksDB native logs to log4j - dbOptions.setStatistics(new Statistics()) - private val nativeStats = dbOptions.statistics() + rocksDbOptions.setStatistics(new Statistics()) + private val nativeStats = rocksDbOptions.statistics() private val workingDir = createTempDir("workingDir") private val fileManager = new RocksDBFileManager(dfsRootDir, createTempDir("fileManager"), @@ -782,7 +779,7 @@ class RocksDB( readOptions.close() writeOptions.close() flushOptions.close() - dbOptions.close() + rocksDbOptions.close() dbLogger.close() synchronized { latestSnapshot.foreach(_.close()) @@ -941,7 +938,7 @@ class RocksDB( private def openDB(): Unit = { assert(db == null) - db = NativeRocksDB.open(dbOptions, workingDir.toString) + db = NativeRocksDB.open(rocksDbOptions, workingDir.toString) logInfo(log"Opened DB with conf ${MDC(LogKeys.CONFIG, conf)}") } @@ -962,7 +959,7 @@ class RocksDB( /** Create a native RocksDB logger that forwards native logs to log4j with correct log levels. */ private def createLogger(): Logger = { - val dbLogger = new Logger(dbOptions.infoLogLevel()) { + val dbLogger = new Logger(rocksDbOptions.infoLogLevel()) { override def log(infoLogLevel: InfoLogLevel, logMsg: String) = { // Map DB log level to log4j levels // Warn is mapped to info because RocksDB warn is too verbose @@ -985,8 +982,8 @@ class RocksDB( dbLogger.setInfoLogLevel(dbLogLevel) // The log level set in dbLogger is effective and the one to dbOptions isn't applied to // customized logger. We still set it as it might show up in RocksDB config file or logging. - dbOptions.setInfoLogLevel(dbLogLevel) - dbOptions.setLogger(dbLogger) + rocksDbOptions.setInfoLogLevel(dbLogLevel) + rocksDbOptions.setLogger(dbLogger) logInfo(log"Set RocksDB native logging level to ${MDC(LogKeys.ROCKS_DB_LOG_LEVEL, dbLogLevel)}") dbLogger } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala index 608a22a284b6c..9fcd2001cce50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBSuite.scala @@ -526,12 +526,12 @@ class RocksDBSuite extends AlsoTestWithChangelogCheckpointingEnabled with Shared val conf = RocksDBConf().copy(compression = "zstd") withDB(remoteDir, conf = conf, useColumnFamilies = colFamiliesEnabled) { db => - assert(db.columnFamilyOptions.compressionType() == CompressionType.ZSTD_COMPRESSION) + assert(db.rocksDbOptions.compressionType() == CompressionType.ZSTD_COMPRESSION) } // Test the default is LZ4 withDB(remoteDir, conf = RocksDBConf().copy(), useColumnFamilies = colFamiliesEnabled) { db => - assert(db.columnFamilyOptions.compressionType() == CompressionType.LZ4_COMPRESSION) + assert(db.rocksDbOptions.compressionType() == CompressionType.LZ4_COMPRESSION) } } From 5134c68896738179d34e2220ac6171c317900f61 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 25 Sep 2024 14:29:12 +0900 Subject: [PATCH 077/250] [SPARK-49765][DOCS][PYTHON] Adjust documentation of "spark.sql.pyspark.plotting.max_rows" ### What changes were proposed in this pull request? Adjust documentation of "spark.sql.pyspark.plotting.max_rows". ### Why are the changes needed? Adjust for https://github.com/apache/spark/pull/48218, which eliminates the need for the "spark.sql.pyspark.plotting.sample_ratio" config. ### Does this PR introduce _any_ user-facing change? Doc change only. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48221 from xinrong-meng/conf_doc. Authored-by: Xinrong Meng Signed-off-by: Hyukjin Kwon --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c9c227a21cfff..9c46dd8e83ab2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3171,9 +3171,9 @@ object SQLConf { val PYSPARK_PLOT_MAX_ROWS = buildConf("spark.sql.pyspark.plotting.max_rows") - .doc( - "The visual limit on top-n-based plots. If set to 1000, the first 1000 data points " + - "will be used for plotting.") + .doc("The visual limit on plots. If set to 1000 for top-n-based plots (pie, bar, barh), " + + "the first 1000 data points will be used for plotting. For sampled-based plots " + + "(scatter, area, line), 1000 data points will be randomly sampled.") .version("4.0.0") .intConf .createWithDefault(1000) From 46c5accaa55101fe59bce916c17516a70fdfe134 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 25 Sep 2024 19:04:02 +0900 Subject: [PATCH 078/250] [SPARK-49609][PYTHON][TESTS][FOLLOW-UP] Skip Spark Connect tests if dependencies are not found ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/48085 that skips the compatibility tests if Spark Connect dependencies are not installed. ### Why are the changes needed? To recover the PyPy3 build https://github.com/apache/spark/actions/runs/11016544408/job/30592416115 which does not have PyArrow installed. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48239 from HyukjinKwon/SPARK-49609-followup. Authored-by: Hyukjin Kwon Signed-off-by: Haejoon Lee --- python/pyspark/sql/tests/test_connect_compatibility.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index ca1f828ef4d78..8f3e86f5186a8 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -18,6 +18,7 @@ import unittest import inspect +from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -172,6 +173,7 @@ def check_missing_methods(classic_cls, connect_cls, cls_name, expected_missing_m ) +@unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase): pass From 7f0ecd4221a7043b539fb20a792c00f379a5885e Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Wed, 25 Sep 2024 19:24:05 +0900 Subject: [PATCH 079/250] [SPARK-49764][PYTHON][CONNECT] Support area plots ### What changes were proposed in this pull request? Support area plots with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. Area plots are supported as shown below. ```py >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30))] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> fig = df.plot.area(x="date", y=["sales", "signups", "visits"]) # df.plot(kind="area", x="date", y=["sales", "signups", "visits"]) >>> fig.show() ``` ![newplot (7)](https://github.com/user-attachments/assets/e603cd99-ce8b-4448-8e1f-cbc093097c45) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48236 from xinrong-meng/plot_area. Authored-by: Xinrong Meng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/plot/core.py | 35 +++++++++++++++++++ .../sql/tests/plot/test_frame_plot_plotly.py | 35 +++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 0a3a0101e1898..9f83d00696524 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -93,6 +93,7 @@ def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame": class PySparkPlotAccessor: plot_data_map = { + "area": PySparkSampledPlotBase().get_sampled, "bar": PySparkTopNPlotBase().get_top_n, "barh": PySparkTopNPlotBase().get_top_n, "line": PySparkSampledPlotBase().get_sampled, @@ -264,3 +265,37 @@ def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure": >>> df.plot.scatter(x='length', y='width') # doctest: +SKIP """ return self(kind="scatter", x=x, y=y, **kwargs) + + def area(self, x: str, y: str, **kwargs: Any) -> "Figure": + """ + Draw a stacked area plot. + + An area plot displays quantitative data visually. + + Parameters + ---------- + x : str + Name of column to use for the horizontal axis. + y : str or list of str + Name(s) of the column(s) to plot. + **kwargs: Optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> from datetime import datetime + >>> data = [ + ... (3, 5, 20, datetime(2018, 1, 31)), + ... (2, 5, 42, datetime(2018, 2, 28)), + ... (3, 6, 28, datetime(2018, 3, 31)), + ... (9, 12, 62, datetime(2018, 4, 30)) + ... ] + >>> columns = ["sales", "signups", "visits", "date"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP + """ + return self(kind="area", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index ccfe1a75424e0..6176525b49550 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -16,6 +16,8 @@ # import unittest +from datetime import datetime + import pyspark.sql.plot # noqa: F401 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message @@ -34,6 +36,17 @@ def sdf2(self): columns = ["length", "width", "species"] return self.spark.createDataFrame(data, columns) + @property + def sdf3(self): + data = [ + (3, 5, 20, datetime(2018, 1, 31)), + (2, 5, 42, datetime(2018, 2, 28)), + (3, 6, 28, datetime(2018, 3, 31)), + (9, 12, 62, datetime(2018, 4, 30)), + ] + columns = ["sales", "signups", "visits", "date"] + return self.spark.createDataFrame(data, columns) + def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""): if kind == "line": self.assertEqual(fig_data["mode"], "lines") @@ -46,6 +59,11 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name= elif kind == "scatter": self.assertEqual(fig_data["type"], "scatter") self.assertEqual(fig_data["orientation"], "v") + self.assertEqual(fig_data["mode"], "markers") + elif kind == "area": + self.assertEqual(fig_data["type"], "scatter") + self.assertEqual(fig_data["orientation"], "v") + self.assertEqual(fig_data["mode"], "lines") self.assertEqual(fig_data["xaxis"], "x") self.assertEqual(list(fig_data["x"]), expected_x) @@ -98,6 +116,23 @@ def test_scatter_plot(self): "scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9] ) + def test_area_plot(self): + # single column as vertical axis + fig = self.sdf3.plot(kind="area", x="date", y="sales") + expected_x = [ + datetime(2018, 1, 31, 0, 0), + datetime(2018, 2, 28, 0, 0), + datetime(2018, 3, 31, 0, 0), + datetime(2018, 4, 30, 0, 0), + ] + self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9]) + + # multiple columns as vertical axis + fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"]) + self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9], "sales") + self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 12], "signups") + self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 62], "visits") + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass From e1b2ac55b4b9463824d3f23eb7fbac88ede843d9 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 25 Sep 2024 19:25:47 +0900 Subject: [PATCH 080/250] [SPARK-49767][PS][CONNECT] Refactor the internal function invocation ### What changes were proposed in this pull request? Refactor the internal function invocation ### Why are the changes needed? by introducing a new helper function `_invoke_internal_function_over_columns`, we no longer need to add dedicated internal functions in `PythonSQLUtils` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48227 from zhengruifeng/py_fn. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/plot/core.py | 2 +- python/pyspark/pandas/spark/functions.py | 175 +++--------------- python/pyspark/pandas/window.py | 3 +- .../spark/sql/api/python/PythonSQLUtils.scala | 43 +---- .../spark/sql/DataFrameSelfJoinSuite.scala | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 3 +- 6 files changed, 33 insertions(+), 196 deletions(-) diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 6f036b7669246..7333fae1ad432 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -215,7 +215,7 @@ def compute_hist(psdf, bins): # refers to org.apache.spark.ml.feature.Bucketizer#binarySearchForBuckets def binary_search_for_buckets(value: Column): - index = SF.binary_search(F.lit(bins), value) + index = SF.array_binary_search(F.lit(bins), value) bucket = F.when(index >= 0, index).otherwise(-index - 2) unboundErrMsg = F.lit(f"value %s out of the bins bounds: [{bins[0]}, {bins[-1]}]") return ( diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 4bcf07f6f6503..4d95466a98e12 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -19,197 +19,72 @@ """ from pyspark.sql import Column, functions as F from pyspark.sql.utils import is_remote -from typing import Union +from typing import Union, TYPE_CHECKING +if TYPE_CHECKING: + from pyspark.sql._typing import ColumnOrName -def product(col: Column, dropna: bool) -> Column: + +def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - return _invoke_function_over_columns( - "pandas_product", - col, - lit(dropna), - ) + return _invoke_function_over_columns(name, *cols) else: + from pyspark.sql.classic.column import _to_seq, _to_java_column from pyspark import SparkContext sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasProduct(col._jc, dropna)) + return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column))) -def stddev(col: Column, ddof: int) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit - - return _invoke_function_over_columns( - "pandas_stddev", - col, - lit(ddof), - ) +def product(col: Column, dropna: bool) -> Column: + return _invoke_internal_function_over_columns("pandas_product", col, F.lit(dropna)) - else: - from pyspark import SparkContext - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasStddev(col._jc, ddof)) +def stddev(col: Column, ddof: int) -> Column: + return _invoke_internal_function_over_columns("pandas_stddev", col, F.lit(ddof)) def var(col: Column, ddof: int) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit - - return _invoke_function_over_columns( - "pandas_var", - col, - lit(ddof), - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasVariance(col._jc, ddof)) + return _invoke_internal_function_over_columns("pandas_var", col, F.lit(ddof)) def skew(col: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - - return _invoke_function_over_columns( - "pandas_skew", - col, - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasSkewness(col._jc)) + return _invoke_internal_function_over_columns("pandas_skew", col) def kurt(col: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - - return _invoke_function_over_columns( - "pandas_kurt", - col, - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasKurtosis(col._jc)) + return _invoke_internal_function_over_columns("pandas_kurt", col) def mode(col: Column, dropna: bool) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit - - return _invoke_function_over_columns( - "pandas_mode", - col, - lit(dropna), - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasMode(col._jc, dropna)) + return _invoke_internal_function_over_columns("pandas_mode", col, F.lit(dropna)) def covar(col1: Column, col2: Column, ddof: int) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit + return _invoke_internal_function_over_columns("pandas_covar", col1, col2, F.lit(ddof)) - return _invoke_function_over_columns( - "pandas_covar", - col1, - col2, - lit(ddof), - ) - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.pandasCovar(col1._jc, col2._jc, ddof)) - - -def ewm(col: Column, alpha: float, ignore_na: bool) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns, lit - - return _invoke_function_over_columns( - "ewm", - col, - lit(alpha), - lit(ignore_na), - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.ewm(col._jc, alpha, ignore_na)) +def ewm(col: Column, alpha: float, ignorena: bool) -> Column: + return _invoke_internal_function_over_columns("ewm", col, F.lit(alpha), F.lit(ignorena)) def null_index(col: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - - return _invoke_function_over_columns( - "null_index", - col, - ) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc)) + return _invoke_internal_function_over_columns("null_index", col) def distributed_sequence_id() -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function - - return _invoke_function("distributed_sequence_id") - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id()) + return _invoke_internal_function_over_columns("distributed_sequence_id") def collect_top_k(col: Column, num: int, reverse: bool) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + return _invoke_internal_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) - return _invoke_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.collect_top_k(col._jc, num, reverse)) - - -def binary_search(col: Column, value: Column) -> Column: - if is_remote(): - from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns - - return _invoke_function_over_columns("array_binary_search", col, value) - - else: - from pyspark import SparkContext - - sc = SparkContext._active_spark_context - return Column(sc._jvm.PythonSQLUtils.binary_search(col._jc, value._jc)) +def array_binary_search(col: Column, value: Column) -> Column: + return _invoke_internal_function_over_columns("array_binary_search", col, value) def make_interval(unit: str, e: Union[Column, int, float]) -> Column: diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py index 0aaeb7df89be5..fb5dd29169e91 100644 --- a/python/pyspark/pandas/window.py +++ b/python/pyspark/pandas/window.py @@ -2434,7 +2434,8 @@ def _compute_unified_alpha(self) -> float: if opt_count != 1: raise ValueError("com, span, halflife, and alpha are mutually exclusive") - return unified_alpha + # convert possible numpy.float64 to float for lit function + return float(unified_alpha) @abstractmethod def _apply_as_series_or_frame(self, func: Callable[[Column], Column]) -> FrameLike: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index bc270e6ac64ad..3504f6e76f79d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} @@ -147,45 +146,6 @@ private[sql] object PythonSQLUtils extends Logging { def castTimestampNTZToLong(c: Column): Column = Column.internalFn("timestamp_ntz_to_long", c) - def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column = - Column.internalFn("ewm", e, lit(alpha), lit(ignoreNA)) - - def nullIndex(e: Column): Column = Column.internalFn("null_index", e) - - def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = - Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) - - def binary_search(e: Column, value: Column): Column = - Column.internalFn("array_binary_search", e, value) - - def pandasProduct(e: Column, ignoreNA: Boolean): Column = - Column.internalFn("pandas_product", e, lit(ignoreNA)) - - def pandasStddev(e: Column, ddof: Int): Column = - Column.internalFn("pandas_stddev", e, lit(ddof)) - - def pandasVariance(e: Column, ddof: Int): Column = - Column.internalFn("pandas_var", e, lit(ddof)) - - def pandasSkewness(e: Column): Column = - Column.internalFn("pandas_skew", e) - - def pandasKurtosis(e: Column): Column = - Column.internalFn("pandas_kurt", e) - - def pandasMode(e: Column, ignoreNA: Boolean): Column = - Column.internalFn("pandas_mode", e, lit(ignoreNA)) - - def pandasCovar(col1: Column, col2: Column, ddof: Int): Column = - Column.internalFn("pandas_covar", col1, col2, lit(ddof)) - - /** - * A long column that increases one by one. - * This is for 'distributed-sequence' default index in pandas API on Spark. - */ - def distributed_sequence_id(): Column = - Column.internalFn("distributed_sequence_id") - def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) @@ -205,6 +165,9 @@ private[sql] object PythonSQLUtils extends Logging { @scala.annotation.varargs def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*) + + @scala.annotation.varargs + def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala index 1d7698df2f1be..f0ed2241fd286 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan} import org.apache.spark.sql.expressions.Window @@ -405,7 +404,7 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession { assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y"))) // Test for AttachDistributedSequence - val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*")) + val df13 = df1.select(Column.internalFn("distributed_sequence_id").alias("seq"), col("*")) val df14 = df13.filter($"value" === "A2") assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2"))) assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e1774cab4a0de..2c0d9e29bb273 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,7 +29,6 @@ import org.scalatest.matchers.should.Matchers._ import org.apache.spark.SparkException import org.apache.spark.api.python.PythonEvalType -import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -2318,7 +2317,7 @@ class DataFrameSuite extends QueryTest test("SPARK-36338: DataFrame.withSequenceColumn should append unique sequence IDs") { val ids = spark.range(10).repartition(5).select( - distributed_sequence_id().alias("default_index"), col("id")) + Column.internalFn("distributed_sequence_id").alias("default_index"), col("id")) assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet) assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet) } From c362d500acba5bcf476a2a91ac9b7441ba1e7e2d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 25 Sep 2024 19:35:33 +0900 Subject: [PATCH 081/250] [SPARK-49775][SQL][TESTS] Make tests of `INVALID_PARAMETER_VALUE.CHARSET` deterministic ### What changes were proposed in this pull request? Make tests of `INVALID_PARAMETER_VALUE.CHARSET` deterministic ### Why are the changes needed? `VALID_CHARSETS` is a Set, so `VALID_CHARSETS.mkString(", ")` is non-deterministic, and cause failures in different testing environments, e.g. ``` org.scalatest.exceptions.TestFailedException: ansi/string-functions.sql Expected "...sets" : "UTF-16LE, U[TF-8, UTF-32, UTF-16BE, UTF-16, US-ASCII, ISO-8859-1]", "functionName...", but got "...sets" : "UTF-16LE, U[S-ASCII, ISO-8859-1, UTF-8, UTF-32, UTF-16BE, UTF-16]", "functionName..." Result did not match for query #93 select encode('hello', 'WINDOWS-1252') at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:472) at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:471) at org.scalatest.funsuite.AnyFunSuite.newAssertionFailedException(AnyFunSuite.scala:1564) at org.scalatest.Assertions.assertResult(Assertions.scala:847) at org.scalatest.Assertions.assertResult$(Assertions.scala:842) at org.scalatest.funsuite.AnyFunSuite.assertResult(AnyFunSuite.scala:1564) ``` ### Does this PR introduce _any_ user-facing change? No, test only ### How was this patch tested? updated tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48235 from zhengruifeng/sql_test_sort_charset. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../sql/catalyst/util/CharsetProvider.scala | 2 +- .../results/ansi/string-functions.sql.out | 16 ++++++++-------- .../sql-tests/results/string-functions.sql.out | 16 ++++++++-------- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala index 0e7fca24e1374..d85673f2ce811 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala @@ -24,7 +24,7 @@ private[sql] object CharsetProvider { final lazy val VALID_CHARSETS = - Set("us-ascii", "iso-8859-1", "utf-8", "utf-16be", "utf-16le", "utf-16", "utf-32") + Array("us-ascii", "iso-8859-1", "utf-8", "utf-16be", "utf-16le", "utf-16", "utf-32").sorted def forName( charset: String, diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out index cf1bce3c0e504..706673606625b 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/string-functions.sql.out @@ -842,7 +842,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -860,7 +860,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -878,7 +878,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -896,7 +896,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -1140,7 +1140,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1158,7 +1158,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1208,7 +1208,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1226,7 +1226,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 14d7b31f8c63f..3f9f24f817f2c 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -778,7 +778,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -796,7 +796,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -814,7 +814,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -832,7 +832,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`encode`", "parameter" : "`charset`" } @@ -1076,7 +1076,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1094,7 +1094,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "Windows-xxx", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1144,7 +1144,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } @@ -1162,7 +1162,7 @@ org.apache.spark.SparkIllegalArgumentException "sqlState" : "22023", "messageParameters" : { "charset" : "WINDOWS-1252", - "charsets" : "utf-8, utf-16be, iso-8859-1, utf-16le, utf-16, utf-32, us-ascii", + "charsets" : "iso-8859-1, us-ascii, utf-16, utf-16be, utf-16le, utf-32, utf-8", "functionName" : "`decode`", "parameter" : "`charset`" } From 0ccf53ae6faabc4420317d379da77a299794c84c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 25 Sep 2024 19:21:36 +0800 Subject: [PATCH 082/250] [SPARK-49609][PYTHON][FOLLOWUP] Correct the typehint for `filter` and `where` ### What changes were proposed in this pull request? Correct the typehint for `filter` and `where` ### Why are the changes needed? the input `str` should not be treated as column name ### Does this PR introduce _any_ user-facing change? doc change ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48244 from zhengruifeng/py_filter_where. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/sql/classic/dataframe.py | 2 +- python/pyspark/sql/connect/dataframe.py | 2 +- python/pyspark/sql/dataframe.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 23484fcf0051f..0dd66a9d86545 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -1787,7 +1787,7 @@ def semanticHash(self) -> int: def inputFiles(self) -> List[str]: return list(self._jdf.inputFiles()) - def where(self, condition: "ColumnOrName") -> ParentDataFrame: + def where(self, condition: Union[Column, str]) -> ParentDataFrame: return self.filter(condition) # Two aliases below were added for pandas compatibility many years ago. diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index cb37af8868aad..146cfe11bc502 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -1260,7 +1260,7 @@ def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame: res._cached_schema = self._merge_cached_schema(other) return res - def where(self, condition: "ColumnOrName") -> ParentDataFrame: + def where(self, condition: Union[Column, str]) -> ParentDataFrame: if not isinstance(condition, (str, Column)): raise PySparkTypeError( errorClass="NOT_COLUMN_OR_STR", diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 2179a844b1e5e..142034583dbd2 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -3351,7 +3351,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> "DataFrame": ... @dispatch_df_method - def filter(self, condition: "ColumnOrName") -> "DataFrame": + def filter(self, condition: Union[Column, str]) -> "DataFrame": """Filters rows using the given condition. :func:`where` is an alias for :func:`filter`. @@ -5902,7 +5902,7 @@ def inputFiles(self) -> List[str]: ... @dispatch_df_method - def where(self, condition: "ColumnOrName") -> "DataFrame": + def where(self, condition: Union[Column, str]) -> "DataFrame": """ :func:`where` is an alias for :func:`filter`. From d23023202185f9fd175059caf7499251848c0758 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Wed, 25 Sep 2024 22:41:26 +0900 Subject: [PATCH 083/250] [SPARK-49745][SS] Add change to read registered timers through state data source reader ### What changes were proposed in this pull request? Add change to read registered timers through state data source reader ### Why are the changes needed? Without this, users cannot read registered timers per grouping key within the transformWithState operator ### Does this PR introduce _any_ user-facing change? Yes Users can now read registered timers using the following query: ``` val stateReaderDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, ) .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) .load() ``` ### How was this patch tested? Added unit tests ``` [info] Run completed in 20 seconds, 834 milliseconds. [info] Total number of tests run: 4 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 4, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48205 from anishshri-db/task/SPARK-49745. Lead-authored-by: Anish Shrigondekar Co-authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../v2/state/StateDataSource.scala | 50 ++++++-- .../v2/state/StatePartitionReader.scala | 5 +- .../v2/state/utils/SchemaUtil.scala | 33 ++++++ .../StateStoreColumnFamilySchemaUtils.scala | 12 ++ .../streaming/StateTypesEncoderUtils.scala | 3 + .../StatefulProcessorHandleImpl.scala | 16 +++ .../execution/streaming/TimerStateImpl.scala | 9 ++ .../TransformWithStateVariableUtils.scala | 6 +- .../v2/state/StateDataSourceReadSuite.scala | 19 +++ ...ateDataSourceTransformWithStateSuite.scala | 109 +++++++++++++++++- .../TransformWithValueStateTTLSuite.scala | 21 +++- 11 files changed, 263 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 429464ea5438d..39bc4dd9fb9c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -29,15 +29,16 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.DataSourceOptions import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, STATE_VAR_NAME} +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, READ_REGISTERED_TIMERS, STATE_VAR_NAME} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil -import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.streaming.TimeMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -132,7 +133,7 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging sourceOptions: StateSourceOptions, stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = { val twsShortName = "transformWithStateExec" - if (sourceOptions.stateVarName.isDefined) { + if (sourceOptions.stateVarName.isDefined || sourceOptions.readRegisteredTimers) { // Perform checks for transformWithState operator in case state variable name is provided require(stateStoreMetadata.size == 1) val opMetadata = stateStoreMetadata.head @@ -153,10 +154,21 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging "No state variable names are defined for the transformWithState operator") } + val twsOperatorProperties = TransformWithStateOperatorProperties.fromJson(operatorProperties) + val timeMode = twsOperatorProperties.timeMode + if (sourceOptions.readRegisteredTimers && timeMode == TimeMode.None().toString) { + throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS, + "Registered timers are not available in TimeMode=None.") + } + // if the state variable is not one of the defined/available state variables, then we // fail the query - val stateVarName = sourceOptions.stateVarName.get - val twsOperatorProperties = TransformWithStateOperatorProperties.fromJson(operatorProperties) + val stateVarName = if (sourceOptions.readRegisteredTimers) { + TimerStateUtils.getTimerStateVarName(timeMode) + } else { + sourceOptions.stateVarName.get + } + val stateVars = twsOperatorProperties.stateVariables if (stateVars.filter(stateVar => stateVar.stateName == stateVarName).size != 1) { throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, @@ -196,9 +208,10 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging var keyStateEncoderSpecOpt: Option[KeyStateEncoderSpec] = None var stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema] = None var transformWithStateVariableInfoOpt: Option[TransformWithStateVariableInfo] = None + var timeMode: String = TimeMode.None.toString if (sourceOptions.joinSide == JoinSideValues.none) { - val stateVarName = sourceOptions.stateVarName + var stateVarName = sourceOptions.stateVarName .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME) // Read the schema file path from operator metadata version v2 onwards @@ -208,6 +221,12 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging val storeMetadataEntry = storeMetadata.head val operatorProperties = TransformWithStateOperatorProperties.fromJson( storeMetadataEntry.operatorPropertiesJson) + timeMode = operatorProperties.timeMode + + if (sourceOptions.readRegisteredTimers) { + stateVarName = TimerStateUtils.getTimerStateVarName(timeMode) + } + val stateVarInfoList = operatorProperties.stateVariables .filter(stateVar => stateVar.stateName == stateVarName) require(stateVarInfoList.size == 1, s"Failed to find unique state variable info " + @@ -304,6 +323,7 @@ case class StateSourceOptions( fromSnapshotOptions: Option[FromSnapshotOptions], readChangeFeedOptions: Option[ReadChangeFeedOptions], stateVarName: Option[String], + readRegisteredTimers: Boolean, flattenCollectionTypes: Boolean) { def stateCheckpointLocation: Path = new Path(resolvedCpLocation, DIR_NAME_STATE) @@ -336,6 +356,7 @@ object StateSourceOptions extends DataSourceOptions { val CHANGE_START_BATCH_ID = newOption("changeStartBatchId") val CHANGE_END_BATCH_ID = newOption("changeEndBatchId") val STATE_VAR_NAME = newOption("stateVarName") + val READ_REGISTERED_TIMERS = newOption("readRegisteredTimers") val FLATTEN_COLLECTION_TYPES = newOption("flattenCollectionTypes") object JoinSideValues extends Enumeration { @@ -377,6 +398,19 @@ object StateSourceOptions extends DataSourceOptions { val stateVarName = Option(options.get(STATE_VAR_NAME)) .map(_.trim) + val readRegisteredTimers = try { + Option(options.get(READ_REGISTERED_TIMERS)) + .map(_.toBoolean).getOrElse(false) + } catch { + case _: IllegalArgumentException => + throw StateDataSourceErrors.invalidOptionValue(READ_REGISTERED_TIMERS, + "Boolean value is expected") + } + + if (readRegisteredTimers && stateVarName.isDefined) { + throw StateDataSourceErrors.conflictOptions(Seq(READ_REGISTERED_TIMERS, STATE_VAR_NAME)) + } + val flattenCollectionTypes = try { Option(options.get(FLATTEN_COLLECTION_TYPES)) .map(_.toBoolean).getOrElse(true) @@ -489,8 +523,8 @@ object StateSourceOptions extends DataSourceOptions { StateSourceOptions( resolvedCpLocation, batchId.get, operatorId, storeName, joinSide, - readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, stateVarName, - flattenCollectionTypes) + readChangeFeed, fromSnapshotOptions, readChangeFeedOptions, + stateVarName, readRegisteredTimers, flattenCollectionTypes) } private def resolvedCheckpointLocation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index ae12b18c1f627..d77d97f0057fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -107,6 +107,8 @@ abstract class StatePartitionReaderBase( useColumnFamilies = useColFamilies, storeConf, hadoopConf.value, useMultipleValuesPerKey = useMultipleValuesPerKey) + val isInternal = partition.sourceOptions.readRegisteredTimers + if (useColFamilies) { val store = provider.getStore(partition.sourceOptions.batchId + 1) require(stateStoreColFamilySchemaOpt.isDefined) @@ -117,7 +119,8 @@ abstract class StatePartitionReaderBase( stateStoreColFamilySchema.keySchema, stateStoreColFamilySchema.valueSchema, stateStoreColFamilySchema.keyStateEncoderSpec.get, - useMultipleValuesPerKey = useMultipleValuesPerKey) + useMultipleValuesPerKey = useMultipleValuesPerKey, + isInternal = isInternal) } provider } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala index dc0d6af951143..c337d548fa42b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/utils/SchemaUtil.scala @@ -230,6 +230,7 @@ object SchemaUtil { "map_value" -> classOf[MapType], "user_map_key" -> classOf[StructType], "user_map_value" -> classOf[StructType], + "expiration_timestamp_ms" -> classOf[LongType], "partition_id" -> classOf[IntegerType]) val expectedFieldNames = if (sourceOptions.readChangeFeed) { @@ -256,6 +257,9 @@ object SchemaUtil { Seq("key", "map_value", "partition_id") } + case TimerState => + Seq("key", "expiration_timestamp_ms", "partition_id") + case _ => throw StateDataSourceErrors .internalError(s"Unsupported state variable type $stateVarType") @@ -322,6 +326,14 @@ object SchemaUtil { .add("partition_id", IntegerType) } + case TimerState => + val groupingKeySchema = SchemaUtil.getSchemaAsDataType( + stateStoreColFamilySchema.keySchema, "key") + new StructType() + .add("key", groupingKeySchema) + .add("expiration_timestamp_ms", LongType) + .add("partition_id", IntegerType) + case _ => throw StateDataSourceErrors.internalError(s"Unsupported state variable type $stateVarType") } @@ -407,9 +419,30 @@ object SchemaUtil { unifyMapStateRowPair(store.iterator(stateVarName), compositeKeySchema, partitionId, stateSourceOptions) + case StateVariableType.TimerState => + store + .iterator(stateVarName) + .map { pair => + unifyTimerRow(pair.key, compositeKeySchema, partitionId) + } + case _ => throw new IllegalStateException( s"Unsupported state variable type: $stateVarType") } } + + private def unifyTimerRow( + rowKey: UnsafeRow, + groupingKeySchema: StructType, + partitionId: Int): InternalRow = { + val groupingKey = rowKey.get(0, groupingKeySchema).asInstanceOf[UnsafeRow] + val expirationTimestamp = rowKey.getLong(1) + + val row = new GenericInternalRow(3) + row.update(0, groupingKey) + row.update(1, expirationTimestamp) + row.update(2, partitionId) + row + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala index 99229c6132eb2..7da8408f98b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateStoreColumnFamilySchemaUtils.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._ import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateStoreColFamilySchema} +import org.apache.spark.sql.types.StructType object StateStoreColumnFamilySchemaUtils { @@ -61,4 +62,15 @@ object StateStoreColumnFamilySchemaUtils { Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), Some(userKeyEnc.schema)) } + + def getTimerStateSchema( + stateName: String, + keySchema: StructType, + valSchema: StructType): StateStoreColFamilySchema = { + StateStoreColFamilySchema( + stateName, + keySchema, + valSchema, + Some(PrefixKeyScanStateEncoderSpec(keySchema, 1))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala index 1f5ad2fc85470..b70f9699195d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StateTypesEncoderUtils.scala @@ -288,6 +288,9 @@ class TimerKeyEncoder(keyExprEnc: ExpressionEncoder[Any]) { .add("key", new StructType(keyExprEnc.schema.fields)) .add("expiryTimestampMs", LongType, nullable = false) + val schemaForValueRow: StructType = + StructType(Array(StructField("__dummy__", NullType))) + private val keySerializer = keyExprEnc.createSerializer() private val keyDeserializer = keyExprEnc.resolveAndBind().createDeserializer() private val prefixKeyProjection = UnsafeProjection.create(schemaForPrefixKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 942d395dec0e2..8beacbec7e6ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -308,6 +308,12 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi private val stateVariableInfos: mutable.Map[String, TransformWithStateVariableInfo] = new mutable.HashMap[String, TransformWithStateVariableInfo]() + // If timeMode is not None, add a timer column family schema to the operator metadata so that + // registered timers can be read using the state data source reader. + if (timeMode != TimeMode.None()) { + addTimerColFamily() + } + def getColumnFamilySchemas: Map[String, StateStoreColFamilySchema] = columnFamilySchemas.toMap def getStateVariableInfos: Map[String, TransformWithStateVariableInfo] = stateVariableInfos.toMap @@ -318,6 +324,16 @@ class DriverStatefulProcessorHandleImpl(timeMode: TimeMode, keyExprEnc: Expressi } } + private def addTimerColFamily(): Unit = { + val stateName = TimerStateUtils.getTimerStateVarName(timeMode.toString) + val timerEncoder = new TimerKeyEncoder(keyExprEnc) + val colFamilySchema = StateStoreColumnFamilySchemaUtils. + getTimerStateSchema(stateName, timerEncoder.schemaForKeyRow, timerEncoder.schemaForValueRow) + columnFamilySchemas.put(stateName, colFamilySchema) + val stateVariableInfo = TransformWithStateVariableUtils.getTimerState(stateName) + stateVariableInfos.put(stateName, stateVariableInfo) + } + override def getValueState[T](stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state", PRE_INIT) val colFamilySchema = StateStoreColumnFamilySchemaUtils. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala index 82a4226fcfd54..d0fbaf6600609 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TimerStateImpl.scala @@ -34,6 +34,15 @@ object TimerStateUtils { val EVENT_TIMERS_STATE_NAME = "$eventTimers" val KEY_TO_TIMESTAMP_CF = "_keyToTimestamp" val TIMESTAMP_TO_KEY_CF = "_timestampToKey" + + def getTimerStateVarName(timeMode: String): String = { + assert(timeMode == TimeMode.EventTime.toString || timeMode == TimeMode.ProcessingTime.toString) + if (timeMode == TimeMode.EventTime.toString) { + TimerStateUtils.EVENT_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF + } else { + TimerStateUtils.PROC_TIMERS_STATE_NAME + TimerStateUtils.KEY_TO_TIMESTAMP_CF + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala index 0a32564f973a3..4a192b3e51c71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateVariableUtils.scala @@ -43,12 +43,16 @@ object TransformWithStateVariableUtils { def getMapState(stateName: String, ttlEnabled: Boolean): TransformWithStateVariableInfo = { TransformWithStateVariableInfo(stateName, StateVariableType.MapState, ttlEnabled) } + + def getTimerState(stateName: String): TransformWithStateVariableInfo = { + TransformWithStateVariableInfo(stateName, StateVariableType.TimerState, ttlEnabled = false) + } } // Enum of possible State Variable types object StateVariableType extends Enumeration { type StateVariableType = Value - val ValueState, ListState, MapState = Value + val ValueState, ListState, MapState, TimerState = Value } case class TransformWithStateVariableInfo( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 8707facc4c126..5f55848d540df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -288,6 +288,25 @@ class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase { } } + test("ERROR: trying to specify state variable name along with " + + "readRegisteredTimers should fail") { + withTempDir { tempDir => + val exc = intercept[StateDataSourceConflictOptions] { + spark.read.format("statestore") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.STATE_VAR_NAME, "test") + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load(tempDir.getAbsolutePath) + } + checkError(exc, "STDS_CONFLICT_OPTIONS", "42613", + Map("options" -> + s"['${ + StateSourceOptions.READ_REGISTERED_TIMERS + }', '${StateSourceOptions.STATE_VAR_NAME}']")) + } + } + test("ERROR: trying to specify non boolean value for " + "flattenCollectionTypes") { withTempDir { tempDir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index 69df86fd5f746..bd047d1132fbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -21,9 +21,9 @@ import java.time.Duration import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.{AlsoTestWithChangelogCheckpointingEnabled, RocksDBStateStoreProvider, TestClass} -import org.apache.spark.sql.functions.explode +import org.apache.spark.sql.functions.{explode, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, OutputMode, RunningCountStatefulProcessor, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} +import org.apache.spark.sql.streaming.{ExpiredTimerInfo, InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} import org.apache.spark.sql.streaming.util.StreamManualClock /** Stateful processor of single value state var with non-primitive type */ @@ -176,8 +176,19 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue]) assert(ex.getMessage.contains("State variable non-exist is not defined")) - // TODO: this should be removed when readChangeFeed is supported for value state + // Verify that trying to read timers in TimeMode as None fails val ex1 = intercept[Exception] { + spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load() + } + assert(ex1.isInstanceOf[StateDataSourceInvalidOptionValue]) + assert(ex1.getMessage.contains("Registered timers are not available")) + + // TODO: this should be removed when readChangeFeed is supported for value state + val ex2 = intercept[Exception] { spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) @@ -186,7 +197,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) .load() } - assert(ex1.isInstanceOf[StateDataSourceConflictOptions]) + assert(ex2.isInstanceOf[StateDataSourceConflictOptions]) } } } @@ -563,4 +574,94 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } } + + test("state data source - processing-time timers integration") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val clock = new StreamManualClock + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithProcTimeTimerUpdates(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(Trigger.ProcessingTime("1 second"), triggerClock = clock, + checkpointLocation = tempDir.getCanonicalPath), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), // at batch 0, ts = 1, timer = "a" -> [6] (= 1 + 5) + AddData(inputData, "a"), + AdvanceManualClock(2 * 1000), + CheckNewAnswer(("a", "2")), // at batch 1, ts = 3, timer = "a" -> [10.5] (3 + 7.5) + StopStream) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", + "expiration_timestamp_ms AS expiryTimestamp", + "partition_id") + + checkAnswer(resultDf, + Seq(Row("a", 10500L, 0))) + } + } + } + + test("state data source - event-time timers integration") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { tempDir => + val inputData = MemoryStream[(String, Int)] + val result = + inputData.toDS() + .select($"_1".as("key"), timestamp_seconds($"_2").as("eventTime")) + .withWatermark("eventTime", "10 seconds") + .as[(String, Long)] + .groupByKey(_._1) + .transformWithState( + new MaxEventTimeStatefulProcessor(), + TimeMode.EventTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getCanonicalPath), + + AddData(inputData, ("a", 11), ("a", 13), ("a", 15)), + // Max event time = 15. Timeout timestamp for "a" = 15 + 5 = 20. Watermark = 15 - 10 = 5. + CheckNewAnswer(("a", 15)), // Output = max event time of a + + AddData(inputData, ("a", 4)), // Add data older than watermark for "a" + CheckNewAnswer(), // No output as data should get filtered by watermark + StopStream) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.READ_REGISTERED_TIMERS, true) + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value AS groupingKey", + "expiration_timestamp_ms AS expiryTimestamp", + "partition_id") + + checkAnswer(resultDf, + Seq(Row("a", 20000L, 0))) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala index 45056d104e84e..1fbeaeb817bd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithValueStateTTLSuite.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoders -import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, ValueStateImpl, ValueStateImplWithTTL} +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, ListStateImplWithTTL, MapStateImplWithTTL, MemoryStream, TimerStateUtils, ValueStateImpl, ValueStateImplWithTTL} import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.util.StreamManualClock @@ -265,7 +265,16 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) val keySchema = new StructType().add("value", StringType) + val schemaForKeyRow: StructType = new StructType() + .add("key", new StructType(keySchema.fields)) + .add("expiryTimestampMs", LongType, nullable = false) + val schemaForValueRow: StructType = StructType(Array(StructField("__dummy__", NullType))) val schema0 = StateStoreColFamilySchema( + TimerStateUtils.getTimerStateVarName(TimeMode.ProcessingTime().toString), + schemaForKeyRow, + schemaForValueRow, + Some(PrefixKeyScanStateEncoderSpec(schemaForKeyRow, 1))) + val schema1 = StateStoreColFamilySchema( "valueStateTTL", keySchema, new StructType().add("value", @@ -275,14 +284,14 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) - val schema1 = StateStoreColFamilySchema( + val schema2 = StateStoreColFamilySchema( "valueState", keySchema, new StructType().add("value", IntegerType, false), Some(NoPrefixKeyStateEncoderSpec(keySchema)), None ) - val schema2 = StateStoreColFamilySchema( + val schema3 = StateStoreColFamilySchema( "listState", keySchema, new StructType().add("value", @@ -300,7 +309,7 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { val compositeKeySchema = new StructType() .add("key", new StructType().add("value", StringType)) .add("userKey", userKeySchema) - val schema3 = StateStoreColFamilySchema( + val schema4 = StateStoreColFamilySchema( "mapState", compositeKeySchema, new StructType().add("value", @@ -351,9 +360,9 @@ class TransformWithValueStateTTLSuite extends TransformWithStateTTLTest { q.lastProgress.stateOperators.head.customMetrics .get("numMapStateWithTTLVars").toInt) - assert(colFamilySeq.length == 4) + assert(colFamilySeq.length == 5) assert(colFamilySeq.map(_.toString).toSet == Set( - schema0, schema1, schema2, schema3 + schema0, schema1, schema2, schema3, schema4 ).map(_.toString)) }, StopStream From 983f6f434af335b9270a0748dc5b4b18c7dc4846 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 25 Sep 2024 07:50:20 -0700 Subject: [PATCH 084/250] [SPARK-49746][BUILD] Upgrade Scala to 2.13.15 ### What changes were proposed in this pull request? The pr aims to upgrade `scala` from `2.13.14` to `2.13.15`. ### Why are the changes needed? https://contributors.scala-lang.org/t/scala-2-13-15-release-planning/6649 image **Note: since 2.13.15, "-Wconf:cat=deprecation:wv,any:e" no longer takes effect and needs to be changed to "-Wconf:any:e", "-Wconf:cat=deprecation:wv", please refer to the details: https://github.com/scala/scala/pull/10708** ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48192 from panbingkun/SPARK-49746. Lead-authored-by: panbingkun Co-authored-by: YangJie Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 8 ++++---- docs/_config.yml | 2 +- pom.xml | 7 ++++--- project/SparkBuild.scala | 6 +++++- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 88526995293f5..19b8a237d30aa 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -144,7 +144,7 @@ jetty-util-ajax/11.0.23//jetty-util-ajax-11.0.23.jar jetty-util/11.0.23//jetty-util-11.0.23.jar jjwt-api/0.12.6//jjwt-api-0.12.6.jar jline/2.14.6//jline-2.14.6.jar -jline/3.25.1//jline-3.25.1.jar +jline/3.26.3//jline-3.26.3.jar jna/5.14.0//jna-5.14.0.jar joda-time/2.13.0//joda-time-2.13.0.jar jodd-core/3.5.2//jodd-core-3.5.2.jar @@ -252,11 +252,11 @@ py4j/0.10.9.7//py4j-0.10.9.7.jar remotetea-oncrpc/1.1.2//remotetea-oncrpc-1.1.2.jar rocksdbjni/9.5.2//rocksdbjni-9.5.2.jar scala-collection-compat_2.13/2.7.0//scala-collection-compat_2.13-2.7.0.jar -scala-compiler/2.13.14//scala-compiler-2.13.14.jar -scala-library/2.13.14//scala-library-2.13.14.jar +scala-compiler/2.13.15//scala-compiler-2.13.15.jar +scala-library/2.13.15//scala-library-2.13.15.jar scala-parallel-collections_2.13/1.0.4//scala-parallel-collections_2.13-1.0.4.jar scala-parser-combinators_2.13/2.4.0//scala-parser-combinators_2.13-2.4.0.jar -scala-reflect/2.13.14//scala-reflect-2.13.14.jar +scala-reflect/2.13.15//scala-reflect-2.13.15.jar scala-xml_2.13/2.3.0//scala-xml_2.13-2.3.0.jar slf4j-api/2.0.16//slf4j-api-2.0.16.jar snakeyaml-engine/2.7//snakeyaml-engine-2.7.jar diff --git a/docs/_config.yml b/docs/_config.yml index e74eda0470417..089d6bf2097b8 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -22,7 +22,7 @@ include: SPARK_VERSION: 4.0.0-SNAPSHOT SPARK_VERSION_SHORT: 4.0.0 SCALA_BINARY_VERSION: "2.13" -SCALA_VERSION: "2.13.14" +SCALA_VERSION: "2.13.15" SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK SPARK_GITHUB_URL: https://github.com/apache/spark # Before a new release, we should: diff --git a/pom.xml b/pom.xml index 131e754da8157..f3dc92426ac4e 100644 --- a/pom.xml +++ b/pom.xml @@ -169,7 +169,7 @@ 3.2.2 4.4 - 2.13.14 + 2.13.15 2.13 2.2.0 4.9.1 @@ -226,7 +226,7 @@ and ./python/packaging/connect/setup.py too. --> 17.0.0 - 3.0.0-M2 + 3.0.0 0.12.6 @@ -3051,7 +3051,8 @@ -explaintypes -release 17 - -Wconf:cat=deprecation:wv,any:e + -Wconf:any:e + -Wconf:cat=deprecation:wv -Wunused:imports -Wconf:cat=scaladoc:wv -Wconf:msg=^(?=.*?method|value|type|object|trait|inheritance)(?=.*?deprecated)(?=.*?since 2.13).+$:e diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 2f390cb70baa8..82950fb30287a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -234,7 +234,11 @@ object SparkBuild extends PomBuild { // replace -Xfatal-warnings with fine-grained configuration, since 2.13.2 // verbose warning on deprecation, error on all others // see `scalac -Wconf:help` for details - "-Wconf:cat=deprecation:wv,any:e", + // since 2.13.15, "-Wconf:cat=deprecation:wv,any:e" no longer takes effect and needs to + // be changed to "-Wconf:any:e", "-Wconf:cat=deprecation:wv", + // please refer to the details: https://github.com/scala/scala/pull/10708 + "-Wconf:any:e", + "-Wconf:cat=deprecation:wv", // 2.13-specific warning hits to be muted (as narrowly as possible) and addressed separately "-Wunused:imports", "-Wconf:msg=^(?=.*?method|value|type|object|trait|inheritance)(?=.*?deprecated)(?=.*?since 2.13).+$:e", From 1f2e7b87db76ef60eded8a6db09f6690238471ce Mon Sep 17 00:00:00 2001 From: Enrico Minack Date: Wed, 25 Sep 2024 07:53:12 -0700 Subject: [PATCH 085/250] [SPARK-49731][K8S] Support K8s volume `mount.subPathExpr` and `hostPath` volume `type` ### What changes were proposed in this pull request? Add the following config options: - `spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.subPathExpr` - `spark.kubernetes.executor.volumes.hostPath.[VolumeName].options.type` ### Why are the changes needed? K8s Spec - https://kubernetes.io/docs/concepts/storage/volumes/#hostpath-volume-types - https://kubernetes.io/docs/concepts/storage/volumes/#using-subpath-expanded-environment These are natural extensions of the existing options - `spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.subPath` - `spark.kubernetes.executor.volumes.hostPath.[VolumeName].options.path` ### Does this PR introduce _any_ user-facing change? Above config options. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48181 from EnricoMi/k8s-volume-options. Authored-by: Enrico Minack Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/k8s/Config.scala | 2 + .../deploy/k8s/KubernetesVolumeSpec.scala | 3 +- .../deploy/k8s/KubernetesVolumeUtils.scala | 18 ++++- .../features/MountVolumesFeatureStep.scala | 6 +- .../spark/deploy/k8s/KubernetesTestConf.scala | 11 ++- .../k8s/KubernetesVolumeUtilsSuite.scala | 42 ++++++++++- .../features/LocalDirsFeatureStepSuite.scala | 3 +- .../MountVolumesFeatureStepSuite.scala | 72 ++++++++++++++++++- 8 files changed, 144 insertions(+), 13 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 3a4d68c19014d..9c50f8ddb00cc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -769,8 +769,10 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_NFS_TYPE = "nfs" val KUBERNETES_VOLUMES_MOUNT_PATH_KEY = "mount.path" val KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY = "mount.subPath" + val KUBERNETES_VOLUMES_MOUNT_SUBPATHEXPR_KEY = "mount.subPathExpr" val KUBERNETES_VOLUMES_MOUNT_READONLY_KEY = "mount.readOnly" val KUBERNETES_VOLUMES_OPTIONS_PATH_KEY = "options.path" + val KUBERNETES_VOLUMES_OPTIONS_TYPE_KEY = "options.type" val KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY = "options.claimName" val KUBERNETES_VOLUMES_OPTIONS_CLAIM_STORAGE_CLASS_KEY = "options.storageClass" val KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY = "options.medium" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index 9dfd40a773eb1..b4fe414e3cde5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.k8s private[spark] sealed trait KubernetesVolumeSpecificConf -private[spark] case class KubernetesHostPathVolumeConf(hostPath: String) +private[spark] case class KubernetesHostPathVolumeConf(hostPath: String, volumeType: String) extends KubernetesVolumeSpecificConf private[spark] case class KubernetesPVCVolumeConf( @@ -42,5 +42,6 @@ private[spark] case class KubernetesVolumeSpec( volumeName: String, mountPath: String, mountSubPath: String, + mountSubPathExpr: String, mountReadOnly: Boolean, volumeConf: KubernetesVolumeSpecificConf) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index 6463512c0114b..88bb998d88b7d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -45,7 +45,9 @@ object KubernetesVolumeUtils { val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_PATH_KEY" val readOnlyKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_READONLY_KEY" val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" + val subPathExprKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATHEXPR_KEY" val labelKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_LABEL_KEY" + verifyMutuallyExclusiveOptionKeys(properties, subPathKey, subPathExprKey) val volumeLabelsMap = properties .filter(_._1.startsWith(labelKey)) @@ -57,6 +59,7 @@ object KubernetesVolumeUtils { volumeName = volumeName, mountPath = properties(pathKey), mountSubPath = properties.getOrElse(subPathKey, ""), + mountSubPathExpr = properties.getOrElse(subPathExprKey, ""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), volumeConf = parseVolumeSpecificConf(properties, volumeType, volumeName, Option(volumeLabelsMap))) @@ -87,8 +90,11 @@ object KubernetesVolumeUtils { volumeType match { case KUBERNETES_VOLUMES_HOSTPATH_TYPE => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" + val typeKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_TYPE_KEY" verifyOptionKey(options, pathKey, KUBERNETES_VOLUMES_HOSTPATH_TYPE) - KubernetesHostPathVolumeConf(options(pathKey)) + // "" means that no checks will be performed before mounting the hostPath volume + // backward compatibility default + KubernetesHostPathVolumeConf(options(pathKey), options.getOrElse(typeKey, "")) case KUBERNETES_VOLUMES_PVC_TYPE => val claimNameKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY" @@ -129,6 +135,16 @@ object KubernetesVolumeUtils { } } + private def verifyMutuallyExclusiveOptionKeys( + options: Map[String, String], + keys: String*): Unit = { + val givenKeys = keys.filter(options.contains) + if (givenKeys.length > 1) { + throw new IllegalArgumentException("These config options are mutually exclusive: " + + s"${givenKeys.mkString(", ")}") + } + } + private def verifySize(size: Option[String]): Unit = { size.foreach { v => if (v.forall(_.isDigit) && parseLong(v) < 1024) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index 5cc61c746b0e0..eea4604010b21 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -65,14 +65,14 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) .withMountPath(spec.mountPath) .withReadOnly(spec.mountReadOnly) .withSubPath(spec.mountSubPath) + .withSubPathExpr(spec.mountSubPathExpr) .withName(spec.volumeName) .build() val volumeBuilder = spec.volumeConf match { - case KubernetesHostPathVolumeConf(hostPath) => - /* "" means that no checks will be performed before mounting the hostPath volume */ + case KubernetesHostPathVolumeConf(hostPath, volumeType) => new VolumeBuilder() - .withHostPath(new HostPathVolumeSource(hostPath, "")) + .withHostPath(new HostPathVolumeSource(hostPath, volumeType)) case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size, labels) => val claimName = conf match { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala index 7e0a65bcdda90..e0ddcd3d416f0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala @@ -113,9 +113,10 @@ object KubernetesTestConf { volumes.foreach { case spec => val (vtype, configs) = spec.volumeConf match { - case KubernetesHostPathVolumeConf(path) => - (KUBERNETES_VOLUMES_HOSTPATH_TYPE, - Map(KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> path)) + case KubernetesHostPathVolumeConf(hostPath, volumeType) => + (KUBERNETES_VOLUMES_HOSTPATH_TYPE, Map( + KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> hostPath, + KUBERNETES_VOLUMES_OPTIONS_TYPE_KEY -> volumeType)) case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit, labels) => val sconf = storageClass @@ -145,6 +146,10 @@ object KubernetesTestConf { conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY), spec.mountSubPath) } + if (spec.mountSubPathExpr.nonEmpty) { + conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_SUBPATHEXPR_KEY), + spec.mountSubPathExpr) + } conf.set(key(vtype, spec.volumeName, KUBERNETES_VOLUMES_MOUNT_READONLY_KEY), spec.mountReadOnly.toString) configs.foreach { case (k, v) => diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index 5c103739d3082..1e62db725fb6e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -30,7 +30,20 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === - KubernetesHostPathVolumeConf("/hostPath")) + KubernetesHostPathVolumeConf("/hostPath", "")) + } + + test("Parses hostPath volume type correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.hostPath.volumeName.mount.path", "/path") + sparkConf.set("test.hostPath.volumeName.options.path", "/hostPath") + sparkConf.set("test.hostPath.volumeName.options.type", "Type") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesHostPathVolumeConf] === + KubernetesHostPathVolumeConf("/hostPath", "Type")) } test("Parses subPath correctly") { @@ -43,6 +56,33 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.volumeName === "volumeName") assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountSubPath === "subPath") + assert(volumeSpec.mountSubPathExpr === "") + } + + test("Parses subPathExpr correctly") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.readOnly", "true") + sparkConf.set("test.emptyDir.volumeName.mount.subPathExpr", "subPathExpr") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountSubPath === "") + assert(volumeSpec.mountSubPathExpr === "subPathExpr") + } + + test("Rejects mutually exclusive subPath and subPathExpr") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.emptyDir.volumeName.mount.path", "/path") + sparkConf.set("test.emptyDir.volumeName.mount.subPath", "subPath") + sparkConf.set("test.emptyDir.volumeName.mount.subPathExpr", "subPathExpr") + + val msg = intercept[IllegalArgumentException] { + KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + }.getMessage + assert(msg === "These config options are mutually exclusive: " + + "emptyDir.volumeName.mount.subPath, emptyDir.volumeName.mount.subPathExpr") } test("Parses persistentVolumeClaim volumes correctly") { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala index eaadad163f064..3a9561051a894 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/LocalDirsFeatureStepSuite.scala @@ -137,8 +137,9 @@ class LocalDirsFeatureStepSuite extends SparkFunSuite { "spark-local-dir-test", "/tmp", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "") ) val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val mountVolumeStep = new MountVolumesFeatureStep(kubernetesConf) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index 6a68898c5f61c..c94a7a6ec26a7 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -27,8 +27,9 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "type") ) val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) val step = new MountVolumesFeatureStep(kubernetesConf) @@ -36,6 +37,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(configuredPod.pod.getSpec.getVolumes.size() === 1) assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getPath === "/hostPath/tmp") + assert(configuredPod.pod.getSpec.getVolumes.get(0).getHostPath.getType === "type") assert(configuredPod.container.getVolumeMounts.size() === 1) assert(configuredPod.container.getVolumeMounts.get(0).getMountPath === "/tmp") assert(configuredPod.container.getVolumeMounts.get(0).getName === "testVolume") @@ -47,6 +49,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -69,6 +72,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("pvc-spark-SPARK_EXECUTOR_ID") ) @@ -94,6 +98,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("pvc-spark-SPARK_EXECUTOR_ID", Some("fast"), Some("512M")) ) @@ -119,6 +124,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf("OnDemand") ) @@ -136,6 +142,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, storageClass = Some("gp3"), @@ -156,6 +163,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, storageClass = Some("gp3"), @@ -177,6 +185,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "checkpointVolume1", "/checkpoints1", "", + "", true, KubernetesPVCVolumeConf(claimName = "pvcClaim1", storageClass = Some("gp3"), @@ -188,6 +197,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "checkpointVolume2", "/checkpoints2", "", + "", true, KubernetesPVCVolumeConf(claimName = "pvcClaim2", storageClass = Some("gp3"), @@ -209,6 +219,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesPVCVolumeConf(MountVolumesFeatureStep.PVC_ON_DEMAND) ) @@ -226,6 +237,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, KubernetesEmptyDirVolumeConf(Some("Memory"), Some("6G")) ) @@ -249,6 +261,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, KubernetesEmptyDirVolumeConf(None, None) ) @@ -271,6 +284,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", false, KubernetesNFSVolumeConf("/share/name", "nfs.example.com") ) @@ -293,6 +307,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "", + "", true, KubernetesNFSVolumeConf("/share/name", "nfs.example.com") ) @@ -315,13 +330,15 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "hpVolume", "/tmp", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "") ) val pvcVolumeConf = KubernetesVolumeSpec( "checkpointVolume", "/checkpoints", "", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -339,13 +356,15 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "hpVolume", "/data", "", + "", false, - KubernetesHostPathVolumeConf("/hostPath/tmp") + KubernetesHostPathVolumeConf("/hostPath/tmp", "") ) val pvcVolumeConf = KubernetesVolumeSpec( "checkpointVolume", "/data", "", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -364,6 +383,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testVolume", "/tmp", "foo", + "", false, KubernetesEmptyDirVolumeConf(None, None) ) @@ -378,11 +398,32 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(emptyDirMount.getSubPath === "foo") } + test("Mounts subpathexpr on emptyDir") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "foo", + false, + KubernetesEmptyDirVolumeConf(None, None) + ) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val emptyDirMount = configuredPod.container.getVolumeMounts.get(0) + assert(emptyDirMount.getMountPath === "/tmp") + assert(emptyDirMount.getName === "testVolume") + assert(emptyDirMount.getSubPathExpr === "foo") + } + test("Mounts subpath on persistentVolumeClaims") { val volumeConf = KubernetesVolumeSpec( "testVolume", "/tmp", "bar", + "", true, KubernetesPVCVolumeConf("pvcClaim") ) @@ -400,12 +441,36 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(pvcMount.getSubPath === "bar") } + test("Mounts subpathexpr on persistentVolumeClaims") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "bar", + true, + KubernetesPVCVolumeConf("pvcClaim") + ) + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName === "pvcClaim") + assert(configuredPod.container.getVolumeMounts.size() === 1) + val pvcMount = configuredPod.container.getVolumeMounts.get(0) + assert(pvcMount.getMountPath === "/tmp") + assert(pvcMount.getName === "testVolume") + assert(pvcMount.getSubPathExpr === "bar") + } + test("Mounts multiple subpaths") { val volumeConf = KubernetesEmptyDirVolumeConf(None, None) val emptyDirSpec = KubernetesVolumeSpec( "testEmptyDir", "/tmp/foo", "foo", + "", true, KubernetesEmptyDirVolumeConf(None, None) ) @@ -413,6 +478,7 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { "testPVC", "/tmp/bar", "bar", + "", true, KubernetesEmptyDirVolumeConf(None, None) ) From 09209f0ff503b29f9da92ba7db8aa820c03b3c0f Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 25 Sep 2024 07:57:08 -0700 Subject: [PATCH 086/250] [SPARK-49775][SQL][FOLLOW-UP] Use SortedSet instead of Array with sorting ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/48235 that addresses https://github.com/apache/spark/pull/48235#discussion_r1775020195 comment. ### Why are the changes needed? For better performance (in theory) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests should verify them ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48245 from HyukjinKwon/SPARK-49775-followup. Authored-by: Hyukjin Kwon Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/catalyst/util/CharsetProvider.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala index d85673f2ce811..f805d2ed87b52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharsetProvider.scala @@ -18,13 +18,15 @@ import java.nio.charset.{Charset, CharsetDecoder, CharsetEncoder, CodingErrorAction, IllegalCharsetNameException, UnsupportedCharsetException} import java.util.Locale + import scala.collection.SortedSet + import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf private[sql] object CharsetProvider { final lazy val VALID_CHARSETS = - Array("us-ascii", "iso-8859-1", "utf-8", "utf-16be", "utf-16le", "utf-16", "utf-32").sorted + SortedSet("us-ascii", "iso-8859-1", "utf-8", "utf-16be", "utf-16le", "utf-16", "utf-32") def forName( charset: String, From 80d6651cf6a1835d0de3e12e08253d2a9816d499 Mon Sep 17 00:00:00 2001 From: Julek Sompolski Date: Wed, 25 Sep 2024 23:34:23 +0800 Subject: [PATCH 087/250] [SPARK-48195][FOLLOWUP] Accumulator reset() no longer needed in CollectMetricsExec.doExecute() ### What changes were proposed in this pull request? Small followup to https://github.com/apache/spark/pull/48037. `collector.reset()` is no longer needed in `CollectMetricsExec.doExecute()` because it is reset in `resetMetrics()`. This doesn't really matter in practice, but removing to clean up. ### Why are the changes needed? Tiny cleanup. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This change doesn't matter in practice. Just cleanup. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48243 from juliuszsompolski/SPARK-48195-followup. Authored-by: Julek Sompolski Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/execution/CollectMetricsExec.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala index 2115e21f81d71..0a487bac77696 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CollectMetricsExec.scala @@ -67,7 +67,6 @@ case class CollectMetricsExec( override protected def doExecute(): RDD[InternalRow] = { val collector = accumulator - collector.reset() child.execute().mapPartitions { rows => // Only publish the value of the accumulator when the task has completed. This is done by // updating a task local accumulator ('updater') which will be merged with the actual From c0984e70469d99595b8e6eda0d943308f590aaec Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 26 Sep 2024 13:17:59 +0900 Subject: [PATCH 088/250] [SPARK-49609][PYTHON][TESTS][FOLLOW-UP] Avoid import connect modules when connect dependencies not installed ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/48085 that skips the connect import which requires Connect dependencies. ### Why are the changes needed? To recover the PyPy3 build https://github.com/apache/spark/actions/runs/11035779484/job/30652736098 which does not have PyArrow installed. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48259 from HyukjinKwon/SPARK-49609-followup2. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_connect_compatibility.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index 8f3e86f5186a8..dfa0fa63b2dd5 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -21,11 +21,13 @@ from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.sqlutils import ReusedSQLTestCase from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame -from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame from pyspark.sql.classic.column import Column as ClassicColumn -from pyspark.sql.connect.column import Column as ConnectColumn from pyspark.sql.session import SparkSession as ClassicSparkSession -from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + +if should_test_connect: + from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame + from pyspark.sql.connect.column import Column as ConnectColumn + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession class ConnectCompatibilityTestsMixin: From 5629779287724a891c81b16f982f9529bd379c39 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 25 Sep 2024 22:34:35 -0700 Subject: [PATCH 089/250] [SPARK-49786][K8S] Lower `KubernetesClusterSchedulerBackend.onDisconnected` log level to debug ### What changes were proposed in this pull request? This PR aims to lower `KubernetesClusterSchedulerBackend.onDisconnected` log level to debug. ### Why are the changes needed? This INFO-level message was added here. We already propagate the disconnection reason to UI, and `No executor found` has been used when an unknown peer is connect or disconnect. - https://github.com/apache/spark/pull/37821 The driver can be accessed by non-executors by design. And, all other resource managers do not complain at INFO level. ``` INFO KubernetesClusterSchedulerBackend$KubernetesDriverEndpoint: No executor found for x.x.x.0:x ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manual review because this is a log level change. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48249 from dongjoon-hyun/SPARK-49786. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../cluster/k8s/KubernetesClusterSchedulerBackend.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 4e4634504a0f3..09faa2a7fb1b3 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,7 +32,7 @@ import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.submit.KubernetesClientUtils import org.apache.spark.deploy.security.HadoopDelegationTokenManager -import org.apache.spark.internal.LogKeys.{COUNT, HOST_PORT, TOTAL} +import org.apache.spark.internal.LogKeys.{COUNT, TOTAL} import org.apache.spark.internal.MDC import org.apache.spark.internal.config.SCHEDULER_MIN_REGISTERED_RESOURCES_RATIO import org.apache.spark.resource.ResourceProfile @@ -356,7 +356,7 @@ private[spark] class KubernetesClusterSchedulerBackend( execIDRequester -= rpcAddress // Expected, executors re-establish a connection with an ID case _ => - logInfo(log"No executor found for ${MDC(HOST_PORT, rpcAddress)}") + logDebug(s"No executor found for ${rpcAddress}") } } } From 913a0f7813c5b2d2bf105160bf8e55e08b34513b Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 26 Sep 2024 15:15:37 +0800 Subject: [PATCH 090/250] [SPARK-49784][PYTHON][TESTS] Add more test for `spark.sql` ### What changes were proposed in this pull request? add more test for `spark.sql` ### Why are the changes needed? for test coverage ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48246 from zhengruifeng/py_sql_test. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- dev/sparktestsupport/modules.py | 2 + .../sql/tests/connect/test_parity_sql.py | 37 ++++ python/pyspark/sql/tests/test_sql.py | 185 ++++++++++++++++++ 3 files changed, 224 insertions(+) create mode 100644 python/pyspark/sql/tests/connect/test_parity_sql.py create mode 100644 python/pyspark/sql/tests/test_sql.py diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index eda6b063350e5..d2c000b702a64 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -520,6 +520,7 @@ def __hash__(self): "pyspark.sql.tests.test_errors", "pyspark.sql.tests.test_functions", "pyspark.sql.tests.test_group", + "pyspark.sql.tests.test_sql", "pyspark.sql.tests.pandas.test_pandas_cogrouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map", "pyspark.sql.tests.pandas.test_pandas_grouped_map_with_state", @@ -1032,6 +1033,7 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_serde", "pyspark.sql.tests.connect.test_parity_functions", "pyspark.sql.tests.connect.test_parity_group", + "pyspark.sql.tests.connect.test_parity_sql", "pyspark.sql.tests.connect.test_parity_dataframe", "pyspark.sql.tests.connect.test_parity_collection", "pyspark.sql.tests.connect.test_parity_creation", diff --git a/python/pyspark/sql/tests/connect/test_parity_sql.py b/python/pyspark/sql/tests/connect/test_parity_sql.py new file mode 100644 index 0000000000000..4c6b11c60cbe9 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_sql.py @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.test_sql import SQLTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class SQLParityTests(SQLTestsMixin, ReusedConnectTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_sql import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_sql.py b/python/pyspark/sql/tests/test_sql.py new file mode 100644 index 0000000000000..bf50bbc11ac33 --- /dev/null +++ b/python/pyspark/sql/tests/test_sql.py @@ -0,0 +1,185 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql import Row +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class SQLTestsMixin: + def test_simple(self): + res = self.spark.sql("SELECT 1 + 1").collect() + self.assertEqual(len(res), 1) + self.assertEqual(res[0][0], 2) + + def test_args_dict(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name)", + args={"table_name": "test"}, + ) + + self.assertEqual(df.count(), 10) + self.assertEqual(df.limit(5).count(), 5) + self.assertEqual(df.offset(5).count(), 5) + + self.assertEqual(df.take(1), [Row(id=0)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_args_list(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + df = self.spark.sql( + "SELECT * FROM test WHERE ? < id AND id < ?", + args=[1, 6], + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.limit(3).count(), 3) + self.assertEqual(df.offset(3).count(), 1) + + self.assertEqual(df.take(1), [Row(id=2)]) + self.assertEqual(df.tail(1), [Row(id=5)]) + + def test_kwargs_literal(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name) WHERE {m1} < id AND id < {m2} OR id = {m3}", + args={"table_name": "test"}, + m1=3, + m2=7, + m3=9, + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.collect(), [Row(id=4), Row(id=5), Row(id=6), Row(id=9)]) + self.assertEqual(df.take(1), [Row(id=4)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_kwargs_literal_multiple_ref(self): + with self.tempView("test"): + self.spark.range(10).createOrReplaceTempView("test") + + df = self.spark.sql( + "SELECT * FROM IDENTIFIER(:table_name) WHERE {m} = id OR id > {m} OR {m} < 0", + args={"table_name": "test"}, + m=6, + ) + + self.assertEqual(df.count(), 4) + self.assertEqual(df.collect(), [Row(id=6), Row(id=7), Row(id=8), Row(id=9)]) + self.assertEqual(df.take(1), [Row(id=6)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_kwargs_dataframe(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE id > 4", + df=df0, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 5) + self.assertEqual(df1.take(1), [Row(id=5)]) + self.assertEqual(df1.tail(1), [Row(id=9)]) + + def test_kwargs_dataframe_with_column(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE {df.id} > :m1 AND {df[id]} < :m2", + {"m1": 4, "m2": 9}, + df=df0, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 4) + self.assertEqual(df1.take(1), [Row(id=5)]) + self.assertEqual(df1.tail(1), [Row(id=8)]) + + def test_nested_view(self): + with self.tempView("v1", "v2", "v3", "v4"): + self.spark.range(10).createOrReplaceTempView("v1") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v1", "m": 1}, + ).createOrReplaceTempView("v2") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v2", "m": 2}, + ).createOrReplaceTempView("v3") + self.spark.sql( + "SELECT * FROM IDENTIFIER(:view) WHERE id > :m", + args={"view": "v3", "m": 3}, + ).createOrReplaceTempView("v4") + + df = self.spark.sql("select * from v4") + self.assertEqual(df.count(), 6) + self.assertEqual(df.take(1), [Row(id=4)]) + self.assertEqual(df.tail(1), [Row(id=9)]) + + def test_nested_dataframe(self): + df0 = self.spark.range(10) + df1 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[1], + df=df0, + ) + df2 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[2], + df=df1, + ) + df3 = self.spark.sql( + "SELECT * FROM {df} WHERE id > ?", + args=[3], + df=df2, + ) + + self.assertEqual(df0.schema, df1.schema) + self.assertEqual(df1.count(), 8) + self.assertEqual(df1.take(1), [Row(id=2)]) + self.assertEqual(df1.tail(1), [Row(id=9)]) + + self.assertEqual(df0.schema, df2.schema) + self.assertEqual(df2.count(), 7) + self.assertEqual(df2.take(1), [Row(id=3)]) + self.assertEqual(df2.tail(1), [Row(id=9)]) + + self.assertEqual(df0.schema, df3.schema) + self.assertEqual(df3.count(), 6) + self.assertEqual(df3.take(1), [Row(id=4)]) + self.assertEqual(df3.tail(1), [Row(id=9)]) + + +class SQLTests(SQLTestsMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + from pyspark.sql.tests.test_sql import * # noqa: F401 + + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From fe1cf3200223c33ed4670bfa5924d5a4053c8ef9 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Thu, 26 Sep 2024 17:38:58 +0900 Subject: [PATCH 091/250] [SPARK-49656][SS] Add support for state variables with value state collection types and read change feed options ### What changes were proposed in this pull request? Add support for state variables with value state collection types and read change feed options ### Why are the changes needed? Without this, we cannot support reading per key changes for state variables used with stateful processors. ### Does this PR introduce _any_ user-facing change? Yes Users can now query value state variables with the following query: ``` val changeFeedDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, ) .option(StateSourceOptions.STATE_VAR_NAME, ) .option(StateSourceOptions.READ_CHANGE_FEED, true) .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) .load() ``` ### How was this patch tested? Added unit tests ``` [info] Run completed in 17 seconds, 318 milliseconds. [info] Total number of tests run: 2 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48148 from anishshri-db/task/SPARK-49656. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../v2/state/StateDataSource.scala | 10 +- .../v2/state/StatePartitionReader.scala | 10 +- .../state/HDFSBackedStateStoreProvider.scala | 10 +- .../state/RocksDBStateStoreProvider.scala | 79 ++++++++++--- .../streaming/state/StateStore.scala | 6 +- .../streaming/state/StateStoreChangelog.scala | 11 +- ...ateDataSourceTransformWithStateSuite.scala | 107 +++++++++++++++--- 7 files changed, 190 insertions(+), 43 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index 39bc4dd9fb9c8..edddfbd6ccaef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{J import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.datasources.v2.state.utils.SchemaUtil -import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} +import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata, StateVariableType, TimerStateUtils, TransformWithStateOperatorProperties, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} import org.apache.spark.sql.execution.streaming.state.{KeyStateEncoderSpec, NoPrefixKeyStateEncoderSpec, PrefixKeyScanStateEncoderSpec, StateSchemaCompatibilityChecker, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreProviderId} @@ -170,13 +170,15 @@ class StateDataSource extends TableProvider with DataSourceRegister with Logging } val stateVars = twsOperatorProperties.stateVariables - if (stateVars.filter(stateVar => stateVar.stateName == stateVarName).size != 1) { + val stateVarInfo = stateVars.filter(stateVar => stateVar.stateName == stateVarName) + if (stateVarInfo.size != 1) { throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, s"State variable $stateVarName is not defined for the transformWithState operator.") } - // TODO: Support change feed and transformWithState together - if (sourceOptions.readChangeFeed) { + // TODO: add support for list and map type + if (sourceOptions.readChangeFeed && + stateVarInfo.head.stateVariableType != StateVariableType.ValueState) { throw StateDataSourceErrors.conflictOptions(Seq(StateSourceOptions.READ_CHANGE_FEED, StateSourceOptions.STATE_VAR_NAME)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala index d77d97f0057fb..b925aee5b627a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StatePartitionReader.scala @@ -223,10 +223,18 @@ class StateStoreChangeDataPartitionReader( throw StateStoreErrors.stateStoreProviderDoesNotSupportFineGrainedReplay( provider.getClass.toString) } + + val colFamilyNameOpt = if (stateVariableInfoOpt.isDefined) { + Some(stateVariableInfoOpt.get.stateName) + } else { + None + } + provider.asInstanceOf[SupportsFineGrainedReplay] .getStateStoreChangeDataReader( partition.sourceOptions.readChangeFeedOptions.get.changeStartBatchId + 1, - partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1) + partition.sourceOptions.readChangeFeedOptions.get.changeEndBatchId + 1, + colFamilyNameOpt) } override lazy val iter: Iterator[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index d9f4443b79618..884b8aa3853cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -991,8 +991,16 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with result } - override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + override def getStateStoreChangeDataReader( + startVersion: Long, + endVersion: Long, + colFamilyNameOpt: Option[String] = None): StateStoreChangeDataReader = { + // Multiple column families are not supported with HDFSBackedStateStoreProvider + if (colFamilyNameOpt.isDefined) { + throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) + } + new HDFSBackedStateStoreChangeDataReader(fm, baseDir, startVersion, endVersion, CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), keySchema, valueSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 85f80ce9eb1ae..6ab634668bc2a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -498,7 +498,10 @@ private[sql] class RocksDBStateStoreProvider } } - override def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + override def getStateStoreChangeDataReader( + startVersion: Long, + endVersion: Long, + colFamilyNameOpt: Option[String] = None): StateStoreChangeDataReader = { val statePath = stateStoreId.storeCheckpointLocation() val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) @@ -508,7 +511,8 @@ private[sql] class RocksDBStateStoreProvider startVersion, endVersion, CompressionCodec.createCodec(sparkConf, storeConf.compressionCodec), - keyValueEncoderMap) + keyValueEncoderMap, + colFamilyNameOpt) } /** @@ -676,27 +680,70 @@ class RocksDBStateStoreChangeDataReader( endVersion: Long, compressionCodec: CompressionCodec, keyValueEncoderMap: - ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)]) + ConcurrentHashMap[String, (RocksDBKeyStateEncoder, RocksDBValueStateEncoder)], + colFamilyNameOpt: Option[String] = None) extends StateStoreChangeDataReader( - fm, stateLocation, startVersion, endVersion, compressionCodec) { + fm, stateLocation, startVersion, endVersion, compressionCodec, colFamilyNameOpt) { override protected var changelogSuffix: String = "changelog" + private def getColFamilyIdBytes: Option[Array[Byte]] = { + if (colFamilyNameOpt.isDefined) { + val colFamilyName = colFamilyNameOpt.get + if (!keyValueEncoderMap.containsKey(colFamilyName)) { + throw new IllegalStateException( + s"Column family $colFamilyName not found in the key value encoder map") + } + Some(keyValueEncoderMap.get(colFamilyName)._1.getColumnFamilyIdBytes()) + } else { + None + } + } + + private val colFamilyIdBytesOpt: Option[Array[Byte]] = getColFamilyIdBytes + override def getNext(): (RecordType.Value, UnsafeRow, UnsafeRow, Long) = { - val reader = currentChangelogReader() - if (reader == null) { - return null + var currRecord: (RecordType.Value, Array[Byte], Array[Byte]) = null + val currEncoder: (RocksDBKeyStateEncoder, RocksDBValueStateEncoder) = + keyValueEncoderMap.get(colFamilyNameOpt + .getOrElse(StateStore.DEFAULT_COL_FAMILY_NAME)) + + if (colFamilyIdBytesOpt.isDefined) { + // If we are reading records for a particular column family, the corresponding vcf id + // will be encoded in the key byte array. We need to extract that and compare for the + // expected column family id. If it matches, we return the record. If not, we move to + // the next record. Note that this has be handled across multiple changelog files and we + // rely on the currentChangelogReader to move to the next changelog file when needed. + while (currRecord == null) { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + + val nextRecord = reader.next() + val colFamilyIdBytes: Array[Byte] = colFamilyIdBytesOpt.get + val endIndex = colFamilyIdBytes.size + // Function checks for byte arrays being equal + // from index 0 to endIndex - 1 (both inclusive) + if (java.util.Arrays.equals(nextRecord._2, 0, endIndex, + colFamilyIdBytes, 0, endIndex)) { + currRecord = nextRecord + } + } + } else { + val reader = currentChangelogReader() + if (reader == null) { + return null + } + currRecord = reader.next() } - val (recordType, keyArray, valueArray) = reader.next() - // Todo: does not support multiple virtual column families - val (rocksDBKeyStateEncoder, rocksDBValueStateEncoder) = - keyValueEncoderMap.get(StateStore.DEFAULT_COL_FAMILY_NAME) - val keyRow = rocksDBKeyStateEncoder.decodeKey(keyArray) - if (valueArray == null) { - (recordType, keyRow, null, currentChangelogVersion - 1) + + val keyRow = currEncoder._1.decodeKey(currRecord._2) + if (currRecord._3 == null) { + (currRecord._1, keyRow, null, currentChangelogVersion - 1) } else { - val valueRow = rocksDBValueStateEncoder.decodeValue(valueArray) - (recordType, keyRow, valueRow, currentChangelogVersion - 1) + val valueRow = currEncoder._2.decodeValue(currRecord._3) + (currRecord._1, keyRow, valueRow, currentChangelogVersion - 1) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d55a973a14e16..6e616cc71a80c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -519,10 +519,14 @@ trait SupportsFineGrainedReplay { * * @param startVersion starting changelog version * @param endVersion ending changelog version + * @param colFamilyNameOpt optional column family name to read from * @return iterator that gives tuple(recordType: [[RecordType.Value]], nested key: [[UnsafeRow]], * nested value: [[UnsafeRow]], batchId: [[Long]]) */ - def getStateStoreChangeDataReader(startVersion: Long, endVersion: Long): + def getStateStoreChangeDataReader( + startVersion: Long, + endVersion: Long, + colFamilyNameOpt: Option[String] = None): NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala index 651d72da16095..e89550da37e03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreChangelog.scala @@ -397,13 +397,15 @@ class StateStoreChangelogReaderV2( * @param startVersion start version of the changelog file to read * @param endVersion end version of the changelog file to read * @param compressionCodec de-compression method using for reading changelog file + * @param colFamilyNameOpt optional column family name to read from */ abstract class StateStoreChangeDataReader( fm: CheckpointFileManager, stateLocation: Path, startVersion: Long, endVersion: Long, - compressionCodec: CompressionCodec) + compressionCodec: CompressionCodec, + colFamilyNameOpt: Option[String] = None) extends NextIterator[(RecordType.Value, UnsafeRow, UnsafeRow, Long)] with Logging { assert(startVersion >= 1) @@ -451,9 +453,12 @@ abstract class StateStoreChangeDataReader( finished = true return null } - // Todo: Does not support StateStoreChangelogReaderV2 - changelogReader = + + changelogReader = if (colFamilyNameOpt.isDefined) { + new StateStoreChangelogReaderV2(fm, fileIterator.next(), compressionCodec) + } else { new StateStoreChangelogReaderV1(fm, fileIterator.next(), compressionCodec) + } } changelogReader } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index bd047d1132fbe..84c6eb54681a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -186,18 +186,49 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } assert(ex1.isInstanceOf[StateDataSourceInvalidOptionValue]) assert(ex1.getMessage.contains("Registered timers are not available")) + } + } + } - // TODO: this should be removed when readChangeFeed is supported for value state - val ex2 = intercept[Exception] { - spark.read + testWithChangelogCheckpointingEnabled("state data source cdf integration - " + + "value state with single variable") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithSingleValueVar(), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, "a"), + CheckNewAnswer(("a", "1")), + AddData(inputData, "b"), + CheckNewAnswer(("b", "1")), + StopStream + ) + + val changeFeedDf = spark.read .format("statestore") .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) .option(StateSourceOptions.STATE_VAR_NAME, "valueState") - .option(StateSourceOptions.READ_CHANGE_FEED, "true") + .option(StateSourceOptions.READ_CHANGE_FEED, true) .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) .load() - } - assert(ex2.isInstanceOf[StateDataSourceConflictOptions]) + + val opDf = changeFeedDf.selectExpr( + "change_type", + "key.value AS groupingKey", + "value.id AS valueId", "value.name AS valueName", + "partition_id") + + checkAnswer(opDf, + Seq(Row("update", "a", 1L, "dummyKey", 0), Row("update", "b", 1L, "dummyKey", 1))) } } } @@ -260,19 +291,61 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } assert(ex.isInstanceOf[StateDataSourceInvalidOptionValue]) assert(ex.getMessage.contains("State variable non-exist is not defined")) + } + } + } - // TODO: this should be removed when readChangeFeed is supported for TTL based state - // variables - val ex1 = intercept[Exception] { - spark.read - .format("statestore") - .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) - .option(StateSourceOptions.STATE_VAR_NAME, "countState") - .option(StateSourceOptions.READ_CHANGE_FEED, "true") - .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) - .load() + testWithChangelogCheckpointingEnabled("state data source cdf integration - " + + "value state with single variable and TTL") { + withTempDir { tempDir => + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithTTL(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = tempDir.getAbsolutePath), + AddData(inputData, "a"), + AddData(inputData, "b"), + Execute { _ => + // wait for the batch to run since we are using processing time + Thread.sleep(5000) + }, + StopStream + ) + + val stateReaderDf = spark.read + .format("statestore") + .option(StateSourceOptions.PATH, tempDir.getAbsolutePath) + .option(StateSourceOptions.STATE_VAR_NAME, "countState") + .option(StateSourceOptions.READ_CHANGE_FEED, true) + .option(StateSourceOptions.CHANGE_START_BATCH_ID, 0) + .load() + + val resultDf = stateReaderDf.selectExpr( + "key.value", "value.value", "value.ttlExpirationMs", "partition_id") + + var count = 0L + resultDf.collect().foreach { row => + count = count + 1 + assert(row.getLong(2) > 0) } - assert(ex1.isInstanceOf[StateDataSourceConflictOptions]) + + // verify that 2 state rows are present + assert(count === 2) + + val answerDf = stateReaderDf.selectExpr( + "change_type", + "key.value AS groupingKey", + "value.value.value AS valueId", "partition_id") + checkAnswer(answerDf, + Seq(Row("update", "a", 1L, 0), Row("update", "b", 1L, 1))) } } } From a116a5bf708dbd2e0efc0b1f63f3f655d3e830da Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 26 Sep 2024 08:37:04 -0400 Subject: [PATCH 092/250] [SPARK-49416][CONNECT][SQL] Add Shared DataStreamReader interface ### What changes were proposed in this pull request? This PR adds a shared DataStreamReader to sql. ### Why are the changes needed? We are creating a unified Scala interface for sql. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48213 from hvanhovell/SPARK-49416. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/SparkSession.scala | 10 +- .../sql/streaming/DataStreamReader.scala | 295 ++++------------ .../CheckConnectJvmClientCompatibility.scala | 8 +- project/MimaExcludes.scala | 1 + .../spark/sql/api/DataStreamReader.scala | 297 ++++++++++++++++ .../apache/spark/sql/api/SparkSession.scala | 11 + .../org/apache/spark/sql/SparkSession.scala | 10 +- .../sql/streaming/DataStreamReader.scala | 325 ++++-------------- 8 files changed, 438 insertions(+), 519 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamReader.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 5313369a2c987..1b41566ca1d1d 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -209,15 +209,7 @@ class SparkSession private[sql] ( /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(this) - /** - * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. - * {{{ - * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") - * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") - * }}} - * - * @since 3.5.0 - */ + /** @inheritdoc */ def readStream: DataStreamReader = new DataStreamReader(this) lazy val streams: StreamingQueryManager = new StreamingQueryManager(this) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 789425c9daea1..2ff34a6343644 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -21,11 +21,9 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.connect.proto.Read.DataSource -import org.apache.spark.internal.Logging -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Dataset -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder +import org.apache.spark.sql.{api, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.connect.ConnectConversions._ +import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.types.StructType /** @@ -35,101 +33,49 @@ import org.apache.spark.sql.types.StructType * @since 3.5.0 */ @Evolving -final class DataStreamReader private[sql] (sparkSession: SparkSession) extends Logging { +final class DataStreamReader private[sql] (sparkSession: SparkSession) + extends api.DataStreamReader { - /** - * Specifies the input data source format. - * - * @since 3.5.0 - */ - def format(source: String): DataStreamReader = { + private val sourceBuilder = DataSource.newBuilder() + + /** @inheritdoc */ + def format(source: String): this.type = { sourceBuilder.setFormat(source) this } - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can skip - * the schema inference step, and thus speed up data loading. - * - * @since 3.5.0 - */ - def schema(schema: StructType): DataStreamReader = { + /** @inheritdoc */ + def schema(schema: StructType): this.type = { if (schema != null) { sourceBuilder.setSchema(schema.json) // Use json. DDL does not retail all the attributes. } this } - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) - * can infer the input schema automatically from data. By specifying the schema here, the - * underlying data source can skip the schema inference step, and thus speed up data loading. - * - * @since 3.5.0 - */ - def schema(schemaString: String): DataStreamReader = { + /** @inheritdoc */ + override def schema(schemaString: String): this.type = { sourceBuilder.setSchema(schemaString) this } - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: String): DataStreamReader = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { sourceBuilder.putOptions(key, value) this } - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Long): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 3.5.0 - */ - def option(key: String, value: Double): DataStreamReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.options(options.asJava) - this } - /** - * (Java-specific) Adds input options for the underlying data source. - * - * @since 3.5.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = { sourceBuilder.putAllOptions(options) this } - /** - * Loads input data stream in as a `DataFrame`, for data streams that don't require a path (e.g. - * external key-value stores). - * - * @since 3.5.0 - */ + /** @inheritdoc */ def load(): DataFrame = { sparkSession.newDataFrame { relationBuilder => relationBuilder.getReadBuilder @@ -138,120 +84,14 @@ final class DataStreamReader private[sql] (sparkSession: SparkSession) extends L } } - /** - * Loads input in as a `DataFrame`, for data streams that read from some path. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def load(path: String): DataFrame = { sourceBuilder.clearPaths() sourceBuilder.addPaths(path) load() } - /** - * Loads a JSON file stream and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
    - * - * You can find the JSON-specific options for reading JSON file stream in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def json(path: String): DataFrame = { - format("json").load(path) - } - - /** - * Loads a CSV file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
    - * - * You can find the CSV-specific options for reading CSV file stream in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def csv(path: String): DataFrame = format("csv").load(path) - - /** - * Loads a XML file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
    - * - * You can find the XML-specific options for reading XML file stream in - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = format("xml").load(path) - - /** - * Loads a ORC file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
    - * - * ORC-specific option(s) for reading ORC file stream can be found in Data - * Source Option in the version you use. - * - * @since 3.5.0 - */ - def orc(path: String): DataFrame = format("orc").load(path) - - /** - * Loads a Parquet file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
    - * - * Parquet-specific option(s) for reading Parquet file stream can be found in Data - * Source Option in the version you use. - * - * @since 3.5.0 - */ - def parquet(path: String): DataFrame = format("parquet").load(path) - - /** - * Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should - * support streaming mode. - * @param tableName - * The name of the table - * @since 3.5.0 - */ + /** @inheritdoc */ def table(tableName: String): DataFrame = { require(tableName != null, "The table name can't be null") sparkSession.newDataFrame { builder => @@ -263,59 +103,44 @@ final class DataStreamReader private[sql] (sparkSession: SparkSession) extends L } } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. The text files must be encoded - * as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.readStream.text("/path/to/directory/") - * - * // Java: - * spark.readStream().text("/path/to/directory/") - * }}} - * - * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): - * sets the maximum number of new files to be considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to - * be considered in every trigger.
    - * - * You can find the text-specific options for reading text files in - * Data Source Option in the version you use. - * - * @since 3.5.0 - */ - def text(path: String): DataFrame = format("text").load(path) - - /** - * Loads text file(s) and returns a `Dataset` of String. The underlying schema of the Dataset - * contains a single string column named "value". The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text file is a new element in the resulting Dataset. For - * example: - * {{{ - * // Scala: - * spark.readStream.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.readStream().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataStreamReader.text`. - * - * @param path - * input path - * @since 3.5.0 - */ - def textFile(path: String): Dataset[String] = { - text(path).select("value").as[String](StringEncoder) + override protected def assertNoSpecifiedSchema(operation: String): Unit = { + if (sourceBuilder.hasSchema) { + throw DataTypeErrors.userSpecifiedSchemaUnsupportedError(operation) + } } - private val sourceBuilder = DataSource.newBuilder() + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant overrides. + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) + + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) + + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) + + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) + + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) + + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) + + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 16f6983efb187..c8776af18a14a 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -304,7 +304,13 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[DirectMissingMethodProblem]( "org.apache.spark.sql.DataFrameReader.validateJsonSchema"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.sql.DataFrameReader.validateXmlSchema")) + "org.apache.spark.sql.DataFrameReader.validateXmlSchema"), + + // Protected DataStreamReader methods... + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.DataStreamReader.validateJsonSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.streaming.DataStreamReader.validateXmlSchema")) checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9a89ebb4797c9..0bd0121e6e141 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -179,6 +179,7 @@ object MimaExcludes { // SPARK-49282: Shared SparkSessionBuilder ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession$Builder"), ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ + loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") // Default exclude rules diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamReader.scala new file mode 100644 index 0000000000000..219ecb77d4033 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataStreamReader.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import scala.jdk.CollectionConverters._ + +import _root_.java + +import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.types.StructType + +/** + * Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems, + * key-value stores, etc). Use `SparkSession.readStream` to access this. + * + * @since 2.0.0 + */ +@Evolving +abstract class DataStreamReader { + + /** + * Specifies the input data source format. + * + * @since 2.0.0 + */ + def format(source: String): this.type + + /** + * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema + * automatically from data. By specifying the schema here, the underlying data source can skip + * the schema inference step, and thus speed up data loading. + * + * @since 2.0.0 + */ + def schema(schema: StructType): this.type + + /** + * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) + * can infer the input schema automatically from data. By specifying the schema here, the + * underlying data source can skip the schema inference step, and thus speed up data loading. + * + * @since 2.3.0 + */ + def schema(schemaString: String): this.type = { + schema(StructType.fromDDL(schemaString)) + } + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: String): this.type + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Boolean): this.type = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Long): this.type = option(key, value.toString) + + /** + * Adds an input option for the underlying data source. + * + * @since 2.0.0 + */ + def option(key: String, value: Double): this.type = option(key, value.toString) + + /** + * (Scala-specific) Adds input options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: scala.collection.Map[String, String]): this.type + + /** + * (Java-specific) Adds input options for the underlying data source. + * + * @since 2.0.0 + */ + def options(options: java.util.Map[String, String]): this.type = { + this.options(options.asScala) + this + } + + /** + * Loads input data stream in as a `DataFrame`, for data streams that don't require a path (e.g. + * external key-value stores). + * + * @since 2.0.0 + */ + def load(): Dataset[Row] + + /** + * Loads input in as a `DataFrame`, for data streams that read from some path. + * + * @since 2.0.0 + */ + def load(path: String): Dataset[Row] + + /** + * Loads a JSON file stream and returns the results as a `DataFrame`. + * + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `multiLine` option to true. + * + * This function goes through the input once to determine the input schema. If you know the + * schema in advance, use the version that specifies the schema to avoid the extra scan. + * + * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): + * sets the maximum number of new files to be considered in every trigger.
    • + *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to + * be considered in every trigger.
    + * + * You can find the JSON-specific options for reading JSON file stream in + * Data Source Option in the version you use. + * + * @since 2.0.0 + */ + def json(path: String): Dataset[Row] = { + validateJsonSchema() + format("json").load(path) + } + + /** + * Loads a CSV file stream and returns the result as a `DataFrame`. + * + * This function will go through the input once to determine the input schema if `inferSchema` + * is enabled. To avoid going through the entire data once, disable `inferSchema` option or + * specify the schema explicitly using `schema`. + * + * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): + * sets the maximum number of new files to be considered in every trigger.
    • + *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to + * be considered in every trigger.
    + * + * You can find the CSV-specific options for reading CSV file stream in + * Data Source Option in the version you use. + * + * @since 2.0.0 + */ + def csv(path: String): Dataset[Row] = format("csv").load(path) + + /** + * Loads a XML file stream and returns the result as a `DataFrame`. + * + * This function will go through the input once to determine the input schema if `inferSchema` + * is enabled. To avoid going through the entire data once, disable `inferSchema` option or + * specify the schema explicitly using `schema`. + * + * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): + * sets the maximum number of new files to be considered in every trigger.
    • + *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to + * be considered in every trigger.
    + * + * You can find the XML-specific options for reading XML file stream in + * Data Source Option in the version you use. + * + * @since 4.0.0 + */ + def xml(path: String): Dataset[Row] = { + validateXmlSchema() + format("xml").load(path) + } + + /** + * Loads a ORC file stream, returning the result as a `DataFrame`. + * + * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): + * sets the maximum number of new files to be considered in every trigger.
    • + *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to + * be considered in every trigger.
    + * + * ORC-specific option(s) for reading ORC file stream can be found in Data + * Source Option in the version you use. + * + * @since 2.3.0 + */ + def orc(path: String): Dataset[Row] = { + format("orc").load(path) + } + + /** + * Loads a Parquet file stream, returning the result as a `DataFrame`. + * + * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): + * sets the maximum number of new files to be considered in every trigger.
    • + *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to + * be considered in every trigger.
    + * + * Parquet-specific option(s) for reading Parquet file stream can be found in Data + * Source Option in the version you use. + * + * @since 2.0.0 + */ + def parquet(path: String): Dataset[Row] = { + format("parquet").load(path) + } + + /** + * Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should + * support streaming mode. + * @param tableName + * The name of the table + * @since 3.1.0 + */ + def table(tableName: String): Dataset[Row] + + /** + * Loads text files and returns a `DataFrame` whose schema starts with a string column named + * "value", and followed by partitioned columns if there are any. The text files must be encoded + * as UTF-8. + * + * By default, each line in the text files is a new row in the resulting DataFrame. For example: + * {{{ + * // Scala: + * spark.readStream.text("/path/to/directory/") + * + * // Java: + * spark.readStream().text("/path/to/directory/") + * }}} + * + * You can set the following option(s):
    • `maxFilesPerTrigger` (default: no max limit): + * sets the maximum number of new files to be considered in every trigger.
    • + *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files to + * be considered in every trigger.
    + * + * You can find the text-specific options for reading text files in + * Data Source Option in the version you use. + * + * @since 2.0.0 + */ + def text(path: String): Dataset[Row] = format("text").load(path) + + /** + * Loads text file(s) and returns a `Dataset` of String. The underlying schema of the Dataset + * contains a single string column named "value". The text files must be encoded as UTF-8. + * + * If the directory structure of the text files contains partitioning information, those are + * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. + * + * By default, each line in the text file is a new element in the resulting Dataset. For + * example: + * {{{ + * // Scala: + * spark.readStream.textFile("/path/to/spark/README.md") + * + * // Java: + * spark.readStream().textFile("/path/to/spark/README.md") + * }}} + * + * You can set the text-specific options as specified in `DataStreamReader.text`. + * + * @param path + * input path + * @since 2.1.0 + */ + def textFile(path: String): Dataset[String] = { + assertNoSpecifiedSchema("textFile") + text(path).select("value").as(Encoders.STRING) + } + + protected def assertNoSpecifiedSchema(operation: String): Unit + + protected def validateJsonSchema(): Unit = () + + protected def validateXmlSchema(): Unit = () + +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 2295c153cd51c..0f73a94c3c4a4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -506,6 +506,17 @@ abstract class SparkSession extends Serializable with Closeable { */ def read: DataFrameReader + /** + * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. + * {{{ + * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") + * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") + * }}} + * + * @since 2.0.0 + */ + def readStream: DataStreamReader + /** * (Scala-specific) Implicit methods available in Scala for converting common Scala objects into * `DataFrame`s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index fe139d629eb24..983cc24718fd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -739,15 +739,7 @@ class SparkSession private( /** @inheritdoc */ def read: DataFrameReader = new DataFrameReader(self) - /** - * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`. - * {{{ - * sparkSession.readStream.parquet("/path/to/directory/of/parquet/files") - * sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files") - * }}} - * - * @since 2.0.0 - */ + /** @inheritdoc */ def readStream: DataStreamReader = new DataStreamReader(self) // scalastyle:off diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 24d769fc8fc87..f42d8b667ab12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -22,12 +22,12 @@ import java.util.Locale import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving -import org.apache.spark.internal.Logging -import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.{api, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} +import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.errors.QueryCompilationErrors @@ -49,25 +49,15 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * @since 2.0.0 */ @Evolving -final class DataStreamReader private[sql](sparkSession: SparkSession) extends Logging { - /** - * Specifies the input data source format. - * - * @since 2.0.0 - */ - def format(source: String): DataStreamReader = { +final class DataStreamReader private[sql](sparkSession: SparkSession) extends api.DataStreamReader { + /** @inheritdoc */ + def format(source: String): this.type = { this.source = source this } - /** - * Specifies the input schema. Some data sources (e.g. JSON) can infer the input schema - * automatically from data. By specifying the schema here, the underlying data source can - * skip the schema inference step, and thus speed up data loading. - * - * @since 2.0.0 - */ - def schema(schema: StructType): DataStreamReader = { + /** @inheritdoc */ + def schema(schema: StructType): this.type = { if (schema != null) { val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] this.userSpecifiedSchema = Option(replaced) @@ -75,75 +65,19 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo this } - /** - * Specifies the schema by using the input DDL-formatted string. Some data sources (e.g. JSON) can - * infer the input schema automatically from data. By specifying the schema here, the underlying - * data source can skip the schema inference step, and thus speed up data loading. - * - * @since 2.3.0 - */ - def schema(schemaString: String): DataStreamReader = { - schema(StructType.fromDDL(schemaString)) - } - - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: String): DataStreamReader = { + /** @inheritdoc */ + def option(key: String, value: String): this.type = { this.extraOptions += (key -> value) this } - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Boolean): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Long): DataStreamReader = option(key, value.toString) - - /** - * Adds an input option for the underlying data source. - * - * @since 2.0.0 - */ - def option(key: String, value: Double): DataStreamReader = option(key, value.toString) - - /** - * (Scala-specific) Adds input options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: scala.collection.Map[String, String]): DataStreamReader = { + /** @inheritdoc */ + def options(options: scala.collection.Map[String, String]): this.type = { this.extraOptions ++= options this } - /** - * (Java-specific) Adds input options for the underlying data source. - * - * @since 2.0.0 - */ - def options(options: java.util.Map[String, String]): DataStreamReader = { - this.options(options.asScala) - this - } - - - /** - * Loads input data stream in as a `DataFrame`, for data streams that don't require a path - * (e.g. external key-value stores). - * - * @since 2.0.0 - */ + /** @inheritdoc */ def load(): DataFrame = loadInternal(None) private def loadInternal(path: Option[String]): DataFrame = { @@ -205,11 +139,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo } } - /** - * Loads input in as a `DataFrame`, for data streams that read from some path. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def load(path: String): DataFrame = { if (!sparkSession.sessionState.conf.legacyPathOptionBehavior && extraOptions.contains("path")) { @@ -218,133 +148,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo loadInternal(Some(path)) } - /** - * Loads a JSON file stream and returns the results as a `DataFrame`. - * - * JSON Lines (newline-delimited JSON) is supported by - * default. For JSON (one record per file), set the `multiLine` option to true. - * - * This function goes through the input once to determine the input schema. If you know the - * schema in advance, use the version that specifies the schema to avoid the extra scan. - * - * You can set the following option(s): - *
      - *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
    • - *
    - * - * You can find the JSON-specific options for reading JSON file stream in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def json(path: String): DataFrame = { - userSpecifiedSchema.foreach(checkJsonSchema) - format("json").load(path) - } - - /** - * Loads a CSV file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s): - *
      - *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
    • - *
    - * - * You can find the CSV-specific options for reading CSV file stream in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def csv(path: String): DataFrame = format("csv").load(path) - - /** - * Loads a XML file stream and returns the result as a `DataFrame`. - * - * This function will go through the input once to determine the input schema if `inferSchema` - * is enabled. To avoid going through the entire data once, disable `inferSchema` option or - * specify the schema explicitly using `schema`. - * - * You can set the following option(s): - *
      - *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
    • - *
    - * - * You can find the XML-specific options for reading XML file stream in - * - * Data Source Option in the version you use. - * - * @since 4.0.0 - */ - def xml(path: String): DataFrame = { - userSpecifiedSchema.foreach(checkXmlSchema) - format("xml").load(path) - } - - /** - * Loads a ORC file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s): - *
      - *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
    • - *
    - * - * ORC-specific option(s) for reading ORC file stream can be found in - * - * Data Source Option in the version you use. - * - * @since 2.3.0 - */ - def orc(path: String): DataFrame = { - format("orc").load(path) - } - - /** - * Loads a Parquet file stream, returning the result as a `DataFrame`. - * - * You can set the following option(s): - *
      - *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
    • - *
    - * - * Parquet-specific option(s) for reading Parquet file stream can be found in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def parquet(path: String): DataFrame = { - format("parquet").load(path) - } - - /** - * Define a Streaming DataFrame on a Table. The DataSource corresponding to the table should - * support streaming mode. - * @param tableName The name of the table - * @since 3.1.0 - */ + /** @inheritdoc */ def table(tableName: String): DataFrame = { require(tableName != null, "The table name can't be null") val identifier = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) @@ -356,65 +160,56 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo isStreaming = true)) } - /** - * Loads text files and returns a `DataFrame` whose schema starts with a string column named - * "value", and followed by partitioned columns if there are any. - * The text files must be encoded as UTF-8. - * - * By default, each line in the text files is a new row in the resulting DataFrame. For example: - * {{{ - * // Scala: - * spark.readStream.text("/path/to/directory/") - * - * // Java: - * spark.readStream().text("/path/to/directory/") - * }}} - * - * You can set the following option(s): - *
      - *
    • `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be - * considered in every trigger.
    • - *
    • `maxBytesPerTrigger` (default: no max limit): sets the maximum total size of new files - * to be considered in every trigger.
    • - *
    - * - * You can find the text-specific options for reading text files in - * - * Data Source Option in the version you use. - * - * @since 2.0.0 - */ - def text(path: String): DataFrame = format("text").load(path) - - /** - * Loads text file(s) and returns a `Dataset` of String. The underlying schema of the Dataset - * contains a single string column named "value". - * The text files must be encoded as UTF-8. - * - * If the directory structure of the text files contains partitioning information, those are - * ignored in the resulting Dataset. To include partitioning information as columns, use `text`. - * - * By default, each line in the text file is a new element in the resulting Dataset. For example: - * {{{ - * // Scala: - * spark.readStream.textFile("/path/to/spark/README.md") - * - * // Java: - * spark.readStream().textFile("/path/to/spark/README.md") - * }}} - * - * You can set the text-specific options as specified in `DataStreamReader.text`. - * - * @param path input path - * @since 2.1.0 - */ - def textFile(path: String): Dataset[String] = { + override protected def assertNoSpecifiedSchema(operation: String): Unit = { if (userSpecifiedSchema.nonEmpty) { - throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError("textFile") + throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation) } - text(path).select("value").as[String](sparkSession.implicits.newStringEncoder) } + override protected def validateJsonSchema(): Unit = userSpecifiedSchema.foreach(checkJsonSchema) + + override protected def validateXmlSchema(): Unit = userSpecifiedSchema.foreach(checkXmlSchema) + + /////////////////////////////////////////////////////////////////////////////////////// + // Covariant overrides. + /////////////////////////////////////////////////////////////////////////////////////// + + /** @inheritdoc */ + override def schema(schemaString: String): this.type = super.schema(schemaString) + + /** @inheritdoc */ + override def option(key: String, value: Boolean): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Long): this.type = super.option(key, value) + + /** @inheritdoc */ + override def option(key: String, value: Double): this.type = super.option(key, value) + + /** @inheritdoc */ + override def options(options: java.util.Map[String, String]): this.type = super.options(options) + + /** @inheritdoc */ + override def json(path: String): DataFrame = super.json(path) + + /** @inheritdoc */ + override def csv(path: String): DataFrame = super.csv(path) + + /** @inheritdoc */ + override def xml(path: String): DataFrame = super.xml(path) + + /** @inheritdoc */ + override def orc(path: String): DataFrame = super.orc(path) + + /** @inheritdoc */ + override def parquet(path: String): DataFrame = super.parquet(path) + + /** @inheritdoc */ + override def text(path: String): DataFrame = super.text(path) + + /** @inheritdoc */ + override def textFile(path: String): Dataset[String] = super.textFile(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// From 7b20e5841a856cd0d81821e330b3ec33098bb9be Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 26 Sep 2024 09:28:16 -0400 Subject: [PATCH 093/250] [SPARK-49286][CONNECT][SQL] Move Avro/Protobuf functions to sql/api ### What changes were proposed in this pull request? This PR moves avro and protobuf functions to sql/api. ### Why are the changes needed? We are creating a unified Scala SQL interface. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48258 from hvanhovell/SPARK-49286. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/avro/functions.scala | 93 ----- .../apache/spark/sql/protobuf/functions.scala | 324 ------------------ project/MimaExcludes.scala | 6 + .../org/apache/spark/sql/avro/functions.scala | 8 +- .../apache/spark/sql/protobuf/functions.scala | 82 +++-- 5 files changed, 50 insertions(+), 463 deletions(-) delete mode 100755 connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala delete mode 100644 connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala rename {connector/connect/client/jvm => sql/api}/src/main/scala/org/apache/spark/sql/avro/functions.scala (97%) rename {connector/connect/client/jvm => sql/api}/src/main/scala/org/apache/spark/sql/protobuf/functions.scala (90%) diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala deleted file mode 100755 index 828a609a10e9c..0000000000000 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.avro - -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.Column -import org.apache.spark.sql.internal.ExpressionUtils.{column, expression} - - -// scalastyle:off: object.name -object functions { -// scalastyle:on: object.name - - /** - * Converts a binary column of avro format into its corresponding catalyst value. The specified - * schema must match the read data, otherwise the behavior is undefined: it may fail or return - * arbitrary result. - * - * @param data the binary column. - * @param jsonFormatSchema the avro schema in JSON string format. - * - * @since 3.0.0 - */ - @Experimental - def from_avro( - data: Column, - jsonFormatSchema: String): Column = { - AvroDataToCatalyst(data, jsonFormatSchema, Map.empty) - } - - /** - * Converts a binary column of Avro format into its corresponding catalyst value. - * The specified schema must match actual schema of the read data, otherwise the behavior - * is undefined: it may fail or return arbitrary result. - * To deserialize the data with a compatible and evolved schema, the expected Avro schema can be - * set via the option avroSchema. - * - * @param data the binary column. - * @param jsonFormatSchema the avro schema in JSON string format. - * @param options options to control how the Avro record is parsed. - * - * @since 3.0.0 - */ - @Experimental - def from_avro( - data: Column, - jsonFormatSchema: String, - options: java.util.Map[String, String]): Column = { - AvroDataToCatalyst(data, jsonFormatSchema, options.asScala.toMap) - } - - /** - * Converts a column into binary of avro format. - * - * @param data the data column. - * - * @since 3.0.0 - */ - @Experimental - def to_avro(data: Column): Column = { - CatalystDataToAvro(data, None) - } - - /** - * Converts a column into binary of avro format. - * - * @param data the data column. - * @param jsonFormatSchema user-specified output avro schema in JSON string format. - * - * @since 3.0.0 - */ - @Experimental - def to_avro(data: Column, jsonFormatSchema: String): Column = { - CatalystDataToAvro(data, Some(jsonFormatSchema)) - } -} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala deleted file mode 100644 index 3b0def8fc73f7..0000000000000 --- a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.protobuf - -import scala.jdk.CollectionConverters._ - -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.Column -import org.apache.spark.sql.functions.lit -import org.apache.spark.sql.protobuf.utils.ProtobufUtils - -// scalastyle:off: object.name -object functions { -// scalastyle:on: object.name - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. The - * Protobuf definition is provided through Protobuf descriptor file. - * - * @param data - * the binary column. - * @param messageName - * the protobuf message name to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.4.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageName: String, - descFilePath: String, - options: java.util.Map[String, String]): Column = { - val descriptorFileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - from_protobuf(data, messageName, descriptorFileContent, options) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value.The - * Protobuf definition is provided through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.5.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageName: String, - binaryFileDescriptorSet: Array[Byte], - options: java.util.Map[String, String]): Column = { - Column.fnWithOptions( - "from_protobuf", - options.asScala.iterator, - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. The - * Protobuf definition is provided through Protobuf descriptor file. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @since 3.4.0 - */ - @Experimental - def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - from_protobuf(data, messageName, fileContent) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value.The - * Protobuf definition is provided through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @since 3.5.0 - */ - @Experimental - def from_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) - : Column = { - Column.fn( - "from_protobuf", - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the binary column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @since 3.4.0 - */ - @Experimental - def from_protobuf(data: Column, messageClassName: String): Column = { - Column.fn( - "from_protobuf", - data, - lit(messageClassName) - ) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the binary column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @param options - * @since 3.4.0 - */ - @Experimental - def from_protobuf( - data: Column, - messageClassName: String, - options: java.util.Map[String, String]): Column = { - Column.fnWithOptions( - "from_protobuf", - options.asScala.iterator, - data, - lit(messageClassName) - ) - } - - /** - * Converts a column into binary of protobuf format. The Protobuf definition is provided - * through Protobuf descriptor file. - * - * @param data - * the data column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - to_protobuf(data, messageName, descFilePath, Map.empty[String, String].asJava) - } - - /** - * Converts a column into binary of protobuf format.The Protobuf definition is provided - * through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * - * @since 3.5.0 - */ - @Experimental - def to_protobuf(data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]) - : Column = { - Column.fn( - "to_protobuf", - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - /** - * Converts a column into binary of protobuf format. The Protobuf definition is provided - * through Protobuf descriptor file. - * - * @param data - * the data column. - * @param messageName - * the protobuf MessageName to look for in descriptor file. - * @param descFilePath - * the protobuf descriptor file. - * @param options - * @since 3.4.0 - */ - @Experimental - def to_protobuf( - data: Column, - messageName: String, - descFilePath: String, - options: java.util.Map[String, String]): Column = { - val fileContent = ProtobufUtils.readDescriptorFileContent(descFilePath) - to_protobuf(data, messageName, fileContent, options) - } - - /** - * Converts a column into binary of protobuf format.The Protobuf definition is provided - * through Protobuf `FileDescriptorSet`. - * - * @param data - * the binary column. - * @param messageName - * the protobuf MessageName to look for in the descriptor set. - * @param binaryFileDescriptorSet - * Serialized Protobuf descriptor (`FileDescriptorSet`). Typically contents of file created - * using `protoc` with `--descriptor_set_out` and `--include_imports` options. - * @param options - * @since 3.5.0 - */ - @Experimental - def to_protobuf( - data: Column, - messageName: String, - binaryFileDescriptorSet: Array[Byte], - options: java.util.Map[String, String] - ): Column = { - Column.fnWithOptions( - "to_protobuf", - options.asScala.iterator, - data, - lit(messageName), - lit(binaryFileDescriptorSet) - ) - } - - /** - * Converts a column into binary of protobuf format. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the data column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageClassName: String): Column = { - Column.fn( - "to_protobuf", - data, - lit(messageClassName) - ) - } - - /** - * Converts a column into binary of protobuf format. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the data column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @param options - * @since 3.4.0 - */ - @Experimental - def to_protobuf(data: Column, messageClassName: String, options: java.util.Map[String, String]) - : Column = { - Column.fnWithOptions( - "to_protobuf", - options.asScala.iterator, - data, - lit(messageClassName) - ) - } -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 0bd0121e6e141..41f547a43b698 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -178,6 +178,12 @@ object MimaExcludes { // SPARK-49282: Shared SparkSessionBuilder ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.SparkSession$Builder"), + + // SPARK-49286: Avro/Protobuf functions in sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.avro.functions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.avro.functions$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.protobuf.functions"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.protobuf.functions$"), ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/avro/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/avro/functions.scala similarity index 97% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/avro/functions.scala rename to sql/api/src/main/scala/org/apache/spark/sql/avro/functions.scala index e80bccfee4c9c..fffad557aca5e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/avro/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/avro/functions.scala @@ -37,7 +37,7 @@ object functions { * @param jsonFormatSchema * the avro schema in JSON string format. * - * @since 3.5.0 + * @since 3.0.0 */ @Experimental def from_avro(data: Column, jsonFormatSchema: String): Column = { @@ -57,7 +57,7 @@ object functions { * @param options * options to control how the Avro record is parsed. * - * @since 3.5.0 + * @since 3.0.0 */ @Experimental def from_avro( @@ -73,7 +73,7 @@ object functions { * @param data * the data column. * - * @since 3.5.0 + * @since 3.0.0 */ @Experimental def to_avro(data: Column): Column = { @@ -88,7 +88,7 @@ object functions { * @param jsonFormatSchema * user-specified output avro schema in JSON string format. * - * @since 3.5.0 + * @since 3.0.0 */ @Experimental def to_avro(data: Column, jsonFormatSchema: String): Column = { diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala similarity index 90% rename from connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala rename to sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala index 2c953fbd07b9e..ea9e3c429d65a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/protobuf/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.protobuf import java.io.FileNotFoundException import java.nio.file.{Files, NoSuchFileException, Paths} -import java.util.Collections import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -30,7 +29,7 @@ import org.apache.spark.sql.functions.lit // scalastyle:off: object.name object functions { - // scalastyle:on: object.name +// scalastyle:on: object.name /** * Converts a binary column of Protobuf format into its corresponding catalyst value. The @@ -44,7 +43,7 @@ object functions { * The Protobuf descriptor file. This file is usually created using `protoc` with * `--descriptor_set_out` and `--include_imports` options. * @param options - * @since 3.5.0 + * @since 3.4.0 */ @Experimental def from_protobuf( @@ -52,8 +51,8 @@ object functions { messageName: String, descFilePath: String, options: java.util.Map[String, String]): Column = { - val binaryFileDescSet = readDescriptorFileContent(descFilePath) - from_protobuf(data, messageName, binaryFileDescSet, options) + val descriptorFileContent = readDescriptorFileContent(descFilePath) + from_protobuf(data, messageName, descriptorFileContent, options) } /** @@ -95,31 +94,12 @@ object functions { * @param descFilePath * The Protobuf descriptor file. This file is usually created using `protoc` with * `--descriptor_set_out` and `--include_imports` options. - * @since 3.5.0 + * @since 3.4.0 */ @Experimental def from_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - from_protobuf(data, messageName, descFilePath, emptyOptions) - } - - /** - * Converts a binary column of Protobuf format into its corresponding catalyst value. - * `messageClassName` points to Protobuf Java class. The jar containing Java class should be - * shaded. Specifically, `com.google.protobuf.*` should be shaded to - * `org.sparkproject.spark_protobuf.protobuf.*`. - * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from - * Protobuf files. - * - * @param data - * the binary column. - * @param messageClassName - * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. - * The jar with these classes needs to be shaded as described above. - * @since 3.5.0 - */ - @Experimental - def from_protobuf(data: Column, messageClassName: String): Column = { - Column.fn("from_protobuf", data, lit(messageClassName)) + val fileContent = readDescriptorFileContent(descFilePath) + from_protobuf(data, messageName, fileContent) } /** @@ -140,7 +120,27 @@ object functions { data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]): Column = { - from_protobuf(data, messageName, binaryFileDescriptorSet, emptyOptions) + Column.fn("from_protobuf", data, lit(messageName), lit(binaryFileDescriptorSet)) + } + + /** + * Converts a binary column of Protobuf format into its corresponding catalyst value. + * `messageClassName` points to Protobuf Java class. The jar containing Java class should be + * shaded. Specifically, `com.google.protobuf.*` should be shaded to + * `org.sparkproject.spark_protobuf.protobuf.*`. + * https://github.com/rangadi/shaded-protobuf-classes is useful to create shaded jar from + * Protobuf files. + * + * @param data + * the binary column. + * @param messageClassName + * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. + * The jar with these classes needs to be shaded as described above. + * @since 3.4.0 + */ + @Experimental + def from_protobuf(data: Column, messageClassName: String): Column = { + Column.fn("from_protobuf", data, lit(messageClassName)) } /** @@ -157,7 +157,7 @@ object functions { * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. * The jar with these classes needs to be shaded as described above. * @param options - * @since 3.5.0 + * @since 3.4.0 */ @Experimental def from_protobuf( @@ -178,11 +178,11 @@ object functions { * @param descFilePath * The Protobuf descriptor file. This file is usually created using `protoc` with * `--descriptor_set_out` and `--include_imports` options. - * @since 3.5.0 + * @since 3.4.0 */ @Experimental def to_protobuf(data: Column, messageName: String, descFilePath: String): Column = { - to_protobuf(data, messageName, descFilePath, emptyOptions) + to_protobuf(data, messageName, descFilePath, Map.empty[String, String].asJava) } /** @@ -204,7 +204,7 @@ object functions { data: Column, messageName: String, binaryFileDescriptorSet: Array[Byte]): Column = { - to_protobuf(data, messageName, binaryFileDescriptorSet, emptyOptions) + Column.fn("to_protobuf", data, lit(messageName), lit(binaryFileDescriptorSet)) } /** @@ -216,10 +216,9 @@ object functions { * @param messageName * the protobuf MessageName to look for in descriptor file. * @param descFilePath - * The Protobuf descriptor file. This file is usually created using `protoc` with - * `--descriptor_set_out` and `--include_imports` options. + * the protobuf descriptor file. * @param options - * @since 3.5.0 + * @since 3.4.0 */ @Experimental def to_protobuf( @@ -227,8 +226,8 @@ object functions { messageName: String, descFilePath: String, options: java.util.Map[String, String]): Column = { - val binaryFileDescriptorSet = readDescriptorFileContent(descFilePath) - to_protobuf(data, messageName, binaryFileDescriptorSet, options) + val fileContent = readDescriptorFileContent(descFilePath) + to_protobuf(data, messageName, fileContent, options) } /** @@ -271,7 +270,7 @@ object functions { * @param messageClassName * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. * The jar with these classes needs to be shaded as described above. - * @since 3.5.0 + * @since 3.4.0 */ @Experimental def to_protobuf(data: Column, messageClassName: String): Column = { @@ -291,7 +290,7 @@ object functions { * The full name for Protobuf Java class. E.g. com.example.protos.ExampleEvent. * The jar with these classes needs to be shaded as described above. * @param options - * @since 3.5.0 + * @since 3.4.0 */ @Experimental def to_protobuf( @@ -301,8 +300,6 @@ object functions { Column.fnWithOptions("to_protobuf", options.asScala.iterator, data, lit(messageClassName)) } - private def emptyOptions: java.util.Map[String, String] = Collections.emptyMap[String, String]() - // This method is copied from org.apache.spark.sql.protobuf.util.ProtobufUtils private def readDescriptorFileContent(filePath: String): Array[Byte] = { try { @@ -312,7 +309,8 @@ object functions { throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) case ex: NoSuchFileException => throw CompilationErrors.cannotFindDescriptorFileError(filePath, ex) - case NonFatal(ex) => throw CompilationErrors.descriptorParseError(ex) + case NonFatal(ex) => + throw CompilationErrors.descriptorParseError(ex) } } } From 218051a566c78244573077a53d4be43ccc01311d Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 26 Sep 2024 22:30:33 +0800 Subject: [PATCH 094/250] [MINOR][SQL][TESTS] Use `formatString.format(value)` instead of `value.formatted(formatString)` ### What changes were proposed in this pull request? The pr aims to use `formatString.format(value)` instead of `value.formatted(formatString)` for eliminating Warning. ### Why are the changes needed? image ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48262 from panbingkun/minor_formatted. Authored-by: panbingkun Signed-off-by: yangjie01 --- .../columnar/compression/CompressionSchemeBenchmark.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala index 05ae575305299..290cfd56b8bce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/CompressionSchemeBenchmark.scala @@ -91,7 +91,7 @@ object CompressionSchemeBenchmark extends BenchmarkBase with AllCompressionSchem schemes.filter(_.supports(tpe)).foreach { scheme => val (compressFunc, compressionRatio, buf) = prepareEncodeInternal(count, tpe, scheme, input) - val label = s"${getFormattedClassName(scheme)}(${compressionRatio.formatted("%.3f")})" + val label = s"${getFormattedClassName(scheme)}(${"%.3f".format(compressionRatio)})" benchmark.addCase(label)({ i: Int => for (n <- 0L until iters) { From 87b5ffb220824449d943cf3c7fff3eb3682526fc Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 26 Sep 2024 07:37:22 -0700 Subject: [PATCH 095/250] [SPARK-49797][INFRA] Align the running OS image of `maven_test.yml` to `ubuntu-latest` ### What changes were proposed in this pull request? The pr aims to align the running OS image of `maven_test.yml` to `ubuntu-latest` (from `ubuntu-22.04` to `ubuntu-24.04`) ### Why are the changes needed? https://github.com/actions/runner-images/releases/tag/ubuntu24%2F20240922.1 image After https://github.com/actions/runner-images/issues/10636, `ubuntu-latest` has already pointed to `ubuntu-24.04` instead of `ubuntu-22.04`. image I have checked all tasks running on `Ubuntu OS` (except for the 2 related to `TPCDS`), and they are all using `ubuntu-latest`. Currently, only `maven_test.yml` is using `ubuntu-22.04`. Let's align it. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48263 from panbingkun/SPARK-49797. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .github/workflows/maven_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 82b72bd7e91d2..dd089d665d6e3 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -40,7 +40,7 @@ on: description: OS to run this build. required: false type: string - default: ubuntu-22.04 + default: ubuntu-latest envs: description: Additional environment variables to set when running the tests. Should be in JSON format. required: false From 624eda5030eb3a4a426a1c225952af40dba30d1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vladan=20Vasi=C4=87?= Date: Thu, 26 Sep 2024 23:00:22 +0800 Subject: [PATCH 096/250] [SPARK-49444][SQL] Modified UnivocityParser to throw runtime exceptions caused by ArrayIndexOutOfBounds with more user-oriented messages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? I propose to catch and rethrow runtime `ArrayIndexOutOfBounds` exceptions in the `UnivocityParser` class - `parse` method, but with more user-oriented messages. Instead of throwing exceptions in the original format, I propose to inform the users which csv record caused the error. ### Why are the changes needed? Proper informing of users' errors improves user experience. Instead of throwing `ArrayIndexOutOfBounds` exception without clear reason why it happened, proposed changes throw `SparkRuntimeException` with the message that includes original csv line which caused the error. ### Does this PR introduce _any_ user-facing change? This PR introduces a user-facing change which happens when `UnivocityParser` parses malformed csv line with from the input. More specifically, the change is reproduces in the test case within `UnivocityParserSuite` when user specifies `maxColumns` in parser options and parsed csv record has more columns. Instead of resulting in `ArrayIndexOutOfBounds` like mentioned in the HMR ticket, users now get `SparkRuntimeException` with message that contains the input line which caused the error. ### How was this patch tested? This patch was tested in `UnivocityParserSuite`. Test named "Array index out of bounds when parsing CSV with more columns than expected" covers this patch. Additionally, test for bad records in `UnivocityParser`'s `PERMISSIVE` mode is added to confirm that `BadRecordException` is being thrown properly. ### Was this patch authored or co-authored using generative AI tooling? No Closes #47906 from vladanvasi-db/vladanvasi-db/univocity-parser-index-out-of-bounds-handling. Authored-by: Vladan Vasić Signed-off-by: Wenchen Fan --- .../sql/catalyst/csv/UnivocityParser.scala | 19 ++++++++- .../catalyst/csv/UnivocityParserSuite.scala | 39 ++++++++++++++++++- .../test/resources/test-data/more-columns.csv | 1 + .../apache/spark/sql/CsvFunctionsSuite.scala | 5 ++- .../execution/datasources/csv/CSVSuite.scala | 34 ++++++++++++++++ 5 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/resources/test-data/more-columns.csv diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala index ccc8f30a9a9c3..0fd0601803a6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala @@ -21,9 +21,10 @@ import java.io.InputStream import scala.util.control.NonFatal +import com.univocity.parsers.common.TextParsingException import com.univocity.parsers.csv.CsvParser -import org.apache.spark.SparkUpgradeException +import org.apache.spark.{SparkRuntimeException, SparkUpgradeException} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, OrderedFilters} import org.apache.spark.sql.catalyst.expressions.{ExprUtils, GenericInternalRow} @@ -294,6 +295,20 @@ class UnivocityParser( } } + private def parseLine(line: String): Array[String] = { + try { + tokenizer.parseLine(line) + } + catch { + case e: TextParsingException if e.getCause.isInstanceOf[ArrayIndexOutOfBoundsException] => + throw new SparkRuntimeException( + errorClass = "MALFORMED_CSV_RECORD", + messageParameters = Map("badRecord" -> line), + cause = e + ) + } + } + /** * Parses a single CSV string and turns it into either one resulting row or no row (if the * the record is malformed). @@ -306,7 +321,7 @@ class UnivocityParser( (_: String) => Some(InternalRow.empty) } else { // parse if the columnPruning is disabled or requiredSchema is nonEmpty - (input: String) => convert(tokenizer.parseLine(input)) + (input: String) => convert(parseLine(input)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala index 514b529ea8cc0..7974bf68bdd31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/UnivocityParserSuite.scala @@ -23,12 +23,12 @@ import java.util.{Locale, TimeZone} import org.apache.commons.lang3.time.FastDateFormat -import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException} +import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.util.{BadRecordException, DateTimeUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.sources.{EqualTo, Filter, StringStartsWith} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -323,6 +323,41 @@ class UnivocityParserSuite extends SparkFunSuite with SQLHelper { parameters = Map("fieldName" -> "`i`", "fields" -> "")) } + test("Bad records test in permissive mode") { + def checkBadRecord( + input: String = "1,a", + dataSchema: StructType = StructType.fromDDL("i INTEGER, s STRING, d DOUBLE"), + requiredSchema: StructType = StructType.fromDDL("i INTEGER, s STRING"), + options: Map[String, String] = Map("mode" -> "PERMISSIVE")): BadRecordException = { + val csvOptions = new CSVOptions(options, false, "UTC") + val parser = new UnivocityParser(dataSchema, requiredSchema, csvOptions, Seq()) + intercept[BadRecordException] { + parser.parse(input) + } + } + + // Bad record exception caused by conversion error + checkBadRecord(input = "1.5,a,10.3") + + // Bad record exception caused by insufficient number of columns + checkBadRecord(input = "2") + } + + test("Array index out of bounds when parsing CSV with more columns than expected") { + val input = "1,string,3.14,5,7" + val dataSchema: StructType = StructType.fromDDL("i INTEGER, a STRING") + val requiredSchema: StructType = StructType.fromDDL("i INTEGER, a STRING") + val options = new CSVOptions(Map("maxColumns" -> "2"), false, "UTC") + val filters = Seq() + val parser = new UnivocityParser(dataSchema, requiredSchema, options, filters) + checkError( + exception = intercept[SparkRuntimeException] { + parser.parse(input) + }, + condition = "MALFORMED_CSV_RECORD", + parameters = Map("badRecord" -> "1,string,3.14,5,7")) + } + test("SPARK-30960: parse date/timestamp string with legacy format") { def check(parser: UnivocityParser): Unit = { // The legacy format allows 1 or 2 chars for some fields. diff --git a/sql/core/src/test/resources/test-data/more-columns.csv b/sql/core/src/test/resources/test-data/more-columns.csv new file mode 100644 index 0000000000000..06db38f0a145a --- /dev/null +++ b/sql/core/src/test/resources/test-data/more-columns.csv @@ -0,0 +1 @@ +1,3.14,string,5,7 \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index 6589282fd3a51..e6907b8656482 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -24,7 +24,8 @@ import java.util.Locale import scala.jdk.CollectionConverters._ -import org.apache.spark.{SparkException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.{SparkException, SparkRuntimeException, + SparkUnsupportedOperationException, SparkUpgradeException} import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -234,7 +235,7 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("str", StringType) val options = Map("maxCharsPerColumn" -> "2") - val exception = intercept[SparkException] { + val exception = intercept[SparkRuntimeException] { df.select(from_csv($"value", schema, options)).collect() }.getCause.getMessage diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index e2d1d9b05c3c2..023f401516dc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -85,6 +85,7 @@ abstract class CSVSuite private val badAfterGoodFile = "test-data/bad_after_good.csv" private val malformedRowFile = "test-data/malformedRow.csv" private val charFile = "test-data/char.csv" + private val moreColumnsFile = "test-data/more-columns.csv" /** Verifies data and schema. */ private def verifyCars( @@ -3439,6 +3440,39 @@ abstract class CSVSuite expected) } } + + test("SPARK-49444: CSV parsing failure with more than max columns") { + val schema = new StructType() + .add("intColumn", IntegerType, nullable = true) + .add("decimalColumn", DecimalType(10, 2), nullable = true) + + val fileReadException = intercept[SparkException] { + spark + .read + .schema(schema) + .option("header", "false") + .option("maxColumns", "2") + .csv(testFile(moreColumnsFile)) + .collect() + } + + checkErrorMatchPVals( + exception = fileReadException, + condition = "FAILED_READ_FILE.NO_HINT", + parameters = Map("path" -> s".*$moreColumnsFile")) + + val malformedCSVException = fileReadException.getCause.asInstanceOf[SparkRuntimeException] + + checkError( + exception = malformedCSVException, + condition = "MALFORMED_CSV_RECORD", + parameters = Map("badRecord" -> "1,3.14,string,5,7"), + sqlState = "KD000") + + assert(malformedCSVException.getCause.isInstanceOf[TextParsingException]) + val textParsingException = malformedCSVException.getCause.asInstanceOf[TextParsingException] + assert(textParsingException.getCause.isInstanceOf[ArrayIndexOutOfBoundsException]) + } } class CSVv1Suite extends CSVSuite { From 54e62a158ead91d832d477a76aace40ef5b54121 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Thu, 26 Sep 2024 13:37:39 -0700 Subject: [PATCH 097/250] [SPARK-49800][BUILD][K8S] Upgrade `kubernetes-client` to 6.13.4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Upgrade `kubernetes-client` from 6.13.3 to 6.13.4 ### Why are the changes needed? New version that have 5 fixes [Release log 6.13.4](https://github.com/fabric8io/kubernetes-client/releases/tag/v6.13.4) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48268 from bjornjorgensen/k8sclient6.13.4. Authored-by: Bjørn Jørgensen Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 50 +++++++++++++-------------- pom.xml | 2 +- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 19b8a237d30aa..c9a32757554be 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -159,31 +159,31 @@ jsr305/3.0.0//jsr305-3.0.0.jar jta/1.1//jta-1.1.jar jul-to-slf4j/2.0.16//jul-to-slf4j-2.0.16.jar kryo-shaded/4.0.2//kryo-shaded-4.0.2.jar -kubernetes-client-api/6.13.3//kubernetes-client-api-6.13.3.jar -kubernetes-client/6.13.3//kubernetes-client-6.13.3.jar -kubernetes-httpclient-okhttp/6.13.3//kubernetes-httpclient-okhttp-6.13.3.jar -kubernetes-model-admissionregistration/6.13.3//kubernetes-model-admissionregistration-6.13.3.jar -kubernetes-model-apiextensions/6.13.3//kubernetes-model-apiextensions-6.13.3.jar -kubernetes-model-apps/6.13.3//kubernetes-model-apps-6.13.3.jar -kubernetes-model-autoscaling/6.13.3//kubernetes-model-autoscaling-6.13.3.jar -kubernetes-model-batch/6.13.3//kubernetes-model-batch-6.13.3.jar -kubernetes-model-certificates/6.13.3//kubernetes-model-certificates-6.13.3.jar -kubernetes-model-common/6.13.3//kubernetes-model-common-6.13.3.jar -kubernetes-model-coordination/6.13.3//kubernetes-model-coordination-6.13.3.jar -kubernetes-model-core/6.13.3//kubernetes-model-core-6.13.3.jar -kubernetes-model-discovery/6.13.3//kubernetes-model-discovery-6.13.3.jar -kubernetes-model-events/6.13.3//kubernetes-model-events-6.13.3.jar -kubernetes-model-extensions/6.13.3//kubernetes-model-extensions-6.13.3.jar -kubernetes-model-flowcontrol/6.13.3//kubernetes-model-flowcontrol-6.13.3.jar -kubernetes-model-gatewayapi/6.13.3//kubernetes-model-gatewayapi-6.13.3.jar -kubernetes-model-metrics/6.13.3//kubernetes-model-metrics-6.13.3.jar -kubernetes-model-networking/6.13.3//kubernetes-model-networking-6.13.3.jar -kubernetes-model-node/6.13.3//kubernetes-model-node-6.13.3.jar -kubernetes-model-policy/6.13.3//kubernetes-model-policy-6.13.3.jar -kubernetes-model-rbac/6.13.3//kubernetes-model-rbac-6.13.3.jar -kubernetes-model-resource/6.13.3//kubernetes-model-resource-6.13.3.jar -kubernetes-model-scheduling/6.13.3//kubernetes-model-scheduling-6.13.3.jar -kubernetes-model-storageclass/6.13.3//kubernetes-model-storageclass-6.13.3.jar +kubernetes-client-api/6.13.4//kubernetes-client-api-6.13.4.jar +kubernetes-client/6.13.4//kubernetes-client-6.13.4.jar +kubernetes-httpclient-okhttp/6.13.4//kubernetes-httpclient-okhttp-6.13.4.jar +kubernetes-model-admissionregistration/6.13.4//kubernetes-model-admissionregistration-6.13.4.jar +kubernetes-model-apiextensions/6.13.4//kubernetes-model-apiextensions-6.13.4.jar +kubernetes-model-apps/6.13.4//kubernetes-model-apps-6.13.4.jar +kubernetes-model-autoscaling/6.13.4//kubernetes-model-autoscaling-6.13.4.jar +kubernetes-model-batch/6.13.4//kubernetes-model-batch-6.13.4.jar +kubernetes-model-certificates/6.13.4//kubernetes-model-certificates-6.13.4.jar +kubernetes-model-common/6.13.4//kubernetes-model-common-6.13.4.jar +kubernetes-model-coordination/6.13.4//kubernetes-model-coordination-6.13.4.jar +kubernetes-model-core/6.13.4//kubernetes-model-core-6.13.4.jar +kubernetes-model-discovery/6.13.4//kubernetes-model-discovery-6.13.4.jar +kubernetes-model-events/6.13.4//kubernetes-model-events-6.13.4.jar +kubernetes-model-extensions/6.13.4//kubernetes-model-extensions-6.13.4.jar +kubernetes-model-flowcontrol/6.13.4//kubernetes-model-flowcontrol-6.13.4.jar +kubernetes-model-gatewayapi/6.13.4//kubernetes-model-gatewayapi-6.13.4.jar +kubernetes-model-metrics/6.13.4//kubernetes-model-metrics-6.13.4.jar +kubernetes-model-networking/6.13.4//kubernetes-model-networking-6.13.4.jar +kubernetes-model-node/6.13.4//kubernetes-model-node-6.13.4.jar +kubernetes-model-policy/6.13.4//kubernetes-model-policy-6.13.4.jar +kubernetes-model-rbac/6.13.4//kubernetes-model-rbac-6.13.4.jar +kubernetes-model-resource/6.13.4//kubernetes-model-resource-6.13.4.jar +kubernetes-model-scheduling/6.13.4//kubernetes-model-scheduling-6.13.4.jar +kubernetes-model-storageclass/6.13.4//kubernetes-model-storageclass-6.13.4.jar lapack/3.0.3//lapack-3.0.3.jar leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar diff --git a/pom.xml b/pom.xml index f3dc92426ac4e..22048b55da27f 100644 --- a/pom.xml +++ b/pom.xml @@ -231,7 +231,7 @@ org.fusesource.leveldbjni - 6.13.3 + 6.13.4 1.17.6 ${java.home} From 339dd5b93316fecd0455b53b2cedee2b5333a184 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Sep 2024 13:39:02 -0700 Subject: [PATCH 098/250] [SPARK-49791][SQL] Make DelegatingCatalogExtension more extendable ### What changes were proposed in this pull request? This PR updates `DelegatingCatalogExtension` so that it's more extendable - `initialize` becomes not final, so that sub-classes can overwrite it - `delegate` becomes `protected`, so that sub-classes can access it In addition, this PR fixes a mistake that `DelegatingCatalogExtension` is just a convenient default implementation, it's actually the `CatalogExtension` interface that indicates this catalog implementation will delegate requests to the Spark session catalog. https://github.com/apache/spark/pull/47724 should use `CatalogExtension` instead. ### Why are the changes needed? Unblock the Iceberg extension. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48257 from cloud-fan/catalog. Lead-authored-by: Wenchen Fan Co-authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../sql/connector/catalog/DelegatingCatalogExtension.java | 4 ++-- .../spark/sql/catalyst/analysis/ResolveSessionCatalog.scala | 4 ++-- .../org/apache/spark/sql/internal/DataFrameWriterImpl.scala | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index f6686d2e4d3b6..786821514822e 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -38,7 +38,7 @@ @Evolving public abstract class DelegatingCatalogExtension implements CatalogExtension { - private CatalogPlugin delegate; + protected CatalogPlugin delegate; @Override public final void setDelegateCatalog(CatalogPlugin delegate) { @@ -51,7 +51,7 @@ public String name() { } @Override - public final void initialize(String name, CaseInsensitiveStringMap options) {} + public void initialize(String name, CaseInsensitiveStringMap options) {} @Override public Set capabilities() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 02ad2e79a5645..a9ad7523c8fbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, ResolveDefaultColumns => DefaultCols} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, DelegatingCatalogExtension, LookupCatalog, SupportsNamespaces, V1Table} +import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogManager, CatalogPlugin, CatalogV2Util, LookupCatalog, SupportsNamespaces, V1Table} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command._ @@ -706,6 +706,6 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) private def supportsV1Command(catalog: CatalogPlugin): Boolean = { isSessionCatalog(catalog) && ( SQLConf.get.getConf(SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isEmpty || - catalog.isInstanceOf[DelegatingCatalogExtension]) + catalog.isInstanceOf[CatalogExtension]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala index f0eef9ae1cbb0..8164d33f46fee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/DataFrameWriterImpl.scala @@ -429,7 +429,7 @@ final class DataFrameWriterImpl[T] private[sql](ds: Dataset[T]) extends DataFram val canUseV2 = lookupV2Provider().isDefined || (df.sparkSession.sessionState.conf.getConf( SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION).isDefined && !df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME) - .isInstanceOf[DelegatingCatalogExtension]) + .isInstanceOf[CatalogExtension]) session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) => From fc9d421a2345987105aa97947c867ac80ba48a05 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Fri, 27 Sep 2024 08:26:24 +0800 Subject: [PATCH 099/250] [SPARK-49211][SQL][FOLLOW-UP] Support catalog in QualifiedTableName ### What changes were proposed in this pull request? Support catalog in QualifiedTableName and remove `FullQualifiedTableName`. ### Why are the changes needed? Consolidate and remove duplicate code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #48255 from amaliujia/qualifedtablename. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- .../sql/catalyst/catalog/SessionCatalog.scala | 18 ++++++++-------- .../spark/sql/catalyst/identifiers.scala | 21 +++++++++++++++---- .../catalog/SessionCatalogSuite.scala | 8 +++---- .../datasources/DataSourceStrategy.scala | 4 ++-- .../datasources/v2/V2SessionCatalog.scala | 4 ++-- .../sql/StatisticsCollectionTestBase.scala | 4 ++-- .../sql/connector/DataSourceV2SQLSuite.scala | 10 ++++----- .../sql/execution/command/DDLSuite.scala | 6 +++--- .../command/v1/TruncateTableSuite.scala | 4 ++-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 12 +++++------ 10 files changed, 52 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index d3a6cb6ae2845..a0f7af10fefaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -197,7 +197,7 @@ class SessionCatalog( } } - private val tableRelationCache: Cache[FullQualifiedTableName, LogicalPlan] = { + private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { var builder = CacheBuilder.newBuilder() .maximumSize(cacheSize) @@ -205,33 +205,33 @@ class SessionCatalog( builder = builder.expireAfterWrite(cacheTTL, TimeUnit.SECONDS) } - builder.build[FullQualifiedTableName, LogicalPlan]() + builder.build[QualifiedTableName, LogicalPlan]() } /** This method provides a way to get a cached plan. */ - def getCachedPlan(t: FullQualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = { + def getCachedPlan(t: QualifiedTableName, c: Callable[LogicalPlan]): LogicalPlan = { tableRelationCache.get(t, c) } /** This method provides a way to get a cached plan if the key exists. */ - def getCachedTable(key: FullQualifiedTableName): LogicalPlan = { + def getCachedTable(key: QualifiedTableName): LogicalPlan = { tableRelationCache.getIfPresent(key) } /** This method provides a way to cache a plan. */ - def cacheTable(t: FullQualifiedTableName, l: LogicalPlan): Unit = { + def cacheTable(t: QualifiedTableName, l: LogicalPlan): Unit = { tableRelationCache.put(t, l) } /** This method provides a way to invalidate a cached plan. */ - def invalidateCachedTable(key: FullQualifiedTableName): Unit = { + def invalidateCachedTable(key: QualifiedTableName): Unit = { tableRelationCache.invalidate(key) } /** This method discards any cached table relation plans for the given table identifier. */ def invalidateCachedTable(name: TableIdentifier): Unit = { val qualified = qualifyIdentifier(name) - invalidateCachedTable(FullQualifiedTableName( + invalidateCachedTable(QualifiedTableName( qualified.catalog.get, qualified.database.get, qualified.table)) } @@ -301,7 +301,7 @@ class SessionCatalog( } if (cascade && databaseExists(dbName)) { listTables(dbName).foreach { t => - invalidateCachedTable(FullQualifiedTableName(SESSION_CATALOG_NAME, dbName, t.table)) + invalidateCachedTable(QualifiedTableName(SESSION_CATALOG_NAME, dbName, t.table)) } } externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade) @@ -1183,7 +1183,7 @@ class SessionCatalog( def refreshTable(name: TableIdentifier): Unit = synchronized { getLocalOrGlobalTempView(name).map(_.refresh()).getOrElse { val qualifiedIdent = qualifyIdentifier(name) - val qualifiedTableName = FullQualifiedTableName( + val qualifiedTableName = QualifiedTableName( qualifiedIdent.catalog.get, qualifiedIdent.database.get, qualifiedIdent.table) tableRelationCache.invalidate(qualifiedTableName) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index cc881539002b6..ceced9313940a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.sql.connector.catalog.CatalogManager + /** * An identifier that optionally specifies a database. * @@ -107,14 +109,25 @@ case class TableIdentifier(table: String, database: Option[String], catalog: Opt } /** A fully qualified identifier for a table (i.e., database.tableName) */ -case class QualifiedTableName(database: String, name: String) { - override def toString: String = s"$database.$name" -} +case class QualifiedTableName(catalog: String, database: String, name: String) { + /** Two argument ctor for backward compatibility. */ + def this(database: String, name: String) = this( + catalog = CatalogManager.SESSION_CATALOG_NAME, + database = database, + name = name) -case class FullQualifiedTableName(catalog: String, database: String, name: String) { override def toString: String = s"$catalog.$database.$name" } +object QualifiedTableName { + def apply(catalog: String, database: String, name: String): QualifiedTableName = { + new QualifiedTableName(catalog, database, name) + } + + def apply(database: String, name: String): QualifiedTableName = + new QualifiedTableName(database = database, name = name) +} + object TableIdentifier { def apply(tableName: String): TableIdentifier = new TableIdentifier(tableName) def apply(table: String, database: Option[String]): TableIdentifier = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index fbe63f71ae029..cfbc507fb5c74 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import org.scalatest.concurrent.Eventually import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{AliasIdentifier, FullQualifiedTableName, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{AliasIdentifier, FunctionIdentifier, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -1883,7 +1883,7 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { conf.setConf(StaticSQLConf.METADATA_CACHE_TTL_SECONDS, 1L) withConfAndEmptyCatalog(conf) { catalog => - val table = FullQualifiedTableName( + val table = QualifiedTableName( CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "test") // First, make sure the test table is not cached. @@ -1903,14 +1903,14 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually { test("SPARK-34197: refreshTable should not invalidate the relation cache for temporary views") { withBasicCatalog { catalog => createTempView(catalog, "tbl1", Range(1, 10, 1, 10), false) - val qualifiedName1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "tbl1") + val qualifiedName1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "tbl1") catalog.cacheTable(qualifiedName1, Range(1, 10, 1, 10)) catalog.refreshTable(TableIdentifier("tbl1")) assert(catalog.getCachedTable(qualifiedName1) != null) createGlobalTempView(catalog, "tbl2", Range(2, 10, 1, 10), false) val qualifiedName2 = - FullQualifiedTableName(SESSION_CATALOG_NAME, catalog.globalTempDatabase, "tbl2") + QualifiedTableName(SESSION_CATALOG_NAME, catalog.globalTempDatabase, "tbl2") catalog.cacheTable(qualifiedName2, Range(2, 10, 1, 10)) catalog.refreshTable(TableIdentifier("tbl2", Some(catalog.globalTempDatabase))) assert(catalog.getCachedTable(qualifiedName2) != null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2be4b236872f0..a2707da2d1023 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -28,7 +28,7 @@ import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.PREDICATES import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, FullQualifiedTableName, InternalRow, SQLConfHelper} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow, QualifiedTableName, SQLConfHelper} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ @@ -249,7 +249,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] private def readDataSourceTable( table: CatalogTable, extraOptions: CaseInsensitiveStringMap): LogicalPlan = { val qualifiedTableName = - FullQualifiedTableName(table.identifier.catalog.get, table.database, table.identifier.table) + QualifiedTableName(table.identifier.catalog.get, table.database, table.identifier.table) val catalog = sparkSession.sessionState.catalog val dsOptions = DataSourceUtils.generateDatasourceOptions(extraOptions, table) catalog.getCachedPlan(qualifiedTableName, () => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index bd1df87d15c3c..22c13fd98ced1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import org.apache.spark.SparkUnsupportedOperationException -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, FunctionIdentifier, SQLConfHelper, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, SQLConfHelper, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec, SessionCatalog} import org.apache.spark.sql.catalyst.util.TypeUtils._ @@ -93,7 +93,7 @@ class V2SessionCatalog(catalog: SessionCatalog) // table here. To avoid breaking it we do not resolve the table provider and still return // `V1Table` if the custom session catalog is present. if (table.provider.isDefined && !hasCustomSessionCatalog) { - val qualifiedTableName = FullQualifiedTableName( + val qualifiedTableName = QualifiedTableName( table.identifier.catalog.get, table.database, table.identifier.table) // Check if the table is in the v1 table cache to skip the v2 table lookup. if (catalog.getCachedTable(qualifiedTableName) != null) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala index 7fa29dd38fd96..74329ac0e0d23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionTestBase.scala @@ -25,7 +25,7 @@ import java.time.LocalDateTime import scala.collection.mutable import scala.util.Random -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{CatalogColumnStat, CatalogStatistics, CatalogTable, HiveTableRelation} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.AttributeMap @@ -270,7 +270,7 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils def getTableFromCatalogCache(tableName: String): LogicalPlan = { val catalog = spark.sessionState.catalog - val qualifiedTableName = FullQualifiedTableName( + val qualifiedTableName = QualifiedTableName( CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, tableName) catalog.getCachedTable(qualifiedTableName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 7aaec6d500ba0..dac066bbef838 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.{InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CurrentUserContext.CURRENT_USER import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchNamespaceException, TableAlreadyExistsException} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils} @@ -3713,7 +3713,7 @@ class DataSourceV2SQLSuiteV1Filter // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can // configure a new implementation. - val table1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "t") + val table1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "t") spark.sessionState.catalogManager.reset() withSQLConf( V2_SESSION_CATALOG_IMPLEMENTATION.key -> @@ -3722,7 +3722,7 @@ class DataSourceV2SQLSuiteV1Filter checkParquet(table1.toString, path.getAbsolutePath) } } - val table2 = FullQualifiedTableName("testcat3", "default", "t") + val table2 = QualifiedTableName("testcat3", "default", "t") withSQLConf( "spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) { withTempPath { path => @@ -3741,7 +3741,7 @@ class DataSourceV2SQLSuiteV1Filter // Reset CatalogManager to clear the materialized `spark_catalog` instance, so that we can // configure a new implementation. spark.sessionState.catalogManager.reset() - val table1 = FullQualifiedTableName(SESSION_CATALOG_NAME, "default", "t") + val table1 = QualifiedTableName(SESSION_CATALOG_NAME, "default", "t") withSQLConf( V2_SESSION_CATALOG_IMPLEMENTATION.key -> classOf[V2CatalogSupportBuiltinDataSource].getName) { @@ -3750,7 +3750,7 @@ class DataSourceV2SQLSuiteV1Filter } } - val table2 = FullQualifiedTableName("testcat3", "default", "t") + val table2 = QualifiedTableName("testcat3", "default", "t") withSQLConf( "spark.sql.catalog.testcat3" -> classOf[V2CatalogSupportBuiltinDataSource].getName) { withTempPath { path => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 8307326f17fcf..e07f6406901e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.permission.{AclEntry, AclStatus} import org.apache.spark.{SparkClassNotFoundException, SparkException, SparkFiles, SparkRuntimeException} import org.apache.spark.internal.config import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.TempTableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -219,7 +219,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { test("SPARK-25403 refresh the table after inserting data") { withTable("t") { val catalog = spark.sessionState.catalog - val table = FullQualifiedTableName( + val table = QualifiedTableName( CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "t") sql("CREATE TABLE t (a INT) USING parquet") sql("INSERT INTO TABLE t VALUES (1)") @@ -233,7 +233,7 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSparkSession { withTable("t") { withTempDir { dir => val catalog = spark.sessionState.catalog - val table = FullQualifiedTableName( + val table = QualifiedTableName( CatalogManager.SESSION_CATALOG_NAME, catalog.getCurrentDatabase, "t") val p1 = s"${dir.getCanonicalPath}/p1" val p2 = s"${dir.getCanonicalPath}/p2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala index 348b216aeb044..40ae35bbe8aa3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v1/TruncateTableSuite.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.{AclEntry, AclEntryScope, AclEntryType, FsAction, FsPermission} import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.command import org.apache.spark.sql.execution.command.FakeLocalFsFileSystem @@ -148,7 +148,7 @@ trait TruncateTableSuiteBase extends command.TruncateTableSuiteBase { val catalog = spark.sessionState.catalog val qualifiedTableName = - FullQualifiedTableName(CatalogManager.SESSION_CATALOG_NAME, "ns", "tbl") + QualifiedTableName(CatalogManager.SESSION_CATALOG_NAME, "ns", "tbl") val cachedPlan = catalog.getCachedTable(qualifiedTableName) assert(cachedPlan.stats.sizeInBytes == 0) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 7873c36222da0..1f87db31ffa52 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.{FullQualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -56,7 +56,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log private val tableCreationLocks = Striped.lazyWeakLock(100) /** Acquires a lock on the table cache for the duration of `f`. */ - private def withTableCreationLock[A](tableName: FullQualifiedTableName, f: => A): A = { + private def withTableCreationLock[A](tableName: QualifiedTableName, f: => A): A = { val lock = tableCreationLocks.get(tableName) lock.lock() try f finally { @@ -66,7 +66,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // For testing only private[hive] def getCachedDataSourceTable(table: TableIdentifier): LogicalPlan = { - val key = FullQualifiedTableName( + val key = QualifiedTableName( // scalastyle:off caselocale table.catalog.getOrElse(CatalogManager.SESSION_CATALOG_NAME).toLowerCase, table.database.getOrElse(sessionState.catalog.getCurrentDatabase).toLowerCase, @@ -76,7 +76,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } private def getCached( - tableIdentifier: FullQualifiedTableName, + tableIdentifier: QualifiedTableName, pathsInMetastore: Seq[Path], schemaInMetastore: StructType, expectedFileFormat: Class[_ <: FileFormat], @@ -120,7 +120,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log } private def logWarningUnexpectedFileFormat( - tableIdentifier: FullQualifiedTableName, + tableIdentifier: QualifiedTableName, expectedFileFormat: Class[_ <: FileFormat], actualFileFormat: String): Unit = { logWarning(log"Table ${MDC(TABLE_NAME, tableIdentifier)} should be stored as " + @@ -201,7 +201,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log fileType: String, isWrite: Boolean): LogicalRelation = { val metastoreSchema = relation.tableMeta.schema - val tableIdentifier = FullQualifiedTableName(relation.tableMeta.identifier.catalog.get, + val tableIdentifier = QualifiedTableName(relation.tableMeta.identifier.catalog.get, relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sessionState.conf.manageFilesourcePartitions From 09b7aa67ce64d7d4ecc803215eaf85464df181c5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 26 Sep 2024 17:32:29 -0700 Subject: [PATCH 100/250] [SPARK-49803][SQL][TESTS] Increase `spark.test.docker.connectionTimeout` to 10min ### What changes were proposed in this pull request? This PR aims to increase `spark.test.docker.connectionTimeout` to 10min. ### Why are the changes needed? Recently, various DB images fails at `connection` stage on multiple branches. **MASTER** branch https://github.com/apache/spark/actions/runs/11045311764/job/30682732260 ``` [info] OracleIntegrationSuite: [info] org.apache.spark.sql.jdbc.OracleIntegrationSuite *** ABORTED *** (5 minutes, 17 seconds) [info] The code passed to eventually never returned normally. Attempted 298 times over 5.0045005511500005 minutes. Last failure message: ORA-12541: Cannot connect. No listener at host 10.1.0.41 port 41079. (CONNECTION_ID=n9ZWIh+nQn+G9fkwKyoBQA==) ``` **branch-3.5** branch https://github.com/apache/spark/actions/runs/10939696926/job/30370552237 ``` [info] MsSqlServerNamespaceSuite: [info] org.apache.spark.sql.jdbc.v2.MsSqlServerNamespaceSuite *** ABORTED *** (5 minutes, 42 seconds) [info] The code passed to eventually never returned normally. Attempted 11 times over 5.487631282400001 minutes. Last failure message: The TCP/IP connection to the host 10.1.0.56, port 35345 has failed. Error: "Connection refused (Connection refused). Verify the connection properties. Make sure that an instance of SQL Server is running on the host and accepting TCP/IP connections at the port. Make sure that TCP connections to the port are not blocked by a firewall.".. (DockerJDBCIntegrationSuite.scala:166) ``` **branch-3.4** branch https://github.com/apache/spark/actions/runs/10937842509/job/30364658576 ``` [info] MsSqlServerNamespaceSuite: [info] org.apache.spark.sql.jdbc.v2.MsSqlServerNamespaceSuite *** ABORTED *** (5 minutes, 42 seconds) [info] The code passed to eventually never returned normally. Attempted 11 times over 5.487555645633333 minutes. Last failure message: The TCP/IP connection to the host 10.1.0.153, port 46153 has failed. Error: "Connection refused (Connection refused). Verify the connection properties. Make sure that an instance of SQL Server is running on the host and accepting TCP/IP connections at the port. Make sure that TCP connections to the port are not blocked by a firewall.".. (DockerJDBCIntegrationSuite.scala:166) ``` ### Does this PR introduce _any_ user-facing change? No, this is a test-only change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48272 from dongjoon-hyun/SPARK-49803. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index 8d17e0b4e36e6..1df01bd3bfb62 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -115,7 +115,7 @@ abstract class DockerJDBCIntegrationSuite protected val startContainerTimeout: Long = timeStringAsSeconds(sys.props.getOrElse("spark.test.docker.startContainerTimeout", "5min")) protected val connectionTimeout: PatienceConfiguration.Timeout = { - val timeoutStr = sys.props.getOrElse("spark.test.docker.connectionTimeout", "5min") + val timeoutStr = sys.props.getOrElse("spark.test.docker.connectionTimeout", "10min") timeout(timeStringAsSeconds(timeoutStr).seconds) } From 488c3f604490c8632dde67a00118d49ccfcbf578 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Fri, 27 Sep 2024 08:35:10 +0800 Subject: [PATCH 101/250] [SPARK-49776][PYTHON][CONNECT] Support pie plots ### What changes were proposed in this pull request? Support area plots with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. Area plots are supported as shown below. ```py >>> from datetime import datetime >>> data = [ ... (3, 5, 20, datetime(2018, 1, 31)), ... (2, 5, 42, datetime(2018, 2, 28)), ... (3, 6, 28, datetime(2018, 3, 31)), ... (9, 12, 62, datetime(2018, 4, 30))] >>> columns = ["sales", "signups", "visits", "date"] >>> df = spark.createDataFrame(data, columns) >>> fig = df.plot(kind="pie", x="date", y="sales") # df.plot(kind="pie", x="date", y="sales") >>> fig.show() ``` ![newplot (8)](https://github.com/user-attachments/assets/c4078bb7-4d84-4607-bcd7-bdd6fbbf8e28) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48256 from xinrong-meng/plot_pie. Authored-by: Xinrong Meng Signed-off-by: Xinrong Meng --- python/pyspark/errors/error-conditions.json | 5 +++ python/pyspark/sql/plot/core.py | 41 ++++++++++++++++++- python/pyspark/sql/plot/plotly.py | 15 +++++++ .../sql/tests/plot/test_frame_plot_plotly.py | 25 +++++++++++ 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 115ad658e32f5..ed62ea117d369 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -812,6 +812,11 @@ "Pipe function `` exited with error code ." ] }, + "PLOT_NOT_NUMERIC_COLUMN": { + "message": [ + "Argument must be a numerical column for plotting, got ." + ] + }, "PYTHON_HASH_SEED_NOT_SET": { "message": [ "Randomness of hash of string should be disabled via PYTHONHASHSEED." diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index 9f83d00696524..f9667ee2c0d69 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -17,7 +17,8 @@ from typing import Any, TYPE_CHECKING, Optional, Union from types import ModuleType -from pyspark.errors import PySparkRuntimeError, PySparkValueError +from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.sql.types import NumericType from pyspark.sql.utils import require_minimum_plotly_version @@ -97,6 +98,7 @@ class PySparkPlotAccessor: "bar": PySparkTopNPlotBase().get_top_n, "barh": PySparkTopNPlotBase().get_top_n, "line": PySparkSampledPlotBase().get_sampled, + "pie": PySparkTopNPlotBase().get_top_n, "scatter": PySparkSampledPlotBase().get_sampled, } _backends = {} # type: ignore[var-annotated] @@ -299,3 +301,40 @@ def area(self, x: str, y: str, **kwargs: Any) -> "Figure": >>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP """ return self(kind="area", x=x, y=y, **kwargs) + + def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": + """ + Generate a pie plot. + + A pie plot is a proportional representation of the numerical data in a + column. + + Parameters + ---------- + x : str + Name of column to be used as the category labels for the pie plot. + y : str + Name of the column to plot. + **kwargs + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + """ + schema = self.data.schema + + # Check if 'y' is a numerical column + y_field = schema[y] if y in schema.names else None + if y_field is None or not isinstance(y_field.dataType, NumericType): + raise PySparkTypeError( + errorClass="PLOT_NOT_NUMERIC_COLUMN", + messageParameters={ + "arg_name": "y", + "arg_type": str(y_field.dataType) if y_field else "None", + }, + ) + return self(kind="pie", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index 5efc19476057f..91f5363464717 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -27,4 +27,19 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": import plotly + if kind == "pie": + return plot_pie(data, **kwargs) + return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) + + +def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure": + # TODO(SPARK-49530): Support pie subplots with plotly backend + from plotly import express + + pdf = PySparkPlotAccessor.plot_data_map["pie"](data) + x = kwargs.pop("x", None) + y = kwargs.pop("y", None) + fig = express.pie(pdf, values=y, names=x, **kwargs) + + return fig diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 6176525b49550..70a1b336f734a 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -19,6 +19,7 @@ from datetime import datetime import pyspark.sql.plot # noqa: F401 +from pyspark.errors import PySparkTypeError from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message @@ -64,6 +65,11 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name= self.assertEqual(fig_data["type"], "scatter") self.assertEqual(fig_data["orientation"], "v") self.assertEqual(fig_data["mode"], "lines") + elif kind == "pie": + self.assertEqual(fig_data["type"], "pie") + self.assertEqual(list(fig_data["labels"]), expected_x) + self.assertEqual(list(fig_data["values"]), expected_y) + return self.assertEqual(fig_data["xaxis"], "x") self.assertEqual(list(fig_data["x"]), expected_x) @@ -133,6 +139,25 @@ def test_area_plot(self): self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 12], "signups") self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 62], "visits") + def test_pie_plot(self): + fig = self.sdf3.plot(kind="pie", x="date", y="sales") + expected_x = [ + datetime(2018, 1, 31, 0, 0), + datetime(2018, 2, 28, 0, 0), + datetime(2018, 3, 31, 0, 0), + datetime(2018, 4, 30, 0, 0), + ] + self._check_fig_data("pie", fig["data"][0], expected_x, [3, 2, 3, 9]) + + # y is not a numerical column + with self.assertRaises(PySparkTypeError) as pe: + self.sdf.plot.pie(x="int_val", y="category") + self.check_error( + exception=pe.exception, + errorClass="PLOT_NOT_NUMERIC_COLUMN", + messageParameters={"arg_name": "y", "arg_type": "StringType()"}, + ) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass From 27d4a77f2a8ccdbc4d7c3afd6743ec845dc1294b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Fri, 27 Sep 2024 11:33:03 +0900 Subject: [PATCH 102/250] [SPARK-49801][PYTHON][PS][BUILD] Update `pandas` to 2.2.3 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Update pandas from 2.2.2 to 2.2.3 ### Why are the changes needed? [Release notes](https://pandas.pydata.org/pandas-docs/version/2.2.3/whatsnew/v2.2.3.html) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48269 from bjornjorgensen/pandas2.2.3. Authored-by: Bjørn Jørgensen Signed-off-by: Hyukjin Kwon --- dev/infra/Dockerfile | 4 ++-- python/pyspark/pandas/supported_api_gen.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 5939e429b2f35..a40e43bb659f8 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -91,10 +91,10 @@ RUN mkdir -p /usr/local/pypy/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.2' scipy coverage matplotlib lxml +RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.2 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4 graphviz==0.20.3" diff --git a/python/pyspark/pandas/supported_api_gen.py b/python/pyspark/pandas/supported_api_gen.py index bbf0b3cbc3d67..f2a73cb1c1adf 100644 --- a/python/pyspark/pandas/supported_api_gen.py +++ b/python/pyspark/pandas/supported_api_gen.py @@ -38,7 +38,7 @@ MAX_MISSING_PARAMS_SIZE = 5 COMMON_PARAMETER_SET = {"kwargs", "args", "cls"} MODULE_GROUP_MATCH = [(pd, ps), (pdw, psw), (pdg, psg)] -PANDAS_LATEST_VERSION = "2.2.2" +PANDAS_LATEST_VERSION = "2.2.3" RST_HEADER = """ ===================== From 5d701f2d5add05b7af3889d6b87a192c11872298 Mon Sep 17 00:00:00 2001 From: "oleksii.diagiliev" Date: Thu, 26 Sep 2024 21:59:12 -0700 Subject: [PATCH 103/250] [SPARK-49804][K8S] Fix to use the exit code of executor container always ### What changes were proposed in this pull request? When deploying Spark pods on Kubernetes with sidecars, the reported executor's exit code may be incorrect. For example, the reported executor's exit code is 0(success), but the actual is 52 (OOM). ``` 2024-09-25 02:35:29,383 ERROR TaskSchedulerImpl.logExecutorLoss - Lost executor 1 on XXXXX: The executor with id 1 exited with exit code 0(success). The API gave the following container statuses: container name: fluentd container image: docker-images-release.XXXXX.com/XXXXX/fluentd:XXXXX container state: terminated container started at: 2024-09-25T02:32:17Z container finished at: 2024-09-25T02:34:52Z exit code: 0 termination reason: Completed container name: istio-proxy container image: docker-images-release.XXXXX.com/XXXXX-istio/proxyv2:XXXXX container state: running container started at: 2024-09-25T02:32:16Z container name: spark-kubernetes-executor container image: docker-dev-artifactory.XXXXX.com/XXXXX/spark-XXXXX:XXXXX container state: terminated container started at: 2024-09-25T02:32:17Z container finished at: 2024-09-25T02:35:28Z exit code: 52 termination reason: Error ``` The `ExecutorPodsLifecycleManager.findExitCode()` looks for any terminated container and may choose the sidecar instead of the main executor container. I'm changing it to look for the executor container always. Note, it may happen that the pod fails because of the failure of the sidecar container while executor's container is still running, with my changes the reported exit code will be -1 (`UNKNOWN_EXIT_CODE`). ### Why are the changes needed? To correctly report executor failure reason on UI, in the logs and for the event listeners `SparkListener.onExecutorRemoved()` ### Does this PR introduce _any_ user-facing change? Yes, the executor's exit code is taken from the main container instead of the sidecar. ### How was this patch tested? Added unit test and tested manually on the Kubernetes cluster by simulating different types of executor failure (JVM OOM and container eviction due to disk pressure on the node). ### Was this patch authored or co-authored using generative AI tooling? No Closes #48275 from fe2s/SPARK-49804-fix-exit-code. Lead-authored-by: oleksii.diagiliev Co-authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../k8s/ExecutorPodsLifecycleManager.scala | 6 ++- .../k8s/ExecutorLifecycleTestUtils.scala | 37 ++++++++++++++++++- .../ExecutorPodsLifecycleManagerSuite.scala | 14 ++++++- 3 files changed, 53 insertions(+), 4 deletions(-) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala index 0d79efa06e497..992be9099639e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala @@ -62,6 +62,9 @@ private[spark] class ExecutorPodsLifecycleManager( private val namespace = conf.get(KUBERNETES_NAMESPACE) + private val sparkContainerName = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME) + .getOrElse(DEFAULT_EXECUTOR_CONTAINER_NAME) + def start(schedulerBackend: KubernetesClusterSchedulerBackend): Unit = { val eventProcessingInterval = conf.get(KUBERNETES_EXECUTOR_EVENT_PROCESSING_INTERVAL) snapshotsStore.addSubscriber(eventProcessingInterval) { @@ -246,7 +249,8 @@ private[spark] class ExecutorPodsLifecycleManager( private def findExitCode(podState: FinalPodState): Int = { podState.pod.getStatus.getContainerStatuses.asScala.find { containerStatus => - containerStatus.getState.getTerminated != null + containerStatus.getName == sparkContainerName && + containerStatus.getState.getTerminated != null }.map { terminatedContainer => terminatedContainer.getState.getTerminated.getExitCode.toInt }.getOrElse(UNKNOWN_EXIT_CODE) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala index 299979071b5d7..fc75414e4a7e0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorLifecycleTestUtils.scala @@ -29,6 +29,7 @@ import org.apache.spark.resource.ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID object ExecutorLifecycleTestUtils { val TEST_SPARK_APP_ID = "spark-app-id" + val TEST_SPARK_EXECUTOR_CONTAINER_NAME = "spark-executor" def failedExecutorWithoutDeletion( executorId: Long, rpId: Int = DEFAULT_RESOURCE_PROFILE_ID): Pod = { @@ -37,7 +38,7 @@ object ExecutorLifecycleTestUtils { .withPhase("failed") .withStartTime(Instant.now.toString) .addNewContainerStatus() - .withName("spark-executor") + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) .withImage("k8s-spark") .withNewState() .withNewTerminated() @@ -49,6 +50,38 @@ object ExecutorLifecycleTestUtils { .addNewContainerStatus() .withName("spark-executor-sidecar") .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(2) + .endTerminated() + .endState() + .endContainerStatus() + .withMessage("Executor failed.") + .withReason("Executor failed because of a thrown error.") + .endStatus() + .build() + } + + def failedExecutorWithSidecarStatusListedFirst( + executorId: Long, rpId: Int = DEFAULT_RESOURCE_PROFILE_ID): Pod = { + new PodBuilder(podWithAttachedContainerForId(executorId, rpId)) + .editOrNewStatus() + .withPhase("failed") + .withStartTime(Instant.now.toString) + .addNewContainerStatus() // sidecar status listed before executor's container status + .withName("spark-executor-sidecar") + .withImage("k8s-spark-sidecar") + .withNewState() + .withNewTerminated() + .withMessage("Failed") + .withExitCode(2) + .endTerminated() + .endState() + .endContainerStatus() + .addNewContainerStatus() + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) + .withImage("k8s-spark") .withNewState() .withNewTerminated() .withMessage("Failed") @@ -200,7 +233,7 @@ object ExecutorLifecycleTestUtils { .endSpec() .build() val container = new ContainerBuilder() - .withName("spark-executor") + .withName(TEST_SPARK_EXECUTOR_CONTAINER_NAME) .withImage("k8s-spark") .build() SparkPod(pod, container) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala index 96be5dfabd121..d3b7213807afb 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy.k8s.Config +import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.deploy.k8s.Fabric8Aliases._ import org.apache.spark.deploy.k8s.KubernetesUtils._ @@ -60,6 +61,8 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte before { MockitoAnnotations.openMocks(this).close() + val sparkConf = new SparkConf() + .set(KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME, TEST_SPARK_EXECUTOR_CONTAINER_NAME) snapshotsStore = new DeterministicExecutorPodsSnapshotsStore() namedExecutorPods = mutable.Map.empty[String, PodResource] when(schedulerBackend.getExecutorsWithRegistrationTs()).thenReturn(Map.empty[String, Long]) @@ -67,7 +70,7 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte when(podOperations.inNamespace(anyString())).thenReturn(podsWithNamespace) when(podsWithNamespace.withName(any(classOf[String]))).thenAnswer(namedPodsAnswer()) eventHandlerUnderTest = new ExecutorPodsLifecycleManager( - new SparkConf(), + sparkConf, kubernetesClient, snapshotsStore) eventHandlerUnderTest.start(schedulerBackend) @@ -162,6 +165,15 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte .edit(any[UnaryOperator[Pod]]()) } + test("SPARK-49804: Use the exit code of executor container always") { + val failedPod = failedExecutorWithSidecarStatusListedFirst(1) + snapshotsStore.updatePod(failedPod) + snapshotsStore.notifySubscribers() + val msg = exitReasonMessage(1, failedPod, 1) + val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg) + verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason) + } + private def exitReasonMessage(execId: Int, failedPod: Pod, exitCode: Int): String = { val reason = Option(failedPod.getStatus.getReason) val message = Option(failedPod.getStatus.getMessage) From f18c4e7722b46e8573e959f5f3b063ed0efa5d23 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 27 Sep 2024 15:27:34 +0800 Subject: [PATCH 104/250] [SPARK-49805][SQL][ML] Remove private[xxx] functions from `function.scala` ### What changes were proposed in this pull request? Remove private[xxx] functions from `function.scala` ### Why are the changes needed? internal functions can be directly invoked by `Column.internalFn`, no need to add them in `function.scala` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48276 from zhengruifeng/move_private_func. Authored-by: Ruifeng Zheng Signed-off-by: yangjie01 --- .../main/scala/org/apache/spark/ml/recommendation/ALS.scala | 5 ++++- .../apache/spark/ml/recommendation/CollectTopKSuite.scala | 3 ++- sql/api/src/main/scala/org/apache/spark/sql/functions.scala | 3 --- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 1a004f71749e1..5899bf891ec9d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -517,7 +517,7 @@ class ALSModel private[ml] ( ) ratings.groupBy(srcOutputColumn) - .agg(collect_top_k(struct(ratingColumn, dstOutputColumn), num, false)) + .agg(ALSModel.collect_top_k(struct(ratingColumn, dstOutputColumn), num, false)) .as[(Int, Seq[(Float, Int)])] .map(t => (t._1, t._2.map(p => (p._2, p._1)))) .toDF(srcOutputColumn, recommendColumn) @@ -546,6 +546,9 @@ object ALSModel extends MLReadable[ALSModel] { private val Drop = "drop" private[recommendation] final val supportedColdStartStrategies = Array(NaN, Drop) + private[recommendation] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = + Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) + @Since("1.6.0") override def read: MLReader[ALSModel] = new ALSModelReader diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala index b79e10d0d267e..bd83d5498ae6f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/CollectTopKSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.ml.recommendation +import org.apache.spark.ml.recommendation.ALSModel.collect_top_k import org.apache.spark.ml.util.MLTest import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.functions.{col, collect_top_k, struct} +import org.apache.spark.sql.functions.{col, struct} class CollectTopKSuite extends MLTest { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 93bff22621057..e6fd06f2ec632 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -401,9 +401,6 @@ object functions { def count_min_sketch(e: Column, eps: Column, confidence: Column): Column = count_min_sketch(e, eps, confidence, lit(SparkClassUtils.random.nextLong)) - private[spark] def collect_top_k(e: Column, num: Int, reverse: Boolean): Column = - Column.internalFn("collect_top_k", e, lit(num), lit(reverse)) - /** * Aggregate function: returns the Pearson Correlation Coefficient for two columns. * From 9b739d415cd51c8dd3f9332bae225196bab17d48 Mon Sep 17 00:00:00 2001 From: Mikhail Nikoliukin Date: Fri, 27 Sep 2024 21:48:35 +0800 Subject: [PATCH 105/250] [SPARK-49757][SQL] Support IDENTIFIER expression in SET CATALOG statement ### What changes were proposed in this pull request? This pr adds possibility to use `IDENTIFIER(...)` for a catalog name in `SET CATALOG` statement. For instance `SET CATALOG IDENTIFIER('test')` now works the same as `SET CATALOG test` ### Why are the changes needed? 1. Consistency of API. It can be confusing for user that he can use IDENTIFIER in some contexts but cannot for catalogs. 2. Parametrization. It allows user to write `SET CATALOG IDENTIFIER(:user_data)` and doesn't worry about SQL injections. ### Does this PR introduce _any_ user-facing change? Yes, now `SET CATALOG IDENTIFIER(...)` works. It can be used with any string expressions and parametrization. But multipart identifiers (like `IDENTIFIER('database.table')`) are banned and will rise ParseException with new type `INVALID_SQL_SYNTAX.MULTI_PART_CATALOG_NAME`. This restriction always has been on grammar level, but now user can try to bind such identifier via parameters. ### How was this patch tested? Unit tests with several new covering new behavior. ### Was this patch authored or co-authored using generative AI tooling? Yes, some code suggestions Generated-by: GitHub Copilot Closes #48228 from mikhailnik-db/SPARK-49757. Authored-by: Mikhail Nikoliukin Signed-off-by: Wenchen Fan --- .../resources/error/error-conditions.json | 2 +- .../sql/catalyst/parser/SqlBaseParser.g4 | 8 +++- .../spark/sql/errors/QueryParsingErrors.scala | 14 +++++-- .../spark/sql/execution/SparkSqlParser.scala | 35 ++++++++++++---- .../identifier-clause.sql.out | 2 +- .../results/identifier-clause.sql.out | 2 +- .../sql/connector/DataSourceV2SQLSuite.scala | 42 +++++++++++++++++++ .../sql/errors/QueryParsingErrorsSuite.scala | 4 +- .../execution/command/DDLParserSuite.scala | 4 +- 9 files changed, 93 insertions(+), 20 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e83202d9e5ee3..3fcb53426eccf 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3023,7 +3023,7 @@ }, "MULTI_PART_NAME" : { "message" : [ - " with multiple part function name() is not allowed." + " with multiple part name() is not allowed." ] }, "OPTION_IS_INVALID" : { diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 094f7f5315b80..866634b041280 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -148,7 +148,7 @@ statement | ctes? dmlStatementNoWith #dmlStatement | USE identifierReference #use | USE namespace identifierReference #useNamespace - | SET CATALOG (errorCapturingIdentifier | stringLit) #setCatalog + | SET CATALOG catalogIdentifierReference #setCatalog | CREATE namespace (IF errorCapturingNot EXISTS)? identifierReference (commentSpec | locationSpec | @@ -594,6 +594,12 @@ identifierReference | multipartIdentifier ; +catalogIdentifierReference + : IDENTIFIER_KW LEFT_PAREN expression RIGHT_PAREN + | errorCapturingIdentifier + | stringLit + ; + queryOrganization : (ORDER BY order+=sortItem (COMMA order+=sortItem)*)? (CLUSTER BY clusterBy+=expression (COMMA clusterBy+=expression)*)? diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index b19607a28f06c..b0743d6de4772 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -621,9 +621,8 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { def unsupportedFunctionNameError(funcName: Seq[String], ctx: ParserRuleContext): Throwable = { new ParseException( errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - messageParameters = Map( - "statement" -> toSQLStmt("CREATE TEMPORARY FUNCTION"), - "funcName" -> toSQLId(funcName)), + messageParameters = + Map("statement" -> toSQLStmt("CREATE TEMPORARY FUNCTION"), "name" -> toSQLId(funcName)), ctx) } @@ -665,7 +664,14 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { new ParseException( errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", messageParameters = - Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "funcName" -> toSQLId(name)), + Map("statement" -> toSQLStmt("DROP TEMPORARY FUNCTION"), "name" -> toSQLId(name)), + ctx) + } + + def invalidNameForSetCatalog(name: Seq[String], ctx: ParserRuleContext): Throwable = { + new ParseException( + errorClass = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + messageParameters = Map("statement" -> toSQLStmt("SET CATALOG"), "name" -> toSQLId(name)), ctx) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index a8261e5d98ba0..1c735154f25ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -27,7 +27,7 @@ import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, SchemaEvolution, SchemaTypeEvolution, UnresolvedFunctionName, UnresolvedIdentifier, UnresolvedNamespace} +import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, PersistedView, PlanWithUnresolvedIdentifier, SchemaEvolution, SchemaTypeEvolution, UnresolvedFunctionName, UnresolvedIdentifier, UnresolvedNamespace} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.parser._ @@ -67,6 +67,25 @@ class SparkSqlAstBuilder extends AstBuilder { private val configValueDef = """([^;]*);*""".r private val strLiteralDef = """(".*?[^\\]"|'.*?[^\\]'|[^ \n\r\t"']+)""".r + private def withCatalogIdentClause( + ctx: CatalogIdentifierReferenceContext, + builder: Seq[String] => LogicalPlan): LogicalPlan = { + val exprCtx = ctx.expression + if (exprCtx != null) { + // resolve later in analyzer + PlanWithUnresolvedIdentifier(withOrigin(exprCtx) { expression(exprCtx) }, Nil, + (ident, _) => builder(ident)) + } else if (ctx.errorCapturingIdentifier() != null) { + // resolve immediately + builder.apply(Seq(ctx.errorCapturingIdentifier().getText)) + } else if (ctx.stringLit() != null) { + // resolve immediately + builder.apply(Seq(string(visitStringLit(ctx.stringLit())))) + } else { + throw SparkException.internalError("Invalid catalog name") + } + } + /** * Create a [[SetCommand]] logical plan. * @@ -276,13 +295,13 @@ class SparkSqlAstBuilder extends AstBuilder { * Create a [[SetCatalogCommand]] logical command. */ override def visitSetCatalog(ctx: SetCatalogContext): LogicalPlan = withOrigin(ctx) { - if (ctx.errorCapturingIdentifier() != null) { - SetCatalogCommand(ctx.errorCapturingIdentifier().getText) - } else if (ctx.stringLit() != null) { - SetCatalogCommand(string(visitStringLit(ctx.stringLit()))) - } else { - throw SparkException.internalError("Invalid catalog name") - } + withCatalogIdentClause(ctx.catalogIdentifierReference, identifiers => { + if (identifiers.size > 1) { + // can occur when user put multipart string in IDENTIFIER(...) clause + throw QueryParsingErrors.invalidNameForSetCatalog(identifiers, ctx) + } + SetCatalogCommand(identifiers.head) + }) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out index f0bf8b883dd8b..20e6ca1e6a2ec 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out @@ -893,7 +893,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", "sqlState" : "42000", "messageParameters" : { - "funcName" : "`default`.`myDoubleAvg`", + "name" : "`default`.`myDoubleAvg`", "statement" : "DROP TEMPORARY FUNCTION" }, "queryContext" : [ { diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out index 952fb8fdc2bd2..596745b4ba5d8 100644 --- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out @@ -1024,7 +1024,7 @@ org.apache.spark.sql.catalyst.parser.ParseException "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME", "sqlState" : "42000", "messageParameters" : { - "funcName" : "`default`.`myDoubleAvg`", + "name" : "`default`.`myDoubleAvg`", "statement" : "DROP TEMPORARY FUNCTION" }, "queryContext" : [ { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index dac066bbef838..6b58d23e92603 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -2887,6 +2887,48 @@ class DataSourceV2SQLSuiteV1Filter "config" -> "\"spark.sql.catalog.not_exist_catalog\"")) } + test("SPARK-49757: SET CATALOG statement with IDENTIFIER should work") { + val catalogManager = spark.sessionState.catalogManager + assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME) + + sql("SET CATALOG IDENTIFIER('testcat')") + assert(catalogManager.currentCatalog.name() == "testcat") + + spark.sql("SET CATALOG IDENTIFIER(:param)", Map("param" -> "testcat2")) + assert(catalogManager.currentCatalog.name() == "testcat2") + + checkError( + exception = intercept[CatalogNotFoundException] { + sql("SET CATALOG IDENTIFIER('not_exist_catalog')") + }, + condition = "CATALOG_NOT_FOUND", + parameters = Map( + "catalogName" -> "`not_exist_catalog`", + "config" -> "\"spark.sql.catalog.not_exist_catalog\"") + ) + } + + test("SPARK-49757: SET CATALOG statement with IDENTIFIER with multipart name should fail") { + val catalogManager = spark.sessionState.catalogManager + assert(catalogManager.currentCatalog.name() == SESSION_CATALOG_NAME) + + val sqlText = "SET CATALOG IDENTIFIER(:param)" + checkError( + exception = intercept[ParseException] { + spark.sql(sqlText, Map("param" -> "testcat.ns1")) + }, + condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", + parameters = Map( + "name" -> "`testcat`.`ns1`", + "statement" -> "SET CATALOG" + ), + context = ExpectedContext( + fragment = sqlText, + start = 0, + stop = 29) + ) + } + test("SPARK-35973: ShowCatalogs") { val schema = new StructType() .add("catalog", StringType, nullable = false) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index da7b6e7f63c85..666f85e19c1c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -334,7 +334,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL sqlState = "42000", parameters = Map( "statement" -> "CREATE TEMPORARY FUNCTION", - "funcName" -> "`ns`.`db`.`func`"), + "name" -> "`ns`.`db`.`func`"), context = ExpectedContext( fragment = sqlText, start = 0, @@ -367,7 +367,7 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL sqlState = "42000", parameters = Map( "statement" -> "DROP TEMPORARY FUNCTION", - "funcName" -> "`db`.`func`"), + "name" -> "`db`.`func`"), context = ExpectedContext( fragment = sqlText, start = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 176eb7c290764..8b868c0e17230 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -688,7 +688,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { checkError( exception = parseException(sql1), condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> "`a`.`b`"), + parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> "`a`.`b`"), context = ExpectedContext( fragment = sql1, start = 0, @@ -698,7 +698,7 @@ class DDLParserSuite extends AnalysisTest with SharedSparkSession { checkError( exception = parseException(sql2), condition = "INVALID_SQL_SYNTAX.MULTI_PART_NAME", - parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "funcName" -> "`a`.`b`"), + parameters = Map("statement" -> "DROP TEMPORARY FUNCTION", "name" -> "`a`.`b`"), context = ExpectedContext( fragment = sql2, start = 0, From d7abddc454ffef6ac16e8f6df6f601eec621ddfd Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Fri, 27 Sep 2024 21:53:10 +0800 Subject: [PATCH 106/250] [SPARK-49808][SQL] Fix a deadlock in subquery execution due to lazy vals ### What changes were proposed in this pull request? Fix a deadlock in subquery execution due to lazy vals ### Why are the changes needed? we observed a deadlock between `QueryPlan.canonicalized` and `QueryPlan.references`: ``` 24/09/04 04:46:54 ERROR DeadlockDetector: Found 2 new deadlock thread(s): "ScalaTest-run-running-SubquerySuite" prio=5 Id=1 BLOCKED on org.apache.spark.sql.execution.aggregate.HashAggregateExec87abc7f owned by "subquery-5" Id=112 at app//org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:684) - blocked on org.apache.spark.sql.execution.aggregate.HashAggregateExec87abc7f at app//org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:684) at app//org.apache.spark.sql.catalyst.plans.QueryPlan.$anonfun$doCanonicalize$2(QueryPlan.scala:716) at app//org.apache.spark.sql.catalyst.plans.QueryPlan$$Lambda$4058/0x00007f740f3d0cb0.apply(Unknown Source) at app//org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren(TreeNode.scala:1314) at app//org.apache.spark.sql.catalyst.trees.UnaryLike.mapChildren$(TreeNode.scala:1313) at app//org.apache.spark.sql.execution.WholeStageCodegenExec.mapChildren(WholeStageCodegenExec.scala:639) at app//org.apache.spark.sql.catalyst.plans.QueryPlan.doCanonicalize(QueryPlan.scala:716) ... "subquery-5" daemon prio=5 Id=112 BLOCKED on org.apache.spark.sql.execution.WholeStageCodegenExec132a3243 owned by "ScalaTest-run-running-SubquerySuite" Id=1 at app//org.apache.spark.sql.catalyst.plans.QueryPlan.references$lzycompute(QueryPlan.scala:101) - blocked on org.apache.spark.sql.execution.WholeStageCodegenExec132a3243 at app//org.apache.spark.sql.catalyst.plans.QueryPlan.references(QueryPlan.scala:101) at app//org.apache.spark.sql.execution.CodegenSupport.usedInputs(WholeStageCodegenExec.scala:325) at app//org.apache.spark.sql.execution.CodegenSupport.usedInputs$(WholeStageCodegenExec.scala:325) at app//org.apache.spark.sql.execution.WholeStageCodegenExec.usedInputs(WholeStageCodegenExec.scala:639) at app//org.apache.spark.sql.execution.CodegenSupport.consume(WholeStageCodegenExec.scala:187) at app//org.apache.spark.sql.execution.CodegenSupport.consume$(WholeStageCodegenExec.scala:157) at app//org.apache.spark.sql.execution.aggregate.HashAggregateExec.consume(HashAggregateExec.scala:53) ``` The main thread `TakeOrderedAndProject.doExecute` is trying to compute `outputOrdering`, it top-down traverse the tree, and requires the lock of `QueryPlan.canonicalized` in the path. In this deadlock, it successfully obtained the lock of `WholeStageCodegenExec` and requires the lock of `HashAggregateExec`; Concurrently, a subquery execution thread is performing code generation and bottom-up traverses the tree via `def consume`, which checks `WholeStageCodegenExec.usedInputs` and refererences a lazy val `QueryPlan.references`. It requires the lock of `QueryPlan.references` in the path. In this deadlock, it successfully obtained the lock of `HashAggregateExec` and requires the lock of `WholeStageCodegenExec`; This is due to Scala's lazy val internally calls this.synchronized on the instance that contains the val. This creates a potential for deadlocks. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? manually checked with `com.databricks.spark.sql.SubquerySuite` we encountered this issue multiple times before this fix in `SubquerySuite`, and after this fix we didn't hit this issue in multiple runs. ### Was this patch authored or co-authored using generative AI tooling? no Closes #48279 from zhengruifeng/fix_deadlock. Authored-by: Ruifeng Zheng Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3f417644082c3..ca5ff78b10e91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.trees.TreePatternBits import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.LazyTry import org.apache.spark.util.collection.BitSet /** @@ -94,9 +95,11 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. */ + def references: AttributeSet = lazyReferences.get + @transient - lazy val references: AttributeSet = { - AttributeSet.fromAttributeSets(expressions.map(_.references)) -- producedAttributes + private val lazyReferences = LazyTry { + AttributeSet(expressions) -- producedAttributes } /** From dd692e90b7384a789142cfccff0dbf10cead6a21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Fri, 27 Sep 2024 07:49:06 -0700 Subject: [PATCH 107/250] [SPARK-49801][FOLLOWUP][INFRA] Update `pandas` to 2.2.3 in `pages.yml` too MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Fix doc build. ### Why are the changes needed? in https://github.com/apache/spark/pull/48269 > Oh, this seems to break GitHub Action Jekyll. https://github.com/apache/spark/actions/runs/11063509911/job/30742286270 Traceback (most recent call last): File "/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/sphinx/config.py", line 332, in eval_config_file exec(code, namespace) File "/home/runner/work/spark/spark/python/docs/source/conf.py", line 33, in generate_supported_api(output_rst_file_path) File "/home/runner/work/spark/spark/python/pyspark/pandas/supported_api_gen.py", line 102, in generate_supported_api _check_pandas_version() File "/home/runner/work/spark/spark/python/pyspark/pandas/supported_api_gen.py", line 116, in _check_pandas_version raise ImportError(msg) ImportError: Warning: pandas 2.2.3 is required; your version is 2.2.2" ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48278 from bjornjorgensen/fix-pandas-2.2.3. Authored-by: Bjørn Jørgensen Signed-off-by: Dongjoon Hyun --- .github/workflows/pages.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index 8faeb0557fbfb..f78f7895a183f 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -60,7 +60,7 @@ jobs: - name: Install Python dependencies run: | pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \ - ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow 'pandas==2.2.2' 'plotly>=4.8' 'docutils<0.18.0' \ + ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.20.0' pyarrow 'pandas==2.2.3' 'plotly>=4.8' 'docutils<0.18.0' \ 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.9.1' \ 'pandas-stubs==1.2.0.53' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' From 6dc628c31cdf48769ccd80cd2b81f7bd6386276f Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 27 Sep 2024 08:51:47 -0700 Subject: [PATCH 108/250] [SPARK-49809][BUILD] Use `sbt.IO` in `SparkBuild.scala` to avoid naming conflicts with `java.io.IO` in Java 23 ### What changes were proposed in this pull request? This pr change to use `sbt.IO` in `SparkBuild.scala` to avoid naming conflicts with `java.io.IO` in Java 23, and after this PR, Spark can be built using sbt with Java 23(current pr does not focus on the results of `sbt/test` with Java 23) ### Why are the changes needed? Make Spark be compiled using sbt with Java 23. Because Java 23 has added `java.io.IO`, and `SparkBuild.scala` imports both `java.io._` and `sbt._`, this results in the following error when executing ``` build/sbt -Phadoop-3 -Phive-thriftserver -Pspark-ganglia-lgpl -Pdocker-integration-tests -Pyarn -Pvolcano -Pkubernetes -Pkinesis-asl -Phive -Phadoop-cloud Test/package streaming-kinesis-asl-assembly/assembly connect/assembly ``` with Java 23 ``` build/sbt -Phadoop-3 -Phive-thriftserver -Pspark-ganglia-lgpl -Pdocker-integration-tests -Pyarn -Pvolcano -Pkubernetes -Pkinesis-asl -Phive -Phadoop-cloud Test/package streaming-kinesis-asl-assembly/assembly connect/assembly Using /Users/yangjie01/Tools/zulu23 as default JAVA_HOME. Note, this will be overridden by -java-home if it is set. [info] welcome to sbt 1.9.3 (Azul Systems, Inc. Java 23) [info] loading settings for project global-plugins from idea.sbt ... [info] loading global plugins from /Users/yangjie01/.sbt/1.0/plugins [info] loading settings for project spark-sbt-build from plugins.sbt ... [info] loading project definition from /Users/yangjie01/SourceCode/git/spark-sbt/project [info] compiling 3 Scala sources to /Users/yangjie01/SourceCode/git/spark-sbt/project/target/scala-2.12/sbt-1.0/classes ... [error] /Users/yangjie01/SourceCode/git/spark-sbt/project/SparkBuild.scala:1209:7: reference to IO is ambiguous; [error] it is imported twice in the same scope by [error] import sbt._ [error] and import java.io._ [error] IO.write(file, s"$hadoopProvidedProp = $isHadoopProvided") [error] ^ [error] one error found [error] (Compile / compileIncremental) Compilation failed ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass Github Actions - Manual check: ``` build/sbt -Phadoop-3 -Phive-thriftserver -Pspark-ganglia-lgpl -Pdocker-integration-tests -Pyarn -Pvolcano -Pkubernetes -Pkinesis-asl -Phive -Phadoop-cloud Test/package streaming-kinesis-asl-assembly/assembly connect/assembly ``` with Java 23, after this pr, the aforementioned command can be executed successfully. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48280 from LuciferYang/build-with-java23. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 82950fb30287a..6137984a53c0a 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1206,7 +1206,7 @@ object YARN { genConfigProperties := { val file = (Compile / classDirectory).value / s"org/apache/spark/deploy/yarn/$propFileName" val isHadoopProvided = SbtPomKeys.effectivePom.value.getProperties.get(hadoopProvidedProp) - IO.write(file, s"$hadoopProvidedProp = $isHadoopProvided") + sbt.IO.write(file, s"$hadoopProvidedProp = $isHadoopProvided") }, Compile / copyResources := (Def.taskDyn { val c = (Compile / copyResources).value From b6681fbf32fa3596d7649d413f20cc5c6da64991 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 27 Sep 2024 13:42:54 -0700 Subject: [PATCH 109/250] [SPARK-49787][SQL] Cast between UDT and other types ### What changes were proposed in this pull request? This patch adds UDT support to `Cast` expression. ### Why are the changes needed? Our customer faced an error when migrating queries that write UDT column from Hive to Iceberg table. The error happens when Spark tries to cast UDT column to the data type (i.e., the sql type of the UDT) of the table column. The cast is added by table column resolution rule for V2 writing commands. Currently `Cast` expression doesn't support casting between UDT and other types. However, underlying an UDT, it is serialized as its `sqlType`, `Cast` should be able to cast between the `sqlType` and other types. ### Does this PR introduce _any_ user-facing change? Yes. User query can cast between UDT and other types. ### How was this patch tested? Unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes #48251 from viirya/cast_udt. Authored-by: Liang-Chi Hsieh Signed-off-by: huaxingao --- python/pyspark/sql/tests/test_types.py | 16 +- .../apache/spark/sql/types/UpCastRule.scala | 4 + .../spark/sql/catalyst/expressions/Cast.scala | 175 ++++++++++-------- .../sql/catalyst/expressions/literals.scala | 84 +++++---- .../catalyst/expressions/CastSuiteBase.scala | 42 ++++- 5 files changed, 202 insertions(+), 119 deletions(-) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 8610ace52d86a..c240a84d1edb9 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -28,7 +28,6 @@ from pyspark.sql import Row from pyspark.sql import functions as F from pyspark.errors import ( - AnalysisException, ParseException, PySparkTypeError, PySparkValueError, @@ -1130,10 +1129,17 @@ def test_cast_to_string_with_udt(self): def test_cast_to_udt_with_udt(self): row = Row(point=ExamplePoint(1.0, 2.0), python_only_point=PythonOnlyPoint(1.0, 2.0)) df = self.spark.createDataFrame([row]) - with self.assertRaises(AnalysisException): - df.select(F.col("point").cast(PythonOnlyUDT())).collect() - with self.assertRaises(AnalysisException): - df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect() + result = df.select(F.col("point").cast(PythonOnlyUDT())).collect() + self.assertEqual( + result, + [Row(point=PythonOnlyPoint(1.0, 2.0))], + ) + + result = df.select(F.col("python_only_point").cast(ExamplePointUDT())).collect() + self.assertEqual( + result, + [Row(python_only_point=ExamplePoint(1.0, 2.0))], + ) def test_struct_type(self): struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala index 4993e249b3059..6f2fd41f1f799 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/UpCastRule.scala @@ -66,6 +66,10 @@ private[sql] object UpCastRule { case (from: UserDefinedType[_], to: UserDefinedType[_]) if to.acceptsType(from) => true + case (udt: UserDefinedType[_], toType) => canUpCast(udt.sqlType, toType) + + case (fromType, udt: UserDefinedType[_]) => canUpCast(fromType, udt.sqlType) + case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7a2799e99fe2d..9a29cb4a2bfb3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -150,6 +150,10 @@ object Cast extends QueryErrorsBase { case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true + case (udt: UserDefinedType[_], toType) => canAnsiCast(udt.sqlType, toType) + + case (fromType, udt: UserDefinedType[_]) => canAnsiCast(fromType, udt.sqlType) + case _ => false } @@ -267,6 +271,10 @@ object Cast extends QueryErrorsBase { case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true + case (udt: UserDefinedType[_], toType) => canCast(udt.sqlType, toType) + + case (fromType, udt: UserDefinedType[_]) => canCast(fromType, udt.sqlType) + case _ => false } @@ -1123,33 +1131,42 @@ case class Cast( variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId, zoneId) }) } else { - to match { - case dt if dt == from => identity[Any] - case VariantType => input => variant.VariantExpressionEvalUtils.castToVariant(input, from) - case _: StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case TimestampNTZType => castToTimestampNTZ(from) - case CalendarIntervalType => castToInterval(from) - case it: DayTimeIntervalType => castToDayTimeInterval(from, it) - case it: YearMonthIntervalType => castToYearMonthInterval(from, it) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => - castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) - case udt: UserDefinedType[_] if udt.acceptsType(from) => - identity[Any] - case _: UserDefinedType[_] => - throw QueryExecutionErrors.cannotCastError(from, to) + from match { + // `castToString` has special handling for `UserDefinedType` + case udt: UserDefinedType[_] if !to.isInstanceOf[StringType] => + castInternal(udt.sqlType, to) + case _ => + to match { + case dt if dt == from => identity[Any] + case VariantType => input => + variant.VariantExpressionEvalUtils.castToVariant(input, from) + case _: StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case TimestampNTZType => castToTimestampNTZ(from) + case CalendarIntervalType => castToInterval(from) + case it: DayTimeIntervalType => castToDayTimeInterval(from, it) + case it: YearMonthIntervalType => castToYearMonthInterval(from, it) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => + castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] if udt.acceptsType(from) => + identity[Any] + case udt: UserDefinedType[_] => + castInternal(from, udt.sqlType) + case _ => + throw QueryExecutionErrors.cannotCastError(from, to) + } } } } @@ -1211,54 +1228,64 @@ case class Cast( private[this] def nullSafeCastFunction( from: DataType, to: DataType, - ctx: CodegenContext): CastFunction = to match { - - case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" - case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" - case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) => - val tmp = ctx.freshVariable("tmp", classOf[Object]) - val dataTypeArg = ctx.addReferenceObj("dataType", to) - val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId) - val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) - val failOnError = evalMode != EvalMode.TRY - val cls = classOf[variant.VariantGet].getName - code""" - Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg); - if ($tmp == null) { - $evNull = true; - } else { - $evPrim = (${CodeGenerator.boxedType(to)})$tmp; + ctx: CodegenContext): CastFunction = { + from match { + // `castToStringCode` has special handling for `UserDefinedType` + case udt: UserDefinedType[_] if !to.isInstanceOf[StringType] => + nullSafeCastFunction(udt.sqlType, to, ctx) + case _ => + to match { + + case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;" + case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;" + case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) => + val tmp = ctx.freshVariable("tmp", classOf[Object]) + val dataTypeArg = ctx.addReferenceObj("dataType", to) + val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId) + val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) + val failOnError = evalMode != EvalMode.TRY + val cls = classOf[variant.VariantGet].getName + code""" + Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg, $zoneIdArg); + if ($tmp == null) { + $evNull = true; + } else { + $evPrim = (${CodeGenerator.boxedType(to)})$tmp; + } + """ + case VariantType => + val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$") + val fromArg = ctx.addReferenceObj("from", from) + (c, evPrim, evNull) => code"$evPrim = $cls.castToVariant($c, $fromArg);" + case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim) + case BinaryType => castToBinaryCode(from) + case DateType => castToDateCode(from, ctx) + case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) + case TimestampType => castToTimestampCode(from, ctx) + case TimestampNTZType => castToTimestampNTZCode(from, ctx) + case CalendarIntervalType => castToIntervalCode(from) + case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it) + case it: YearMonthIntervalType => castToYearMonthIntervalCode(from, it) + case BooleanType => castToBooleanCode(from, ctx) + case ByteType => castToByteCode(from, ctx) + case ShortType => castToShortCode(from, ctx) + case IntegerType => castToIntCode(from, ctx) + case FloatType => castToFloatCode(from, ctx) + case LongType => castToLongCode(from, ctx) + case DoubleType => castToDoubleCode(from, ctx) + + case array: ArrayType => + castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) + case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) + case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + case udt: UserDefinedType[_] if udt.acceptsType(from) => + (c, evPrim, evNull) => code"$evPrim = $c;" + case udt: UserDefinedType[_] => + nullSafeCastFunction(from, udt.sqlType, ctx) + case _ => + throw QueryExecutionErrors.cannotCastError(from, to) } - """ - case VariantType => - val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$") - val fromArg = ctx.addReferenceObj("from", from) - (c, evPrim, evNull) => code"$evPrim = $cls.castToVariant($c, $fromArg);" - case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim) - case BinaryType => castToBinaryCode(from) - case DateType => castToDateCode(from, ctx) - case decimal: DecimalType => castToDecimalCode(from, decimal, ctx) - case TimestampType => castToTimestampCode(from, ctx) - case TimestampNTZType => castToTimestampNTZCode(from, ctx) - case CalendarIntervalType => castToIntervalCode(from) - case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it) - case it: YearMonthIntervalType => castToYearMonthIntervalCode(from, it) - case BooleanType => castToBooleanCode(from, ctx) - case ByteType => castToByteCode(from, ctx) - case ShortType => castToShortCode(from, ctx) - case IntegerType => castToIntCode(from, ctx) - case FloatType => castToFloatCode(from, ctx) - case LongType => castToLongCode(from, ctx) - case DoubleType => castToDoubleCode(from, ctx) - - case array: ArrayType => - castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) - case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) - case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) - case udt: UserDefinedType[_] if udt.acceptsType(from) => - (c, evPrim, evNull) => code"$evPrim = $c;" - case _: UserDefinedType[_] => - throw QueryExecutionErrors.cannotCastError(from, to) + } } // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 4cffc7f0b53a3..362bb9af1661e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -441,47 +441,53 @@ case class Literal (value: Any, dataType: DataType) extends LeafExpression { override def eval(input: InternalRow): Any = value override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val javaType = CodeGenerator.javaType(dataType) - if (value == null) { - ExprCode.forNullValue(dataType) - } else { - def toExprCode(code: String): ExprCode = { - ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) - } - dataType match { - case BooleanType | IntegerType | DateType | _: YearMonthIntervalType => - toExprCode(value.toString) - case FloatType => - value.asInstanceOf[Float] match { - case v if v.isNaN => - toExprCode("Float.NaN") - case Float.PositiveInfinity => - toExprCode("Float.POSITIVE_INFINITY") - case Float.NegativeInfinity => - toExprCode("Float.NEGATIVE_INFINITY") - case _ => - toExprCode(s"${value}F") - } - case DoubleType => - value.asInstanceOf[Double] match { - case v if v.isNaN => - toExprCode("Double.NaN") - case Double.PositiveInfinity => - toExprCode("Double.POSITIVE_INFINITY") - case Double.NegativeInfinity => - toExprCode("Double.NEGATIVE_INFINITY") - case _ => - toExprCode(s"${value}D") - } - case ByteType | ShortType => - ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) - case TimestampType | TimestampNTZType | LongType | _: DayTimeIntervalType => - toExprCode(s"${value}L") - case _ => - val constRef = ctx.addReferenceObj("literal", value, javaType) - ExprCode.forNonNullValue(JavaCode.global(constRef, dataType)) + def gen(ctx: CodegenContext, ev: ExprCode, dataType: DataType): ExprCode = { + val javaType = CodeGenerator.javaType(dataType) + if (value == null) { + ExprCode.forNullValue(dataType) + } else { + def toExprCode(code: String): ExprCode = { + ExprCode.forNonNullValue(JavaCode.literal(code, dataType)) + } + + dataType match { + case BooleanType | IntegerType | DateType | _: YearMonthIntervalType => + toExprCode(value.toString) + case FloatType => + value.asInstanceOf[Float] match { + case v if v.isNaN => + toExprCode("Float.NaN") + case Float.PositiveInfinity => + toExprCode("Float.POSITIVE_INFINITY") + case Float.NegativeInfinity => + toExprCode("Float.NEGATIVE_INFINITY") + case _ => + toExprCode(s"${value}F") + } + case DoubleType => + value.asInstanceOf[Double] match { + case v if v.isNaN => + toExprCode("Double.NaN") + case Double.PositiveInfinity => + toExprCode("Double.POSITIVE_INFINITY") + case Double.NegativeInfinity => + toExprCode("Double.NEGATIVE_INFINITY") + case _ => + toExprCode(s"${value}D") + } + case ByteType | ShortType => + ExprCode.forNonNullValue(JavaCode.expression(s"($javaType)$value", dataType)) + case TimestampType | TimestampNTZType | LongType | _: DayTimeIntervalType => + toExprCode(s"${value}L") + case udt: UserDefinedType[_] => + gen(ctx, ev, udt.sqlType) + case _ => + val constRef = ctx.addReferenceObj("literal", value, javaType) + ExprCode.forNonNullValue(JavaCode.global(constRef, dataType)) + } } } + gen(ctx, ev, dataType) } override def sql: String = (value, dataType) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala index e87b54339821f..f915d6efeb827 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import java.time.{Duration, LocalDate, LocalDateTime, Period} +import java.time.{Duration, LocalDate, LocalDateTime, Period, Year => JYear} import java.time.temporal.ChronoUnit import java.util.{Calendar, Locale, TimeZone} @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} import org.apache.spark.sql.types.DayTimeIntervalType.{DAY, HOUR, MINUTE, SECOND} +import org.apache.spark.sql.types.TestUDT._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR} import org.apache.spark.unsafe.types.UTF8String @@ -1409,4 +1410,43 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(!Cast(timestampLiteral, TimestampNTZType).resolved) assert(!Cast(timestampNTZLiteral, TimestampType).resolved) } + + test("SPARK-49787: Cast between UDT and other types") { + val value = new MyDenseVector(Array(1.0, 2.0, -1.0)) + val udtType = new MyDenseVectorUDT() + val targetType = ArrayType(DoubleType, containsNull = false) + + val serialized = udtType.serialize(value) + + checkEvaluation(Cast(new Literal(serialized, udtType), targetType), serialized) + checkEvaluation(Cast(new Literal(serialized, targetType), udtType), serialized) + + val year = JYear.parse("2024") + val yearUDTType = new YearUDT() + + val yearSerialized = yearUDTType.serialize(year) + + checkEvaluation(Cast(new Literal(yearSerialized, yearUDTType), IntegerType), 2024) + checkEvaluation(Cast(new Literal(2024, IntegerType), yearUDTType), yearSerialized) + + val yearString = UTF8String.fromString("2024") + checkEvaluation(Cast(new Literal(yearSerialized, yearUDTType), StringType), yearString) + checkEvaluation(Cast(new Literal(yearString, StringType), yearUDTType), yearSerialized) + } +} + +private[sql] class YearUDT extends UserDefinedType[JYear] { + override def sqlType: DataType = IntegerType + + override def serialize(obj: JYear): Int = { + obj.getValue + } + + def deserialize(datum: Any): JYear = datum match { + case value: Int => JYear.of(value) + } + + override def userClass: Class[JYear] = classOf[JYear] + + private[spark] override def asNullable: YearUDT = this } From 4d70954b1aeb10767cea82250eb975e2c85f1f3b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 27 Sep 2024 14:21:33 -0700 Subject: [PATCH 110/250] [SPARK-49817][BUILD] Upgrade `gcs-connector` to `2.2.25` ### What changes were proposed in this pull request? This PR aims to upgrade `gcs-connector` to 2.2.25. ### Why are the changes needed? To bring the latest bug fixes. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass the CIs and manual test. ``` $ dev/make-distribution.sh -Phadoop-cloud $ cd dist $ export KEYFILE=~/.ssh/apache-spark.json $ export EMAIL=$(jq -r '.client_email' < $KEYFILE) $ export PRIVATE_KEY_ID=$(jq -r '.private_key_id' < $KEYFILE) $ export PRIVATE_KEY="$(jq -r '.private_key' < $KEYFILE)" $ bin/spark-shell \ -c spark.hadoop.fs.gs.auth.service.account.email=$EMAIL \ -c spark.hadoop.fs.gs.auth.service.account.private.key.id=$PRIVATE_KEY_ID \ -c spark.hadoop.fs.gs.auth.service.account.private.key="$PRIVATE_KEY" WARNING: Using incubator modules: jdk.incubator.vector Using Spark's default log4j profile: org/apache/spark/log4j2-pattern-layout-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 4.0.0-SNAPSHOT /_/ Using Scala version 2.13.15 (OpenJDK 64-Bit Server VM, Java 21.0.4) Type in expressions to have them evaluated. Type :help for more information. 24/09/27 09:34:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Spark context Web UI available at http://localhost:4040 Spark context available as 'sc' (master = local[*], app id = local-1727454893738). Spark session available as 'spark'. scala> spark.read.text("gs://apache-spark-bucket/README.md").count() val res0: Long = 124 scala> spark.read.orc("examples/src/main/resources/users.orc").write.mode("overwrite").orc("gs://apache-spark-bucket/users.orc") scala> spark.read.orc("gs://apache-spark-bucket/users.orc").show() +------+--------------+----------------+ | name|favorite_color|favorite_numbers| +------+--------------+----------------+ |Alyssa| NULL| [3, 9, 15, 20]| | Ben| red| []| +------+--------------+----------------+ scala> ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48285 from dongjoon-hyun/SPARK-49817. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index c9a32757554be..95a667ccfc72d 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -67,7 +67,7 @@ error_prone_annotations/2.26.1//error_prone_annotations-2.26.1.jar esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar failureaccess/1.0.2//failureaccess-1.0.2.jar flatbuffers-java/24.3.25//flatbuffers-java-24.3.25.jar -gcs-connector/hadoop3-2.2.21/shaded/gcs-connector-hadoop3-2.2.21-shaded.jar +gcs-connector/hadoop3-2.2.25/shaded/gcs-connector-hadoop3-2.2.25-shaded.jar gmetric4j/1.0.10//gmetric4j-1.0.10.jar gson/2.11.0//gson-2.11.0.jar guava/33.2.1-jre//guava-33.2.1-jre.jar diff --git a/pom.xml b/pom.xml index 22048b55da27f..4bdb92d86a727 100644 --- a/pom.xml +++ b/pom.xml @@ -161,7 +161,7 @@ 0.12.8 - hadoop3-2.2.21 + hadoop3-2.2.25 4.5.14 4.4.16 From d813f5467e930ed4a22d2ea6aa4333cf379ea7f9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Fri, 27 Sep 2024 19:41:37 -0400 Subject: [PATCH 111/250] [SPARK-49417][CONNECT][SQL] Add Shared StreamingQueryManager interface ### What changes were proposed in this pull request? This PR adds a shared StreamingQueryManager interface. ### Why are the changes needed? We are working on a shared Scala SQL interface for Classic and Connect. This change is part of this work. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48217 from hvanhovell/SPARK-49417. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/SparkSession.scala | 1 + .../sql/streaming/StreamingQueryManager.scala | 93 ++----------- .../apache/spark/sql/api/SparkSession.scala | 11 +- .../spark/sql/api/StreamingQueryManager.scala | 130 ++++++++++++++++++ .../org/apache/spark/sql/SparkSession.scala | 7 +- .../sql/streaming/StreamingQueryManager.scala | 95 ++----------- 6 files changed, 169 insertions(+), 168 deletions(-) create mode 100644 sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQueryManager.scala diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 1b41566ca1d1d..b31670c1da57e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -212,6 +212,7 @@ class SparkSession private[sql] ( /** @inheritdoc */ def readStream: DataStreamReader = new DataStreamReader(this) + /** @inheritdoc */ lazy val streams: StreamingQueryManager = new StreamingQueryManager(this) /** @inheritdoc */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 7efced227d6d1..647d29c714dbb 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -27,7 +27,7 @@ import org.apache.spark.connect.proto.Command import org.apache.spark.connect.proto.StreamingQueryManagerCommand import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{api, SparkSession} import org.apache.spark.sql.connect.common.InvalidPlanInput /** @@ -36,7 +36,9 @@ import org.apache.spark.sql.connect.common.InvalidPlanInput * @since 3.5.0 */ @Evolving -class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Logging { +class StreamingQueryManager private[sql] (sparkSession: SparkSession) + extends api.StreamingQueryManager + with Logging { // Mapping from id to StreamingQueryListener. There's another mapping from id to // StreamingQueryListener on server side. This is used by removeListener() to find the id @@ -53,29 +55,17 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo streamingQueryListenerBus.close() } - /** - * Returns a list of active queries associated with this SQLContext - * - * @since 3.5.0 - */ + /** @inheritdoc */ def active: Array[StreamingQuery] = { executeManagerCmd(_.setActive(true)).getActive.getActiveQueriesList.asScala.map { q => RemoteStreamingQuery.fromStreamingQueryInstanceResponse(sparkSession, q) }.toArray } - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def get(id: UUID): StreamingQuery = get(id.toString) - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def get(id: String): StreamingQuery = { val response = executeManagerCmd(_.setGetQuery(id)) if (response.hasQuery) { @@ -85,52 +75,13 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the creation - * of the context, or since `resetTerminated()` was called. If any query was terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return immediately (if the query was terminated by `query.stop()`), or throw the exception - * immediately (if the query was terminated with exception). Use `resetTerminated()` to clear - * past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, if - * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the - * exception. For correctly documenting exceptions across multiple queries, users need to stop - * all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException - * if any query has terminated with an exception - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(): Unit = { executeManagerCmd(_.getAwaitAnyTerminationBuilder.build()) } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the creation - * of the context, or since `resetTerminated()` was called. Returns whether any query has - * terminated or not (multiple may have terminated). If any query has terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return `true` immediately (if the query was terminated by `query.stop()`), or throw the - * exception immediately (if the query was terminated with exception). Use `resetTerminated()` - * to clear past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, if - * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the - * exception. For correctly documenting exceptions across multiple queries, users need to stop - * all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException - * if any query has terminated with an exception - * @since 3.5.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(timeoutMs: Long): Boolean = { require(timeoutMs > 0, "Timeout has to be positive") @@ -139,40 +90,22 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo timeoutMs)).getAwaitAnyTermination.getTerminated } - /** - * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to - * wait for new terminations. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def resetTerminated(): Unit = { executeManagerCmd(_.setResetTerminated(true)) } - /** - * Register a [[StreamingQueryListener]] to receive up-calls for life cycle events of - * [[StreamingQuery]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def addListener(listener: StreamingQueryListener): Unit = { streamingQueryListenerBus.append(listener) } - /** - * Deregister a [[StreamingQueryListener]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def removeListener(listener: StreamingQueryListener): Unit = { streamingQueryListenerBus.remove(listener) } - /** - * List all [[StreamingQueryListener]]s attached to this [[StreamingQueryManager]]. - * - * @since 3.5.0 - */ + /** @inheritdoc */ def listListeners(): Array[StreamingQueryListener] = { streamingQueryListenerBus.list() } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 0f73a94c3c4a4..4dfeb87a11d92 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -25,7 +25,7 @@ import _root_.java.lang import _root_.java.net.URI import _root_.java.util -import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.sql.{Encoder, Row, RuntimeConfig} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SparkClassUtils @@ -93,6 +93,15 @@ abstract class SparkSession extends Serializable with Closeable { */ def udf: UDFRegistration + /** + * Returns a `StreamingQueryManager` that allows managing all the `StreamingQuery`s active on + * `this`. + * + * @since 2.0.0 + */ + @Unstable + def streams: StreamingQueryManager + /** * Start a new session with isolated SQL configurations, temporary tables, registered functions * are isolated, but sharing the underlying `SparkContext` and cached data. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQueryManager.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQueryManager.scala new file mode 100644 index 0000000000000..88ba9a493d063 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/StreamingQueryManager.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.api + +import _root_.java.util.UUID + +import org.apache.spark.annotation.Evolving +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamingQueryListener} + +/** + * A class to manage all the [[StreamingQuery]] active in a `SparkSession`. + * + * @since 2.0.0 + */ +@Evolving +abstract class StreamingQueryManager { + + /** + * Returns a list of active queries associated with this SQLContext + * + * @since 2.0.0 + */ + def active: Array[_ <: StreamingQuery] + + /** + * Returns the query if there is an active query with the given id, or null. + * + * @since 2.1.0 + */ + def get(id: UUID): StreamingQuery + + /** + * Returns the query if there is an active query with the given id, or null. + * + * @since 2.1.0 + */ + def get(id: String): StreamingQuery + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the creation + * of the context, or since `resetTerminated()` was called. If any query was terminated with an + * exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return immediately (if the query was terminated by `query.stop()`), or throw the exception + * immediately (if the query was terminated with exception). Use `resetTerminated()` to clear + * past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, if + * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the + * exception. For correctly documenting exceptions across multiple queries, users need to stop + * all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws org.apache.spark.sql.streaming.StreamingQueryException + * if any query has terminated with an exception + * @since 2.0.0 + */ + @throws[StreamingQueryException] + def awaitAnyTermination(): Unit + + /** + * Wait until any of the queries on the associated SQLContext has terminated since the creation + * of the context, or since `resetTerminated()` was called. Returns whether any query has + * terminated or not (multiple may have terminated). If any query has terminated with an + * exception, then the exception will be thrown. + * + * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either + * return `true` immediately (if the query was terminated by `query.stop()`), or throw the + * exception immediately (if the query was terminated with exception). Use `resetTerminated()` + * to clear past terminations and wait for new terminations. + * + * In the case where multiple queries have terminated since `resetTermination()` was called, if + * any query has terminated with exception, then `awaitAnyTermination()` will throw any of the + * exception. For correctly documenting exceptions across multiple queries, users need to stop + * all of them after any of them terminates with exception, and then check the + * `query.exception()` for each query. + * + * @throws org.apache.spark.sql.streaming.StreamingQueryException + * if any query has terminated with an exception + * @since 2.0.0 + */ + @throws[StreamingQueryException] + def awaitAnyTermination(timeoutMs: Long): Boolean + + /** + * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to + * wait for new terminations. + * + * @since 2.0.0 + */ + def resetTerminated(): Unit + + /** + * Register a [[org.apache.spark.sql.streaming.StreamingQueryListener]] to receive up-calls for + * life cycle events of [[StreamingQuery]]. + * + * @since 2.0.0 + */ + def addListener(listener: StreamingQueryListener): Unit + + /** + * Deregister a [[org.apache.spark.sql.streaming.StreamingQueryListener]]. + * + * @since 2.0.0 + */ + def removeListener(listener: StreamingQueryListener): Unit + + /** + * List all [[org.apache.spark.sql.streaming.StreamingQueryListener]]s attached to this + * [[StreamingQueryManager]]. + * + * @since 3.0.0 + */ + def listListeners(): Array[StreamingQueryListener] +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 983cc24718fd2..eeb46fbf145d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -229,12 +229,7 @@ class SparkSession private( @Unstable def dataSource: DataSourceRegistration = sessionState.dataSourceRegistration - /** - * Returns a `StreamingQueryManager` that allows managing all the - * `StreamingQuery`s active on `this`. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @Unstable def streams: StreamingQueryManager = sessionState.streamingQueryManager diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 9d6fd2e28dea4..42f6d04466b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -27,7 +27,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Evolving import org.apache.spark.internal.{Logging, MDC} import org.apache.spark.internal.LogKeys.{CLASS_NAME, QUERY_ID, RUN_ID} -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{api, Dataset, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.streaming.{WriteToStream, WriteToStreamStatement} import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog} @@ -47,7 +47,9 @@ import org.apache.spark.util.{Clock, SystemClock, Utils} @Evolving class StreamingQueryManager private[sql] ( sparkSession: SparkSession, - sqlConf: SQLConf) extends Logging { + sqlConf: SQLConf) + extends api.StreamingQueryManager + with Logging { private[sql] val stateStoreCoordinator = StateStoreCoordinatorRef.forDriver(sparkSession.sparkContext.env) @@ -70,7 +72,7 @@ class StreamingQueryManager private[sql] ( * failed. The exception is the exception of the last failed query. */ @GuardedBy("awaitTerminationLock") - private var lastTerminatedQueryException: Option[StreamingQueryException] = null + private var lastTerminatedQueryException: Option[StreamingQueryException] = _ try { sparkSession.sparkContext.conf.get(STREAMING_QUERY_LISTENERS).foreach { classNames => @@ -90,51 +92,20 @@ class StreamingQueryManager private[sql] ( throw QueryExecutionErrors.registeringStreamingQueryListenerError(e) } - /** - * Returns a list of active queries associated with this SQLContext - * - * @since 2.0.0 - */ + /** @inheritdoc */ def active: Array[StreamingQuery] = activeQueriesSharedLock.synchronized { activeQueries.values.toArray } - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 2.1.0 - */ + /** @inheritdoc */ def get(id: UUID): StreamingQuery = activeQueriesSharedLock.synchronized { activeQueries.get(id).orNull } - /** - * Returns the query if there is an active query with the given id, or null. - * - * @since 2.1.0 - */ + /** @inheritdoc */ def get(id: String): StreamingQuery = get(UUID.fromString(id)) - /** - * Wait until any of the queries on the associated SQLContext has terminated since the - * creation of the context, or since `resetTerminated()` was called. If any query was terminated - * with an exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return immediately (if the query was terminated by `query.stop()`), - * or throw the exception immediately (if the query was terminated with exception). Use - * `resetTerminated()` to clear past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, - * if any query has terminated with exception, then `awaitAnyTermination()` will - * throw any of the exception. For correctly documenting exceptions across multiple queries, - * users need to stop all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException if any query has terminated with an exception - * - * @since 2.0.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(): Unit = { awaitTerminationLock.synchronized { @@ -147,27 +118,7 @@ class StreamingQueryManager private[sql] ( } } - /** - * Wait until any of the queries on the associated SQLContext has terminated since the - * creation of the context, or since `resetTerminated()` was called. Returns whether any query - * has terminated or not (multiple may have terminated). If any query has terminated with an - * exception, then the exception will be thrown. - * - * If a query has terminated, then subsequent calls to `awaitAnyTermination()` will either - * return `true` immediately (if the query was terminated by `query.stop()`), - * or throw the exception immediately (if the query was terminated with exception). Use - * `resetTerminated()` to clear past terminations and wait for new terminations. - * - * In the case where multiple queries have terminated since `resetTermination()` was called, - * if any query has terminated with exception, then `awaitAnyTermination()` will - * throw any of the exception. For correctly documenting exceptions across multiple queries, - * users need to stop all of them after any of them terminates with exception, and then check the - * `query.exception()` for each query. - * - * @throws StreamingQueryException if any query has terminated with an exception - * - * @since 2.0.0 - */ + /** @inheritdoc */ @throws[StreamingQueryException] def awaitAnyTermination(timeoutMs: Long): Boolean = { @@ -187,42 +138,24 @@ class StreamingQueryManager private[sql] ( } } - /** - * Forget about past terminated queries so that `awaitAnyTermination()` can be used again to - * wait for new terminations. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def resetTerminated(): Unit = { awaitTerminationLock.synchronized { lastTerminatedQueryException = null } } - /** - * Register a [[StreamingQueryListener]] to receive up-calls for life cycle events of - * [[StreamingQuery]]. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def addListener(listener: StreamingQueryListener): Unit = { listenerBus.addListener(listener) } - /** - * Deregister a [[StreamingQueryListener]]. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def removeListener(listener: StreamingQueryListener): Unit = { listenerBus.removeListener(listener) } - /** - * List all [[StreamingQueryListener]]s attached to this [[StreamingQueryManager]]. - * - * @since 3.0.0 - */ + /** @inheritdoc */ def listListeners(): Array[StreamingQueryListener] = { listenerBus.listeners.asScala.toArray } From 0c1905951f8c31482b0f5ea334c29c13a83cc3c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B8rn=20J=C3=B8rgensen?= Date: Sat, 28 Sep 2024 08:52:13 +0900 Subject: [PATCH 112/250] [SPARK-49820][PYTHON] Change `raise IOError` to `raise OSError` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Change `raise IOError` to `raise OSError` ### Why are the changes needed? > OSError is the builtin error type used for exceptions that relate to the operating system. > > In Python 3.3, a variety of other exceptions, like WindowsError were aliased to OSError. These aliases remain in place for compatibility with older versions of Python, but may be removed in future versions. > > Prefer using OSError directly, as it is more idiomatic and future-proof. > [RUFF rule](https://docs.astral.sh/ruff/rules/os-error-alias/) [Python OSError](https://docs.python.org/3/library/exceptions.html#OSError) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48287 from bjornjorgensen/IOError-to--OSError. Authored-by: Bjørn Jørgensen Signed-off-by: Hyukjin Kwon --- python/pyspark/install.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/install.py b/python/pyspark/install.py index 90b0150b0a8ca..ba67a157e964d 100644 --- a/python/pyspark/install.py +++ b/python/pyspark/install.py @@ -163,7 +163,7 @@ def install_spark(dest, spark_version, hadoop_version, hive_version): tar.close() if os.path.exists(package_local_path): os.remove(package_local_path) - raise IOError("Unable to download %s." % pretty_pkg_name) + raise OSError("Unable to download %s." % pretty_pkg_name) def get_preferred_mirrors(): From f9a2077fd32faf63796a68cbb3483b486f220b1c Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 28 Sep 2024 16:21:30 +0900 Subject: [PATCH 113/250] [SPARK-49810][PYTHON] Extract the preparation of `DataFrame.sort` to parent class ### What changes were proposed in this pull request? Extract the preparation of df.sort to parent class ### Why are the changes needed? deduplicate code, the logics in two classes are similar ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48282 from zhengruifeng/py_sql_sort. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/classic/dataframe.py | 52 +++-------------------- python/pyspark/sql/connect/dataframe.py | 53 ++--------------------- python/pyspark/sql/dataframe.py | 56 +++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 96 deletions(-) diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index 0dd66a9d86545..9f9dedbd38207 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -55,6 +55,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync from pyspark.sql.column import Column +from pyspark.sql.functions import builtin as F from pyspark.sql.classic.column import _to_seq, _to_list, _to_java_column from pyspark.sql.readwriter import DataFrameWriter, DataFrameWriterV2 from pyspark.sql.merge import MergeIntoWriter @@ -873,7 +874,8 @@ def sortWithinPartitions( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs)) + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) + jdf = self._jdf.sortWithinPartitions(self._jseq(_cols, _to_java_column)) return DataFrame(jdf, self.sparkSession) def sort( @@ -881,7 +883,8 @@ def sort( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: - jdf = self._jdf.sort(self._sort_cols(cols, kwargs)) + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) + jdf = self._jdf.sort(self._jseq(_cols, _to_java_column)) return DataFrame(jdf, self.sparkSession) orderBy = sort @@ -928,51 +931,6 @@ def _jcols_ordinal(self, *cols: "ColumnOrNameOrOrdinal") -> "JavaObject": _cols.append(c) # type: ignore[arg-type] return self._jseq(_cols, _to_java_column) - def _sort_cols( - self, - cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], - kwargs: Dict[str, Any], - ) -> "JavaObject": - """Return a JVM Seq of Columns that describes the sort order""" - if not cols: - raise PySparkValueError( - errorClass="CANNOT_BE_EMPTY", - messageParameters={"item": "column"}, - ) - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - - jcols = [] - for c in cols: - if isinstance(c, int) and not isinstance(c, bool): - # ordinal is 1-based - if c > 0: - _c = self[c - 1] - # negative ordinal means sort by desc - elif c < 0: - _c = self[-c - 1].desc() - else: - raise PySparkIndexError( - errorClass="ZERO_INDEX", - messageParameters={}, - ) - else: - _c = c # type: ignore[assignment] - jcols.append(_to_java_column(cast("ColumnOrName", _c))) - - ascending = kwargs.get("ascending", True) - if isinstance(ascending, (bool, int)): - if not ascending: - jcols = [jc.desc() for jc in jcols] - elif isinstance(ascending, list): - jcols = [jc if asc else jc.desc() for asc, jc in zip(ascending, jcols)] - else: - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_LIST", - messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, - ) - return self._jseq(jcols) - def describe(self, *cols: Union[str, List[str]]) -> ParentDataFrame: if len(cols) == 1 and isinstance(cols[0], list): cols = cols[0] # type: ignore[assignment] diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 146cfe11bc502..136fe60532df4 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -739,62 +739,16 @@ def limit(self, num: int) -> ParentDataFrame: def tail(self, num: int) -> List[Row]: return DataFrame(plan.Tail(child=self._plan, limit=num), session=self._session).collect() - def _sort_cols( - self, - cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], - kwargs: Dict[str, Any], - ) -> List[Column]: - """Return a JVM Seq of Columns that describes the sort order""" - if cols is None: - raise PySparkValueError( - errorClass="CANNOT_BE_EMPTY", - messageParameters={"item": "cols"}, - ) - - if len(cols) == 1 and isinstance(cols[0], list): - cols = cols[0] - - _cols: List[Column] = [] - for c in cols: - if isinstance(c, int) and not isinstance(c, bool): - # ordinal is 1-based - if c > 0: - _c = self[c - 1] - # negative ordinal means sort by desc - elif c < 0: - _c = self[-c - 1].desc() - else: - raise PySparkIndexError( - errorClass="ZERO_INDEX", - messageParameters={}, - ) - else: - _c = c # type: ignore[assignment] - _cols.append(F._to_col(cast("ColumnOrName", _c))) - - ascending = kwargs.get("ascending", True) - if isinstance(ascending, (bool, int)): - if not ascending: - _cols = [c.desc() for c in _cols] - elif isinstance(ascending, list): - _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)] - else: - raise PySparkTypeError( - errorClass="NOT_BOOL_OR_LIST", - messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, - ) - - return [F._sort_col(c) for c in _cols] - def sort( self, *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) res = DataFrame( plan.Sort( self._plan, - columns=self._sort_cols(cols, kwargs), + columns=[F._sort_col(c) for c in _cols], is_global=True, ), session=self._session, @@ -809,10 +763,11 @@ def sortWithinPartitions( *cols: Union[int, str, Column, List[Union[int, str, Column]]], **kwargs: Any, ) -> ParentDataFrame: + _cols = self._preapare_cols_for_sort(F.col, cols, kwargs) res = DataFrame( plan.Sort( self._plan, - columns=self._sort_cols(cols, kwargs), + columns=[F._sort_col(c) for c in _cols], is_global=False, ), session=self._session, diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 142034583dbd2..5906108163b46 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2891,6 +2891,62 @@ def sort( """ ... + def _preapare_cols_for_sort( + self, + _to_col: Callable[[str], Column], + cols: Sequence[Union[int, str, Column, List[Union[int, str, Column]]]], + kwargs: Dict[str, Any], + ) -> Sequence[Column]: + from pyspark.errors import PySparkTypeError, PySparkValueError, PySparkIndexError + + if not cols: + raise PySparkValueError( + errorClass="CANNOT_BE_EMPTY", messageParameters={"item": "cols"} + ) + + if len(cols) == 1 and isinstance(cols[0], list): + cols = cols[0] + + _cols: List[Column] = [] + for c in cols: + if isinstance(c, int) and not isinstance(c, bool): + # ordinal is 1-based + if c > 0: + _cols.append(self[c - 1]) + # negative ordinal means sort by desc + elif c < 0: + _cols.append(self[-c - 1].desc()) + else: + raise PySparkIndexError( + errorClass="ZERO_INDEX", + messageParameters={}, + ) + elif isinstance(c, Column): + _cols.append(c) + elif isinstance(c, str): + _cols.append(_to_col(c)) + else: + raise PySparkTypeError( + errorClass="NOT_COLUMN_OR_INT_OR_STR", + messageParameters={ + "arg_name": "col", + "arg_type": type(c).__name__, + }, + ) + + ascending = kwargs.get("ascending", True) + if isinstance(ascending, (bool, int)): + if not ascending: + _cols = [c.desc() for c in _cols] + elif isinstance(ascending, list): + _cols = [c if asc else c.desc() for asc, c in zip(ascending, _cols)] + else: + raise PySparkTypeError( + errorClass="NOT_COLUMN_OR_INT_OR_STR", + messageParameters={"arg_name": "ascending", "arg_type": type(ascending).__name__}, + ) + return _cols + orderBy = sort @dispatch_df_method From 4c12c78801b8de39020981678ec426af8bea00f3 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 28 Sep 2024 16:25:00 +0900 Subject: [PATCH 114/250] [SPARK-49814][CONNECT] When Spark Connect Client starts, show the `spark version` of the `connect server` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? The pr aims to show the spark version of the connect server when Spark Connect Client starts. ### Why are the changes needed? With the gradual popularize of Spark Connect module, when the Spark Connect client starts, explicitly displaying the spark version of the `connect server`, will reduce confusion for users during execution, such as the new version having some features. However, if it connects to an old version and encounters some problems, it will have to manually troubleshoot. image ### Does this PR introduce _any_ user-facing change? Yes, Connect‘s end-users can intuitively know the `Spark version` on the `server side` when starting the client, reducing confusion. ### How was this patch tested? Manually check. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48283 from panbingkun/SPARK-49814. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- .../scala/org/apache/spark/sql/application/ConnectRepl.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala index 63fa2821a6c6a..bff6db25a21f2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/application/ConnectRepl.scala @@ -50,8 +50,9 @@ object ConnectRepl { /_/ Type in expressions to have them evaluated. +Spark connect server version %s. Spark session available as 'spark'. - """.format(spark_version) + """ def main(args: Array[String]): Unit = doMain(args) @@ -102,7 +103,7 @@ Spark session available as 'spark'. // Please note that we make ammonite generate classes instead of objects. // Classes tend to have superior serialization behavior when using UDFs. val main = new ammonite.Main( - welcomeBanner = Option(splash), + welcomeBanner = Option(splash.format(spark_version, spark.version)), predefCode = predefCode, replCodeWrapper = ExtendedCodeClassWrapper, scriptCodeWrapper = ExtendedCodeClassWrapper, From 550c2071bf8e1e740e595ae9321ae11015d77917 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 28 Sep 2024 16:23:06 -0700 Subject: [PATCH 115/250] [SPARK-49822][SQL][TESTS] Update postgres docker image to 17.0 ### What changes were proposed in this pull request? This PR aims to update the `postgres` docker image from `16.3` to `17.0`. ### Why are the changes needed? This will help Apache Spark test the latest postgres. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48291 from panbingkun/SPARK-49822. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/jdbc/PostgresIntegrationSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala | 6 +++--- .../spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala | 6 +++--- .../spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala | 6 +++--- .../apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 3076b599ef4ef..071b976f044c3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.PostgresIntegrationSuite" * }}} @@ -42,7 +42,7 @@ import org.apache.spark.tags.DockerTest @DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala index 5acb6423bbd9b..62f9c6e0256f3 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresKrbIntegrationSuite.scala @@ -25,9 +25,9 @@ import org.apache.spark.sql.execution.datasources.jdbc.connection.SecureConnecti import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly *PostgresKrbIntegrationSuite" * }}} @@ -38,7 +38,7 @@ class PostgresKrbIntegrationSuite extends DockerKrbJDBCIntegrationSuite { override protected val keytabFileName = "postgres.keytab" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala index 8d367f476403f..a79bbf39a71b8 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/GeneratedSubquerySuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.tags.DockerTest /** * This suite is used to generate subqueries, and test Spark against Postgres. - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.GeneratedSubquerySuite" * }}} @@ -39,7 +39,7 @@ import org.apache.spark.tags.DockerTest class GeneratedSubquerySuite extends DockerJDBCIntegrationSuite with QueryGeneratorHelper { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala index f3a08541365c1..80ba35df6c893 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/querytest/PostgreSQLQueryTestSuite.scala @@ -30,9 +30,9 @@ import org.apache.spark.tags.DockerTest * confidence, and you won't have to manually verify the golden files generated with your test. * 2. Add this line to your .sql file: --ONLY_IF spark * - * Note: To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * Note: To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests * "testOnly org.apache.spark.sql.jdbc.PostgreSQLQueryTestSuite" * }}} @@ -45,7 +45,7 @@ class PostgreSQLQueryTestSuite extends CrossDbmsQueryTestSuite { protected val customInputFilePath: String = new File(inputFilePath, "subquery").getAbsolutePath override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 850391e8dc33c..6bb415a928837 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine) + * To run this test suite for a specific version (e.g., postgres:17.0-alpine) * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresIntegrationSuite" * }}} */ @@ -38,7 +38,7 @@ import org.apache.spark.tags.DockerTest class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index 665746f1d5770..6d4f1cc2fd3fc 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -26,16 +26,16 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., postgres:16.4-alpine): + * To run this test suite for a specific version (e.g., postgres:17.0-alpine): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:16.4-alpine + * ENABLE_DOCKER_INTEGRATION_TESTS=1 POSTGRES_DOCKER_IMAGE_NAME=postgres:17.0-alpine * ./build/sbt -Pdocker-integration-tests "testOnly *v2.PostgresNamespaceSuite" * }}} */ @DockerTest class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:16.4-alpine") + override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:17.0-alpine") override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) From 47d2c9ca064e9d80a444d21cfac47ca334230242 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 28 Sep 2024 16:27:13 -0700 Subject: [PATCH 116/250] [SPARK-49712][SQL] Remove encoderFor from connect-client-jvm ### What changes were proposed in this pull request? This PR removes `sql.encoderFor` from the connect-client-jvm module and replaces it by `AgnosticEncoders.agnosticEncoderFor`. ### Why are the changes needed? It will cause a clash when we swap the interface and the implementation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48266 from hvanhovell/SPARK-49712. Authored-by: Herman van Hovell Signed-off-by: Dongjoon Hyun --- .../main/scala/org/apache/spark/sql/Dataset.scala | 10 +++++----- .../apache/spark/sql/KeyValueGroupedDataset.scala | 14 +++++++------- .../spark/sql/RelationalGroupedDataset.scala | 7 ++++++- .../scala/org/apache/spark/sql/SparkSession.scala | 4 ++-- .../spark/sql/internal/UdfToProtoUtils.scala | 10 +++++----- .../main/scala/org/apache/spark/sql/package.scala | 6 ------ .../apache/spark/sql/SQLImplicitsTestSuite.scala | 3 ++- .../connect/client/arrow/ArrowEncoderSuite.scala | 8 ++++---- 8 files changed, 31 insertions(+), 31 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index d2877ccaf06c9..6bae04ef80231 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -143,7 +143,7 @@ class Dataset[T] private[sql] ( // Make sure we don't forget to set plan id. assert(plan.getRoot.getCommon.hasPlanId) - private[sql] val agnosticEncoder: AgnosticEncoder[T] = encoderFor(encoder) + private[sql] val agnosticEncoder: AgnosticEncoder[T] = agnosticEncoderFor(encoder) override def toString: String = { try { @@ -437,7 +437,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { - val encoder = encoderFor(c1.encoder) + val encoder = agnosticEncoderFor(c1.encoder) val col = if (encoder.schema == encoder.dataType) { functions.inline(functions.array(c1)) } else { @@ -452,7 +452,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val encoder = ProductEncoder.tuple(columns.map(c => encoderFor(c.encoder))) + val encoder = ProductEncoder.tuple(columns.map(c => agnosticEncoderFor(c.encoder))) selectUntyped(encoder, columns) } @@ -526,7 +526,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { - KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) + KeyValueGroupedDatasetImpl[K, T](this, agnosticEncoderFor[K], func) } /** @inheritdoc */ @@ -881,7 +881,7 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ def mapPartitions[U: Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = { - val outputEncoder = encoderFor[U] + val outputEncoder = agnosticEncoderFor[U] val udf = SparkUserDefinedFunction( function = func, inputEncoders = agnosticEncoder :: Nil, diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6bf2518901470..63b5f27c4745e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, ProductEncoder} import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.UdfUtils import org.apache.spark.sql.expressions.SparkUserDefinedFunction @@ -398,7 +398,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( new KeyValueGroupedDatasetImpl[L, V, IK, IV]( sparkSession, plan, - encoderFor[L], + agnosticEncoderFor[L], ivEncoder, vEncoder, groupingExprs, @@ -412,7 +412,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( plan, kEncoder, ivEncoder, - encoderFor[W], + agnosticEncoderFor[W], groupingExprs, valueMapFunc .map(_.andThen(valueFunc)) @@ -430,7 +430,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( f: (K, Iterator[V]) => IterableOnce[U]): Dataset[U] = { // Apply mapValues changes to the udf val nf = UDFAdaptors.flatMapGroupsWithMappedValues(f, valueMapFunc) - val outputEncoder = encoderFor[U] + val outputEncoder = agnosticEncoderFor[U] sparkSession.newDataset[U](outputEncoder) { builder => builder.getGroupMapBuilder .setInput(plan.getRoot) @@ -446,7 +446,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( val otherImpl = other.asInstanceOf[KeyValueGroupedDatasetImpl[K, U, _, Any]] // Apply mapValues changes to the udf val nf = UDFAdaptors.coGroupWithMappedValues(f, valueMapFunc, otherImpl.valueMapFunc) - val outputEncoder = encoderFor[R] + val outputEncoder = agnosticEncoderFor[R] sparkSession.newDataset[R](outputEncoder) { builder => builder.getCoGroupMapBuilder .setInput(plan.getRoot) @@ -461,7 +461,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( override protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { // TODO(SPARK-43415): For each column, apply the valueMap func first... - val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => encoderFor(c.encoder))) + val rEnc = ProductEncoder.tuple(kEncoder +: columns.map(c => agnosticEncoderFor(c.encoder))) sparkSession.newDataset(rEnc) { builder => builder.getAggregateBuilder .setInput(plan.getRoot) @@ -501,7 +501,7 @@ private class KeyValueGroupedDatasetImpl[K, V, IK, IV]( null } - val outputEncoder = encoderFor[U] + val outputEncoder = agnosticEncoderFor[U] val nf = UDFAdaptors.flatMapGroupsWithStateWithMappedValues(func, valueMapFunc) sparkSession.newDataset[U](outputEncoder) { builder => diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 14ceb3f4bb144..5bded40b0d132 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.ConnectConversions._ /** @@ -82,7 +83,11 @@ class RelationalGroupedDataset private[sql] ( /** @inheritdoc */ def as[K: Encoder, T: Encoder]: KeyValueGroupedDataset[K, T] = { - KeyValueGroupedDatasetImpl[K, T](df, encoderFor[K], encoderFor[T], groupingExprs) + KeyValueGroupedDatasetImpl[K, T]( + df, + agnosticEncoderFor[K], + agnosticEncoderFor[T], + groupingExprs) } /** @inheritdoc */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index b31670c1da57e..222b5ea79508e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -36,7 +36,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BoxedLongEncoder, UnboundRowEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder} import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult} import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration import org.apache.spark.sql.connect.client.arrow.ArrowSerializer @@ -136,7 +136,7 @@ class SparkSession private[sql] ( /** @inheritdoc */ def createDataset[T: Encoder](data: Seq[T]): Dataset[T] = { - createDataset(encoderFor[T], data.iterator) + createDataset(agnosticEncoderFor[T], data.iterator) } /** @inheritdoc */ diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala index 85ce2cb820437..409c43f480b8e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/UdfToProtoUtils.scala @@ -25,9 +25,9 @@ import com.google.protobuf.ByteString import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.common.DataTypeProtoConverter.toConnectProtoType import org.apache.spark.sql.connect.common.UdfPacket -import org.apache.spark.sql.encoderFor import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction} import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils} @@ -79,12 +79,12 @@ private[sql] object UdfToProtoUtils { udf match { case f: SparkUserDefinedFunction => val outputEncoder = f.outputEncoder - .map(e => encoderFor(e)) + .map(e => agnosticEncoderFor(e)) .getOrElse(RowEncoder.encoderForDataType(f.dataType, lenient = false)) val inputEncoders = if (f.inputEncoders.forall(_.isEmpty)) { Nil // Java UDFs have no bindings for their inputs. } else { - f.inputEncoders.map(e => encoderFor(e.get)) // TODO support Any and UnboundRow. + f.inputEncoders.map(e => agnosticEncoderFor(e.get)) // TODO support Any and UnboundRow. } inputEncoders.foreach(e => protoUdf.addInputTypes(toConnectProtoType(e.dataType))) protoUdf @@ -93,8 +93,8 @@ private[sql] object UdfToProtoUtils { .setAggregate(false) f.givenName.foreach(invokeUdf.setFunctionName) case f: UserDefinedAggregator[_, _, _] => - val outputEncoder = encoderFor(f.aggregator.outputEncoder) - val inputEncoder = encoderFor(f.inputEncoder) + val outputEncoder = agnosticEncoderFor(f.aggregator.outputEncoder) + val inputEncoder = agnosticEncoderFor(f.inputEncoder) protoUdf .setPayload(toUdfPacketBytes(f.aggregator, inputEncoder :: Nil, outputEncoder)) .addInputTypes(toConnectProtoType(inputEncoder.dataType)) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala index 556b472283a37..ada94b76fcbcd 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala @@ -17,12 +17,6 @@ package org.apache.spark -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder - package object sql { type DataFrame = Dataset[Row] - - private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = { - implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]] - } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala index 57342e12fcb51..b3b8020b1e4c7 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SQLImplicitsTestSuite.scala @@ -26,6 +26,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.commons.lang3.SystemUtils import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.agnosticEncoderFor import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.connect.client.arrow.{ArrowDeserializers, ArrowSerializer} import org.apache.spark.sql.test.ConnectFunSuite @@ -55,7 +56,7 @@ class SQLImplicitsTestSuite extends ConnectFunSuite with BeforeAndAfterAll { import org.apache.spark.util.ArrayImplicits._ import spark.implicits._ def testImplicit[T: Encoder](expected: T): Unit = { - val encoder = encoderFor[T] + val encoder = agnosticEncoderFor[T] val allocator = new RootAllocator() try { val batch = ArrowSerializer.serialize( diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 5397dae9dcc5f..7176c582d0bbc 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -30,11 +30,11 @@ import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector import org.scalatest.BeforeAndAfterAll -import org.apache.spark.{sql, SparkRuntimeException, SparkUnsupportedOperationException} +import org.apache.spark.{SparkRuntimeException, SparkUnsupportedOperationException} import org.apache.spark.sql.{AnalysisException, Encoders, Row} import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec, OuterScopes} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, RowEncoder, ScalaDecimalEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => toRowEncoder} import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_SECOND @@ -770,7 +770,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { } test("java serialization") { - val encoder = sql.encoderFor(Encoders.javaSerialization[(Int, String)]) + val encoder = agnosticEncoderFor(Encoders.javaSerialization[(Int, String)]) roundTripAndCheckIdentical(encoder) { () => Iterator.tabulate(10)(i => (i, "itr_" + i)) } @@ -778,7 +778,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { test("kryo serialization") { val e = intercept[SparkRuntimeException] { - val encoder = sql.encoderFor(Encoders.kryo[(Int, String)]) + val encoder = agnosticEncoderFor(Encoders.kryo[(Int, String)]) roundTripAndCheckIdentical(encoder) { () => Iterator.tabulate(10)(i => (i, "itr_" + i)) } From 8dfecc1463ff0c2a3a18e7a4409736344c2dc3b8 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Sat, 28 Sep 2024 16:30:15 -0700 Subject: [PATCH 117/250] [SPARK-49434][SPARK-49435][CONNECT][SQL] Move aggregators to sql/api ### What changes were proposed in this pull request? This PR moves all user facing Aggregators from sql/core to sql/api. ### Why are the changes needed? We are create a unifies Scala SQL interface. This is part of that effort. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48267 from hvanhovell/SPARK-49434. Authored-by: Herman van Hovell Signed-off-by: Dongjoon Hyun --- project/MimaExcludes.scala | 5 ++++ .../spark/sql/expressions/javalang/typed.java | 10 +++---- .../sql/expressions/ReduceAggregator.scala | 16 +++++------ .../sql/expressions/scalalang/typed.scala | 4 +-- .../sql/internal}/typedaggregators.scala | 27 +++++++++---------- ...ColumnNodeToExpressionConverterSuite.scala | 2 +- 6 files changed, 31 insertions(+), 33 deletions(-) rename sql/{core => api}/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java (88%) rename sql/{core => api}/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala (82%) rename sql/{core => api}/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala (94%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/aggregate => api/src/main/scala/org/apache/spark/sql/internal}/typedaggregators.scala (81%) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 41f547a43b698..2b3d76eb0c2c3 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -184,6 +184,11 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.avro.functions$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.protobuf.functions"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.protobuf.functions$"), + + // SPARK-49434: Move aggregators to sql/api + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.javalang.typed"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed$"), ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java b/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java similarity index 88% rename from sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java rename to sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java index e1e4ba4c8e0dc..91a1231ec0303 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java +++ b/sql/api/src/main/java/org/apache/spark/sql/expressions/javalang/typed.java @@ -19,13 +19,13 @@ import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.TypedColumn; -import org.apache.spark.sql.execution.aggregate.TypedAverage; -import org.apache.spark.sql.execution.aggregate.TypedCount; -import org.apache.spark.sql.execution.aggregate.TypedSumDouble; -import org.apache.spark.sql.execution.aggregate.TypedSumLong; +import org.apache.spark.sql.internal.TypedAverage; +import org.apache.spark.sql.internal.TypedCount; +import org.apache.spark.sql.internal.TypedSumDouble; +import org.apache.spark.sql.internal.TypedSumLong; /** - * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java. + * Type-safe functions available for {@link org.apache.spark.sql.api.Dataset} operations in Java. * * Scala users should use {@link org.apache.spark.sql.expressions.scalalang.typed}. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala similarity index 82% rename from sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala rename to sql/api/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index 192b5bf65c4c5..9d98d1a98b00d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -18,19 +18,17 @@ package org.apache.spark.sql.expressions import org.apache.spark.SparkException -import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{PrimitiveBooleanEncoder, ProductEncoder} +import org.apache.spark.sql.{Encoder, Encoders} /** * An aggregator that uses a single associative and commutative reduce function. This reduce - * function can be used to go through all input values and reduces them to a single value. - * If there is no input, a null value is returned. + * function can be used to go through all input values and reduces them to a single value. If + * there is no input, a null value is returned. * * This class currently assumes there is at least one input row. */ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) - extends Aggregator[T, (Boolean, T), T] { + extends Aggregator[T, (Boolean, T), T] { @transient private val encoder = implicitly[Encoder[T]] @@ -47,10 +45,8 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T]) - override def bufferEncoder: Encoder[(Boolean, T)] = { - ProductEncoder.tuple(Seq(PrimitiveBooleanEncoder, encoder.asInstanceOf[AgnosticEncoder[T]])) - .asInstanceOf[Encoder[(Boolean, T)]] - } + override def bufferEncoder: Encoder[(Boolean, T)] = + Encoders.tuple(Encoders.scalaBoolean, encoder) override def outputEncoder: Encoder[T] = encoder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala rename to sql/api/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala index 8d17edd42442e..9ea3ab8cd4e1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.expressions.scalalang -import org.apache.spark.sql._ -import org.apache.spark.sql.execution.aggregate._ +import org.apache.spark.sql.TypedColumn +import org.apache.spark.sql.internal.{TypedAverage, TypedCount, TypedSumDouble, TypedSumLong} /** * Type-safe functions available for `Dataset` operations in Scala. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/typedaggregators.scala similarity index 81% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala rename to sql/api/src/main/scala/org/apache/spark/sql/internal/typedaggregators.scala index b6550bf3e4aac..aabb3a6f00fd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/typedaggregators.scala @@ -15,26 +15,24 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.aggregate +package org.apache.spark.sql.internal import org.apache.spark.api.java.function.MapFunction -import org.apache.spark.sql.{Encoder, TypedColumn} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.{Encoder, Encoders, TypedColumn} import org.apache.spark.sql.expressions.Aggregator //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines internal implementations for aggregators. //////////////////////////////////////////////////////////////////////////////////////////////////// - class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Double] { override def zero: Double = 0.0 override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 override def finish(reduction: Double): Double = reduction - override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]() - override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + override def bufferEncoder: Encoder[Double] = Encoders.scalaDouble + override def outputEncoder: Encoder[Double] = Encoders.scalaDouble // Java api support def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) => f.call(x).asInstanceOf[Double]) @@ -44,15 +42,14 @@ class TypedSumDouble[IN](val f: IN => Double) extends Aggregator[IN, Double, Dou } } - class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] { override def zero: Long = 0L override def reduce(b: Long, a: IN): Long = b + f(a) override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction - override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() - override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def bufferEncoder: Encoder[Long] = Encoders.scalaLong + override def outputEncoder: Encoder[Long] = Encoders.scalaLong // Java api support def this(f: MapFunction[IN, java.lang.Long]) = this((x: IN) => f.call(x).asInstanceOf[Long]) @@ -62,7 +59,6 @@ class TypedSumLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] { } } - class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] { override def zero: Long = 0 override def reduce(b: Long, a: IN): Long = { @@ -71,8 +67,8 @@ class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] { override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction - override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]() - override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]() + override def bufferEncoder: Encoder[Long] = Encoders.scalaLong + override def outputEncoder: Encoder[Long] = Encoders.scalaLong // Java api support def this(f: MapFunction[IN, Object]) = this((x: IN) => f.call(x).asInstanceOf[Any]) @@ -81,7 +77,6 @@ class TypedCount[IN](val f: IN => Any) extends Aggregator[IN, Long, Long] { } } - class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { override def zero: (Double, Long) = (0.0, 0L) override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) @@ -90,8 +85,10 @@ class TypedAverage[IN](val f: IN => Double) extends Aggregator[IN, (Double, Long (b1._1 + b2._1, b1._2 + b2._2) } - override def bufferEncoder: Encoder[(Double, Long)] = ExpressionEncoder[(Double, Long)]() - override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]() + override def bufferEncoder: Encoder[(Double, Long)] = + Encoders.tuple(Encoders.scalaDouble, Encoders.scalaLong) + + override def outputEncoder: Encoder[Double] = Encoders.scalaDouble // Java api support def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) => f.call(x).asInstanceOf[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala index c993aa8e52031..76fcdfc380950 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/ColumnNodeToExpressionConverterSuite.scala @@ -324,7 +324,7 @@ class ColumnNodeToExpressionConverterSuite extends SparkFunSuite { a.asInstanceOf[AgnosticEncoder[Any]] test("udf") { - val int2LongSum = new aggregate.TypedSumLong[Int]((i: Int) => i.toLong) + val int2LongSum = new TypedSumLong[Int]((i: Int) => i.toLong) val bufferEncoder = encoderFor(int2LongSum.bufferEncoder) val outputEncoder = encoderFor(int2LongSum.outputEncoder) val bufferAttrs = bufferEncoder.namedExpressions.map { From 039fd13eacb1cef835045e3a60cebf958589e1a2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 28 Sep 2024 19:45:52 -0700 Subject: [PATCH 118/250] [SPARK-49749][CORE] Change log level to debug in BlockManagerInfo ### What changes were proposed in this pull request? This PR changes the log level to debug in `BlockManagerInfo`. ### Why are the changes needed? Before this PR: Logging in `BlockManagerMasterEndpoint` uses 3.25% of the CPU and generates 60.5% of the logs. image ``` cat spark.20240921-09.log | grep "in memory on" | wc -l 8587851 cat spark.20240921-09.log | wc -l 14185544 ``` After this PR: image ``` cat spark.20240926-09.log | grep "in memory on" | wc -l 0 cat spark.20240926-09.log | wc -l 2224037 ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? N/A. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48197 from wangyum/SPARK-49749. Authored-by: Yuming Wang Signed-off-by: Dongjoon Hyun --- .../spark/storage/BlockManagerMasterEndpoint.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 73f89ea0e86e5..fc4e6e771aad7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -1059,13 +1059,13 @@ private[spark] class BlockManagerInfo( _blocks.put(blockId, blockStatus) _remainingMem -= memSize if (blockExists) { - logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} in memory on " + + logDebug(log"Updated ${MDC(BLOCK_ID, blockId)} in memory on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (current size: " + log"${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, original " + log"size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } else { - logInfo(log"Added ${MDC(BLOCK_ID, blockId)} in memory on " + + logDebug(log"Added ${MDC(BLOCK_ID, blockId)} in memory on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + log"(size: ${MDC(CURRENT_MEMORY_SIZE, Utils.bytesToString(memSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") @@ -1075,12 +1075,12 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) if (blockExists) { - logInfo(log"Updated ${MDC(BLOCK_ID, blockId)} on disk on " + + logDebug(log"Updated ${MDC(BLOCK_ID, blockId)} on disk on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} " + log"(current size: ${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))}," + log" original size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } else { - logInfo(log"Added ${MDC(BLOCK_ID, blockId)} on disk on " + + logDebug(log"Added ${MDC(BLOCK_ID, blockId)} on disk on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} (size: " + log"${MDC(CURRENT_DISK_SIZE, Utils.bytesToString(diskSize))})") } @@ -1098,13 +1098,13 @@ private[spark] class BlockManagerInfo( blockStatus.remove(blockId) } if (originalLevel.useMemory) { - logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + logDebug(log"Removed ${MDC(BLOCK_ID, blockId)} on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} in memory " + log"(size: ${MDC(ORIGINAL_MEMORY_SIZE, Utils.bytesToString(originalMemSize))}, " + log"free: ${MDC(FREE_MEMORY_SIZE, Utils.bytesToString(_remainingMem))})") } if (originalLevel.useDisk) { - logInfo(log"Removed ${MDC(BLOCK_ID, blockId)} on " + + logDebug(log"Removed ${MDC(BLOCK_ID, blockId)} on " + log"${MDC(HOST_PORT, blockManagerId.hostPort)} on disk" + log" (size: ${MDC(ORIGINAL_DISK_SIZE, Utils.bytesToString(originalDiskSize))})") } From 885c3fac724611ca59add984eb0629d32644b56f Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Mon, 30 Sep 2024 15:02:40 +0900 Subject: [PATCH 119/250] [SPARK-49823][SS] Avoid flush during shutdown in rocksdb close path ### What changes were proposed in this pull request? Avoid flush during shutdown in rocksdb close path ### Why are the changes needed? Without this change, we see sometimes that `cancelAllBackgroundWork` gets hung if there are memtables that need to be flushed. We also don't need to flush in this path, because we only assume that sync flush is required in the commit path. ``` at app//org.rocksdb.RocksDB.cancelAllBackgroundWork(Native Method) at app//org.rocksdb.RocksDB.cancelAllBackgroundWork(RocksDB.java:4053) at app//org.apache.spark.sql.execution.streaming.state.RocksDB.closeDB(RocksDB.scala:1406) at app//org.apache.spark.sql.execution.streaming.state.RocksDB.load(RocksDB.scala:383) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Verified the config is passed manually in the logs and existing unit tests. Before: ``` sql/core/target/unit-tests.log:141:18:20:06.223 pool-1-thread-1-ScalaTest-running-RocksDBSuite INFO RocksDB [Thread-17]: [NativeRocksDB-1] Options.avoid_flush_during_shutdown: 0 sql/core/target/unit-tests.log:776:18:20:06.871 pool-1-thread-1-ScalaTest-running-RocksDBSuite INFO RocksDB [Thread-17]: [NativeRocksDB-1] Options.avoid_flush_during_shutdown: 0 sql/core/target/unit-tests.log:1096:18:20:07.129 pool-1-thread-1-ScalaTest-running-RocksDBSuite INFO RocksDB [Thread-17]: [NativeRocksDB-1] Options.avoid_flush_during_shutdown: 0 ``` After: ``` sql/core/target/unit-tests.log:6561:18:17:42.723 pool-1-thread-1-ScalaTest-running-RocksDBSuite INFO RocksDB [Thread-17]: [NativeRocksDB-1] Options.avoid_flush_during_shutdown: 1 sql/core/target/unit-tests.log:6947:18:17:43.035 pool-1-thread-1-ScalaTest-running-RocksDBSuite INFO RocksDB [Thread-17]: [NativeRocksDB-1] Options.avoid_flush_during_shutdown: 1 sql/core/target/unit-tests.log:7344:18:17:43.313 pool-1-thread-1-ScalaTest-running-RocksDBSuite INFO RocksDB [Thread-17]: [NativeRocksDB-1] Options.avoid_flush_during_shutdown: 1 ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48292 from anishshri-db/task/SPARK-49823. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../org/apache/spark/sql/execution/streaming/state/RocksDB.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index f8d0c8722c3f5..c7f8434e5345b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -134,6 +134,7 @@ class RocksDB( rocksDbOptions.setTableFormatConfig(tableFormatConfig) rocksDbOptions.setMaxOpenFiles(conf.maxOpenFiles) rocksDbOptions.setAllowFAllocate(conf.allowFAllocate) + rocksDbOptions.setAvoidFlushDuringShutdown(true) rocksDbOptions.setMergeOperator(new StringAppendOperator()) if (conf.boundedMemoryUsage) { From d85e7bc0beb49dd1d894d487cf6a5a02075280dd Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 30 Sep 2024 17:41:42 +0800 Subject: [PATCH 120/250] [SPARK-49811][SQL] Rename StringTypeAnyCollation ### What changes were proposed in this pull request? Rename StringTypeAnyCollation to StringTypeWithCaseAccentSensitivity. Name StringTypeAnyCollation is unfortunate, with adding new type of collations it requires ren ### Why are the changes needed? Name StringTypeAnyCollation is unfortunate, with adding new specifier (for example trim specifier) it requires always renaming it to (something like AllCollationExeptTrimCollation) until new collation is implemented in all functions. It gets even more confusing if multiple collations are not supported for some functions. Instead of this naming convention should be only specifiers that are supported and avoid using all. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Just renaming all tests passing. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48265 from jovanpavl-db/rename-string-type-collations. Authored-by: Jovan Pavlovic Signed-off-by: Wenchen Fan --- .../internal/types/AbstractStringType.scala | 7 +- .../sql/catalyst/analysis/TypeCoercion.scala | 5 +- .../expressions/CallMethodViaReflection.scala | 9 +- .../catalyst/expressions/CollationKey.scala | 4 +- .../sql/catalyst/expressions/ExprUtils.scala | 5 +- .../aggregate/datasketchesAggregates.scala | 6 +- .../expressions/collationExpressions.scala | 6 +- .../expressions/collectionOperations.scala | 13 ++- .../catalyst/expressions/csvExpressions.scala | 4 +- .../expressions/datetimeExpressions.scala | 41 ++++--- .../expressions/jsonExpressions.scala | 14 ++- .../expressions/maskExpressions.scala | 10 +- .../expressions/mathExpressions.scala | 8 +- .../spark/sql/catalyst/expressions/misc.scala | 13 ++- .../expressions/numberFormatExpressions.scala | 7 +- .../expressions/regexpExpressions.scala | 18 +-- .../expressions/stringExpressions.scala | 103 ++++++++++-------- .../catalyst/expressions/urlExpressions.scala | 13 ++- .../variant/variantExpressions.scala | 7 +- .../sql/catalyst/expressions/xml/xpath.scala | 6 +- .../catalyst/expressions/xmlExpressions.scala | 4 +- .../analysis/AnsiTypeCoercionSuite.scala | 20 ++-- .../expressions/StringExpressionsSuite.scala | 4 +- .../sql/CollationExpressionWalkerSuite.scala | 51 +++++---- 24 files changed, 218 insertions(+), 160 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala index dc4ee013fd189..6feb662632763 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/AbstractStringType.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} /** - * StringTypeCollated is an abstract class for StringType with collation support. + * AbstractStringType is an abstract class for StringType with collation support. */ abstract class AbstractStringType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType @@ -46,9 +46,10 @@ case object StringTypeBinaryLcase extends AbstractStringType { } /** - * Use StringTypeAnyCollation for expressions supporting all possible collation types. + * Use StringTypeWithCaseAccentSensitivity for expressions supporting all collation types (binary + * and ICU) but limited to using case and accent sensitivity specifiers. */ -case object StringTypeAnyCollation extends AbstractStringType { +case object StringTypeWithCaseAccentSensitivity extends AbstractStringType { override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5983346ff1e27..e0298b19931c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -32,7 +32,8 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, AbstractMapType, AbstractStringType, + StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.sql.types.UpCastRule.numericPrecedence @@ -438,7 +439,7 @@ abstract class TypeCoercionBase { } case aj @ ArrayJoin(arr, d, nr) - if !AbstractArrayType(StringTypeAnyCollation).acceptsType(arr.dataType) && + if !AbstractArrayType(StringTypeWithCaseAccentSensitivity).acceptsType(arr.dataType) && ArrayType.acceptsType(arr.dataType) => val containsNull = arr.dataType.asInstanceOf[ArrayType].containsNull implicitCast(arr, ArrayType(StringType, containsNull)) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index 13ea8c77c41b4..6aa11b6fd16df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -84,7 +84,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("class"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(children.head) ) ) @@ -97,7 +97,7 @@ case class CallMethodViaReflection( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("method"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(children(1)) ) ) @@ -114,7 +114,8 @@ case class CallMethodViaReflection( "paramIndex" -> ordinalNumber(idx), "requiredType" -> toSQLType( TypeCollection(BooleanType, ByteType, ShortType, - IntegerType, LongType, FloatType, DoubleType, StringTypeAnyCollation)), + IntegerType, LongType, FloatType, DoubleType, + StringTypeWithCaseAccentSensitivity)), "inputSql" -> toSQLExpr(e), "inputType" -> toSQLType(e.dataType)) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala index 6e400d026e0ee..28ec8482e5cdd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = BinaryType final lazy val collationId: Int = expr.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 749152f135e92..08cb03edb78b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} -import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String @@ -61,7 +61,8 @@ object ExprUtils extends QueryErrorsBase { def convertToMapData(exp: Expression): Map[String, String] = exp match { case m: CreateMap - if AbstractMapType(StringTypeAnyCollation, StringTypeAnyCollation).acceptsType(m.dataType) => + if AbstractMapType(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) + .acceptsType(m.dataType) => val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => key.toString -> value.toString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala index 2102428131f64..78bd02d5703cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BinaryType, BooleanType, DataType, IntegerType, LongType, StringType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String @@ -105,7 +105,9 @@ case class HllSketchAgg( override def prettyName: String = "hll_sketch_agg" override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(IntegerType, LongType, StringTypeAnyCollation, BinaryType), IntegerType) + Seq( + TypeCollection(IntegerType, LongType, StringTypeWithCaseAccentSensitivity, BinaryType), + IntegerType) override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index d45ca533f9392..0cff70436db7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ // scalastyle:off line.contains.tab @@ -73,7 +73,7 @@ case class Collate(child: Expression, collationName: String) extends UnaryExpression with ExpectsInputTypes { private val collationId = CollationFactory.collationNameToId(collationName) override def dataType: DataType = StringType(collationId) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override protected def withNewChildInternal( newChild: Expression): Expression = copy(newChild) @@ -111,5 +111,5 @@ case class Collation(child: Expression) val collationName = CollationFactory.fetchCollation(collationId).collationName Literal.create(collationName, SQLConf.get.defaultStringType) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 5cdd3c7eb62d1..c091d51fc177f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SQLOpenHashSet import org.apache.spark.unsafe.UTF8StringBuilder @@ -1348,7 +1348,7 @@ case class Reverse(child: Expression) // Input types are utilized by type coercion in ImplicitTypeCasts. override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, ArrayType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, ArrayType)) override def dataType: DataType = child.dataType @@ -2134,9 +2134,12 @@ case class ArrayJoin( this(array, delimiter, Some(nullReplacement)) override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) { - Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation, StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity) } else { - Seq(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) + Seq(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity) } override def children: Seq[Expression] = if (nullReplacement.isDefined) { @@ -2857,7 +2860,7 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio with QueryErrorsBase { private def allowedTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, BinaryType, ArrayType) + Seq(StringTypeWithCaseAccentSensitivity, BinaryType, ArrayType) final override val nodePatterns: Seq[TreePattern] = Seq(CONCAT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index cb10440c48328..2f4462c0664f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -147,7 +147,7 @@ case class CsvToStructs( converter(parser.parse(csv)) } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def prettyName: String = "from_csv" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 36bd53001594e..b166d235557fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.catalyst.util.LegacyDateFormats.SIMPLE_DATE_FORMAT import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -961,7 +961,8 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(TimestampType, StringTypeWithCaseAccentSensitivity) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1269,8 +1270,10 @@ abstract class ToTimestamp override def forTimestampNTZ: Boolean = left.dataType == TimestampNTZType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType), - StringTypeAnyCollation) + Seq(TypeCollection( + StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType + ), + StringTypeWithCaseAccentSensitivity) override def dataType: DataType = LongType override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true @@ -1441,7 +1444,8 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(LongType, StringTypeWithCaseAccentSensitivity) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) @@ -1549,7 +1553,8 @@ case class NextDay( def this(left: Expression, right: Expression) = this(left, right, SQLConf.get.ansiEnabled) - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DateType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = DateType override def nullable: Boolean = true @@ -1760,7 +1765,8 @@ sealed trait UTCTimestamp extends BinaryExpression with ImplicitCastInputTypes w val func: (Long, String) => Long val funcName: String - override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(TimestampType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = TimestampType override def nullSafeEval(time: Any, timezone: Any): Any = { @@ -2100,8 +2106,9 @@ case class ParseToDate( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - TypeCollection(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) +: - format.map(_ => StringTypeAnyCollation).toSeq + TypeCollection( + StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) +: + format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq } override protected def withNewChildrenInternal( @@ -2172,10 +2179,10 @@ case class ParseToTimestamp( override def inputTypes: Seq[AbstractDataType] = { // Note: ideally this function should only take string input, but we allow more types here to // be backward compatible. - val types = Seq(StringTypeAnyCollation, DateType, TimestampType, TimestampNTZType) + val types = Seq(StringTypeWithCaseAccentSensitivity, DateType, TimestampType, TimestampNTZType) TypeCollection( (if (dataType.isInstanceOf[TimestampType]) types :+ NumericType else types): _* - ) +: format.map(_ => StringTypeAnyCollation).toSeq + ) +: format.map(_ => StringTypeWithCaseAccentSensitivity).toSeq } override protected def withNewChildrenInternal( @@ -2305,7 +2312,8 @@ case class TruncDate(date: Expression, format: Expression) override def left: Expression = date override def right: Expression = format - override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DateType, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = DateType override def prettyName: String = "trunc" override val instant = date @@ -2374,7 +2382,8 @@ case class TruncTimestamp( override def left: Expression = format override def right: Expression = timestamp - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, TimestampType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, TimestampType) override def dataType: TimestampType = TimestampType override def prettyName: String = "date_trunc" override val instant = timestamp @@ -2675,7 +2684,7 @@ case class MakeTimestamp( // casted into decimal safely, we use DecimalType(16, 6) which is wider than DecimalType(10, 0). override def inputTypes: Seq[AbstractDataType] = Seq(IntegerType, IntegerType, IntegerType, IntegerType, IntegerType, DecimalType(16, 6)) ++ - timezone.map(_ => StringTypeAnyCollation) + timezone.map(_ => StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = if (failOnError) children.exists(_.nullable) else true override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = @@ -3122,8 +3131,8 @@ case class ConvertTimezone( override def second: Expression = targetTz override def third: Expression = sourceTs - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, - StringTypeAnyCollation, TimestampNTZType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, TimestampNTZType) override def dataType: DataType = TimestampNTZType override def nullSafeEval(srcTz: Any, tgtTz: Any, micros: Any): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 2037eb22fede6..bdcf3f0c1eeab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePatt import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} import org.apache.spark.util.Utils @@ -134,7 +134,7 @@ case class GetJsonObject(json: Expression, path: Expression) override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def prettyName: String = "get_json_object" @@ -489,7 +489,9 @@ case class JsonTuple(children: Seq[Expression]) throw QueryCompilationErrors.wrongNumArgsError( toSQLId(prettyName), Seq("> 1"), children.length ) - } else if (children.forall(child => StringTypeAnyCollation.acceptsType(child.dataType))) { + } else if ( + children.forall( + child => StringTypeWithCaseAccentSensitivity.acceptsType(child.dataType))) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( @@ -726,7 +728,7 @@ case class JsonToStructs( converter(parser.parse(json.asInstanceOf[UTF8String])) } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def sql: String = schema match { case _: MapType => "entries" @@ -968,7 +970,7 @@ case class SchemaOfJson( case class LengthOfJsonArray(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" @@ -1041,7 +1043,7 @@ case class LengthOfJsonArray(child: Expression) extends UnaryExpression case class JsonObjectKeys(child: Expression) extends UnaryExpression with CodegenFallback with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index c11357352c79a..cb62fa2cc3bd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -192,8 +192,12 @@ case class Mask( * NumericType, IntegralType, FractionalType. */ override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation, - StringTypeAnyCollation, StringTypeAnyCollation) + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index ddba820414ae4..e46acf467db22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -453,7 +453,7 @@ case class Conv( override def second: Expression = fromBaseExpr override def third: Expression = toBaseExpr override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, IntegerType) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, IntegerType) override def dataType: DataType = first.dataType override def nullable: Boolean = true @@ -1114,7 +1114,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringTypeAnyCollation)) + Seq(TypeCollection(LongType, BinaryType, StringTypeWithCaseAccentSensitivity)) override def dataType: DataType = child.dataType match { case st: StringType => st @@ -1158,7 +1158,7 @@ case class Unhex(child: Expression, failOnError: Boolean = false) def this(expr: Expression) = this(expr, false) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nullable: Boolean = true override def dataType: DataType = BinaryType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 6629f724c4dda..cb846f606632b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -85,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, MapType(StringType, StringType)) + Seq(StringTypeWithCaseAccentSensitivity, MapType(StringType, StringType)) override def left: Expression = errorClass override def right: Expression = errorParms @@ -415,7 +415,9 @@ case class AesEncrypt( override def prettyName: String = "aes_encrypt" override def inputTypes: Seq[AbstractDataType] = - Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, + Seq(BinaryType, BinaryType, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, BinaryType, BinaryType) override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad) @@ -489,7 +491,10 @@ case class AesDecrypt( this(input, key, Literal("GCM")) override def inputTypes: Seq[AbstractDataType] = { - Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType) + Seq(BinaryType, + BinaryType, + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, BinaryType) } override def prettyName: String = "aes_decrypt" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index e914190c06456..5bd2ab6035e10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.util.ToNumberParser import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, DatetimeType, Decimal, DecimalType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -50,7 +50,7 @@ abstract class ToNumberBase(left: Expression, right: Expression, errorOnFail: Bo } override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -284,7 +284,8 @@ case class ToCharacter(left: Expression, right: Expression) } override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(DecimalType, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() if (inputTypeCheck.isSuccess) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 970397c76a1cd..fdc3c27890469 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -35,7 +35,8 @@ import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{LIKE_FAMLIY, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, TreePattern} import org.apache.spark.sql.catalyst.util.{CollationSupport, GenericArrayData, StringUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.internal.types.{StringTypeAnyCollation, StringTypeBinaryLcase} +import org.apache.spark.sql.internal.types.{ + StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -46,7 +47,7 @@ abstract class StringRegexExpression extends BinaryExpression def matches(regex: Pattern, str: String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId final lazy val collationRegexFlags: Int = CollationSupport.collationAwareRegexFlags(collationId) @@ -278,7 +279,7 @@ case class ILike( this(left, right, '\\') override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Expression = { @@ -567,7 +568,7 @@ case class StringSplit(str: Expression, regex: Expression, limit: Expression) override def dataType: DataType = ArrayType(str.dataType, containsNull = false) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) override def first: Expression = str override def second: Expression = regex override def third: Expression = limit @@ -711,7 +712,8 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def dataType: DataType = subject.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, StringTypeBinaryLcase, IntegerType) + Seq(StringTypeBinaryLcase, + StringTypeWithCaseAccentSensitivity, StringTypeBinaryLcase, IntegerType) final lazy val collationId: Int = subject.dataType.asInstanceOf[StringType].collationId override def prettyName: String = "regexp_replace" @@ -799,7 +801,7 @@ abstract class RegExpExtractBase final override val nodePatterns: Seq[TreePattern] = Seq(REGEXP_EXTRACT_FAMILY) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation, IntegerType) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity, IntegerType) override def first: Expression = subject override def second: Expression = regexp override def third: Expression = idx @@ -1052,7 +1054,7 @@ case class RegExpCount(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpCount = @@ -1092,7 +1094,7 @@ case class RegExpSubStr(left: Expression, right: Expression) override def children: Seq[Expression] = Seq(left, right) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeBinaryLcase, StringTypeAnyCollation) + Seq(StringTypeBinaryLcase, StringTypeWithCaseAccentSensitivity) override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): RegExpSubStr = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 786c3968be0fe..c91c57ee1eb3e 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -38,7 +38,8 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LO import org.apache.spark.sql.catalyst.util.{ArrayData, CharsetProvider, CollationFactory, CollationSupport, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation, StringTypeNonCSAICollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, + StringTypeNonCSAICollation, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.array.ByteArrayMethods @@ -81,8 +82,10 @@ case class ConcatWs(children: Seq[Expression]) /** The 1st child (separator) is str, and rest are either str or array of str. */ override def inputTypes: Seq[AbstractDataType] = { val arrayOrStr = - TypeCollection(AbstractArrayType(StringTypeAnyCollation), StringTypeAnyCollation) - StringTypeAnyCollation +: Seq.fill(children.size - 1)(arrayOrStr) + TypeCollection(AbstractArrayType(StringTypeWithCaseAccentSensitivity), + StringTypeWithCaseAccentSensitivity + ) + StringTypeWithCaseAccentSensitivity +: Seq.fill(children.size - 1)(arrayOrStr) } override def dataType: DataType = children.head.dataType @@ -433,7 +436,7 @@ trait String2StringExpression extends ImplicitCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = child.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(input: Any): Any = convert(input.asInstanceOf[UTF8String]) @@ -515,7 +518,7 @@ abstract class StringPredicate extends BinaryExpression def compare(l: UTF8String, r: UTF8String): Boolean override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) @@ -732,7 +735,7 @@ case class IsValidUTF8(input: Expression) extends RuntimeReplaceable with Implic override lazy val replacement: Expression = Invoke(input, "isValid", BooleanType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "is_valid_utf8" @@ -779,7 +782,7 @@ case class MakeValidUTF8(input: Expression) extends RuntimeReplaceable with Impl override lazy val replacement: Expression = Invoke(input, "makeValid", input.dataType) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "make_valid_utf8" @@ -824,7 +827,7 @@ case class ValidateUTF8(input: Expression) extends RuntimeReplaceable with Impli Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "validate_utf8" @@ -873,7 +876,7 @@ case class TryValidateUTF8(input: Expression) extends RuntimeReplaceable with Im Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nodeName: String = "try_validate_utf8" @@ -1008,8 +1011,8 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: override def dataType: DataType = input.dataType override def inputTypes: Seq[AbstractDataType] = Seq( - TypeCollection(StringTypeAnyCollation, BinaryType), - TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) + TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), + TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) override def checkInputDataTypes(): TypeCheckResult = { val inputTypeCheck = super.checkInputDataTypes() @@ -1213,7 +1216,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override protected def nullSafeEval(word: Any, set: Any): Any = { CollationSupport.FindInSet. @@ -1241,7 +1244,8 @@ trait String2TrimExpression extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = srcStr +: trimStr.toSeq override def dataType: DataType = srcStr.dataType - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) final lazy val collationId: Int = srcStr.dataType.asInstanceOf[StringType].collationId @@ -1846,7 +1850,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1926,7 +1930,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression = Litera override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, IntegerType, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, IntegerType, StringTypeWithCaseAccentSensitivity) override def nullSafeEval(string: Any, len: Any, pad: Any): Any = { string.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) @@ -1971,7 +1975,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC override def dataType: DataType = children(0).dataType override def inputTypes: Seq[AbstractDataType] = - StringTypeAnyCollation :: List.fill(children.size - 1)(AnyDataType) + StringTypeWithCaseAccentSensitivity :: List.fill(children.size - 1)(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { if (children.isEmpty) { @@ -2082,7 +2086,7 @@ case class InitCap(child: Expression) // Flag to indicate whether to use ICU instead of JVM case mappings for UTF8_BINARY collation. private final lazy val useICU = SQLConf.get.getConf(SQLConf.ICU_CASE_MAPPINGS_ENABLED) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = child.dataType override def nullSafeEval(string: Any): Any = { @@ -2114,7 +2118,8 @@ case class StringRepeat(str: Expression, times: Expression) override def left: Expression = str override def right: Expression = times override def dataType: DataType = str.dataType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, IntegerType) override def nullSafeEval(string: Any, n: Any): Any = { string.asInstanceOf[UTF8String].repeat(n.asInstanceOf[Integer]) @@ -2207,7 +2212,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def dataType: DataType = str.dataType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType, IntegerType) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType, IntegerType) override def first: Expression = str override def second: Expression = pos @@ -2265,7 +2270,8 @@ case class Right(str: Expression, len: Expression) extends RuntimeReplaceable ) ) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation, IntegerType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeWithCaseAccentSensitivity, IntegerType) override def left: Expression = str override def right: Expression = len override protected def withNewChildrenInternal( @@ -2296,7 +2302,7 @@ case class Left(str: Expression, len: Expression) extends RuntimeReplaceable override lazy val replacement: Expression = Substring(str, Literal(1), len) override def inputTypes: Seq[AbstractDataType] = { - Seq(TypeCollection(StringTypeAnyCollation, BinaryType), IntegerType) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType), IntegerType) } override def left: Expression = str @@ -2332,7 +2338,7 @@ case class Length(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numChars @@ -2367,7 +2373,7 @@ case class BitLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes * 8 @@ -2406,7 +2412,7 @@ case class OctetLength(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(StringTypeAnyCollation, BinaryType)) + Seq(TypeCollection(StringTypeWithCaseAccentSensitivity, BinaryType)) protected override def nullSafeEval(value: Any): Any = child.dataType match { case _: StringType => value.asInstanceOf[UTF8String].numBytes @@ -2466,8 +2472,9 @@ case class Levenshtein( } override def inputTypes: Seq[AbstractDataType] = threshold match { - case Some(_) => Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType) - case _ => Seq(StringTypeAnyCollation, StringTypeAnyCollation) + case Some(_) => + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity, IntegerType) + case _ => Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) } override def children: Seq[Expression] = threshold match { @@ -2592,7 +2599,7 @@ case class SoundEx(child: Expression) override def dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() @@ -2622,7 +2629,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = IntegerType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) protected override def nullSafeEval(string: Any): Any = { // only pick the first character to reduce the `toString` cost @@ -2767,7 +2774,7 @@ case class UnBase64(child: Expression, failOnError: Boolean = false) extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant { override def dataType: DataType = BinaryType - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) def this(expr: Expression) = this(expr, false) @@ -2946,7 +2953,8 @@ case class StringDecode( this(bin, charset, SQLConf.get.legacyJavaCharsets, SQLConf.get.legacyCodingErrorAction) override val dataType: DataType = SQLConf.get.defaultStringType - override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(BinaryType, StringTypeWithCaseAccentSensitivity) override def prettyName: String = "decode" override def toString: String = s"$prettyName($bin, $charset)" @@ -2955,7 +2963,7 @@ case class StringDecode( SQLConf.get.defaultStringType, "decode", Seq(bin, charset, Literal(legacyCharsets), Literal(legacyErrorAction)), - Seq(BinaryType, StringTypeAnyCollation, BooleanType, BooleanType)) + Seq(BinaryType, StringTypeWithCaseAccentSensitivity, BooleanType, BooleanType)) override def children: Seq[Expression] = Seq(bin, charset) override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = @@ -3012,15 +3020,20 @@ case class Encode( override def dataType: DataType = BinaryType override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override lazy val replacement: Expression = StaticInvoke( classOf[Encode], BinaryType, "encode", Seq( - str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType)), - Seq(StringTypeAnyCollation, StringTypeAnyCollation, BooleanType, BooleanType)) + str, charset, Literal(legacyCharsets, BooleanType), Literal(legacyErrorAction, BooleanType) + ), + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, + BooleanType, + BooleanType)) override def toString: String = s"$prettyName($str, $charset)" @@ -3104,7 +3117,8 @@ case class ToBinary( override def children: Seq[Expression] = expr +: format.toSeq - override def inputTypes: Seq[AbstractDataType] = children.map(_ => StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + children.map(_ => StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { def isValidFormat: Boolean = { @@ -3120,7 +3134,8 @@ case class ToBinary( errorSubClass = "INVALID_ARG_VALUE", messageParameters = Map( "inputName" -> "fmt", - "requireType" -> s"case-insensitive ${toSQLType(StringTypeAnyCollation)}", + "requireType" -> + s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(fmt, f.dataType) ) @@ -3131,7 +3146,7 @@ case class ToBinary( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(f) ) ) @@ -3140,7 +3155,8 @@ case class ToBinary( errorSubClass = "INVALID_ARG_VALUE", messageParameters = Map( "inputName" -> "fmt", - "requireType" -> s"case-insensitive ${toSQLType(StringTypeAnyCollation)}", + "requireType" -> + s"case-insensitive ${toSQLType(StringTypeWithCaseAccentSensitivity)}", "validValues" -> "'hex', 'utf-8', 'utf8', or 'base64'", "inputValue" -> toSQLValue(f.eval(), f.dataType) ) @@ -3189,7 +3205,7 @@ case class FormatNumber(x: Expression, d: Expression) override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(NumericType, TypeCollection(IntegerType, StringTypeAnyCollation)) + Seq(NumericType, TypeCollection(IntegerType, StringTypeWithCaseAccentSensitivity)) private val defaultFormat = "#,###,###,###,###,###,##0" @@ -3394,7 +3410,9 @@ case class Sentences( override def dataType: DataType = ArrayType(ArrayType(str.dataType, containsNull = false), containsNull = false) override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) + Seq( + StringTypeWithCaseAccentSensitivity, + StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def first: Expression = str override def second: Expression = language override def third: Expression = country @@ -3540,10 +3558,9 @@ case class Luhncheck(input: Expression) extends RuntimeReplaceable with Implicit classOf[ExpressionImplUtils], BooleanType, "isLuhnNumber", - Seq(input), - inputTypes) + Seq(input), inputTypes) - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "luhn_check" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index 3e4e4f992002a..09e91da65484f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType} import org.apache.spark.unsafe.types.UTF8String @@ -59,13 +59,13 @@ case class UrlEncode(child: Expression) SQLConf.get.defaultStringType, "encode", Seq(child), - Seq(StringTypeAnyCollation)) + Seq(StringTypeWithCaseAccentSensitivity)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "url_encode" } @@ -98,13 +98,13 @@ case class UrlDecode(child: Expression, failOnError: Boolean = true) SQLConf.get.defaultStringType, "decode", Seq(child, Literal(failOnError)), - Seq(StringTypeAnyCollation, BooleanType)) + Seq(StringTypeWithCaseAccentSensitivity, BooleanType)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) } - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def prettyName: String = "url_decode" } @@ -190,7 +190,8 @@ case class ParseUrl(children: Seq[Expression], failOnError: Boolean = SQLConf.ge def this(children: Seq[Expression]) = this(children, SQLConf.get.ansiEnabled) override def nullable: Boolean = true - override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq.fill(children.size)(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = SQLConf.get.defaultStringType override def prettyName: String = "parse_url" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala index 2c8ca1e8bb2bb..323f6e42f3e50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ import org.apache.spark.types.variant.VariantUtil.{IntervalFields, Type} @@ -66,7 +66,7 @@ case class ParseJson(child: Expression, failOnError: Boolean = true) inputTypes :+ BooleanType :+ BooleanType, returnNullable = !failOnError) - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def dataType: DataType = VariantType @@ -271,7 +271,8 @@ case class VariantGet( final override def nodePatternsInternal(): Seq[TreePattern] = Seq(VARIANT_GET) - override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(VariantType, StringTypeWithCaseAccentSensitivity) override def prettyName: String = if (failOnError) "variant_get" else "try_variant_get" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala index 31e65cf0abc95..6c38bd88144b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -42,7 +42,7 @@ abstract class XPathExtract override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, StringTypeAnyCollation) + Seq(StringTypeWithCaseAccentSensitivity, StringTypeWithCaseAccentSensitivity) override def checkInputDataTypes(): TypeCheckResult = { if (!path.foldable) { @@ -50,7 +50,7 @@ abstract class XPathExtract errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("path"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(path) ) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala index 48a87db291a8d..6f1430b04ed67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xmlExpressions.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.catalyst.xml.{StaxXmlGenerator, StaxXmlParser, ValidatorUtil, XmlInferSchema, XmlOptions} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -124,7 +124,7 @@ case class XmlToStructs( defineCodeGen(ctx, ev, input => s"(InternalRow) $expr.nullSafeEval($input)") } - override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil + override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil override def prettyName: String = "from_xml" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index de600d881b624..342dcbd8e6b6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeWithCaseAccentSensitivity} import org.apache.spark.sql.types._ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { @@ -1057,11 +1057,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(IntegerType)) shouldCast( ArrayType(StringType), - AbstractArrayType(StringTypeAnyCollation), + AbstractArrayType(StringTypeWithCaseAccentSensitivity), ArrayType(StringType)) shouldCast( ArrayType(IntegerType), - AbstractArrayType(StringTypeAnyCollation), + AbstractArrayType(StringTypeWithCaseAccentSensitivity), ArrayType(StringType)) shouldCast( ArrayType(StringType), @@ -1075,11 +1075,11 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { ArrayType(ArrayType(IntegerType))) shouldCast( ArrayType(ArrayType(StringType)), - AbstractArrayType(AbstractArrayType(StringTypeAnyCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(IntegerType)), - AbstractArrayType(AbstractArrayType(StringTypeAnyCollation)), + AbstractArrayType(AbstractArrayType(StringTypeWithCaseAccentSensitivity)), ArrayType(ArrayType(StringType))) shouldCast( ArrayType(ArrayType(StringType)), @@ -1088,14 +1088,16 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { // Invalid casts involving casting arrays into non-complex types. shouldNotCast(ArrayType(IntegerType), IntegerType) - shouldNotCast(ArrayType(StringType), StringTypeAnyCollation) + shouldNotCast(ArrayType(StringType), StringTypeWithCaseAccentSensitivity) shouldNotCast(ArrayType(StringType), IntegerType) - shouldNotCast(ArrayType(IntegerType), StringTypeAnyCollation) + shouldNotCast(ArrayType(IntegerType), StringTypeWithCaseAccentSensitivity) // Invalid casts involving casting arrays of arrays into arrays of non-complex types. shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(IntegerType)) - shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(StringTypeAnyCollation)) + shouldNotCast(ArrayType(ArrayType(StringType)), + AbstractArrayType(StringTypeWithCaseAccentSensitivity)) shouldNotCast(ArrayType(ArrayType(StringType)), AbstractArrayType(IntegerType)) - shouldNotCast(ArrayType(ArrayType(IntegerType)), AbstractArrayType(StringTypeAnyCollation)) + shouldNotCast(ArrayType(ArrayType(IntegerType)), + AbstractArrayType(StringTypeWithCaseAccentSensitivity)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 9b454ba764f92..1aae2f10b7326 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.util.CharsetProvider import org.apache.spark.sql.errors.QueryExecutionErrors.toSQLId import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.internal.types.StringTypeWithCaseAccentSensitivity import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1466,7 +1466,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( "inputName" -> toSQLId("fmt"), - "inputType" -> toSQLType(StringTypeAnyCollation), + "inputType" -> toSQLType(StringTypeWithCaseAccentSensitivity), "inputExpr" -> toSQLExpr(wrongFmt) ) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 1d23774a51692..879c0c480943d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -66,10 +66,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi collationType: CollationType): Any = inputEntry match { case e: Class[_] if e.isAssignableFrom(classOf[Expression]) => - generateLiterals(StringTypeAnyCollation, collationType) + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) case se: Class[_] if se.isAssignableFrom(classOf[Seq[Expression]]) => - CreateArray(Seq(generateLiterals(StringTypeAnyCollation, collationType), - generateLiterals(StringTypeAnyCollation, collationType))) + CreateArray(Seq(generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) case oe: Class[_] if oe.isAssignableFrom(classOf[Option[Any]]) => None case b: Class[_] if b.isAssignableFrom(classOf[Boolean]) => false case dt: Class[_] if dt.isAssignableFrom(classOf[DataType]) => StringType @@ -142,12 +142,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case ArrayType => - generateLiterals(StringTypeAnyCollation, collationType).map( + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType).map( lit => Literal.create(Seq(lit.asInstanceOf[Literal].value), ArrayType(lit.dataType)) ).head case MapType => - val key = generateLiterals(StringTypeAnyCollation, collationType) - val value = generateLiterals(StringTypeAnyCollation, collationType) + val key = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) + val value = generateLiterals(StringTypeWithCaseAccentSensitivity, collationType) CreateMap(Seq(key, value)) case MapType(keyType, valueType, _) => val key = generateLiterals(keyType, collationType) @@ -159,8 +159,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi CreateMap(Seq(key, value)) case StructType => CreateNamedStruct( - Seq(Literal("start"), generateLiterals(StringTypeAnyCollation, collationType), - Literal("end"), generateLiterals(StringTypeAnyCollation, collationType))) + Seq(Literal("start"), + generateLiterals(StringTypeWithCaseAccentSensitivity, collationType), + Literal("end"), generateLiterals(StringTypeWithCaseAccentSensitivity, collationType))) } /** @@ -209,10 +210,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array(" + generateInputAsString(elementType, collationType) + ")" case ArrayType => - "array(" + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "array(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType => - "map(" + generateInputAsString(StringTypeAnyCollation, collationType) + ", " + - generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "map(" + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case MapType(keyType, valueType, _) => "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" @@ -220,8 +221,9 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "map(" + generateInputAsString(keyType, collationType) + ", " + generateInputAsString(valueType, collationType) + ")" case StructType => - "named_struct( 'start', " + generateInputAsString(StringTypeAnyCollation, collationType) + - ", 'end', " + generateInputAsString(StringTypeAnyCollation, collationType) + ")" + "named_struct( 'start', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ", 'end', " + + generateInputAsString(StringTypeWithCaseAccentSensitivity, collationType) + ")" case StructType(fields) => "named_struct(" + fields.map(f => "'" + f.name + "', " + generateInputAsString(f.dataType, collationType)).mkString(", ") + ")" @@ -267,10 +269,12 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi case ArrayType(elementType, _) => "array<" + generateInputTypeAsStrings(elementType, collationType) + ">" case ArrayType => - "array<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "array<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ">" case MapType => - "map<" + generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ", " + - generateInputTypeAsStrings(StringTypeAnyCollation, collationType) + ">" + "map<" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + + ", " + + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case MapType(keyType, valueType, _) => "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" @@ -278,9 +282,10 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi "map<" + generateInputTypeAsStrings(keyType, collationType) + ", " + generateInputTypeAsStrings(valueType, collationType) + ">" case StructType => - "struct" + generateInputTypeAsStrings(StringTypeWithCaseAccentSensitivity, collationType) + ">" case StructType(fields) => "named_struct<" + fields.map(f => "'" + f.name + "', " + generateInputTypeAsStrings(f.dataType, collationType)).mkString(", ") + ">" @@ -293,8 +298,8 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi */ def hasStringType(inputType: AbstractDataType): Boolean = { inputType match { - case _: StringType | StringTypeAnyCollation | StringTypeBinaryLcase | AnyDataType => - true + case _: StringType | StringTypeWithCaseAccentSensitivity | StringTypeBinaryLcase | AnyDataType + => true case ArrayType => true case MapType => true case MapType(keyType, valueType, _) => hasStringType(keyType) || hasStringType(valueType) @@ -408,7 +413,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var i = 0 for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { method.invoke(null, funInfo.getClassName, input).asInstanceOf[ExpectsInputTypes] } @@ -498,7 +503,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression] @@ -609,7 +614,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi var input: Seq[Expression] = Seq.empty var result: Expression = null for (_ <- 1 to 10) { - input = input :+ generateLiterals(StringTypeAnyCollation, Utf8Binary) + input = input :+ generateLiterals(StringTypeWithCaseAccentSensitivity, Utf8Binary) try { val tempResult = method.invoke(null, f.getClassName, input) if (result == null) result = tempResult.asInstanceOf[Expression] From c54c017e93090a5fb2edf1b5ef029561b6387a3f Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 30 Sep 2024 17:44:13 +0800 Subject: [PATCH 121/250] [SPARK-49666][SQL] Add feature flag for trim collation feature ### What changes were proposed in this pull request? Introducing new specifier for trim collations (both leading and trailing trimming). These are initial changes so that trim specifier is recognized and put under feature flag (all code paths blocked). ### Why are the changes needed? Support for trailing space trimming is one of the requested feature by users. ### Does this PR introduce _any_ user-facing change? This is guarded by feature flag. ### How was this patch tested? Added tests to CollationSuite, SqlConfSuite and QueryCompilationErrorSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48222 from jovanpavl-db/trim-collation-feature-initial-support. Authored-by: Jovan Pavlovic Signed-off-by: Wenchen Fan --- .../sql/catalyst/util/CollationFactory.java | 341 +++++++++++++----- .../unsafe/types/CollationFactorySuite.scala | 5 +- .../resources/error/error-conditions.json | 10 +- .../expressions/collationExpressions.scala | 4 + .../sql/catalyst/parser/AstBuilder.scala | 4 + .../sql/errors/QueryCompilationErrors.scala | 7 + .../apache/spark/sql/internal/SQLConf.scala | 14 + .../spark/sql/execution/SparkSqlParser.scala | 4 + .../org/apache/spark/sql/CollationSuite.scala | 56 ++- .../errors/QueryCompilationErrorsSuite.scala | 33 ++ .../spark/sql/internal/SQLConfSuite.scala | 7 + 11 files changed, 381 insertions(+), 104 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index d5dbca7eb89bc..e368e2479a3a1 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -99,7 +99,8 @@ public record CollationMeta( String icuVersion, String padAttribute, boolean accentSensitivity, - boolean caseSensitivity) { } + boolean caseSensitivity, + String spaceTrimming) { } /** * Entry encapsulating all information about a collation. @@ -200,6 +201,7 @@ public Collation( * bit 28-24: Reserved. * bit 23-22: Reserved for version. * bit 21-18: Reserved for space trimming. + * 0000 = none, 0001 = left trim, 0010 = right trim, 0011 = trim. * bit 17-0: Depend on collation family. * --- * INDETERMINATE collation ID binary layout: @@ -214,7 +216,8 @@ public Collation( * UTF8_BINARY collation ID binary layout: * bit 31-24: Zeroes. * bit 23-22: Zeroes, reserved for version. - * bit 21-18: Zeroes, reserved for space trimming. + * bit 21-18: Reserved for space trimming. + * 0000 = none, 0001 = left trim, 0010 = right trim, 0011 = trim. * bit 17-3: Zeroes. * bit 2: 0, reserved for accent sensitivity. * bit 1: 0, reserved for uppercase and case-insensitive. @@ -225,7 +228,8 @@ public Collation( * bit 29: 1 * bit 28-24: Zeroes. * bit 23-22: Zeroes, reserved for version. - * bit 21-18: Zeroes, reserved for space trimming. + * bit 21-18: Reserved for space trimming. + * 0000 = none, 0001 = left trim, 0010 = right trim, 0011 = trim. * bit 17: 0 = case-sensitive, 1 = case-insensitive. * bit 16: 0 = accent-sensitive, 1 = accent-insensitive. * bit 15-14: Zeroes, reserved for punctuation sensitivity. @@ -238,7 +242,13 @@ public Collation( * - UNICODE -> 0x20000000 * - UNICODE_AI -> 0x20010000 * - UNICODE_CI -> 0x20020000 + * - UNICODE_LTRIM -> 0x20040000 + * - UNICODE_RTRIM -> 0x20080000 + * - UNICODE_TRIM -> 0x200C0000 * - UNICODE_CI_AI -> 0x20030000 + * - UNICODE_CI_TRIM -> 0x200E0000 + * - UNICODE_AI_TRIM -> 0x200D0000 + * - UNICODE_CI_AI_TRIM-> 0x200F0000 * - af -> 0x20000001 * - af_CI_AI -> 0x20030001 */ @@ -259,6 +269,15 @@ protected enum ImplementationProvider { UTF8_BINARY, ICU } + /** + * Bits 19-18 having value 00 for no space trimming, 01 for left space trimming + * 10 for right space trimming and 11 for both sides space trimming. Bits 21, 20 + * remained reserved (and fixed to 0) for future use. + */ + protected enum SpaceTrimming { + NONE, LTRIM, RTRIM, TRIM + } + /** * Offset in binary collation ID layout. */ @@ -279,6 +298,17 @@ protected enum ImplementationProvider { */ protected static final int IMPLEMENTATION_PROVIDER_MASK = 0b1; + + /** + * Offset in binary collation ID layout. + */ + protected static final int SPACE_TRIMMING_OFFSET = 18; + + /** + * Bitmask corresponding to width in bits in binary collation ID layout. + */ + protected static final int SPACE_TRIMMING_MASK = 0b11; + private static final int INDETERMINATE_COLLATION_ID = -1; /** @@ -303,6 +333,14 @@ private static DefinitionOrigin getDefinitionOrigin(int collationId) { DEFINITION_ORIGIN_OFFSET, DEFINITION_ORIGIN_MASK)]; } + /** + * Utility function to retrieve `SpaceTrimming` enum instance from collation ID. + */ + protected static SpaceTrimming getSpaceTrimming(int collationId) { + return SpaceTrimming.values()[SpecifierUtils.getSpecValue(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK)]; + } + /** * Main entry point for retrieving `Collation` instance from collation ID. */ @@ -358,6 +396,8 @@ private static int collationNameToId(String collationName) throws SparkException protected abstract CollationMeta buildCollationMeta(); + protected abstract String normalizedCollationName(); + static List listCollations() { return Stream.concat( CollationSpecUTF8.listCollations().stream(), @@ -398,48 +438,99 @@ private enum CaseSensitivity { private static final String UTF8_LCASE_COLLATION_NAME = "UTF8_LCASE"; private static final int UTF8_BINARY_COLLATION_ID = - new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).collationId; + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED, SpaceTrimming.NONE).collationId; private static final int UTF8_LCASE_COLLATION_ID = - new CollationSpecUTF8(CaseSensitivity.LCASE).collationId; + new CollationSpecUTF8(CaseSensitivity.LCASE, SpaceTrimming.NONE).collationId; protected static Collation UTF8_BINARY_COLLATION = - new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.UNSPECIFIED, SpaceTrimming.NONE).buildCollation(); protected static Collation UTF8_LCASE_COLLATION = - new CollationSpecUTF8(CaseSensitivity.LCASE).buildCollation(); + new CollationSpecUTF8(CaseSensitivity.LCASE, SpaceTrimming.NONE).buildCollation(); + private final CaseSensitivity caseSensitivity; + private final SpaceTrimming spaceTrimming; private final int collationId; - private CollationSpecUTF8(CaseSensitivity caseSensitivity) { - this.collationId = + private CollationSpecUTF8( + CaseSensitivity caseSensitivity, + SpaceTrimming spaceTrimming) { + this.caseSensitivity = caseSensitivity; + this.spaceTrimming = spaceTrimming; + + int collationId = SpecifierUtils.setSpecValue(0, CASE_SENSITIVITY_OFFSET, caseSensitivity); + this.collationId = + SpecifierUtils.setSpecValue(collationId, SPACE_TRIMMING_OFFSET, spaceTrimming); } private static int collationNameToId(String originalName, String collationName) throws SparkException { - if (UTF8_BINARY_COLLATION.collationName.equals(collationName)) { - return UTF8_BINARY_COLLATION_ID; - } else if (UTF8_LCASE_COLLATION.collationName.equals(collationName)) { - return UTF8_LCASE_COLLATION_ID; + + int baseId; + String collationNamePrefix; + + if (collationName.startsWith(UTF8_BINARY_COLLATION.collationName)) { + baseId = UTF8_BINARY_COLLATION_ID; + collationNamePrefix = UTF8_BINARY_COLLATION.collationName; + } else if (collationName.startsWith(UTF8_LCASE_COLLATION.collationName)) { + baseId = UTF8_LCASE_COLLATION_ID; + collationNamePrefix = UTF8_LCASE_COLLATION.collationName; } else { // Throw exception with original (before case conversion) collation name. throw collationInvalidNameException(originalName); } + + String remainingSpecifiers = collationName.substring(collationNamePrefix.length()); + if(remainingSpecifiers.isEmpty()) { + return baseId; + } + if(!remainingSpecifiers.startsWith("_")){ + throw collationInvalidNameException(originalName); + } + + SpaceTrimming spaceTrimming = SpaceTrimming.NONE; + String remainingSpec = remainingSpecifiers.substring(1); + if (remainingSpec.equals("LTRIM")) { + spaceTrimming = SpaceTrimming.LTRIM; + } else if (remainingSpec.equals("RTRIM")) { + spaceTrimming = SpaceTrimming.RTRIM; + } else if(remainingSpec.equals("TRIM")) { + spaceTrimming = SpaceTrimming.TRIM; + } else { + throw collationInvalidNameException(originalName); + } + + return SpecifierUtils.setSpecValue(baseId, SPACE_TRIMMING_OFFSET, spaceTrimming); } private static CollationSpecUTF8 fromCollationId(int collationId) { // Extract case sensitivity from collation ID. int caseConversionOrdinal = SpecifierUtils.getSpecValue(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); - // Verify only case sensitivity bits were set settable in UTF8_BINARY family of collations. - assert (SpecifierUtils.removeSpec(collationId, - CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK) == 0); - return new CollationSpecUTF8(CaseSensitivity.values()[caseConversionOrdinal]); + // Extract space trimming from collation ID. + int spaceTrimmingOrdinal = getSpaceTrimming(collationId).ordinal(); + assert(isValidCollationId(collationId)); + return new CollationSpecUTF8( + CaseSensitivity.values()[caseConversionOrdinal], + SpaceTrimming.values()[spaceTrimmingOrdinal]); + } + + private static boolean isValidCollationId(int collationId) { + collationId = SpecifierUtils.removeSpec( + collationId, + SPACE_TRIMMING_OFFSET, + SPACE_TRIMMING_MASK); + collationId = SpecifierUtils.removeSpec( + collationId, + CASE_SENSITIVITY_OFFSET, + CASE_SENSITIVITY_MASK); + return collationId == 0; } @Override protected Collation buildCollation() { - if (collationId == UTF8_BINARY_COLLATION_ID) { + if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { return new Collation( - UTF8_BINARY_COLLATION_NAME, + normalizedCollationName(), PROVIDER_SPARK, null, UTF8String::binaryCompare, @@ -450,7 +541,7 @@ protected Collation buildCollation() { /* supportsLowercaseEquality = */ false); } else { return new Collation( - UTF8_LCASE_COLLATION_NAME, + normalizedCollationName(), PROVIDER_SPARK, null, CollationAwareUTF8String::compareLowerCase, @@ -464,29 +555,52 @@ protected Collation buildCollation() { @Override protected CollationMeta buildCollationMeta() { - if (collationId == UTF8_BINARY_COLLATION_ID) { + if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { return new CollationMeta( CATALOG, SCHEMA, - UTF8_BINARY_COLLATION_NAME, + normalizedCollationName(), /* language = */ null, /* country = */ null, /* icuVersion = */ null, COLLATION_PAD_ATTRIBUTE, /* accentSensitivity = */ true, - /* caseSensitivity = */ true); + /* caseSensitivity = */ true, + spaceTrimming.toString()); } else { return new CollationMeta( CATALOG, SCHEMA, - UTF8_LCASE_COLLATION_NAME, + normalizedCollationName(), /* language = */ null, /* country = */ null, /* icuVersion = */ null, COLLATION_PAD_ATTRIBUTE, /* accentSensitivity = */ true, - /* caseSensitivity = */ false); + /* caseSensitivity = */ false, + spaceTrimming.toString()); + } + } + + /** + * Compute normalized collation name. Components of collation name are given in order: + * - Base collation name (UTF8_BINARY or UTF8_LCASE) + * - Optional space trimming when non-default preceded by underscore + * Examples: UTF8_BINARY, UTF8_BINARY_LCASE_LTRIM, UTF8_BINARY_TRIM. + */ + @Override + protected String normalizedCollationName() { + StringBuilder builder = new StringBuilder(); + if(caseSensitivity == CaseSensitivity.UNSPECIFIED){ + builder.append(UTF8_BINARY_COLLATION_NAME); + } else{ + builder.append(UTF8_LCASE_COLLATION_NAME); } + if (spaceTrimming != SpaceTrimming.NONE) { + builder.append('_'); + builder.append(spaceTrimming.toString()); + } + return builder.toString(); } static List listCollations() { @@ -620,21 +734,33 @@ private enum AccentSensitivity { } } - private static final int UNICODE_COLLATION_ID = - new CollationSpecICU("UNICODE", CaseSensitivity.CS, AccentSensitivity.AS).collationId; - private static final int UNICODE_CI_COLLATION_ID = - new CollationSpecICU("UNICODE", CaseSensitivity.CI, AccentSensitivity.AS).collationId; + private static final int UNICODE_COLLATION_ID = new CollationSpecICU( + "UNICODE", + CaseSensitivity.CS, + AccentSensitivity.AS, + SpaceTrimming.NONE).collationId; + + private static final int UNICODE_CI_COLLATION_ID = new CollationSpecICU( + "UNICODE", + CaseSensitivity.CI, + AccentSensitivity.AS, + SpaceTrimming.NONE).collationId; private final CaseSensitivity caseSensitivity; private final AccentSensitivity accentSensitivity; + private final SpaceTrimming spaceTrimming; private final String locale; private final int collationId; - private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, - AccentSensitivity accentSensitivity) { + private CollationSpecICU( + String locale, + CaseSensitivity caseSensitivity, + AccentSensitivity accentSensitivity, + SpaceTrimming spaceTrimming) { this.locale = locale; this.caseSensitivity = caseSensitivity; this.accentSensitivity = accentSensitivity; + this.spaceTrimming = spaceTrimming; // Construct collation ID from locale, case-sensitivity and accent-sensitivity specifiers. int collationId = ICULocaleToId.get(locale); // Mandatory ICU implementation provider. @@ -644,6 +770,8 @@ private CollationSpecICU(String locale, CaseSensitivity caseSensitivity, caseSensitivity); collationId = SpecifierUtils.setSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, SPACE_TRIMMING_OFFSET, + spaceTrimming); this.collationId = collationId; } @@ -661,58 +789,88 @@ private static int collationNameToId( } if (lastPos == -1) { throw collationInvalidNameException(originalName); - } else { - String locale = collationName.substring(0, lastPos); - int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); - - // Try all combinations of AS/AI and CS/CI. - CaseSensitivity caseSensitivity; - AccentSensitivity accentSensitivity; - if (collationName.equals(locale) || - collationName.equals(locale + "_AS") || - collationName.equals(locale + "_CS") || - collationName.equals(locale + "_AS_CS") || - collationName.equals(locale + "_CS_AS") - ) { - caseSensitivity = CaseSensitivity.CS; - accentSensitivity = AccentSensitivity.AS; - } else if (collationName.equals(locale + "_CI") || - collationName.equals(locale + "_AS_CI") || - collationName.equals(locale + "_CI_AS")) { - caseSensitivity = CaseSensitivity.CI; - accentSensitivity = AccentSensitivity.AS; - } else if (collationName.equals(locale + "_AI") || - collationName.equals(locale + "_CS_AI") || - collationName.equals(locale + "_AI_CS")) { - caseSensitivity = CaseSensitivity.CS; - accentSensitivity = AccentSensitivity.AI; - } else if (collationName.equals(locale + "_AI_CI") || - collationName.equals(locale + "_CI_AI")) { - caseSensitivity = CaseSensitivity.CI; - accentSensitivity = AccentSensitivity.AI; - } else { - throw collationInvalidNameException(originalName); - } + } + String locale = collationName.substring(0, lastPos); + int collationId = ICULocaleToId.get(ICULocaleMapUppercase.get(locale)); + collationId = SpecifierUtils.setSpecValue(collationId, + IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); - // Build collation ID from computed specifiers. - collationId = SpecifierUtils.setSpecValue(collationId, - IMPLEMENTATION_PROVIDER_OFFSET, ImplementationProvider.ICU); - collationId = SpecifierUtils.setSpecValue(collationId, - CASE_SENSITIVITY_OFFSET, caseSensitivity); - collationId = SpecifierUtils.setSpecValue(collationId, - ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + // No other specifiers present. + if(collationName.equals(locale)){ return collationId; } + if(collationName.charAt(locale.length()) != '_'){ + throw collationInvalidNameException(originalName); + } + // Extract remaining specifiers and trim "_" separator. + String remainingSpecifiers = collationName.substring(lastPos + 1); + + // Initialize default specifier flags. + // Case sensitive, accent sensitive, no space trimming. + boolean isCaseSpecifierSet = false; + boolean isAccentSpecifierSet = false; + boolean isSpaceTrimmingSpecifierSet = false; + CaseSensitivity caseSensitivity = CaseSensitivity.CS; + AccentSensitivity accentSensitivity = AccentSensitivity.AS; + SpaceTrimming spaceTrimming = SpaceTrimming.NONE; + + String[] specifiers = remainingSpecifiers.split("_"); + + // Iterate through specifiers and set corresponding flags + for (String specifier : specifiers) { + switch (specifier) { + case "CI": + case "CS": + if (isCaseSpecifierSet) { + throw collationInvalidNameException(originalName); + } + caseSensitivity = CaseSensitivity.valueOf(specifier); + isCaseSpecifierSet = true; + break; + case "AI": + case "AS": + if (isAccentSpecifierSet) { + throw collationInvalidNameException(originalName); + } + accentSensitivity = AccentSensitivity.valueOf(specifier); + isAccentSpecifierSet = true; + break; + case "LTRIM": + case "RTRIM": + case "TRIM": + if (isSpaceTrimmingSpecifierSet) { + throw collationInvalidNameException(originalName); + } + spaceTrimming = SpaceTrimming.valueOf(specifier); + isSpaceTrimmingSpecifierSet = true; + break; + default: + throw collationInvalidNameException(originalName); + } + } + + // Build collation ID from computed specifiers. + collationId = SpecifierUtils.setSpecValue(collationId, + CASE_SENSITIVITY_OFFSET, caseSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + ACCENT_SENSITIVITY_OFFSET, accentSensitivity); + collationId = SpecifierUtils.setSpecValue(collationId, + SPACE_TRIMMING_OFFSET, spaceTrimming); + return collationId; } private static CollationSpecICU fromCollationId(int collationId) { // Parse specifiers from collation ID. + int spaceTrimmingOrdinal = SpecifierUtils.getSpecValue(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK); int caseSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); int accentSensitivityOrdinal = SpecifierUtils.getSpecValue(collationId, ACCENT_SENSITIVITY_OFFSET, ACCENT_SENSITIVITY_MASK); collationId = SpecifierUtils.removeSpec(collationId, IMPLEMENTATION_PROVIDER_OFFSET, IMPLEMENTATION_PROVIDER_MASK); + collationId = SpecifierUtils.removeSpec(collationId, + SPACE_TRIMMING_OFFSET, SPACE_TRIMMING_MASK); collationId = SpecifierUtils.removeSpec(collationId, CASE_SENSITIVITY_OFFSET, CASE_SENSITIVITY_MASK); collationId = SpecifierUtils.removeSpec(collationId, @@ -723,8 +881,9 @@ private static CollationSpecICU fromCollationId(int collationId) { assert(localeId >= 0 && localeId < ICULocaleNames.length); CaseSensitivity caseSensitivity = CaseSensitivity.values()[caseSensitivityOrdinal]; AccentSensitivity accentSensitivity = AccentSensitivity.values()[accentSensitivityOrdinal]; + SpaceTrimming spaceTrimming = SpaceTrimming.values()[spaceTrimmingOrdinal]; String locale = ICULocaleNames[localeId]; - return new CollationSpecICU(locale, caseSensitivity, accentSensitivity); + return new CollationSpecICU(locale, caseSensitivity, accentSensitivity, spaceTrimming); } @Override @@ -752,7 +911,7 @@ protected Collation buildCollation() { // Freeze ICU collator to ensure thread safety. collator.freeze(); return new Collation( - collationName(), + normalizedCollationName(), PROVIDER_ICU, collator, (s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()), @@ -768,13 +927,14 @@ protected CollationMeta buildCollationMeta() { return new CollationMeta( CATALOG, SCHEMA, - collationName(), + normalizedCollationName(), ICULocaleMap.get(locale).getDisplayLanguage(), ICULocaleMap.get(locale).getDisplayCountry(), VersionInfo.ICU_VERSION.toString(), COLLATION_PAD_ATTRIBUTE, accentSensitivity == AccentSensitivity.AS, - caseSensitivity == CaseSensitivity.CS); + caseSensitivity == CaseSensitivity.CS, + spaceTrimming.toString()); } /** @@ -782,9 +942,11 @@ protected CollationMeta buildCollationMeta() { * - Locale name * - Optional case sensitivity when non-default preceded by underscore * - Optional accent sensitivity when non-default preceded by underscore - * Examples: en, en_USA_CI_AI, sr_Cyrl_SRB_AI. + * - Optional space trimming when non-default preceded by underscore + * Examples: en, en_USA_CI_LTRIM, en_USA_CI_AI, en_USA_CI_AI_TRIM, sr_Cyrl_SRB_AI. */ - private String collationName() { + @Override + protected String normalizedCollationName() { StringBuilder builder = new StringBuilder(); builder.append(locale); if (caseSensitivity != CaseSensitivity.CS) { @@ -795,20 +957,21 @@ private String collationName() { builder.append('_'); builder.append(accentSensitivity.toString()); } + if(spaceTrimming != SpaceTrimming.NONE) { + builder.append('_'); + builder.append(spaceTrimming.toString()); + } return builder.toString(); } private static List allCollationNames() { List collationNames = new ArrayList<>(); - for (String locale: ICULocaleToId.keySet()) { - // CaseSensitivity.CS + AccentSensitivity.AS - collationNames.add(locale); - // CaseSensitivity.CS + AccentSensitivity.AI - collationNames.add(locale + "_AI"); - // CaseSensitivity.CI + AccentSensitivity.AS - collationNames.add(locale + "_CI"); - // CaseSensitivity.CI + AccentSensitivity.AI - collationNames.add(locale + "_CI_AI"); + List caseAccentSpecifiers = Arrays.asList("", "_AI", "_CI", "_CI_AI"); + for (String locale : ICULocaleToId.keySet()) { + for (String caseAccent : caseAccentSpecifiers) { + String collationName = locale + caseAccent; + collationNames.add(collationName); + } } return collationNames.stream().sorted().toList(); } @@ -933,6 +1096,14 @@ public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { Collation.CollationSpecICU.AccentSensitivity.AI; } + /** + * Returns whether the collation uses trim collation for the given collation id. + */ + public static boolean usesTrimCollation(int collationId) { + return Collation.CollationSpec.getSpaceTrimming(collationId) != + Collation.CollationSpec.SpaceTrimming.NONE; + } + public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 321d1ccd700f2..054c44f7286b7 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -369,9 +369,8 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig 1 << 15, // UTF8_BINARY mandatory zero bit 15 breach. 1 << 16, // UTF8_BINARY mandatory zero bit 16 breach. 1 << 17, // UTF8_BINARY mandatory zero bit 17 breach. - 1 << 18, // UTF8_BINARY mandatory zero bit 18 breach. - 1 << 19, // UTF8_BINARY mandatory zero bit 19 breach. 1 << 20, // UTF8_BINARY mandatory zero bit 20 breach. + 1 << 21, // UTF8_BINARY mandatory zero bit 21 breach. 1 << 23, // UTF8_BINARY mandatory zero bit 23 breach. 1 << 24, // UTF8_BINARY mandatory zero bit 24 breach. 1 << 25, // UTF8_BINARY mandatory zero bit 25 breach. @@ -382,8 +381,6 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig (1 << 29) | (1 << 13), // ICU mandatory zero bit 13 breach. (1 << 29) | (1 << 14), // ICU mandatory zero bit 14 breach. (1 << 29) | (1 << 15), // ICU mandatory zero bit 15 breach. - (1 << 29) | (1 << 18), // ICU mandatory zero bit 18 breach. - (1 << 29) | (1 << 19), // ICU mandatory zero bit 19 breach. (1 << 29) | (1 << 20), // ICU mandatory zero bit 20 breach. (1 << 29) | (1 << 21), // ICU mandatory zero bit 21 breach. (1 << 29) | (1 << 22), // ICU mandatory zero bit 22 breach. diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3fcb53426eccf..fcaf2b1d9d301 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4886,11 +4886,6 @@ "Catalog does not support ." ] }, - "COLLATION" : { - "message" : [ - "Collation is not yet supported." - ] - }, "COMBINATION_QUERY_RESULT_CLAUSES" : { "message" : [ "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY." @@ -5117,6 +5112,11 @@ "message" : [ "TRANSFORM with SERDE is only supported in hive mode." ] + }, + "TRIM_COLLATION" : { + "message" : [ + "TRIM specifier in the collation." + ] } }, "sqlState" : "0A000" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala index 0cff70436db7d..b67e66323bbbd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -52,6 +52,10 @@ object CollateExpressionBuilder extends ExpressionBuilder { if (evalCollation == null) { throw QueryCompilationErrors.unexpectedNullError("collation", collationExpr) } else { + if (!SQLConf.get.trimCollationEnabled && + evalCollation.toString.toUpperCase().contains("TRIM")) { + throw QueryCompilationErrors.trimCollationNotEnabledError() + } Collate(e, evalCollation.toString) } case (_: StringType, false) => throw QueryCompilationErrors.nonFoldableArgumentError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 674005caaf1b2..ed6cf329eeca8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2557,6 +2557,10 @@ class AstBuilder extends DataTypeAstBuilder } override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) { + val collationName = ctx.collationName.getText + if (!SQLConf.get.trimCollationEnabled && collationName.toUpperCase().contains("TRIM")) { + throw QueryCompilationErrors.trimCollationNotEnabledError() + } ctx.identifier.getText } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 0b5255e95f073..0d27f7bedbd3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -351,6 +351,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def trimCollationNotEnabledError(): Throwable = { + new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.TRIM_COLLATION", + messageParameters = Map.empty + ) + } + def unresolvedUsingColForJoinError( colName: String, suggestion: String, side: String): Throwable = { new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9c46dd8e83ab2..ea187c0316c17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -759,6 +759,18 @@ object SQLConf { .checkValue(_ > 0, "The initial number of partitions must be positive.") .createOptional + lazy val TRIM_COLLATION_ENABLED = + buildConf("spark.sql.collation.trim.enabled") + .internal() + .doc( + "Trim collation feature is under development and its use should be done under this" + + "feature flag. Trim collation trims leading, trailing or both spaces depending of" + + "specifier (LTRIM, RTRIM, TRIM)." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(Utils.isTesting) + val DEFAULT_COLLATION = buildConf(SqlApiConfHelper.DEFAULT_COLLATION) .doc("Sets default collation to use for string literals, parameter markers or the string" + @@ -5482,6 +5494,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { } } + def trimCollationEnabled: Boolean = getConf(TRIM_COLLATION_ENABLED) + override def defaultStringType: StringType = { if (getConf(DEFAULT_COLLATION).toUpperCase(Locale.ROOT) == "UTF8_BINARY") { StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 1c735154f25ed..8fc860c503c96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -168,6 +168,10 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitSetCollation(ctx: SetCollationContext): LogicalPlan = withOrigin(ctx) { + val collationName = ctx.collationName.getText + if (!SQLConf.get.trimCollationEnabled && collationName.toUpperCase().contains("TRIM")) { + throw QueryCompilationErrors.trimCollationNotEnabledError() + } val key = SQLConf.DEFAULT_COLLATION.key SetCommand(Some(key -> Some(ctx.identifier.getText.toUpperCase(Locale.ROOT)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 632b9305feb57..03d3ed6ac7cb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -44,27 +44,57 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private val allFileBasedDataSources = collationPreservingSources ++ collationNonPreservingSources test("collate returns proper type") { - Seq("utf8_binary", "utf8_lcase", "unicode", "unicode_ci").foreach { collationName => + Seq( + "utf8_binary", + "utf8_lcase", + "unicode", + "unicode_ci", + "unicode_ltrim_ci", + "utf8_lcase_trim", + "utf8_binary_rtrim" + ).foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert( + sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId) + ) } } test("collation name is case insensitive") { - Seq("uTf8_BiNaRy", "utf8_lcase", "uNicOde", "UNICODE_ci").foreach { collationName => + Seq( + "uTf8_BiNaRy", + "utf8_lcase", + "uNicOde", + "UNICODE_ci", + "uNiCoDE_ltRIm_cI", + "UtF8_lCaSE_tRIM", + "utf8_biNAry_RtRiM" + ).foreach { collationName => checkAnswer(sql(s"select 'aaa' collate $collationName"), Row("aaa")) val collationId = CollationFactory.collationNameToId(collationName) - assert(sql(s"select 'aaa' collate $collationName").schema(0).dataType - == StringType(collationId)) + assert( + sql(s"select 'aaa' collate $collationName").schema(0).dataType + == StringType(collationId) + ) } } test("collation expression returns name of collation") { - Seq("utf8_binary", "utf8_lcase", "unicode", "unicode_ci").foreach { collationName => + Seq( + "utf8_binary", + "utf8_lcase", + "unicode", + "unicode_ci", + "unicode_ci_ltrim", + "utf8_lcase_trim", + "utf8_binary_rtrim" + ).foreach { collationName => checkAnswer( - sql(s"select collation('aaa' collate $collationName)"), Row(collationName.toUpperCase())) + sql(s"select collation('aaa' collate $collationName)"), + Row(collationName.toUpperCase()) + ) } } @@ -77,9 +107,15 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("collate function syntax with default collation set") { withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UTF8_LCASE") { - assert(sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == - StringType("UTF8_LCASE")) + assert( + sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == + StringType("UTF8_LCASE") + ) assert(sql(s"select collate('aaa', 'UNICODE')").schema(0).dataType == StringType("UNICODE")) + assert( + sql(s"select collate('aaa', 'UNICODE_TRIM')").schema(0).dataType == + StringType("UNICODE_TRIM") + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 832e1873af6a4..5abdca326f2fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -868,6 +868,39 @@ class QueryCompilationErrorsSuite "inputTypes" -> "[\"INT\", \"STRING\", \"STRING\"]")) } + test("SPARK-49666: the trim collation feature is off without collate builder call") { + withSQLConf(SQLConf.TRIM_COLLATION_ENABLED.key -> "false") { + Seq( + "CREATE TABLE t(col STRING COLLATE EN_TRIM_CI) USING parquet", + "CREATE TABLE t(col STRING COLLATE UTF8_LCASE_TRIM) USING parquet", + "SELECT 'aaa' COLLATE UNICODE_LTRIM_CI" + ).foreach { sqlText => + checkError( + exception = intercept[AnalysisException](sql(sqlText)), + condition = "UNSUPPORTED_FEATURE.TRIM_COLLATION" + ) + } + } + } + + test("SPARK-49666: the trim collation feature is off with collate builder call") { + withSQLConf(SQLConf.TRIM_COLLATION_ENABLED.key -> "false") { + Seq( + "SELECT collate('aaa', 'UNICODE_TRIM')", + "SELECT collate('aaa', 'UTF8_BINARY_TRIM')", + "SELECT collate('aaa', 'EN_AI_RTRIM')" + ).foreach { sqlText => + checkError( + exception = intercept[AnalysisException](sql(sqlText)), + condition = "UNSUPPORTED_FEATURE.TRIM_COLLATION", + parameters = Map.empty, + context = + ExpectedContext(fragment = sqlText.substring(7), start = 7, stop = sqlText.length - 1) + ) + } + } + } + test("UNSUPPORTED_CALL: call the unsupported method update()") { checkError( exception = intercept[SparkUnsupportedOperationException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 82795e551b6bf..094c65c63bfdc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -517,6 +517,13 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { "confName" -> "spark.sql.session.collation.default", "proposals" -> "UNICODE" )) + + withSQLConf(SQLConf.TRIM_COLLATION_ENABLED.key -> "false") { + checkError( + exception = intercept[AnalysisException](sql(s"SET COLLATION UNICODE_CI_TRIM")), + condition = "UNSUPPORTED_FEATURE.TRIM_COLLATION" + ) + } } test("SPARK-43028: config not found error") { From 97ae372634b119b2b67304df67463b95b20febd9 Mon Sep 17 00:00:00 2001 From: Nick Young Date: Mon, 30 Sep 2024 20:44:51 +0800 Subject: [PATCH 122/250] [SPARK-49819] Disable CollapseProject for correlated subqueries in projection over aggregate correctly ### What changes were proposed in this pull request? CollapseProject should block collapsing with an aggregate if any correlated subquery is present. There are other correlated subqueries that are not ScalarSubquery that are not accounted for here. ### Why are the changes needed? Availability issue. ### Does this PR introduce _any_ user-facing change? Previously failing queries will not fail anymore. ### How was this patch tested? UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48286 from n-young-db/n-young-db/collapse-project-correlated-subquery-check. Lead-authored-by: Nick Young Co-authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/Optimizer.scala | 8 +++----- .../org/apache/spark/sql/SubquerySuite.scala | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7fc12f7d1fc16..fb234c7bda4c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression.hasCorrelatedSubquery import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans._ @@ -1232,11 +1233,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { * in aggregate if they are also part of the grouping expressions. Otherwise the plan * after subquery rewrite will not be valid. */ - private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = { - p.projectList.forall(_.collect { - case s: ScalarSubquery if s.outerAttrs.nonEmpty => s - }.isEmpty) - } + private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = + !p.projectList.exists(hasCorrelatedSubquery) def buildCleanedProjectList( upper: Seq[NamedExpression], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 6e160b4407ca8..f17cf25565145 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -2142,6 +2142,24 @@ class SubquerySuite extends QueryTest } } + test("SPARK-49819: Do not collapse projects with exist subqueries") { + withTempView("v") { + Seq((0, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("v") + checkAnswer( + sql(""" + |SELECT m, CASE WHEN EXISTS (SELECT SUM(c2) FROM v WHERE c1 = m) THEN 1 ELSE 0 END + |FROM (SELECT MIN(c2) AS m FROM v) + |""".stripMargin), + Row(1, 1) :: Nil) + checkAnswer( + sql(""" + |SELECT c, CASE WHEN EXISTS (SELECT SUM(c2) FROM v WHERE c1 = c) THEN 1 ELSE 0 END + |FROM (SELECT c1 AS c FROM v GROUP BY c1) + |""".stripMargin), + Row(0, 1) :: Row(1, 1) :: Nil) + } + } + test("SPARK-37199: deterministic in QueryPlan considers subquery") { val deterministicQueryPlan = sql("select (select 1 as b) as b") .queryExecution.executedPlan From dbfa909422ad82b0428b258671813510caa6eeac Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 30 Sep 2024 21:15:42 +0800 Subject: [PATCH 123/250] [SPARK-49816][SQL] Should only update out-going-ref-count for referenced outer CTE relation ### What changes were proposed in this pull request? This PR fixes a long-standing reference counting bug in the rule `InlineCTE`. Let's look at the minimal repro: ``` sql( """ |WITH |t1 AS (SELECT 1 col), |t2 AS (SELECT * FROM t1) |SELECT * FROM t2 |""".stripMargin).createTempView("v") // r1 is un-referenced, but it should not decrease the ref count of t2 inside view v. val df = sql( """ |WITH |r1 AS (SELECT * FROM v), |r2 AS (SELECT * FROM v) |SELECT * FROM r2 |""".stripMargin) ``` The logical plan is something like below ``` WithCTE CTEDef r1 View v WithCTE CTEDef t1 OneRowRelation CTEDef t2 CTERef t1 CTERef t2 // main query of the inner WithCTE CTEDef r2 View v // exactly the same as the view v above WithCTE CTEDef t1 OneRowRelation CTEDef t2 CTERef t1 CTERef t2 CTERef r2 // main query of the outer WithCTE ``` Ideally, the ref count of `t1`, `t2` and `r2` should be all `1`. They will be inlined and the final plan is the `OneRowRelation`. However, in `InlineCTE#buildCTEMap`, when we traverse into `CTEDef r1` and hit `CTERef t2`, we mistakenly update the out-going-ref-count of `r1`, which means that `r1` references `t2` and this is totally wrong. Later on, in `InlineCTE#cleanCTEMap`, we find that `r1` is not referenced at all, so we decrease the ref count of its out-going-ref, which is `t2`, and the ref count of `t2` becomes `0`. Finally, in `InlineCTE#inlineCTE`, we leave the plan of `t2` unchanged because its ref count is `0`, and the plan of `t2` contains `CTERef t1`. `t2` is still inlined so we end up with `CTERef t1` as the final plan without the `WithCTE` node. ### Why are the changes needed? bug fix ### Does this PR introduce _any_ user-facing change? Yes, the query failed before and now can work ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #48284 from cloud-fan/cte. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/catalyst/optimizer/InlineCTE.scala | 31 ++++++++++++------- .../org/apache/spark/sql/CTEInlineSuite.scala | 21 +++++++++++++ 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala index 19aa1d96ccd3f..b3384c4e29566 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InlineCTE.scala @@ -71,13 +71,13 @@ case class InlineCTE( * @param plan The plan to collect the CTEs from * @param cteMap A mutable map that accumulates the CTEs and their reference information by CTE * ids. - * @param outerCTEId While collecting the map we use this optional CTE id to identify the - * current outer CTE. + * @param collectCTERefs A function to collect CTE references so that the caller side can do some + * bookkeeping work. */ private def buildCTEMap( plan: LogicalPlan, cteMap: mutable.Map[Long, CTEReferenceInfo], - outerCTEId: Option[Long] = None): Unit = { + collectCTERefs: CTERelationRef => Unit = _ => ()): Unit = { plan match { case WithCTE(child, cteDefs) => cteDefs.foreach { cteDef => @@ -89,26 +89,35 @@ case class InlineCTE( ) } cteDefs.foreach { cteDef => - buildCTEMap(cteDef, cteMap, Some(cteDef.id)) + buildCTEMap(cteDef, cteMap, ref => { + // A CTE relation can references CTE relations defined before it in the same `WithCTE`. + // Here we update the out-going-ref-count for it, in case this CTE relation is not + // referenced at all and can be optimized out, and we need to decrease the ref counts + // for CTE relations that are referenced by it. + if (cteDefs.exists(_.id == ref.cteId)) { + cteMap(cteDef.id).increaseOutgoingRefCount(ref.cteId, 1) + } + // Similarly, a CTE relation can reference CTE relations defined in the outer `WithCTE`. + // Here we call the `collectCTERefs` function so that the outer CTE can also update the + // out-going-ref-count if needed. + collectCTERefs(ref) + }) } - buildCTEMap(child, cteMap, outerCTEId) + buildCTEMap(child, cteMap, collectCTERefs) case ref: CTERelationRef => cteMap(ref.cteId) = cteMap(ref.cteId).withRefCountIncreased(1) - outerCTEId.foreach { cteId => - cteMap(cteId).increaseOutgoingRefCount(ref.cteId, 1) - } - + collectCTERefs(ref) case _ => if (plan.containsPattern(CTE)) { plan.children.foreach { child => - buildCTEMap(child, cteMap, outerCTEId) + buildCTEMap(child, cteMap, collectCTERefs) } plan.expressions.foreach { expr => if (expr.containsAllPatterns(PLAN_EXPRESSION, CTE)) { expr.foreach { - case e: SubqueryExpression => buildCTEMap(e.plan, cteMap, outerCTEId) + case e: SubqueryExpression => buildCTEMap(e.plan, cteMap, collectCTERefs) case _ => } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala index 7b608b7438c29..7a2ce1d7836b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CTEInlineSuite.scala @@ -714,6 +714,27 @@ abstract class CTEInlineSuiteBase |""".stripMargin) checkAnswer(df, Row(1)) } + + test("SPARK-49816: should only update out-going-ref-count for referenced outer CTE relation") { + withView("v") { + sql( + """ + |WITH + |t1 AS (SELECT 1 col), + |t2 AS (SELECT * FROM t1) + |SELECT * FROM t2 + |""".stripMargin).createTempView("v") + // r1 is un-referenced, but it should not decrease the ref count of t2 inside view v. + val df = sql( + """ + |WITH + |r1 AS (SELECT * FROM v), + |r2 AS (SELECT * FROM v) + |SELECT * FROM r2 + |""".stripMargin) + checkAnswer(df, Row(1)) + } + } } class CTEInlineSuiteAEOff extends CTEInlineSuiteBase with DisableAdaptiveExecutionSuite From 3065dd92ab8f36b019c7be06da59d47c1865fe60 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 30 Sep 2024 21:31:15 +0800 Subject: [PATCH 124/250] [SPARK-49561][SQL] Add SQL pipe syntax for the PIVOT and UNPIVOT operators ### What changes were proposed in this pull request? This PR adds SQL pipe syntax support for the PIVOT and UNPIVOT operators. For example: ``` CREATE TEMPORARY VIEW courseSales AS SELECT * FROM VALUES ("dotNET", 2012, 10000), ("Java", 2012, 20000), ("dotNET", 2012, 5000), ("dotNET", 2013, 48000), ("Java", 2013, 30000) as courseSales(course, year, earnings); TABLE courseSales |> SELECT `year`, course, earnings |> PIVOT ( SUM(earnings) FOR course IN ('dotNET', 'Java') ); 2012 15000 20000 2013 48000 30000 ``` ### Why are the changes needed? The SQL pipe operator syntax will let users compose queries in a more flexible fashion. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds a few unit test cases, but mostly relies on golden file test coverage. I did this to make sure the answers are correct as this feature is implemented and also so we can look at the analyzer output plans to ensure they look right as well. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48093 from dtenedor/pipe-pivot. Authored-by: Daniel Tenedorio Signed-off-by: Wenchen Fan --- .../sql/catalyst/parser/SqlBaseParser.g4 | 5 + .../sql/catalyst/parser/AstBuilder.scala | 12 +- .../analyzer-results/pipe-operators.sql.out | 352 ++++++++++++++++++ .../sql-tests/inputs/pipe-operators.sql | 141 +++++++ .../sql-tests/results/pipe-operators.sql.out | 309 +++++++++++++++ .../sql/execution/SparkSqlParserSuite.scala | 39 +- 6 files changed, 849 insertions(+), 9 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 866634b041280..33ac3249eb663 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1499,6 +1499,11 @@ version operatorPipeRightSide : selectClause | whereClause + // The following two cases match the PIVOT or UNPIVOT clause, respectively. + // For each one, we add the other clause as an option in order to return high-quality error + // messages in the event that both are present (this is not allowed). + | pivotClause unpivotClause? + | unpivotClause pivotClause? ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index ed6cf329eeca8..e2350474a8708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5893,7 +5893,17 @@ class AstBuilder extends DataTypeAstBuilder SubqueryAlias(SubqueryAlias.generateSubqueryName(), left) } withWhereClause(c, withSubqueryAlias) - }.get) + }.getOrElse(Option(ctx.pivotClause()).map { c => + if (ctx.unpivotClause() != null) { + throw QueryParsingErrors.unpivotWithPivotInFromClauseNotAllowedError(ctx) + } + withPivot(c, left) + }.getOrElse(Option(ctx.unpivotClause()).map { c => + if (ctx.pivotClause() != null) { + throw QueryParsingErrors.unpivotWithPivotInFromClauseNotAllowedError(ctx) + } + withUnpivot(c, left) + }.get))) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index c44ce153a2f41..8cd062aeb01a3 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -62,6 +62,74 @@ InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_d +- LocalRelation [col1#x, col2#x] +-- !query +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query analysis +CreateViewCommand `courseSales`, select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings), false, false, LocalTempView, UNSUPPORTED, true + +- Project [course#x, year#x, earnings#x] + +- SubqueryAlias courseSales + +- LocalRelation [course#x, year#x, earnings#x] + + +-- !query +create temporary view courseEarnings as select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`) +-- !query analysis +CreateViewCommand `courseEarnings`, select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`), false, false, LocalTempView, UNSUPPORTED, true + +- Project [course#x, 2012#x, 2013#x, 2014#x] + +- SubqueryAlias courseEarnings + +- LocalRelation [course#x, 2012#x, 2013#x, 2014#x] + + +-- !query +create temporary view courseEarningsAndSales as select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014) +-- !query analysis +CreateViewCommand `courseEarningsAndSales`, select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014), false, false, LocalTempView, UNSUPPORTED, true + +- Project [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + +- SubqueryAlias courseEarningsAndSales + +- LocalRelation [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + + +-- !query +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) +-- !query analysis +CreateViewCommand `yearsWithComplexTypes`, select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s), false, false, LocalTempView, UNSUPPORTED, true + +- Project [y#x, a#x, m#x, s#x] + +- SubqueryAlias yearsWithComplexTypes + +- LocalRelation [y#x, a#x, m#x, s#x] + + -- !query table t |> select 1 as x @@ -569,6 +637,290 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +table courseSales +|> select `year`, course, earnings +|> pivot ( + sum(earnings) + for course in ('dotNET', 'Java') + ) +-- !query analysis +Project [year#x, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[0] AS dotNET#xL, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[1] AS Java#xL] ++- Aggregate [year#x], [year#x, pivotfirst(course#x, sum(coursesales.earnings)#xL, dotNET, Java, 0, 0) AS __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x] + +- Aggregate [year#x, course#x], [year#x, course#x, sum(earnings#x) AS sum(coursesales.earnings)#xL] + +- Project [year#x, course#x, earnings#x] + +- SubqueryAlias coursesales + +- View (`courseSales`, [course#x, year#x, earnings#x]) + +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + +- Project [course#x, year#x, earnings#x] + +- SubqueryAlias courseSales + +- LocalRelation [course#x, year#x, earnings#x] + + +-- !query +table courseSales +|> select `year` as y, course as c, earnings as e +|> pivot ( + sum(e) as s, avg(e) as a + for y in (2012 as firstYear, 2013 as secondYear) + ) +-- !query analysis +Project [c#x, __pivot_sum(e) AS s AS `sum(e) AS s`#x[0] AS firstYear_s#xL, __pivot_avg(e) AS a AS `avg(e) AS a`#x[0] AS firstYear_a#x, __pivot_sum(e) AS s AS `sum(e) AS s`#x[1] AS secondYear_s#xL, __pivot_avg(e) AS a AS `avg(e) AS a`#x[1] AS secondYear_a#x] ++- Aggregate [c#x], [c#x, pivotfirst(y#x, sum(e) AS s#xL, 2012, 2013, 0, 0) AS __pivot_sum(e) AS s AS `sum(e) AS s`#x, pivotfirst(y#x, avg(e) AS a#x, 2012, 2013, 0, 0) AS __pivot_avg(e) AS a AS `avg(e) AS a`#x] + +- Aggregate [c#x, y#x], [c#x, y#x, sum(e#x) AS sum(e) AS s#xL, avg(e#x) AS avg(e) AS a#x] + +- Project [pipeselect(year#x) AS y#x, pipeselect(course#x) AS c#x, pipeselect(earnings#x) AS e#x] + +- SubqueryAlias coursesales + +- View (`courseSales`, [course#x, year#x, earnings#x]) + +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + +- Project [course#x, year#x, earnings#x] + +- SubqueryAlias courseSales + +- LocalRelation [course#x, year#x, earnings#x] + + +-- !query +select course, `year`, y, a +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + max(a) + for (y, course) in ((2012, 'dotNET'), (2013, 'Java')) + ) +-- !query analysis +Aggregate [year#x], [year#x, max(if ((named_struct(y, y#x, course, course#x) <=> cast(named_struct(col1, 2012, col2, dotNET) as struct))) a#x else cast(null as array)) AS {2012, dotNET}#x, max(if ((named_struct(y, y#x, course, course#x) <=> cast(named_struct(col1, 2013, col2, Java) as struct))) a#x else cast(null as array)) AS {2013, Java}#x] ++- Project [course#x, year#x, y#x, a#x] + +- Join Inner, (year#x = y#x) + :- SubqueryAlias coursesales + : +- View (`courseSales`, [course#x, year#x, earnings#x]) + : +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + : +- Project [course#x, year#x, earnings#x] + : +- SubqueryAlias courseSales + : +- LocalRelation [course#x, year#x, earnings#x] + +- SubqueryAlias yearswithcomplextypes + +- View (`yearsWithComplexTypes`, [y#x, a#x, m#x, s#x]) + +- Project [cast(y#x as int) AS y#x, cast(a#x as array) AS a#x, cast(m#x as map) AS m#x, cast(s#x as struct) AS s#x] + +- Project [y#x, a#x, m#x, s#x] + +- SubqueryAlias yearsWithComplexTypes + +- LocalRelation [y#x, a#x, m#x, s#x] + + +-- !query +select earnings, `year`, s +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + sum(earnings) + for s in ((1, 'a'), (2, 'b')) + ) +-- !query analysis +Project [year#x, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[0] AS {1, a}#xL, __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x[1] AS {2, b}#xL] ++- Aggregate [year#x], [year#x, pivotfirst(s#x, sum(coursesales.earnings)#xL, [1,a], [2,b], 0, 0) AS __pivot_sum(coursesales.earnings) AS `sum(coursesales.earnings)`#x] + +- Aggregate [year#x, s#x], [year#x, s#x, sum(earnings#x) AS sum(coursesales.earnings)#xL] + +- Project [earnings#x, year#x, s#x] + +- Join Inner, (year#x = y#x) + :- SubqueryAlias coursesales + : +- View (`courseSales`, [course#x, year#x, earnings#x]) + : +- Project [cast(course#x as string) AS course#x, cast(year#x as int) AS year#x, cast(earnings#x as int) AS earnings#x] + : +- Project [course#x, year#x, earnings#x] + : +- SubqueryAlias courseSales + : +- LocalRelation [course#x, year#x, earnings#x] + +- SubqueryAlias yearswithcomplextypes + +- View (`yearsWithComplexTypes`, [y#x, a#x, m#x, s#x]) + +- Project [cast(y#x as int) AS y#x, cast(a#x as array) AS a#x, cast(m#x as map) AS m#x, cast(s#x as struct) AS s#x] + +- Project [y#x, a#x, m#x, s#x] + +- SubqueryAlias yearsWithComplexTypes + +- LocalRelation [y#x, a#x, m#x, s#x] + + +-- !query +table courseEarnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query analysis +Filter isnotnull(coalesce(earningsYear#x)) ++- Expand [[course#x, 2012, 2012#x], [course#x, 2013, 2013#x], [course#x, 2014, 2014#x]], [course#x, year#x, earningsYear#x] + +- SubqueryAlias courseearnings + +- View (`courseEarnings`, [course#x, 2012#x, 2013#x, 2014#x]) + +- Project [cast(course#x as string) AS course#x, cast(2012#x as int) AS 2012#x, cast(2013#x as int) AS 2013#x, cast(2014#x as int) AS 2014#x] + +- Project [course#x, 2012#x, 2013#x, 2014#x] + +- SubqueryAlias courseEarnings + +- LocalRelation [course#x, 2012#x, 2013#x, 2014#x] + + +-- !query +table courseEarnings +|> unpivot include nulls ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query analysis +Expand [[course#x, 2012, 2012#x], [course#x, 2013, 2013#x], [course#x, 2014, 2014#x]], [course#x, year#x, earningsYear#x] ++- SubqueryAlias courseearnings + +- View (`courseEarnings`, [course#x, 2012#x, 2013#x, 2014#x]) + +- Project [cast(course#x as string) AS course#x, cast(2012#x as int) AS 2012#x, cast(2013#x as int) AS 2013#x, cast(2014#x as int) AS 2014#x] + +- Project [course#x, 2012#x, 2013#x, 2014#x] + +- SubqueryAlias courseEarnings + +- LocalRelation [course#x, 2012#x, 2013#x, 2014#x] + + +-- !query +table courseEarningsAndSales +|> unpivot include nulls ( + (earnings, sales) for `year` in ( + (earnings2012, sales2012) as `2012`, + (earnings2013, sales2013) as `2013`, + (earnings2014, sales2014) as `2014`) + ) +-- !query analysis +Expand [[course#x, 2012, earnings2012#x, sales2012#x], [course#x, 2013, earnings2013#x, sales2013#x], [course#x, 2014, earnings2014#x, sales2014#x]], [course#x, year#x, earnings#x, sales#x] ++- SubqueryAlias courseearningsandsales + +- View (`courseEarningsAndSales`, [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x]) + +- Project [cast(course#x as string) AS course#x, cast(earnings2012#x as int) AS earnings2012#x, cast(sales2012#x as int) AS sales2012#x, cast(earnings2013#x as int) AS earnings2013#x, cast(sales2013#x as int) AS sales2013#x, cast(earnings2014#x as int) AS earnings2014#x, cast(sales2014#x as int) AS sales2014#x] + +- Project [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + +- SubqueryAlias courseEarningsAndSales + +- LocalRelation [course#x, earnings2012#x, sales2012#x, earnings2013#x, sales2013#x, earnings2014#x, sales2014#x] + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`year`", + "proposal" : "`course`, `earnings`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 49, + "stopIndex" : 111, + "fragment" : "pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> pivot ( + sum(earnings) + for `year` in (course, 2013) + ) +-- !query analysis +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_LITERAL_PIVOT_VALUES", + "sqlState" : "42K08", + "messageParameters" : { + "expression" : "\"course\"" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )\n unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )\n pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'pivot'", + "hint" : "" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'unpivot'", + "hint" : "" + } +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 49a72137ee047..3aa01d472e83f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -12,6 +12,30 @@ drop table if exists st; create table st(x int, col struct) using parquet; insert into st values (1, (2, 3)); +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings); + +create temporary view courseEarnings as select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`); + +create temporary view courseEarningsAndSales as select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014); + +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s); + -- SELECT operators: positive tests. --------------------------------------- @@ -185,6 +209,123 @@ table t (select x, sum(length(y)) as sum_len from t group by x) |> where sum(length(y)) = 3; +-- Pivot and unpivot operators: positive tests. +----------------------------------------------- + +table courseSales +|> select `year`, course, earnings +|> pivot ( + sum(earnings) + for course in ('dotNET', 'Java') + ); + +table courseSales +|> select `year` as y, course as c, earnings as e +|> pivot ( + sum(e) as s, avg(e) as a + for y in (2012 as firstYear, 2013 as secondYear) + ); + +-- Pivot on multiple pivot columns with aggregate columns of complex data types. +select course, `year`, y, a +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + max(a) + for (y, course) in ((2012, 'dotNET'), (2013, 'Java')) + ); + +-- Pivot on pivot column of struct type. +select earnings, `year`, s +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + sum(earnings) + for s in ((1, 'a'), (2, 'b')) + ); + +table courseEarnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ); + +table courseEarnings +|> unpivot include nulls ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ); + +table courseEarningsAndSales +|> unpivot include nulls ( + (earnings, sales) for `year` in ( + (earnings2012, sales2012) as `2012`, + (earnings2013, sales2013) as `2013`, + (earnings2014, sales2014) as `2014`) + ); + +-- Pivot and unpivot operators: negative tests. +----------------------------------------------- + +-- The PIVOT operator refers to a column 'year' is not available in the input relation. +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + +-- Non-literal PIVOT values are not supported. +table courseSales +|> pivot ( + sum(earnings) + for `year` in (course, 2013) + ); + +-- The PIVOT and UNPIVOT clauses are mutually exclusive. +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ); + +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + +-- Multiple PIVOT and/or UNPIVOT clauses are not supported in the same pipe operator. +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ); + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 38436b0941034..2c6abe2a277ad 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -71,6 +71,54 @@ struct<> +-- !query +create temporary view courseSales as select * from values + ("dotNET", 2012, 10000), + ("Java", 2012, 20000), + ("dotNET", 2012, 5000), + ("dotNET", 2013, 48000), + ("Java", 2013, 30000) + as courseSales(course, year, earnings) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view courseEarnings as select * from values + ("dotNET", 15000, 48000, 22500), + ("Java", 20000, 30000, NULL) + as courseEarnings(course, `2012`, `2013`, `2014`) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view courseEarningsAndSales as select * from values + ("dotNET", 15000, NULL, 48000, 1, 22500, 1), + ("Java", 20000, 1, 30000, 2, NULL, NULL) + as courseEarningsAndSales( + course, earnings2012, sales2012, earnings2013, sales2013, earnings2014, sales2014) +-- !query schema +struct<> +-- !query output + + + +-- !query +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) +-- !query schema +struct<> +-- !query output + + + -- !query table t |> select 1 as x @@ -552,6 +600,267 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +table courseSales +|> select `year`, course, earnings +|> pivot ( + sum(earnings) + for course in ('dotNET', 'Java') + ) +-- !query schema +struct +-- !query output +2012 15000 20000 +2013 48000 30000 + + +-- !query +table courseSales +|> select `year` as y, course as c, earnings as e +|> pivot ( + sum(e) as s, avg(e) as a + for y in (2012 as firstYear, 2013 as secondYear) + ) +-- !query schema +struct +-- !query output +Java 20000 20000.0 30000 30000.0 +dotNET 15000 7500.0 48000 48000.0 + + +-- !query +select course, `year`, y, a +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + max(a) + for (y, course) in ((2012, 'dotNET'), (2013, 'Java')) + ) +-- !query schema +struct,{2013, Java}:array> +-- !query output +2012 [1,1] NULL +2013 NULL [2,2] + + +-- !query +select earnings, `year`, s +from courseSales +join yearsWithComplexTypes on `year` = y +|> pivot ( + sum(earnings) + for s in ((1, 'a'), (2, 'b')) + ) +-- !query schema +struct +-- !query output +2012 35000 NULL +2013 NULL 78000 + + +-- !query +table courseEarnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query schema +struct +-- !query output +Java 2012 20000 +Java 2013 30000 +dotNET 2012 15000 +dotNET 2013 48000 +dotNET 2014 22500 + + +-- !query +table courseEarnings +|> unpivot include nulls ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query schema +struct +-- !query output +Java 2012 20000 +Java 2013 30000 +Java 2014 NULL +dotNET 2012 15000 +dotNET 2013 48000 +dotNET 2014 22500 + + +-- !query +table courseEarningsAndSales +|> unpivot include nulls ( + (earnings, sales) for `year` in ( + (earnings2012, sales2012) as `2012`, + (earnings2013, sales2013) as `2013`, + (earnings2014, sales2014) as `2014`) + ) +-- !query schema +struct +-- !query output +Java 2012 20000 1 +Java 2013 30000 2 +Java 2014 NULL NULL +dotNET 2012 15000 NULL +dotNET 2013 48000 1 +dotNET 2014 22500 1 + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42703", + "messageParameters" : { + "objectName" : "`year`", + "proposal" : "`course`, `earnings`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 49, + "stopIndex" : 111, + "fragment" : "pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> pivot ( + sum(earnings) + for `year` in (course, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "NON_LITERAL_PIVOT_VALUES", + "sqlState" : "42K08", + "messageParameters" : { + "expression" : "\"course\"" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )\n unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "NOT_ALLOWED_IN_FROM.UNPIVOT_WITH_PIVOT", + "sqlState" : "42601", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 186, + "fragment" : "table courseSales\n|> select course, earnings\n|> unpivot (\n earningsYear for `year` in (`2012`, `2013`, `2014`)\n )\n pivot (\n sum(earnings)\n for `year` in (2012, 2013)\n )" + } ] +} + + +-- !query +table courseSales +|> select course, earnings +|> pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'pivot'", + "hint" : "" + } +} + + +-- !query +table courseSales +|> select course, earnings +|> unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + unpivot ( + earningsYear for `year` in (`2012`, `2013`, `2014`) + ) + pivot ( + sum(earnings) + for `year` in (2012, 2013) + ) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "PARSE_SYNTAX_ERROR", + "sqlState" : "42601", + "messageParameters" : { + "error" : "'unpivot'", + "hint" : "" + } +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index ab949c5a21e44..1111a65c6a526 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, UnresolvedAlias, Un import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Concat, GreaterThan, Literal, NullsFirst, SortOrder, UnresolvedWindowExpression, UnspecifiedFrame, WindowSpecDefinition, WindowSpecReference} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, LOCAL_RELATION, PROJECT, UNRESOLVED_RELATION} +import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTempViewUsing, RefreshResource} @@ -887,24 +887,47 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { // Basic selection. // Here we check that every parsed plan contains a projection and a source relation or // inline table. - def checkPipeSelect(query: String): Unit = { + def check(query: String, patterns: Seq[TreePattern]): Unit = { val plan: LogicalPlan = parser.parsePlan(query) - assert(plan.containsPattern(PROJECT)) + assert(patterns.exists(plan.containsPattern)) assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) } + def checkPipeSelect(query: String): Unit = check(query, Seq(PROJECT)) checkPipeSelect("TABLE t |> SELECT 1 AS X") checkPipeSelect("TABLE t |> SELECT 1 AS X, 2 AS Y |> SELECT X + Y AS Z") checkPipeSelect("VALUES (0), (1) tab(col) |> SELECT col * 2 AS result") // Basic WHERE operators. - def checkPipeWhere(query: String): Unit = { - val plan: LogicalPlan = parser.parsePlan(query) - assert(plan.containsPattern(FILTER)) - assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) - } + def checkPipeWhere(query: String): Unit = check(query, Seq(FILTER)) checkPipeWhere("TABLE t |> WHERE X = 1") checkPipeWhere("TABLE t |> SELECT X, LENGTH(Y) AS Z |> WHERE X + LENGTH(Y) < 4") checkPipeWhere("TABLE t |> WHERE X = 1 AND Y = 2 |> WHERE X + Y = 3") checkPipeWhere("VALUES (0), (1) tab(col) |> WHERE col < 1") + // PIVOT and UNPIVOT operations + def checkPivotUnpivot(query: String): Unit = check(query, Seq(PIVOT, UNPIVOT)) + checkPivotUnpivot( + """ + |SELECT * FROM VALUES + | ("dotNET", 2012, 10000), + | ("Java", 2012, 20000), + | ("dotNET", 2012, 5000), + | ("dotNET", 2013, 48000), + | ("Java", 2013, 30000) + | AS courseSales(course, year, earnings) + ||> PIVOT ( + | SUM(earnings) + | FOR course IN ('dotNET', 'Java') + |) + |""".stripMargin) + checkPivotUnpivot( + """ + |SELECT * FROM VALUES + | ("dotNET", 15000, 48000, 22500), + | ("Java", 20000, 30000, NULL) + | AS courseEarnings(course, `2012`, `2013`, `2014`) + ||> UNPIVOT ( + | earningsYear FOR year IN (`2012`, `2013`, `2014`) + |) + |""".stripMargin) } } } From a7fa2700e0f0f70ec6306f48a5bd137225029b80 Mon Sep 17 00:00:00 2001 From: Julek Sompolski Date: Mon, 30 Sep 2024 23:39:50 +0800 Subject: [PATCH 125/250] [SPARK-48196][SQL] Turn QueryExecution lazy val plans into LazyTry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Currently, when evaluation of `lazy val` of some of the plans fails in QueryExecution, this `lazy val` remains not initialized, and another attempt will be made to initialize it the next time it's referenced. This leads to planning being performed multiple times, resulting in inefficiencies, and potential duplication of side effects, for example from ConvertToLocalRelation that can pull in UDFs with side effects. ### Why are the changes needed? Current behaviour leads to inefficiencies and subtle problems in accidental situations, for example when plans are accessed for logging purposes. ### Does this PR introduce _any_ user-facing change? Yes. This change would bring slight behaviour changes: Examples: ``` val df = a.join(b) spark.conf.set(“spark.sql.crossJoin.enabled”, “false”) try { df.collect() } catch { case _ => } spark.conf.set(“spark.sql.crossJoin.enabled”, “true”) df.collect() ``` This used to succeed, because the first time around the plan will not be initialized because it threw an error because of the cartprod, and the second time around it will try to initialize it again and pick up the new config. This will now fail, because the second execution will retrieve the error from the first time around instead of retrying. The old semantics is if plan evaluation fails, try again next time it's accessed and if plan evaluation ever succeeded, keep that plan. The new semantics is that if plan evaluation fails, it keeps that error and rethrows it next time the plan is accessed. A new QueryExecution object / new Dataset is needed to reset it. Spark 4.0 may be a good candidate for a slight change in this, to make sure that we don't re-execute the optimizer, and potential side effects of it. Note: These behaviour changes have already happened in Spark Connect mode, where the Dataset object is not reused across execution. This change makes Spark Classic and Spark Connect behave the same again. ### How was this patch tested? Existing tests shows no issues, except for the tests that exhibit the behaviour change described above. ### Was this patch authored or co-authored using generative AI tooling? Trivial code completion suggestions. Generated-by: Github Copilot Closes #48211 from juliuszsompolski/SPARK-48196-lazyplans. Authored-by: Julek Sompolski Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests/test_udf.py | 3 +- .../spark/sql/execution/QueryExecution.scala | 63 ++++++++++++------- 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 6f672b0ae5fb3..879329bd80c0b 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -237,11 +237,12 @@ def test_udf_in_join_condition(self): f = udf(lambda a, b: a == b, BooleanType()) # The udf uses attributes from both sides of join, so it is pulled out as Filter + # Cross join. - df = left.join(right, f("a", "b")) with self.sql_conf({"spark.sql.crossJoin.enabled": False}): + df = left.join(right, f("a", "b")) with self.assertRaisesRegex(AnalysisException, "Detected implicit cartesian product"): df.collect() with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + df = left.join(right, f("a", "b")) self.assertEqual(df.collect(), [Row(a=1, b=1)]) def test_udf_in_left_outer_join_condition(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 5c894eb7555b1..6ff2c5d4b9d32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -46,8 +46,8 @@ import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.util.{LazyTry, Utils} import org.apache.spark.util.ArrayImplicits._ -import org.apache.spark.util.Utils /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -86,7 +86,7 @@ class QueryExecution( } } - lazy val analyzed: LogicalPlan = { + private val lazyAnalyzed = LazyTry { val plan = executePhase(QueryPlanningTracker.ANALYSIS) { // We can't clone `logical` here, which will reset the `_analyzed` flag. sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker) @@ -95,12 +95,18 @@ class QueryExecution( plan } - lazy val commandExecuted: LogicalPlan = mode match { - case CommandExecutionMode.NON_ROOT => analyzed.mapChildren(eagerlyExecuteCommands) - case CommandExecutionMode.ALL => eagerlyExecuteCommands(analyzed) - case CommandExecutionMode.SKIP => analyzed + def analyzed: LogicalPlan = lazyAnalyzed.get + + private val lazyCommandExecuted = LazyTry { + mode match { + case CommandExecutionMode.NON_ROOT => analyzed.mapChildren(eagerlyExecuteCommands) + case CommandExecutionMode.ALL => eagerlyExecuteCommands(analyzed) + case CommandExecutionMode.SKIP => analyzed + } } + def commandExecuted: LogicalPlan = lazyCommandExecuted.get + private def commandExecutionName(command: Command): String = command match { case _: CreateTableAsSelect => "create" case _: ReplaceTableAsSelect => "replace" @@ -141,22 +147,28 @@ class QueryExecution( } } - // The plan that has been normalized by custom rules, so that it's more likely to hit cache. - lazy val normalized: LogicalPlan = { + private val lazyNormalized = LazyTry { QueryExecution.normalize(sparkSession, commandExecuted, Some(tracker)) } - lazy val withCachedData: LogicalPlan = sparkSession.withActive { - assertAnalyzed() - assertSupported() - // clone the plan to avoid sharing the plan instance between different stages like analyzing, - // optimizing and planning. - sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + // The plan that has been normalized by custom rules, so that it's more likely to hit cache. + def normalized: LogicalPlan = lazyNormalized.get + + private val lazyWithCachedData = LazyTry { + sparkSession.withActive { + assertAnalyzed() + assertSupported() + // clone the plan to avoid sharing the plan instance between different stages like analyzing, + // optimizing and planning. + sparkSession.sharedState.cacheManager.useCachedData(normalized.clone()) + } } + def withCachedData: LogicalPlan = lazyWithCachedData.get + def assertCommandExecuted(): Unit = commandExecuted - lazy val optimizedPlan: LogicalPlan = { + private val lazyOptimizedPlan = LazyTry { // We need to materialize the commandExecuted here because optimizedPlan is also tracked under // the optimizing phase assertCommandExecuted() @@ -174,9 +186,11 @@ class QueryExecution( } } + def optimizedPlan: LogicalPlan = lazyOptimizedPlan.get + def assertOptimized(): Unit = optimizedPlan - lazy val sparkPlan: SparkPlan = { + private val lazySparkPlan = LazyTry { // We need to materialize the optimizedPlan here because sparkPlan is also tracked under // the planning phase assertOptimized() @@ -187,11 +201,11 @@ class QueryExecution( } } + def sparkPlan: SparkPlan = lazySparkPlan.get + def assertSparkPlanPrepared(): Unit = sparkPlan - // executedPlan should not be used to initialize any SparkPlan. It should be - // only used for execution. - lazy val executedPlan: SparkPlan = { + private val lazyExecutedPlan = LazyTry { // We need to materialize the optimizedPlan here, before tracking the planning phase, to ensure // that the optimization time is not counted as part of the planning phase. assertOptimized() @@ -206,8 +220,16 @@ class QueryExecution( plan } + // executedPlan should not be used to initialize any SparkPlan. It should be + // only used for execution. + def executedPlan: SparkPlan = lazyExecutedPlan.get + def assertExecutedPlanPrepared(): Unit = executedPlan + val lazyToRdd = LazyTry { + new SQLExecutionRDD(executedPlan.execute(), sparkSession.sessionState.conf) + } + /** * Internal version of the RDD. Avoids copies and has no schema. * Note for callers: Spark may apply various optimization including reusing object: this means @@ -218,8 +240,7 @@ class QueryExecution( * Given QueryExecution is not a public class, end users are discouraged to use this: please * use `Dataset.rdd` instead where conversion will be applied. */ - lazy val toRdd: RDD[InternalRow] = new SQLExecutionRDD( - executedPlan.execute(), sparkSession.sessionState.conf) + def toRdd: RDD[InternalRow] = lazyToRdd.get /** Get the metrics observed during the execution of the query plan. */ def observedMetrics: Map[String, Row] = CollectMetricsExec.collect(executedPlan) From d68048b06a046cc67ff431fdd8a687b0a1f43603 Mon Sep 17 00:00:00 2001 From: prathit06 Date: Mon, 30 Sep 2024 14:26:01 -0700 Subject: [PATCH 126/250] [SPARK-49833][K8S] Support user-defined annotations for OnDemand PVCs ### What changes were proposed in this pull request? Currently for on-demand PVCs we cannot add user-defined annotations, user-defined annotations can greatly help to add tags in underlying storage. For e.g. if we add `k8s-pvc-tagger/tags` annotation & provide a map like {"env":"dev"}, the same tags are reflected on underlying storage (for e.g. AWS EBS) ### Why are the changes needed? Changes are needed so users can set custom annotations to PVCs ### Does this PR introduce _any_ user-facing change? It does not break any existing behaviour but adds a new feature/improvement to enable custom annotations additions to ondemand PVCs ### How was this patch tested? This was tested in internal/production k8 cluster ### Was this patch authored or co-authored using generative AI tooling? No Closes #48299 from prathit06/ondemand-pvc-annotations. Authored-by: prathit06 Signed-off-by: Dongjoon Hyun --- docs/running-on-kubernetes.md | 18 +++++ .../org/apache/spark/deploy/k8s/Config.scala | 1 + .../deploy/k8s/KubernetesVolumeSpec.scala | 3 +- .../deploy/k8s/KubernetesVolumeUtils.scala | 14 +++- .../features/MountVolumesFeatureStep.scala | 7 +- .../spark/deploy/k8s/KubernetesTestConf.scala | 8 +- .../k8s/KubernetesVolumeUtilsSuite.scala | 42 +++++++++- .../MountVolumesFeatureStepSuite.scala | 77 +++++++++++++++++++ 8 files changed, 160 insertions(+), 10 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index d8be32e047717..f8b935fd77f5c 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -1191,6 +1191,15 @@ See the [configuration page](configuration.html) for information on Spark config 4.0.0 + + spark.kubernetes.driver.volumes.[VolumeType].[VolumeName].annotation.[AnnotationName] + (none) + + Configure Kubernetes Volume annotations passed to the Kubernetes with AnnotationName as key having specified value, must conform with Kubernetes annotations format. For example, + spark.kubernetes.driver.volumes.persistentVolumeClaim.checkpointpvc.annotation.foo=bar. + + 4.0.0 + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].mount.path (none) @@ -1236,6 +1245,15 @@ See the [configuration page](configuration.html) for information on Spark config 4.0.0 + + spark.kubernetes.executor.volumes.[VolumeType].[VolumeName].annotation.[AnnotationName] + (none) + + Configure Kubernetes Volume annotations passed to the Kubernetes with AnnotationName as key having specified value, must conform with Kubernetes annotations format. For example, + spark.kubernetes.executor.volumes.persistentVolumeClaim.checkpointpvc.annotation.foo=bar. + + 4.0.0 + spark.kubernetes.local.dirs.tmpfs false diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 9c50f8ddb00cc..db7fc85976c2a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -779,6 +779,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_VOLUMES_OPTIONS_SERVER_KEY = "options.server" val KUBERNETES_VOLUMES_LABEL_KEY = "label." + val KUBERNETES_VOLUMES_ANNOTATION_KEY = "annotation." val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." val KUBERNETES_DNS_SUBDOMAIN_NAME_MAX_LENGTH = 253 diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala index b4fe414e3cde5..b7113a562fa06 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeSpec.scala @@ -25,7 +25,8 @@ private[spark] case class KubernetesPVCVolumeConf( claimName: String, storageClass: Option[String] = None, size: Option[String] = None, - labels: Option[Map[String, String]] = None) + labels: Option[Map[String, String]] = None, + annotations: Option[Map[String, String]] = None) extends KubernetesVolumeSpecificConf private[spark] case class KubernetesEmptyDirVolumeConf( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala index 88bb998d88b7d..95821a909f351 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtils.scala @@ -47,6 +47,7 @@ object KubernetesVolumeUtils { val subPathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATH_KEY" val subPathExprKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_MOUNT_SUBPATHEXPR_KEY" val labelKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_LABEL_KEY" + val annotationKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_ANNOTATION_KEY" verifyMutuallyExclusiveOptionKeys(properties, subPathKey, subPathExprKey) val volumeLabelsMap = properties @@ -54,6 +55,11 @@ object KubernetesVolumeUtils { .map { case (k, v) => k.replaceAll(labelKey, "") -> v } + val volumeAnnotationsMap = properties + .filter(_._1.startsWith(annotationKey)) + .map { + case (k, v) => k.replaceAll(annotationKey, "") -> v + } KubernetesVolumeSpec( volumeName = volumeName, @@ -62,7 +68,7 @@ object KubernetesVolumeUtils { mountSubPathExpr = properties.getOrElse(subPathExprKey, ""), mountReadOnly = properties.get(readOnlyKey).exists(_.toBoolean), volumeConf = parseVolumeSpecificConf(properties, - volumeType, volumeName, Option(volumeLabelsMap))) + volumeType, volumeName, Option(volumeLabelsMap), Option(volumeAnnotationsMap))) }.toSeq } @@ -86,7 +92,8 @@ object KubernetesVolumeUtils { options: Map[String, String], volumeType: String, volumeName: String, - labels: Option[Map[String, String]]): KubernetesVolumeSpecificConf = { + labels: Option[Map[String, String]], + annotations: Option[Map[String, String]]): KubernetesVolumeSpecificConf = { volumeType match { case KUBERNETES_VOLUMES_HOSTPATH_TYPE => val pathKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_PATH_KEY" @@ -107,7 +114,8 @@ object KubernetesVolumeUtils { options(claimNameKey), options.get(storageClassKey), options.get(sizeLimitKey), - labels) + labels, + annotations) case KUBERNETES_VOLUMES_EMPTYDIR_TYPE => val mediumKey = s"$volumeType.$volumeName.$KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY" diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala index eea4604010b21..3d89696f19fcc 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStep.scala @@ -74,7 +74,7 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) new VolumeBuilder() .withHostPath(new HostPathVolumeSource(hostPath, volumeType)) - case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size, labels) => + case KubernetesPVCVolumeConf(claimNameTemplate, storageClass, size, labels, annotations) => val claimName = conf match { case c: KubernetesExecutorConf => claimNameTemplate @@ -91,12 +91,17 @@ private[spark] class MountVolumesFeatureStep(conf: KubernetesConf) case Some(customLabelsMap) => (customLabelsMap ++ defaultVolumeLabels).asJava case None => defaultVolumeLabels.asJava } + val volumeAnnotations = annotations match { + case Some(value) => value.asJava + case None => Map[String, String]().asJava + } additionalResources.append(new PersistentVolumeClaimBuilder() .withKind(PVC) .withApiVersion("v1") .withNewMetadata() .withName(claimName) .addToLabels(volumeLabels) + .addToAnnotations(volumeAnnotations) .endMetadata() .withNewSpec() .withStorageClassName(storageClass.get) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala index e0ddcd3d416f0..e5ed79718d733 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesTestConf.scala @@ -118,7 +118,7 @@ object KubernetesTestConf { KUBERNETES_VOLUMES_OPTIONS_PATH_KEY -> hostPath, KUBERNETES_VOLUMES_OPTIONS_TYPE_KEY -> volumeType)) - case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit, labels) => + case KubernetesPVCVolumeConf(claimName, storageClass, sizeLimit, labels, annotations) => val sconf = storageClass .map { s => (KUBERNETES_VOLUMES_OPTIONS_CLAIM_STORAGE_CLASS_KEY, s) }.toMap val lconf = sizeLimit.map { l => (KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY, l) }.toMap @@ -126,9 +126,13 @@ object KubernetesTestConf { case Some(value) => value.map { case(k, v) => s"label.$k" -> v } case None => Map() } + val aannotations = annotations match { + case Some(value) => value.map { case (k, v) => s"annotation.$k" -> v } + case None => Map() + } (KUBERNETES_VOLUMES_PVC_TYPE, Map(KUBERNETES_VOLUMES_OPTIONS_CLAIM_NAME_KEY -> claimName) ++ - sconf ++ lconf ++ llabels) + sconf ++ lconf ++ llabels ++ aannotations) case KubernetesEmptyDirVolumeConf(medium, sizeLimit) => val mconf = medium.map { m => (KUBERNETES_VOLUMES_OPTIONS_MEDIUM_KEY, m) }.toMap diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala index 1e62db725fb6e..3c57cba9a7ff0 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesVolumeUtilsSuite.scala @@ -96,7 +96,7 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === - KubernetesPVCVolumeConf("claimName", labels = Some(Map()))) + KubernetesPVCVolumeConf("claimName", labels = Some(Map()), annotations = Some(Map()))) } test("SPARK-49598: Parses persistentVolumeClaim volumes correctly with labels") { @@ -113,7 +113,8 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === KubernetesPVCVolumeConf(claimName = "claimName", - labels = Some(Map("env" -> "test", "foo" -> "bar")))) + labels = Some(Map("env" -> "test", "foo" -> "bar")), + annotations = Some(Map()))) } test("SPARK-49598: Parses persistentVolumeClaim volumes & puts " + @@ -128,7 +129,8 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { assert(volumeSpec.mountPath === "/path") assert(volumeSpec.mountReadOnly) assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === - KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map()))) + KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map()), + annotations = Some(Map()))) } test("Parses emptyDir volumes correctly") { @@ -280,4 +282,38 @@ class KubernetesVolumeUtilsSuite extends SparkFunSuite { }.getMessage assert(m.contains("smaller than 1KiB. Missing units?")) } + + test("SPARK-49833: Parses persistentVolumeClaim volumes correctly with annotations") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName") + sparkConf.set("test.persistentVolumeClaim.volumeName.annotation.key1", "value1") + sparkConf.set("test.persistentVolumeClaim.volumeName.annotation.key2", "value2") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf(claimName = "claimName", + labels = Some(Map()), + annotations = Some(Map("key1" -> "value1", "key2" -> "value2")))) + } + + test("SPARK-49833: Parses persistentVolumeClaim volumes & puts " + + "annotations as empty Map if not provided") { + val sparkConf = new SparkConf(false) + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.path", "/path") + sparkConf.set("test.persistentVolumeClaim.volumeName.mount.readOnly", "true") + sparkConf.set("test.persistentVolumeClaim.volumeName.options.claimName", "claimName") + + val volumeSpec = KubernetesVolumeUtils.parseVolumesWithPrefix(sparkConf, "test.").head + assert(volumeSpec.volumeName === "volumeName") + assert(volumeSpec.mountPath === "/path") + assert(volumeSpec.mountReadOnly) + assert(volumeSpec.volumeConf.asInstanceOf[KubernetesPVCVolumeConf] === + KubernetesPVCVolumeConf(claimName = "claimName", labels = Some(Map()), + annotations = Some(Map()))) + } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala index c94a7a6ec26a7..293773ddb9ec5 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountVolumesFeatureStepSuite.scala @@ -496,4 +496,81 @@ class MountVolumesFeatureStepSuite extends SparkFunSuite { assert(mounts(1).getMountPath === "/tmp/bar") assert(mounts(1).getSubPath === "bar") } + + test("SPARK-49833: Create and mounts persistentVolumeClaims in driver with annotations") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env" -> "test"))) + ) + + val kubernetesConf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeConf)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + assert(configuredPod.pod.getSpec.getVolumes.size() === 1) + val pvcClaim = configuredPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(pvcClaim.getClaimName.endsWith("-driver-pvc-0")) + } + + test("SPARK-49833: Create and mounts persistentVolumeClaims in executors with annotations") { + val volumeConf = KubernetesVolumeSpec( + "testVolume", + "/tmp", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = MountVolumesFeatureStep.PVC_ON_DEMAND, + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env" -> "exec-test"))) + ) + + val executorConf = KubernetesTestConf.createExecutorConf(volumes = Seq(volumeConf)) + val executorStep = new MountVolumesFeatureStep(executorConf) + val executorPod = executorStep.configurePod(SparkPod.initialPod()) + + assert(executorPod.pod.getSpec.getVolumes.size() === 1) + val executorPVC = executorPod.pod.getSpec.getVolumes.get(0).getPersistentVolumeClaim + assert(executorPVC.getClaimName.endsWith("-exec-1-pvc-0")) + } + + test("SPARK-49833: Mount multiple volumes to executor with annotations") { + val pvcVolumeConf1 = KubernetesVolumeSpec( + "checkpointVolume1", + "/checkpoints1", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = "pvcClaim1", + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env1" -> "exec-test-1"))) + ) + + val pvcVolumeConf2 = KubernetesVolumeSpec( + "checkpointVolume2", + "/checkpoints2", + "", + "", + true, + KubernetesPVCVolumeConf(claimName = "pvcClaim2", + storageClass = Some("gp3"), + size = Some("1Mi"), + annotations = Some(Map("env2" -> "exec-test-2"))) + ) + + val kubernetesConf = KubernetesTestConf.createExecutorConf( + volumes = Seq(pvcVolumeConf1, pvcVolumeConf2)) + val step = new MountVolumesFeatureStep(kubernetesConf) + val configuredPod = step.configurePod(SparkPod.initialPod()) + + assert(configuredPod.pod.getSpec.getVolumes.size() === 2) + assert(configuredPod.container.getVolumeMounts.size() === 2) + } } From 123361137bbe4db4120111777091829c5abc807a Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Mon, 30 Sep 2024 15:56:58 -0700 Subject: [PATCH 127/250] [SPARK-49732][CORE][K8S] Spark deamons should respect `spark.log.structuredLogging.enabled` conf ### What changes were proposed in this pull request? Explicitly call `Logging.uninitialize()` after `SparkConf` loading `spark-defaults.conf` ### Why are the changes needed? SPARK-49015 fixes a similar issue that affects services started through `SparkSubmit`, while for other services like SHS, there is still a chance that the logging system is initialized before `SparkConf` constructed, so `spark.log.structuredLogging.enabled` configured at `spark-defaults.conf` won't take effect. The issue only happens when the logging system is initialized before `SparkConf` loading `spark-defaults.conf`. [example 1](https://github.com/apache/spark/pull/47500#issuecomment-2320426384), when `java.net.InetAddress.getLocalHost` returns `127.0.0.1`, ``` scala> java.net.InetAddress.getLocalHost res0: java.net.InetAddress = H27212-MAC-01.local/127.0.0.1 ``` the logging system will be initialized early. ``` {"ts":"2024-09-22T12:50:37.082Z","level":"WARN","msg":"Your hostname, H27212-MAC-01.local, resolves to a loopback address: 127.0.0.1; using 192.168.32.130 instead (on interface en0)","context":{"host":"H27212-MAC-01.local","host_port":"127.0.0.1","host_port2":"192.168.32.130","network_if":"en0"},"logger":"Utils"} {"ts":"2024-09-22T12:50:37.085Z","level":"WARN","msg":"Set SPARK_LOCAL_IP if you need to bind to another address","logger":"Utils"} ``` example 2: SHS calls `Utils.initDaemon(log)` before loading `spark-defaults.conf`(inside construction of `HistoryServerArguments`) https://github.com/apache/spark/blob/d2e8c1cb60e34a1c7e92374c07d682aa5ca79145/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala#L301-L302 ``` {"ts":"2024-09-22T13:20:31.978Z","level":"INFO","msg":"Started daemon with process name: 41505H27212-MAC-01.local","logger":"HistoryServer"} {"ts":"2024-09-22T13:20:31.980Z","level":"INFO","msg":"Registering signal handler for TERM","logger":"SignalUtils"} {"ts":"2024-09-22T13:20:31.981Z","level":"INFO","msg":"Registering signal handler for HUP","logger":"SignalUtils"} {"ts":"2024-09-22T13:20:31.981Z","level":"INFO","msg":"Registering signal handler for INT","logger":"SignalUtils"} ``` then loads `spark-defaults.conf` and ignores `spark.log.structuredLogging.enabled`. ### Does this PR introduce _any_ user-facing change? No, spark structured logging is an unreleased feature. ### How was this patch tested? Write `spark.log.structuredLogging.enabled=false` in `spark-defaults.conf` 4.0.0-preview2 ``` $ SPARK_NO_DAEMONIZE=1 sbin/start-history-server.sh starting org.apache.spark.deploy.history.HistoryServer, logging to /Users/chengpan/app/spark-4.0.0-preview2-bin-hadoop3/logs/spark-chengpan-org.apache.spark.deploy.history.HistoryServer-1-H27212-MAC-01.local.out Spark Command: /Users/chengpan/.sdkman/candidates/java/current/bin/java -cp /Users/chengpan/app/spark-4.0.0-preview2-bin-hadoop3/conf/:/Users/chengpan/app/spark-4.0.0-preview2-bin-hadoop3/jars/slf4j-api-2.0.16.jar:/Users/chengpan/app/spark-4.0.0-preview2-bin-hadoop3/jars/* -Xmx1g org.apache.spark.deploy.history.HistoryServer ======================================== Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties {"ts":"2024-09-22T12:50:37.082Z","level":"WARN","msg":"Your hostname, H27212-MAC-01.local, resolves to a loopback address: 127.0.0.1; using 192.168.32.130 instead (on interface en0)","context":{"host":"H27212-MAC-01.local","host_port":"127.0.0.1","host_port2":"192.168.32.130","network_if":"en0"},"logger":"Utils"} {"ts":"2024-09-22T12:50:37.085Z","level":"WARN","msg":"Set SPARK_LOCAL_IP if you need to bind to another address","logger":"Utils"} {"ts":"2024-09-22T12:50:37.109Z","level":"INFO","msg":"Started daemon with process name: 37764H27212-MAC-01.local","logger":"HistoryServer"} {"ts":"2024-09-22T12:50:37.112Z","level":"INFO","msg":"Registering signal handler for TERM","logger":"SignalUtils"} {"ts":"2024-09-22T12:50:37.112Z","level":"INFO","msg":"Registering signal handler for HUP","logger":"SignalUtils"} {"ts":"2024-09-22T12:50:37.112Z","level":"INFO","msg":"Registering signal handler for INT","logger":"SignalUtils"} {"ts":"2024-09-22T12:50:37.258Z","level":"WARN","msg":"Unable to load native-hadoop library for your platform... using builtin-java classes where applicable","logger":"NativeCodeLoader"} {"ts":"2024-09-22T12:50:37.275Z","level":"INFO","msg":"Changing view acls to: chengpan","logger":"SecurityManager"} {"ts":"2024-09-22T12:50:37.275Z","level":"INFO","msg":"Changing modify acls to: chengpan","logger":"SecurityManager"} {"ts":"2024-09-22T12:50:37.276Z","level":"INFO","msg":"Changing view acls groups to: chengpan","logger":"SecurityManager"} {"ts":"2024-09-22T12:50:37.276Z","level":"INFO","msg":"Changing modify acls groups to: chengpan","logger":"SecurityManager"} {"ts":"2024-09-22T12:50:37.277Z","level":"INFO","msg":"SecurityManager: authentication disabled; ui acls disabled; users with view permissions: chengpan groups with view permissions: EMPTY; users with modify permissions: chengpan; groups with modify permissions: EMPTY; RPC SSL disabled","logger":"SecurityManager"} {"ts":"2024-09-22T12:50:37.309Z","level":"INFO","msg":"History server ui acls disabled; users with admin permissions: ; groups with admin permissions: ","logger":"FsHistoryProvider"} {"ts":"2024-09-22T12:50:37.409Z","level":"INFO","msg":"Start Jetty 0.0.0.0:18080 for HistoryServerUI","logger":"JettyUtils"} {"ts":"2024-09-22T12:50:37.466Z","level":"INFO","msg":"Successfully started service 'HistoryServerUI' on port 18080.","logger":"Utils"} {"ts":"2024-09-22T12:50:37.491Z","level":"INFO","msg":"Bound HistoryServer to 0.0.0.0, and started at http://192.168.32.130:18080","logger":"HistoryServer"} ... ``` This PR ``` $ SPARK_NO_DAEMONIZE=1 sbin/start-history-server.sh starting org.apache.spark.deploy.history.HistoryServer, logging to /Users/chengpan/Projects/apache-spark/dist/logs/spark-chengpan-org.apache.spark.deploy.history.HistoryServer-1-H27212-MAC-01.local.out Spark Command: /Users/chengpan/.sdkman/candidates/java/current/bin/java -cp /Users/chengpan/Projects/apache-spark/dist/conf/:/Users/chengpan/Projects/apache-spark/dist/jars/slf4j-api-2.0.16.jar:/Users/chengpan/Projects/apache-spark/dist/jars/* -Xmx1g org.apache.spark.deploy.history.HistoryServer ======================================== Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties {"ts":"2024-09-22T13:20:31.903Z","level":"WARN","msg":"Your hostname, H27212-MAC-01.local, resolves to a loopback address: 127.0.0.1; using 192.168.32.130 instead (on interface en0)","context":{"host":"H27212-MAC-01.local","host_port":"127.0.0.1","host_port2":"192.168.32.130","network_if":"en0"},"logger":"Utils"} {"ts":"2024-09-22T13:20:31.905Z","level":"WARN","msg":"Set SPARK_LOCAL_IP if you need to bind to another address","logger":"Utils"} {"ts":"2024-09-22T13:20:31.978Z","level":"INFO","msg":"Started daemon with process name: 41505H27212-MAC-01.local","logger":"HistoryServer"} {"ts":"2024-09-22T13:20:31.980Z","level":"INFO","msg":"Registering signal handler for TERM","logger":"SignalUtils"} {"ts":"2024-09-22T13:20:31.981Z","level":"INFO","msg":"Registering signal handler for HUP","logger":"SignalUtils"} {"ts":"2024-09-22T13:20:31.981Z","level":"INFO","msg":"Registering signal handler for INT","logger":"SignalUtils"} {"ts":"2024-09-22T13:20:32.136Z","level":"WARN","msg":"Unable to load native-hadoop library for your platform... using builtin-java classes where applicable","logger":"NativeCodeLoader"} Using Spark's default log4j profile: org/apache/spark/log4j2-pattern-layout-defaults.properties 24/09/22 21:20:32 INFO SecurityManager: Changing view acls to: chengpan 24/09/22 21:20:32 INFO SecurityManager: Changing modify acls to: chengpan 24/09/22 21:20:32 INFO SecurityManager: Changing view acls groups to: chengpan 24/09/22 21:20:32 INFO SecurityManager: Changing modify acls groups to: chengpan 24/09/22 21:20:32 INFO SecurityManager: SecurityManager: authentication disabled; ui acls disabled; users with view permissions: chengpan groups with view permissions: EMPTY; users with modify permissions: chengpan; groups with modify permissions: EMPTY; RPC SSL disabled 24/09/22 21:20:32 INFO FsHistoryProvider: History server ui acls disabled; users with admin permissions: ; groups with admin permissions: 24/09/22 21:20:32 INFO JettyUtils: Start Jetty 0.0.0.0:18080 for HistoryServerUI 24/09/22 21:20:32 INFO Utils: Successfully started service 'HistoryServerUI' on port 18080. 24/09/22 21:20:32 INFO HistoryServer: Bound HistoryServer to 0.0.0.0, and started at http://192.168.32.130:18080 ... ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48198 from pan3793/SPARK-49732. Authored-by: Cheng Pan Signed-off-by: Dongjoon Hyun --- .../spark/deploy/ExternalShuffleService.scala | 3 +++ .../scala/org/apache/spark/deploy/SparkSubmit.scala | 6 +----- .../deploy/history/HistoryServerArguments.scala | 3 +++ .../spark/deploy/master/MasterArguments.scala | 3 +++ .../spark/deploy/worker/WorkerArguments.scala | 4 ++++ .../executor/CoarseGrainedExecutorBackend.scala | 4 ++++ .../main/scala/org/apache/spark/util/Utils.scala | 13 +++++++++++++ .../cluster/k8s/KubernetesExecutorBackend.scala | 4 ++++ 8 files changed, 35 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index f0dcf344ce0da..57b0647e59fd9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -169,6 +169,9 @@ object ExternalShuffleService extends Logging { Utils.initDaemon(log) val sparkConf = new SparkConf Utils.loadDefaultSparkProperties(sparkConf) + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(sparkConf) + Logging.uninitialize() val securityManager = new SecurityManager(sparkConf) // we override this value since this service is started from the command line diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index f3833e85a482e..85ed441d58fd1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -79,11 +79,7 @@ private[spark] class SparkSubmit extends Logging { } else { // For non-shell applications, enable structured logging if it's not explicitly disabled // via the configuration `spark.log.structuredLogging.enabled`. - if (sparkConf.getBoolean(STRUCTURED_LOGGING_ENABLED.key, defaultValue = true)) { - Logging.enableStructuredLogging() - } else { - Logging.disableStructuredLogging() - } + Utils.resetStructuredLogging(sparkConf) } // We should initialize log again after `spark.log.structuredLogging.enabled` effected diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 2fdf7a473a298..f1343a0551384 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -53,6 +53,9 @@ private[history] class HistoryServerArguments(conf: SparkConf, args: Array[Strin // This mutates the SparkConf, so all accesses to it must be made after this line Utils.loadDefaultSparkProperties(conf, propertiesFile) + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(conf) + Logging.uninitialize() // scalastyle:off line.size.limit println private def printUsageAndExit(exitCode: Int, error: String = ""): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 045a3da74dcd0..6647b11874d72 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -53,6 +53,9 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) exte // This mutates the SparkConf, so all accesses to it must be made after this line propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(conf) + Logging.uninitialize() if (conf.contains(MASTER_UI_PORT.key)) { webUiPort = conf.get(MASTER_UI_PORT) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 94a27e1a3e6da..f24cd59418300 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -22,6 +22,7 @@ import java.lang.management.ManagementFactory import scala.annotation.tailrec import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Worker._ import org.apache.spark.util.{IntParam, MemoryParam, Utils} @@ -59,6 +60,9 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { // This mutates the SparkConf, so all accesses to it must be made after this line propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(conf) + Logging.uninitialize() conf.get(WORKER_UI_PORT).foreach { webUiPort = _ } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index eaa07b9a81f5b..e880cf8da9ec2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -468,6 +468,10 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } } + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(driverConf) + Logging.uninitialize() + cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 52213f36a2cd1..5703128aacbb9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2673,6 +2673,19 @@ private[spark] object Utils } } + /** + * Utility function to enable or disable structured logging based on SparkConf. + * This is designed for a code path which logging system may be initilized before + * loading SparkConf. + */ + def resetStructuredLogging(sparkConf: SparkConf): Unit = { + if (sparkConf.getBoolean(STRUCTURED_LOGGING_ENABLED.key, defaultValue = true)) { + Logging.enableStructuredLogging() + } else { + Logging.disableStructuredLogging() + } + } + /** * Return the jar files pointed by the "spark.jars" property. Spark internally will distribute * these jars through file server. In the YARN mode, it will return an empty list, since YARN diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala index c515ae5e3a246..e44d7e29ef606 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBackend.scala @@ -116,6 +116,10 @@ private[spark] object KubernetesExecutorBackend extends Logging { } } + // Initialize logging system again after `spark.log.structuredLogging.enabled` takes effect + Utils.resetStructuredLogging(driverConf) + Logging.uninitialize() + cfg.hadoopDelegationCreds.foreach { tokens => SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } From da106f86260b8138df7c5da5e05af9c801fc318d Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 30 Sep 2024 23:33:06 -0700 Subject: [PATCH 128/250] [SPARK-49840][INFRA] Use `MacOS 15` in `build_maven_java21_macos14.yml` ### What changes were proposed in this pull request? This PR aims to upgrade `MacOS` from `14` to `15` in `build_maven_java21_macos14.yml`. ### Why are the changes needed? To use the latest MacOS as a part of Apache Spark 4.0.0 preparation. - https://github.com/actions/runner-images/blob/main/images/macos/macos-15-arm64-Readme.md ### Does this PR introduce _any_ user-facing change? No. This is an infra change. ### How was this patch tested? N/A. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48305 from dongjoon-hyun/SPARK-49840. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- ...aven_java21_macos14.yml => build_maven_java21_macos15.yml} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename .github/workflows/{build_maven_java21_macos14.yml => build_maven_java21_macos15.yml} (92%) diff --git a/.github/workflows/build_maven_java21_macos14.yml b/.github/workflows/build_maven_java21_macos15.yml similarity index 92% rename from .github/workflows/build_maven_java21_macos14.yml rename to .github/workflows/build_maven_java21_macos15.yml index fb5e609f4eae0..cc6d0ea4e90da 100644 --- a/.github/workflows/build_maven_java21_macos14.yml +++ b/.github/workflows/build_maven_java21_macos15.yml @@ -17,7 +17,7 @@ # under the License. # -name: "Build / Maven (master, Scala 2.13, Hadoop 3, JDK 21, macos-14)" +name: "Build / Maven (master, Scala 2.13, Hadoop 3, JDK 21, MacOS-15)" on: schedule: @@ -32,7 +32,7 @@ jobs: if: github.repository == 'apache/spark' with: java: 21 - os: macos-14 + os: macos-15 envs: >- { "OBJC_DISABLE_INITIALIZE_FORK_SAFETY": "YES" From 8d0f6fb902219adfa5dd019a88c5ef4e8bf2ed7c Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 30 Sep 2024 23:36:39 -0700 Subject: [PATCH 129/250] [SPARK-49826][BUILD] Upgrade jackson to 2.18.0 ### What changes were proposed in this pull request? The pr aims to upgrade `jackson` from `2.17.2` to `2.18.0` ### Why are the changes needed? The full release notes: https://github.com/FasterXML/jackson/wiki/Jackson-Release-2.18.0 image ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48294 from panbingkun/SPARK-49826. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 14 +++++++------- pom.xml | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 95a667ccfc72d..f6ce3d25ebc8a 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -105,16 +105,16 @@ ini4j/0.5.4//ini4j-0.5.4.jar istack-commons-runtime/3.0.8//istack-commons-runtime-3.0.8.jar ivy/2.5.2//ivy-2.5.2.jar j2objc-annotations/3.0.0//j2objc-annotations-3.0.0.jar -jackson-annotations/2.17.2//jackson-annotations-2.17.2.jar +jackson-annotations/2.18.0//jackson-annotations-2.18.0.jar jackson-core-asl/1.9.13//jackson-core-asl-1.9.13.jar -jackson-core/2.17.2//jackson-core-2.17.2.jar -jackson-databind/2.17.2//jackson-databind-2.17.2.jar -jackson-dataformat-cbor/2.17.2//jackson-dataformat-cbor-2.17.2.jar -jackson-dataformat-yaml/2.17.2//jackson-dataformat-yaml-2.17.2.jar +jackson-core/2.18.0//jackson-core-2.18.0.jar +jackson-databind/2.18.0//jackson-databind-2.18.0.jar +jackson-dataformat-cbor/2.18.0//jackson-dataformat-cbor-2.18.0.jar +jackson-dataformat-yaml/2.18.0//jackson-dataformat-yaml-2.18.0.jar jackson-datatype-jdk8/2.17.0//jackson-datatype-jdk8-2.17.0.jar -jackson-datatype-jsr310/2.17.2//jackson-datatype-jsr310-2.17.2.jar +jackson-datatype-jsr310/2.18.0//jackson-datatype-jsr310-2.18.0.jar jackson-mapper-asl/1.9.13//jackson-mapper-asl-1.9.13.jar -jackson-module-scala_2.13/2.17.2//jackson-module-scala_2.13-2.17.2.jar +jackson-module-scala_2.13/2.18.0//jackson-module-scala_2.13-2.18.0.jar jakarta.annotation-api/2.0.0//jakarta.annotation-api-2.0.0.jar jakarta.inject-api/2.0.1//jakarta.inject-api-2.0.1.jar jakarta.servlet-api/5.0.0//jakarta.servlet-api-5.0.0.jar diff --git a/pom.xml b/pom.xml index 4bdb92d86a727..6a77da703dbd2 100644 --- a/pom.xml +++ b/pom.xml @@ -180,8 +180,8 @@ true true 1.9.13 - 2.17.2 - 2.17.2 + 2.18.0 + 2.18.0 2.3.1 1.1.10.7 3.0.3 From c0a1ea2a4c4218fc15b8f990ed2f5ea99755d322 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Mon, 30 Sep 2024 23:45:21 -0700 Subject: [PATCH 130/250] [SPARK-49795][CORE][SQL][SS][DSTREAM][ML][MLLIB][K8S][YARN][EXAMPLES] Clean up deprecated Guava API usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? In order to clean up the usage of deprecated Guava API, the following changes were made in this pr: 1. Replaced `Files.write(from, to, charset)` with `Files.asCharSink(to, charset).write(from)`. This change was made with reference to: https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/io/Files.java#L275-L291 ```java /** * Writes a character sequence (such as a string) to a file using the given character set. * * param from the character sequence to write * param to the destination file * param charset the charset used to encode the output stream; see {link StandardCharsets} for * helpful predefined constants * throws IOException if an I/O error occurs * deprecated Prefer {code asCharSink(to, charset).write(from)}. */ Deprecated InlineMe( replacement = "Files.asCharSink(to, charset).write(from)", imports = "com.google.common.io.Files") public static void write(CharSequence from, File to, Charset charset) throws IOException { asCharSink(to, charset).write(from); } ``` 2. Replaced `Files.append(from, to, charset)` with `Files.asCharSink(to, charset, FileWriteMode.APPEND).write(from)`. This change was made with reference to: https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/io/Files.java#L350-L368 ```java /** * Appends a character sequence (such as a string) to a file using the given character set. * * param from the character sequence to append * param to the destination file * param charset the charset used to encode the output stream; see {link StandardCharsets} for * helpful predefined constants * throws IOException if an I/O error occurs * deprecated Prefer {code asCharSink(to, charset, FileWriteMode.APPEND).write(from)}. This * method is scheduled to be removed in October 2019. */ Deprecated InlineMe( replacement = "Files.asCharSink(to, charset, FileWriteMode.APPEND).write(from)", imports = {"com.google.common.io.FileWriteMode", "com.google.common.io.Files"}) public static void append(CharSequence from, File to, Charset charset) throws IOException { asCharSink(to, charset, FileWriteMode.APPEND).write(from); } ``` 3. Replaced `Files.toString(file, charset)` with `Files.asCharSource(file, charset).read()`. This change was made with reference to: https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/io/Files.java#L243-L259 ```java /** * Reads all characters from a file into a {link String}, using the given character set. * * param file the file to read from * param charset the charset used to decode the input stream; see {link StandardCharsets} for * helpful predefined constants * return a string containing all the characters from the file * throws IOException if an I/O error occurs * deprecated Prefer {code asCharSource(file, charset).read()}. */ Deprecated InlineMe( replacement = "Files.asCharSource(file, charset).read()", imports = "com.google.common.io.Files") public static String toString(File file, Charset charset) throws IOException { return asCharSource(file, charset).read(); } ``` 4. Replaced `HashFunction.murmur3_32()` with `HashFunction.murmur3_32_fixed()`. This change was made with reference to: https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/hash/Hashing.java#L99-L115 ```java /** * Returns a hash function implementing the 32-bit murmur3 * algorithm, x86 variant (little-endian variant), using the given seed value, with a known * bug as described in the deprecation text. * *

    The C++ equivalent is the MurmurHash3_x86_32 function (Murmur3A), which however does not * have the bug. * * deprecated This implementation produces incorrect hash values from the {link * HashFunction#hashString} method if the string contains non-BMP characters. Use {link * #murmur3_32_fixed()} instead. */ Deprecated public static HashFunction murmur3_32() { return Murmur3_32HashFunction.MURMUR3_32; } ``` This change is safe for Spark. The difference between `MURMUR3_32` and `MURMUR3_32_FIXED` lies in the different `supplementaryPlaneFix` parameters passed when constructing the `Murmur3_32HashFunction`: https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/hash/Murmur3_32HashFunction.java#L56-L59 ```java static final HashFunction MURMUR3_32 = new Murmur3_32HashFunction(0, /* supplementaryPlaneFix= */ false); static final HashFunction MURMUR3_32_FIXED = new Murmur3_32HashFunction(0, /* supplementaryPlaneFix= */ true); ``` However, the `supplementaryPlaneFix` parameter is only used in `Murmur3_32HashFunction#hashString`, and Spark only utilizes `Murmur3_32HashFunction#hashInt`. Therefore, there will be no logical changes to this method after this change. https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/hash/Murmur3_32HashFunction.java#L108-L114 ```java Override public HashCode hashInt(int input) { int k1 = mixK1(input); int h1 = mixH1(seed, k1); return fmix(h1, Ints.BYTES); } ``` 5. Replaced `Throwables.propagateIfPossible(throwable, declaredType)` with `Throwables.throwIfInstanceOf(throwable, declaredType)` + `Throwables.throwIfUnchecked(throwable)`. This change was made with reference to: https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/base/Throwables.java#L156-L175 ``` /** * Propagates {code throwable} exactly as-is, if and only if it is an instance of {link * RuntimeException}, {link Error}, or {code declaredType}. * *

    Discouraged in favor of calling {link #throwIfInstanceOf} and {link * #throwIfUnchecked}. * * param throwable the Throwable to possibly propagate * param declaredType the single checked exception type declared by the calling method * deprecated Use a combination of {link #throwIfInstanceOf} and {link #throwIfUnchecked}, * which togther provide the same behavior except that they reject {code null}. */ Deprecated J2ktIncompatible GwtIncompatible // propagateIfInstanceOf public static void propagateIfPossible( CheckForNull Throwable throwable, Class declaredType) throws X { propagateIfInstanceOf(throwable, declaredType); propagateIfPossible(throwable); } ``` 6. Made modifications to `Throwables.propagate` with reference to https://github.com/google/guava/wiki/Why-we-deprecated-Throwables.propagate - For cases where it is known to be a checked exception, including `IOException`, `GeneralSecurityException`, `SaslException`, and `RocksDBException`, none of which are subclasses of `RuntimeException` or `Error`, directly replaced `Throwables.propagate(e)` with `throw new RuntimeException(e);`. - For cases where it cannot be determined whether it is a checked exception or an unchecked exception or Error, use ```java throwIfUnchecked(e); throw new RuntimeException(e); ``` to replace `Throwables.propagate(e)`。 https://github.com/google/guava/blob/0c33dd12b193402cdf6962d43d69743521aa2f76/guava/src/com/google/common/base/Throwables.java#L199-L235 ```java /** * ... * deprecated To preserve behavior, use {code throw e} or {code throw new RuntimeException(e)} * directly, or use a combination of {link #throwIfUnchecked} and {code throw new * RuntimeException(e)}. But consider whether users would be better off if your API threw a * different type of exception. For background on the deprecation, read Why we deprecated {code Throwables.propagate}. */ CanIgnoreReturnValue J2ktIncompatible GwtIncompatible Deprecated public static RuntimeException propagate(Throwable throwable) { throwIfUnchecked(throwable); throw new RuntimeException(throwable); } ``` ### Why are the changes needed? Clean up deprecated Guava API usage. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GitHub Actions ### Was this patch authored or co-authored using generative AI tooling? No Closes #48248 from LuciferYang/guava-deprecation. Authored-by: yangjie01 Signed-off-by: Dongjoon Hyun --- .../apache/spark/util/kvstore/LevelDB.java | 3 +- .../spark/util/kvstore/LevelDBIterator.java | 5 +-- .../apache/spark/util/kvstore/RocksDB.java | 3 +- .../spark/util/kvstore/RocksDBIterator.java | 5 +-- .../spark/network/client/TransportClient.java | 6 ++-- .../client/TransportClientFactory.java | 3 +- .../network/crypto/AuthClientBootstrap.java | 3 +- .../spark/network/crypto/AuthRpcHandler.java | 3 +- .../spark/network/sasl/SparkSaslClient.java | 7 ++--- .../spark/network/sasl/SparkSaslServer.java | 5 ++- .../network/shuffledb/LevelDBIterator.java | 4 +-- .../spark/network/shuffledb/RocksDB.java | 7 ++--- .../network/shuffledb/RocksDBIterator.java | 3 +- .../spark/sql/kafka010/KafkaTestUtils.scala | 4 +-- .../apache/spark/io/ReadAheadInputStream.java | 3 +- .../scala/org/apache/spark/TestUtils.scala | 2 +- .../spark/deploy/worker/DriverRunner.scala | 4 +-- .../spark/deploy/worker/ExecutorRunner.scala | 2 +- .../spark/util/collection/AppendOnlyMap.scala | 2 +- .../spark/util/collection/OpenHashSet.scala | 2 +- .../test/org/apache/spark/JavaAPISuite.java | 2 +- .../scala/org/apache/spark/FileSuite.scala | 4 +-- .../org/apache/spark/SparkContextSuite.scala | 31 ++++++++++--------- .../history/EventLogFileReadersSuite.scala | 6 ++-- .../history/FsHistoryProviderSuite.scala | 3 +- .../history/HistoryServerArgumentsSuite.scala | 4 +-- .../deploy/history/HistoryServerSuite.scala | 2 +- .../plugin/PluginContainerSuite.scala | 2 +- .../ResourceDiscoveryPluginSuite.scala | 2 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 12 +++---- .../apache/spark/util/FileAppenderSuite.scala | 6 ++-- .../org/apache/spark/util/UtilsSuite.scala | 6 ++-- .../JavaRecoverableNetworkWordCount.java | 4 ++- .../RecoverableNetworkWordCount.scala | 5 +-- .../libsvm/JavaLibSVMRelationSuite.java | 2 +- .../source/libsvm/LibSVMRelationSuite.scala | 6 ++-- .../spark/mllib/util/MLUtilsSuite.scala | 6 ++-- .../k8s/SparkKubernetesClientFactory.scala | 2 +- .../HadoopConfDriverFeatureStep.scala | 2 +- .../KerberosConfDriverFeatureStep.scala | 2 +- .../features/PodTemplateConfigMapStep.scala | 2 +- ...ubernetesCredentialsFeatureStepSuite.scala | 2 +- .../HadoopConfDriverFeatureStepSuite.scala | 2 +- .../HadoopConfExecutorFeatureStepSuite.scala | 2 +- .../KerberosConfDriverFeatureStepSuite.scala | 4 +-- .../integrationtest/DecommissionSuite.scala | 6 ++-- .../k8s/integrationtest/KubernetesSuite.scala | 2 +- .../deploy/yarn/BaseYarnClusterSuite.scala | 6 ++-- .../spark/deploy/yarn/YarnClusterSuite.scala | 29 +++++++++-------- .../yarn/YarnShuffleIntegrationSuite.scala | 2 +- .../arrow/ArrowConvertersSuite.scala | 6 ++-- .../HiveThriftServer2Suites.scala | 6 ++-- .../hive/thriftserver/UISeleniumSuite.scala | 6 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 17 +++++----- .../apache/spark/streaming/JavaAPISuite.java | 2 +- .../spark/streaming/CheckpointSuite.scala | 2 +- .../spark/streaming/InputStreamsSuite.scala | 10 +++--- .../spark/streaming/MasterFailureTest.scala | 2 +- 58 files changed, 148 insertions(+), 145 deletions(-) diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 13a9d89f4705c..7f8d6c58aec7e 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -255,7 +255,8 @@ public Iterator iterator() { iteratorTracker.add(new WeakReference<>(it)); return it; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } }; diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index 69757fdc65d68..29ed37ffa44e5 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -127,7 +127,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; @@ -151,7 +151,8 @@ public T next() { next = null; return ret; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java index dc7ad0be5c007..4bc2b233fe12d 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java @@ -287,7 +287,8 @@ public Iterator iterator() { iteratorTracker.add(new WeakReference<>(it)); return it; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } }; diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java index a98b0482e35cc..e350ddc2d445a 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java @@ -113,7 +113,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; @@ -137,7 +137,8 @@ public T next() { next = null; return ret; } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 4c144a73a9299..a9df47645d36f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -290,9 +290,11 @@ public void onFailure(Throwable e) { try { return result.get(timeoutMs, TimeUnit.MILLISECONDS); } catch (ExecutionException e) { - throw Throwables.propagate(e.getCause()); + Throwables.throwIfUnchecked(e.getCause()); + throw new RuntimeException(e.getCause()); } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index e1f19f956cc0a..d64b8c8f838e9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -342,7 +342,8 @@ public void operationComplete(final Future handshakeFuture) { logger.error("Exception while bootstrapping client after {} ms", e, MDC.of(LogKeys.BOOTSTRAP_TIME$.MODULE$, bootstrapTimeMs)); client.close(); - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } long postBootstrap = System.nanoTime(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java index 08e2c084fe67b..2e9ccd0e0ad21 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -22,7 +22,6 @@ import java.security.GeneralSecurityException; import java.util.concurrent.TimeoutException; -import com.google.common.base.Throwables; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.Channel; @@ -80,7 +79,7 @@ public void doBootstrap(TransportClient client, Channel channel) { doSparkAuth(client, channel); client.setClientId(appId); } catch (GeneralSecurityException | IOException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } catch (RuntimeException e) { // There isn't a good exception that can be caught here to know whether it's really // OK to switch back to SASL (because the server doesn't speak the new protocol). So diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 65367743e24f9..087e3d21e22bb 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -132,7 +132,8 @@ protected boolean doAuthChallenge( try { engine.close(); } catch (Exception e) { - throw Throwables.propagate(e); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); } } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index 3600c1045dbf4..a61b1c3c0c416 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -29,7 +29,6 @@ import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import org.apache.spark.internal.SparkLogger; @@ -62,7 +61,7 @@ public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, bool this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, saslProps, new ClientCallbackHandler()); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -72,7 +71,7 @@ public synchronized byte[] firstToken() { try { return saslClient.evaluateChallenge(new byte[0]); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } else { return new byte[0]; @@ -98,7 +97,7 @@ public synchronized byte[] response(byte[] token) { try { return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index b897650afe832..f32fd5145c7c5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -31,7 +31,6 @@ import java.util.Map; import com.google.common.base.Preconditions; -import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -94,7 +93,7 @@ public SparkSaslServer( this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, saslProps, new DigestCallbackHandler()); } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -119,7 +118,7 @@ public synchronized byte[] response(byte[] token) { try { return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; } catch (SaslException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java index 5796e34a6f05e..2ac549775449a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/LevelDBIterator.java @@ -17,8 +17,6 @@ package org.apache.spark.network.shuffledb; -import com.google.common.base.Throwables; - import java.io.IOException; import java.util.Map; import java.util.NoSuchElementException; @@ -47,7 +45,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java index d33895d6c2d62..2737ab8ed754c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDB.java @@ -19,7 +19,6 @@ import java.io.IOException; -import com.google.common.base.Throwables; import org.rocksdb.RocksDBException; /** @@ -37,7 +36,7 @@ public void put(byte[] key, byte[] value) { try { db.put(key, value); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -46,7 +45,7 @@ public byte[] get(byte[] key) { try { return db.get(key); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } @@ -55,7 +54,7 @@ public void delete(byte[] key) { try { db.delete(key); } catch (RocksDBException e) { - throw Throwables.propagate(e); + throw new RuntimeException(e); } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java index 78562f91a4b75..829a7ded6330b 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java +++ b/common/network-common/src/main/java/org/apache/spark/network/shuffledb/RocksDBIterator.java @@ -22,7 +22,6 @@ import java.util.Map; import java.util.NoSuchElementException; -import com.google.common.base.Throwables; import org.rocksdb.RocksIterator; /** @@ -52,7 +51,7 @@ public boolean hasNext() { try { close(); } catch (IOException ioe) { - throw Throwables.propagate(ioe); + throw new RuntimeException(ioe); } } return next != null; diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala index 7852bc814ccd4..c3f02eebab23a 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTestUtils.scala @@ -176,7 +176,7 @@ class KafkaTestUtils( } kdc.getKrb5conf.delete() - Files.write(krb5confStr, kdc.getKrb5conf, StandardCharsets.UTF_8) + Files.asCharSink(kdc.getKrb5conf, StandardCharsets.UTF_8).write(krb5confStr) logDebug(s"krb5.conf file content: $krb5confStr") } @@ -240,7 +240,7 @@ class KafkaTestUtils( | principal="$kafkaServerUser@$realm"; |}; """.stripMargin.trim - Files.write(content, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(content) logDebug(s"Created JAAS file: ${file.getPath}") logDebug(s"JAAS file content: $content") file.getAbsolutePath() diff --git a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java index 5e9f1b78273a5..7dd87df713e6e 100644 --- a/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java +++ b/core/src/main/java/org/apache/spark/io/ReadAheadInputStream.java @@ -120,7 +120,8 @@ private boolean isEndOfStream() { private void checkReadException() throws IOException { if (readAborted) { - Throwables.propagateIfPossible(readException, IOException.class); + Throwables.throwIfInstanceOf(readException, IOException.class); + Throwables.throwIfUnchecked(readException); throw new IOException(readException); } } diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 5e3078d7292ba..fed15a067c00f 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -421,7 +421,7 @@ private[spark] object TestUtils extends SparkTestUtils { def createTempScriptWithExpectedOutput(dir: File, prefix: String, output: String): String = { val file = File.createTempFile(prefix, ".sh", dir) val script = s"cat < expected = Arrays.asList("1", "2", "3", "4"); diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 5651dc9b2dbdc..5f9912cbd021d 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -334,8 +334,8 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { for (i <- 0 until 8) { val tempFile = new File(tempDir, s"part-0000$i") - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", tempFile, - StandardCharsets.UTF_8) + Files.asCharSink(tempFile, StandardCharsets.UTF_8) + .write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1") } for (p <- Seq(1, 2, 8)) { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 12f9d2f83c777..44b2da603a1f6 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -119,8 +119,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu val absolutePath2 = file2.getAbsolutePath try { - Files.write("somewords1", file1, StandardCharsets.UTF_8) - Files.write("somewords2", file2, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("somewords1") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("somewords2") val length1 = file1.length() val length2 = file2.length() @@ -178,10 +178,10 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu s"${jarFile.getParent}/../${jarFile.getParentFile.getName}/${jarFile.getName}#zoo" try { - Files.write("somewords1", file1, StandardCharsets.UTF_8) - Files.write("somewords22", file2, StandardCharsets.UTF_8) - Files.write("somewords333", file3, StandardCharsets.UTF_8) - Files.write("somewords4444", file4, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("somewords1") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("somewords22") + Files.asCharSink(file3, StandardCharsets.UTF_8).write("somewords333") + Files.asCharSink(file4, StandardCharsets.UTF_8).write("somewords4444") val length1 = file1.length() val length2 = file2.length() val length3 = file1.length() @@ -373,8 +373,8 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu assert(subdir2.mkdir()) val file1 = new File(subdir1, "file") val file2 = new File(subdir2, "file") - Files.write("old", file1, StandardCharsets.UTF_8) - Files.write("new", file2, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8).write("old") + Files.asCharSink(file2, StandardCharsets.UTF_8).write("new") sc = new SparkContext("local-cluster[1,1,1024]", "test") sc.addFile(file1.getAbsolutePath) def getAddedFileContents(): String = { @@ -503,12 +503,15 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu try { // Create 5 text files. - Files.write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1", file1, - StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file2", file2, StandardCharsets.UTF_8) - Files.write("someline1 in file3", file3, StandardCharsets.UTF_8) - Files.write("someline1 in file4\nsomeline2 in file4", file4, StandardCharsets.UTF_8) - Files.write("someline1 in file2\nsomeline2 in file5", file5, StandardCharsets.UTF_8) + Files.asCharSink(file1, StandardCharsets.UTF_8) + .write("someline1 in file1\nsomeline2 in file1\nsomeline3 in file1") + Files.asCharSink(file2, StandardCharsets.UTF_8) + .write("someline1 in file2\nsomeline2 in file2") + Files.asCharSink(file3, StandardCharsets.UTF_8).write("someline1 in file3") + Files.asCharSink(file4, StandardCharsets.UTF_8) + .write("someline1 in file4\nsomeline2 in file4") + Files.asCharSink(file5, StandardCharsets.UTF_8) + .write("someline1 in file2\nsomeline2 in file5") sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala index f34f792881f90..7501a98a1a573 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileReadersSuite.scala @@ -221,7 +221,7 @@ class SingleFileEventLogFileReaderSuite extends EventLogFileReadersSuite { val entry = is.getNextEntry assert(entry != null) val actual = new String(ByteStreams.toByteArray(is), StandardCharsets.UTF_8) - val expected = Files.toString(new File(logPath.toString), StandardCharsets.UTF_8) + val expected = Files.asCharSource(new File(logPath.toString), StandardCharsets.UTF_8).read() assert(actual === expected) assert(is.getNextEntry === null) } @@ -368,8 +368,8 @@ class RollingEventLogFilesReaderSuite extends EventLogFileReadersSuite { assert(allFileNames.contains(fileName)) val actual = new String(ByteStreams.toByteArray(is), StandardCharsets.UTF_8) - val expected = Files.toString(new File(logPath.toString, fileName), - StandardCharsets.UTF_8) + val expected = Files.asCharSource( + new File(logPath.toString, fileName), StandardCharsets.UTF_8).read() assert(actual === expected) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 3013a5bf4a294..852f94bda870d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -708,7 +708,8 @@ abstract class FsHistoryProviderSuite extends SparkFunSuite with Matchers with P while (entry != null) { val actual = new String(ByteStreams.toByteArray(inputStream), StandardCharsets.UTF_8) val expected = - Files.toString(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) + Files.asCharSource(logs.find(_.getName == entry.getName).get, StandardCharsets.UTF_8) + .read() actual should be (expected) totalEntries += 1 entry = inputStream.getNextEntry diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala index 2b9b110a41424..807e5ec3e823e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerArgumentsSuite.scala @@ -45,8 +45,8 @@ class HistoryServerArgumentsSuite extends SparkFunSuite { test("Properties File Arguments Parsing --properties-file") { withTempDir { tmpDir => val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) - Files.write("spark.test.CustomPropertyA blah\n" + - "spark.test.CustomPropertyB notblah\n", outFile, UTF_8) + Files.asCharSink(outFile, UTF_8).write("spark.test.CustomPropertyA blah\n" + + "spark.test.CustomPropertyB notblah\n") val argStrings = Array("--properties-file", outFile.getAbsolutePath) val hsa = new HistoryServerArguments(conf, argStrings) assert(conf.get("spark.test.CustomPropertyA") === "blah") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index abb5ae720af07..6b2bd90cd4314 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -283,7 +283,7 @@ abstract class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with val expectedFile = { new File(logDir, entry.getName) } - val expected = Files.toString(expectedFile, StandardCharsets.UTF_8) + val expected = Files.asCharSource(expectedFile, StandardCharsets.UTF_8).read() val actual = new String(ByteStreams.toByteArray(zipStream), StandardCharsets.UTF_8) actual should be (expected) filesCompared += 1 diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala index 79fa8d21bf3f1..fc8f48df2cb7d 100644 --- a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -383,7 +383,7 @@ object NonLocalModeSparkPlugin { resources: Map[String, ResourceInformation]): Unit = { val path = conf.get(TEST_PATH_CONF) val strToWrite = createFileStringWithGpuAddrs(id, resources) - Files.write(strToWrite, new File(path, s"$filePrefix$id"), StandardCharsets.UTF_8) + Files.asCharSink(new File(path, s"$filePrefix$id"), StandardCharsets.UTF_8).write(strToWrite) } def reset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala index ff7d680352177..edf138df9e207 100644 --- a/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala +++ b/core/src/test/scala/org/apache/spark/resource/ResourceDiscoveryPluginSuite.scala @@ -148,7 +148,7 @@ object TestResourceDiscoveryPlugin { def writeFile(conf: SparkConf, id: String): Unit = { val path = conf.get(TEST_PATH_CONF) val fileName = s"$id - ${UUID.randomUUID.toString}" - Files.write(id, new File(path, fileName), StandardCharsets.UTF_8) + Files.asCharSink(new File(path, fileName), StandardCharsets.UTF_8).write(id) } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 3ef382573517b..66b1ee7b58ac8 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -868,23 +868,23 @@ abstract class RpcEnvSuite extends SparkFunSuite { val conf = createSparkConf() val file = new File(tempDir, "file") - Files.write(UUID.randomUUID().toString(), file, UTF_8) + Files.asCharSink(file, UTF_8).write(UUID.randomUUID().toString) val fileWithSpecialChars = new File(tempDir, "file name") - Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) + Files.asCharSink(fileWithSpecialChars, UTF_8).write(UUID.randomUUID().toString) val empty = new File(tempDir, "empty") - Files.write("", empty, UTF_8); + Files.asCharSink(empty, UTF_8).write("") val jar = new File(tempDir, "jar") - Files.write(UUID.randomUUID().toString(), jar, UTF_8) + Files.asCharSink(jar, UTF_8).write(UUID.randomUUID().toString) val dir1 = new File(tempDir, "dir1") assert(dir1.mkdir()) val subFile1 = new File(dir1, "file1") - Files.write(UUID.randomUUID().toString(), subFile1, UTF_8) + Files.asCharSink(subFile1, UTF_8).write(UUID.randomUUID().toString) val dir2 = new File(tempDir, "dir2") assert(dir2.mkdir()) val subFile2 = new File(dir2, "file2") - Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) + Files.asCharSink(subFile2, UTF_8).write(UUID.randomUUID().toString) val fileUri = env.fileServer.addFile(file) val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 35ef0587b9b4c..4497ea1b2b798 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -54,11 +54,11 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter { val inputStream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)) // The `header` should not be covered val header = "Add header" - Files.write(header, testFile, StandardCharsets.UTF_8) + Files.asCharSink(testFile, StandardCharsets.UTF_8).write(header) val appender = new FileAppender(inputStream, testFile) inputStream.close() appender.awaitTermination() - assert(Files.toString(testFile, StandardCharsets.UTF_8) === header + testString) + assert(Files.asCharSource(testFile, StandardCharsets.UTF_8).read() === header + testString) } test("SPARK-35027: basic file appender - close stream") { @@ -392,7 +392,7 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter { IOUtils.closeQuietly(inputStream) } } else { - Files.toString(file, StandardCharsets.UTF_8) + Files.asCharSource(file, StandardCharsets.UTF_8).read() } }.mkString("") assert(allText === expectedText) diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a694e08def89c..a6e3345fc600c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -735,8 +735,8 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { withTempDir { tmpDir => val outFile = File.createTempFile("test-load-spark-properties", "test", tmpDir) System.setProperty("spark.test.fileNameLoadB", "2") - Files.write("spark.test.fileNameLoadA true\n" + - "spark.test.fileNameLoadB 1\n", outFile, UTF_8) + Files.asCharSink(outFile, UTF_8).write("spark.test.fileNameLoadA true\n" + + "spark.test.fileNameLoadB 1\n") val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath) properties .filter { case (k, v) => k.startsWith("spark.")} @@ -765,7 +765,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties { val innerSourceDir = Utils.createTempDir(root = sourceDir.getPath) val sourceFile = File.createTempFile("someprefix", "somesuffix", innerSourceDir) val targetDir = new File(tempDir, "target-dir") - Files.write("some text", sourceFile, UTF_8) + Files.asCharSink(sourceFile, UTF_8).write("some text") val path = if (Utils.isWindows) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index 0c11c40cfe7ed..1052f47ea496e 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.regex.Pattern; +import com.google.common.io.FileWriteMode; import scala.Tuple2; import com.google.common.io.Files; @@ -152,7 +153,8 @@ private static JavaStreamingContext createContext(String ip, System.out.println(output); System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); System.out.println("Appending to " + outputFile.getAbsolutePath()); - Files.append(output + "\n", outputFile, Charset.defaultCharset()); + Files.asCharSink(outputFile, Charset.defaultCharset(), FileWriteMode.APPEND) + .write(output + "\n"); }); return ssc; diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 98539d6494231..1ec6ee4abd327 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -21,7 +21,7 @@ package org.apache.spark.examples.streaming import java.io.File import java.nio.charset.Charset -import com.google.common.io.Files +import com.google.common.io.{Files, FileWriteMode} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.broadcast.Broadcast @@ -134,7 +134,8 @@ object RecoverableNetworkWordCount { println(output) println(s"Dropped ${droppedWordsCounter.value} word(s) totally") println(s"Appending to ${outputFile.getAbsolutePath}") - Files.append(output + "\n", outputFile, Charset.defaultCharset()) + Files.asCharSink(outputFile, Charset.defaultCharset(), FileWriteMode.APPEND) + .write(output + "\n") } ssc } diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index c3038fa9e1f8f..5f0d22ea2a8aa 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -50,7 +50,7 @@ public void setUp() throws IOException { tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource"); File file = new File(tempDir, "part-00000"); String s = "1 1:1.0 3:2.0 5:3.0\n0\n0 2:4.0 4:5.0 6:6.0"; - Files.write(s, file, StandardCharsets.UTF_8); + Files.asCharSink(file, StandardCharsets.UTF_8).write(s); path = tempDir.toURI().toString(); } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index f2bb145614725..6a0d7b1237ee4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -65,9 +65,9 @@ class LibSVMRelationSuite val succ = new File(dir, "_SUCCESS") val file0 = new File(dir, "part-00000") val file1 = new File(dir, "part-00001") - Files.write("", succ, StandardCharsets.UTF_8) - Files.write(lines0, file0, StandardCharsets.UTF_8) - Files.write(lines1, file1, StandardCharsets.UTF_8) + Files.asCharSink(succ, StandardCharsets.UTF_8).write("") + Files.asCharSink(file0, StandardCharsets.UTF_8).write(lines0) + Files.asCharSink(file1, StandardCharsets.UTF_8).write(lines1) path = dir.getPath } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index a90c9c80d4959..1a02e26b9260c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -93,7 +93,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(lines) val path = tempDir.toURI.toString val pointsWithNumFeatures = loadLibSVMFile(sc, path, 6).collect() @@ -126,7 +126,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(lines) val path = tempDir.toURI.toString intercept[SparkException] { @@ -143,7 +143,7 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { """.stripMargin val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(lines) val path = tempDir.toURI.toString intercept[SparkException] { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 79f76e96474e3..2c28dc380046c 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -107,7 +107,7 @@ object SparkKubernetesClientFactory extends Logging { (token, configBuilder) => configBuilder.withOauthToken(token) }.withOption(oauthTokenFile) { (file, configBuilder) => - configBuilder.withOauthToken(Files.toString(file, Charsets.UTF_8)) + configBuilder.withOauthToken(Files.asCharSource(file, Charsets.UTF_8).read()) }.withOption(caCertFile) { (file, configBuilder) => configBuilder.withCaCertFile(file) }.withOption(clientKeyFile) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala index e266d0f904e46..d64378a65d66f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala @@ -116,7 +116,7 @@ private[spark] class HadoopConfDriverFeatureStep(conf: KubernetesConf) override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { if (confDir.isDefined) { val fileMap = confFiles.map { file => - (file.getName(), Files.toString(file, StandardCharsets.UTF_8)) + (file.getName(), Files.asCharSource(file, StandardCharsets.UTF_8).read()) }.toMap.asJava Seq(new ConfigMapBuilder() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala index 82bda88892d04..89aefe47e46d1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala @@ -229,7 +229,7 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri .endMetadata() .withImmutable(true) .addToData( - Map(file.getName() -> Files.toString(file, StandardCharsets.UTF_8)).asJava) + Map(file.getName() -> Files.asCharSource(file, StandardCharsets.UTF_8).read()).asJava) .build() } } ++ { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala index cdc0112294113..f94dad2d15dc1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala @@ -81,7 +81,7 @@ private[spark] class PodTemplateConfigMapStep(conf: KubernetesConf) val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf.sparkConf) val uri = downloadFile(podTemplateFile, Utils.createTempDir(), conf.sparkConf, hadoopConf) val file = new java.net.URI(uri).getPath - val podTemplateString = Files.toString(new File(file), StandardCharsets.UTF_8) + val podTemplateString = Files.asCharSource(new File(file), StandardCharsets.UTF_8).read() Seq(new ConfigMapBuilder() .withNewMetadata() .withName(configmapName) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index f1dd8b94f17ff..a72152a851c4f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -128,7 +128,7 @@ class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite { private def writeCredentials(credentialsFileName: String, credentialsContents: String): File = { val credentialsFile = new File(credentialsTempDirectory, credentialsFileName) - Files.write(credentialsContents, credentialsFile, Charsets.UTF_8) + Files.asCharSink(credentialsFile, Charsets.UTF_8).write(credentialsContents) credentialsFile } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala index 8f21b95236a9c..4310ac0220e5e 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala @@ -48,7 +48,7 @@ class HadoopConfDriverFeatureStepSuite extends SparkFunSuite { val confFiles = Set("core-site.xml", "hdfs-site.xml") confFiles.foreach { f => - Files.write("some data", new File(confDir, f), UTF_8) + Files.asCharSink(new File(confDir, f), UTF_8).write("some data") } val sparkConf = new SparkConfWithEnv(Map(ENV_HADOOP_CONF_DIR -> confDir.getAbsolutePath())) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala index a60227814eb13..04e20258d068f 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStepSuite.scala @@ -36,7 +36,7 @@ class HadoopConfExecutorFeatureStepSuite extends SparkFunSuite { val confFiles = Set("core-site.xml", "hdfs-site.xml") confFiles.foreach { f => - Files.write("some data", new File(confDir, f), UTF_8) + Files.asCharSink(new File(confDir, f), UTF_8).write("some data") } Seq( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala index 163d87643abd3..b172bdc06ddca 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala @@ -55,7 +55,7 @@ class KerberosConfDriverFeatureStepSuite extends SparkFunSuite { test("create krb5.conf config map if local config provided") { val krbConf = File.createTempFile("krb5", ".conf", tmpDir) - Files.write("some data", krbConf, UTF_8) + Files.asCharSink(krbConf, UTF_8).write("some data") val sparkConf = new SparkConf(false) .set(KUBERNETES_KERBEROS_KRB5_FILE, krbConf.getAbsolutePath()) @@ -70,7 +70,7 @@ class KerberosConfDriverFeatureStepSuite extends SparkFunSuite { test("create keytab secret if client keytab file used") { val keytab = File.createTempFile("keytab", ".bin", tmpDir) - Files.write("some data", keytab, UTF_8) + Files.asCharSink(keytab, UTF_8).write("some data") val sparkConf = new SparkConf(false) .set(KEYTAB, keytab.getAbsolutePath()) diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala index ae5f037c6b7d4..950079dcb5362 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/DecommissionSuite.scala @@ -40,7 +40,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => val logConfFilePath = s"${sparkHomeDir.toFile}/conf/log4j2.properties" try { - Files.write( + Files.asCharSink(new File(logConfFilePath), StandardCharsets.UTF_8).write( """rootLogger.level = info |rootLogger.appenderRef.stdout.ref = console |appender.console.type = Console @@ -51,9 +51,7 @@ private[spark] trait DecommissionSuite { k8sSuite: KubernetesSuite => | |logger.spark.name = org.apache.spark |logger.spark.level = debug - """.stripMargin, - new File(logConfFilePath), - StandardCharsets.UTF_8) + """.stripMargin) f() } finally { diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala index 0b0b30e5e04fd..cf129677ad9c2 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala @@ -129,7 +129,7 @@ class KubernetesSuite extends SparkFunSuite val tagFile = new File(path) require(tagFile.isFile, s"No file found for image tag at ${tagFile.getAbsolutePath}.") - Files.toString(tagFile, Charsets.UTF_8).trim + Files.asCharSource(tagFile, Charsets.UTF_8).read().trim } .orElse(sys.props.get(CONFIG_KEY_IMAGE_TAG)) .getOrElse { diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index f0177541accc1..e0dfac62847ea 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -86,7 +86,7 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers { logConfDir.mkdir() val logConfFile = new File(logConfDir, "log4j2.properties") - Files.write(LOG4J_CONF, logConfFile, StandardCharsets.UTF_8) + Files.asCharSink(logConfFile, StandardCharsets.UTF_8).write(LOG4J_CONF) // Disable the disk utilization check to avoid the test hanging when people's disks are // getting full. @@ -232,11 +232,11 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers { // an error message val output = new Object() { override def toString: String = outFile - .map(Files.toString(_, StandardCharsets.UTF_8)) + .map(Files.asCharSource(_, StandardCharsets.UTF_8).read()) .getOrElse("(stdout/stderr was not captured)") } assert(finalState === SparkAppHandle.State.FINISHED, output) - val resultString = Files.toString(result, StandardCharsets.UTF_8) + val resultString = Files.asCharSource(result, StandardCharsets.UTF_8).read() assert(resultString === expected, output) } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 806efd39800fb..92d9f2d62d1c1 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -141,7 +141,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { | | |""".stripMargin - Files.write(coreSite, new File(customConf, "core-site.xml"), StandardCharsets.UTF_8) + Files.asCharSink(new File(customConf, "core-site.xml"), StandardCharsets.UTF_8).write(coreSite) val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(false, @@ -295,23 +295,22 @@ class YarnClusterSuite extends BaseYarnClusterSuite { test("running Spark in yarn-cluster mode displays driver log links") { val log4jConf = new File(tempDir, "log4j.properties") val logOutFile = new File(tempDir, "logs") - Files.write( + Files.asCharSink(log4jConf, StandardCharsets.UTF_8).write( s"""rootLogger.level = debug |rootLogger.appenderRef.file.ref = file |appender.file.type = File |appender.file.name = file |appender.file.fileName = $logOutFile |appender.file.layout.type = PatternLayout - |""".stripMargin, - log4jConf, StandardCharsets.UTF_8) + |""".stripMargin) // Since this test is trying to extract log output from the SparkSubmit process itself, // standard options to the Spark process don't take effect. Leverage the java-opts file which // will get picked up for the SparkSubmit process. val confDir = new File(tempDir, "conf") confDir.mkdir() val javaOptsFile = new File(confDir, "java-opts") - Files.write(s"-Dlog4j.configurationFile=file://$log4jConf\n", javaOptsFile, - StandardCharsets.UTF_8) + Files.asCharSink(javaOptsFile, StandardCharsets.UTF_8) + .write(s"-Dlog4j.configurationFile=file://$log4jConf\n") val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode = false, @@ -320,7 +319,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { extraEnv = Map("SPARK_CONF_DIR" -> confDir.getAbsolutePath), extraConf = Map(CLIENT_INCLUDE_DRIVER_LOGS_LINK.key -> true.toString)) checkResult(finalState, result) - val logOutput = Files.toString(logOutFile, StandardCharsets.UTF_8) + val logOutput = Files.asCharSource(logOutFile, StandardCharsets.UTF_8).read() val logFilePattern = raw"""(?s).+\sDriver Logs \(\): https?://.+/(\?\S+)?\s.+""" logOutput should fullyMatch regex logFilePattern.replace("", "stdout") logOutput should fullyMatch regex logFilePattern.replace("", "stderr") @@ -374,7 +373,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { extraEnv: Map[String, String] = Map()): Unit = { assume(isPythonAvailable) val primaryPyFile = new File(tempDir, "test.py") - Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8) + Files.asCharSink(primaryPyFile, StandardCharsets.UTF_8).write(TEST_PYFILE) // When running tests, let's not assume the user has built the assembly module, which also // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the @@ -396,7 +395,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { subdir } val pyModule = new File(moduleDir, "mod1.py") - Files.write(TEST_PYMODULE, pyModule, StandardCharsets.UTF_8) + Files.asCharSink(pyModule, StandardCharsets.UTF_8).write(TEST_PYMODULE) val mod2Archive = TestUtils.createJarWithFiles(Map("mod2.py" -> TEST_PYMODULE), moduleDir) val pyFiles = Seq(pyModule.getAbsolutePath(), mod2Archive.getPath()).mkString(",") @@ -443,7 +442,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { def createEmptyIvySettingsFile: File = { val emptyIvySettings = File.createTempFile("ivy", ".xml") - Files.write("", emptyIvySettings, StandardCharsets.UTF_8) + Files.asCharSink(emptyIvySettings, StandardCharsets.UTF_8).write("") emptyIvySettings } @@ -555,7 +554,7 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc } result = "success" } finally { - Files.write(result, status, StandardCharsets.UTF_8) + Files.asCharSink(status, StandardCharsets.UTF_8).write(result) sc.stop() } } @@ -658,7 +657,7 @@ private object YarnClusterDriver extends Logging with Matchers { assert(driverAttributes === expectationAttributes) } } finally { - Files.write(result, status, StandardCharsets.UTF_8) + Files.asCharSink(status, StandardCharsets.UTF_8).write(result) sc.stop() } } @@ -707,7 +706,7 @@ private object YarnClasspathTest extends Logging { case t: Throwable => error(s"loading test.resource to $resultPath", t) } finally { - Files.write(result, new File(resultPath), StandardCharsets.UTF_8) + Files.asCharSink(new File(resultPath), StandardCharsets.UTF_8).write(result) } } @@ -751,7 +750,7 @@ private object YarnAddJarTest extends Logging { result = "success" } } finally { - Files.write(result, new File(resultPath), StandardCharsets.UTF_8) + Files.asCharSink(new File(resultPath), StandardCharsets.UTF_8).write(result) sc.stop() } } @@ -796,7 +795,7 @@ private object ExecutorEnvTestApp { executorEnvs.get(k).contains(v) } - Files.write(result.toString, new File(status), StandardCharsets.UTF_8) + Files.asCharSink(new File(status), StandardCharsets.UTF_8).write(result.toString) sc.stop() } diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala index f745265eddfd9..f8d69c0ae568e 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnShuffleIntegrationSuite.scala @@ -181,7 +181,7 @@ private object YarnExternalShuffleDriver extends Logging with Matchers { if (execStateCopy != null) { FileUtils.deleteDirectory(execStateCopy) } - Files.write(result, status, StandardCharsets.UTF_8) + Files.asCharSink(status, StandardCharsets.UTF_8).write(result) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 275b35947182c..c90b1d3ca5978 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -1217,8 +1217,8 @@ class ArrowConvertersSuite extends SharedSparkSession { val tempFile1 = new File(tempDataPath, "testData2-ints-part1.json") val tempFile2 = new File(tempDataPath, "testData2-ints-part2.json") - Files.write(json1, tempFile1, StandardCharsets.UTF_8) - Files.write(json2, tempFile2, StandardCharsets.UTF_8) + Files.asCharSink(tempFile1, StandardCharsets.UTF_8).write(json1) + Files.asCharSink(tempFile2, StandardCharsets.UTF_8).write(json2) validateConversion(schema, arrowBatches(0), tempFile1) validateConversion(schema, arrowBatches(1), tempFile2) @@ -1501,7 +1501,7 @@ class ArrowConvertersSuite extends SharedSparkSession { // NOTE: coalesce to single partition because can only load 1 batch in validator val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) - Files.write(json, tempFile, StandardCharsets.UTF_8) + Files.asCharSink(tempFile, StandardCharsets.UTF_8).write(json) validateConversion(df.schema, batchBytes, tempFile, timeZoneId, errorOnDuplicatedFieldNames) } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 4575549005f33..f1f0befcb0d30 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -1222,7 +1222,7 @@ abstract class HiveThriftServer2TestBase extends SparkFunSuite with BeforeAndAft // overrides all other potential log4j configurations contained in other dependency jar files. val tempLog4jConf = Utils.createTempDir().getCanonicalPath - Files.write( + Files.asCharSink(new File(s"$tempLog4jConf/log4j2.properties"), StandardCharsets.UTF_8).write( """rootLogger.level = info |rootLogger.appenderRef.stdout.ref = console |appender.console.type = Console @@ -1230,9 +1230,7 @@ abstract class HiveThriftServer2TestBase extends SparkFunSuite with BeforeAndAft |appender.console.target = SYSTEM_ERR |appender.console.layout.type = PatternLayout |appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{8}%n - """.stripMargin, - new File(s"$tempLog4jConf/log4j2.properties"), - StandardCharsets.UTF_8) + """.stripMargin) tempLog4jConf } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala index 2b2cbec41d643..8d4a9886a2b25 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala @@ -75,7 +75,7 @@ class UISeleniumSuite // overrides all other potential log4j configurations contained in other dependency jar files. val tempLog4jConf = org.apache.spark.util.Utils.createTempDir().getCanonicalPath - Files.write( + Files.asCharSink(new File(s"$tempLog4jConf/log4j2.properties"), StandardCharsets.UTF_8).write( """rootLogger.level = info |rootLogger.appenderRef.file.ref = console |appender.console.type = Console @@ -83,9 +83,7 @@ class UISeleniumSuite |appender.console.target = SYSTEM_ERR |appender.console.layout.type = PatternLayout |appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %maxLen{%m}{512}%n%ex{8}%n - """.stripMargin, - new File(s"$tempLog4jConf/log4j2.properties"), - StandardCharsets.UTF_8) + """.stripMargin) tempLog4jConf } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 14051034a588e..1c45b02375b30 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.util.{Locale, Set} -import com.google.common.io.Files +import com.google.common.io.{Files, FileWriteMode} import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.{SparkException, TestUtils} @@ -1947,10 +1947,10 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } for (i <- 5 to 7) { - Files.write(s"$i", new File(dirPath, s"part-s-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-s-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t") { @@ -1971,7 +1971,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000 $i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000 $i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t") { sql("CREATE TABLE load_t (a STRING) USING hive") @@ -1986,7 +1986,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t") { sql("CREATE TABLE load_t (a STRING) USING hive") @@ -2010,7 +2010,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t1") { sql("CREATE TABLE load_t1 (a STRING) USING hive") @@ -2025,7 +2025,7 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile for (i <- 1 to 3) { - Files.write(s"$i", new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8) + Files.asCharSink(new File(dirPath, s"part-r-0000$i"), StandardCharsets.UTF_8).write(s"$i") } withTable("load_t2") { sql("CREATE TABLE load_t2 (a STRING) USING hive") @@ -2039,7 +2039,8 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi withTempDir { dir => val path = dir.toURI.toString.stripSuffix("/") val dirPath = dir.getAbsoluteFile - Files.append("1", new File(dirPath, "part-r-000011"), StandardCharsets.UTF_8) + Files.asCharSink( + new File(dirPath, "part-r-000011"), StandardCharsets.UTF_8, FileWriteMode.APPEND).write("1") withTable("part_table") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { sql( diff --git a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java index f8d961fa8dd8e..73c2e89f3729a 100644 --- a/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/test/org/apache/spark/streaming/JavaAPISuite.java @@ -1641,7 +1641,7 @@ public void testRawSocketStream() { private static List> fileTestPrepare(File testDir) throws IOException { File existingFile = new File(testDir, "0"); - Files.write("0\n", existingFile, StandardCharsets.UTF_8); + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n"); Assertions.assertTrue(existingFile.setLastModified(1000)); Assertions.assertEquals(1000, existingFile.lastModified()); return Arrays.asList(Arrays.asList("0")); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 43b0835df7cbf..4aeb0e043a973 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -649,7 +649,7 @@ class CheckpointSuite extends TestSuiteBase with LocalStreamingContext with DStr */ def writeFile(i: Int, clock: Clock): Unit = { val file = new File(testDir, i.toString) - Files.write(s"$i\n", file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(s"$i\n") assert(file.setLastModified(clock.getTimeMillis())) // Check that the file's modification date is actually the value we wrote, since rounding or // truncation will break the test: diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 66fd1ac7bb22e..64335a96045bf 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -132,7 +132,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchDuration = Seconds(2) // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, StandardCharsets.UTF_8) + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n") assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -191,7 +191,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, StandardCharsets.UTF_8) + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n") assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) val pathWithWildCard = testDir.toString + "/*/" @@ -215,7 +215,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { def createFileAndAdvanceTime(data: Int, dir: File): Unit = { val file = new File(testSubDir1, data.toString) - Files.write(s"$data\n", file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(s"$data\n") assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) logInfo(s"Created file $file") @@ -478,7 +478,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val batchDuration = Seconds(2) // Create a file that exists before the StreamingContext is created: val existingFile = new File(testDir, "0") - Files.write("0\n", existingFile, StandardCharsets.UTF_8) + Files.asCharSink(existingFile, StandardCharsets.UTF_8).write("0\n") assert(existingFile.setLastModified(10000) && existingFile.lastModified === 10000) // Set up the streaming context and input streams @@ -502,7 +502,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val input = Seq(1, 2, 3, 4, 5) input.foreach { i => val file = new File(testDir, i.toString) - Files.write(s"$i\n", file, StandardCharsets.UTF_8) + Files.asCharSink(file, StandardCharsets.UTF_8).write(s"$i\n") assert(file.setLastModified(clock.getTimeMillis())) assert(file.lastModified === clock.getTimeMillis()) logInfo("Created file " + file) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 771e65ed40b51..2dc43a231d9b8 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -375,7 +375,7 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) val localFile = new File(localTestDir, (i + 1).toString) val hadoopFile = new Path(testDir, (i + 1).toString) val tempHadoopFile = new Path(testDir, ".tmp_" + (i + 1).toString) - Files.write(input(i) + "\n", localFile, StandardCharsets.UTF_8) + Files.asCharSink(localFile, StandardCharsets.UTF_8).write(input(i) + "\n") var tries = 0 var done = false while (!done && tries < maxTries) { From 97e9bb3ac4b66711ced640ea466eeea5da6d1fd2 Mon Sep 17 00:00:00 2001 From: Gideon P Date: Tue, 1 Oct 2024 15:09:35 +0200 Subject: [PATCH 131/250] [SPARK-48700][SQL] Mode expression for complex types (all collations) ### What changes were proposed in this pull request? Add support for complex types with subfields that are collated strings, for the mode operator. ### Why are the changes needed? Full support for collations as per SPARK-48700 ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Unit tests only, so far. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47154 from GideonPotok/collationmodecomplex. Lead-authored-by: Gideon P Co-authored-by: Gideon Potok <31429832+GideonPotok@users.noreply.github.com> Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 10 + .../catalyst/expressions/aggregate/Mode.scala | 85 ++++-- .../sql/CollationSQLExpressionsSuite.scala | 257 ++++++++++++------ 3 files changed, 250 insertions(+), 102 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index fcaf2b1d9d301..3786643125a9f 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -631,6 +631,11 @@ "Cannot process input data types for the expression: ." ], "subClass" : { + "BAD_INPUTS" : { + "message" : [ + "The input data types to must be valid, but found the input types ." + ] + }, "MISMATCHED_TYPES" : { "message" : [ "All input types must be the same except nullable, containsNull, valueContainsNull flags, but found the input types ." @@ -1011,6 +1016,11 @@ "The input of can't be type data." ] }, + "UNSUPPORTED_MODE_DATA_TYPE" : { + "message" : [ + "The does not support the data type, because there is a \"MAP\" type with keys and/or values that have collated sub-fields." + ] + }, "UNSUPPORTED_UDF_INPUT_TYPE" : { "message" : [ "UDFs do not support '' as an input data type." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index e254a670991a1..8998348f0571b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup} import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} +import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType -import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, UnsafeRowUtils} +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.OpenHashMap @@ -50,17 +53,20 @@ case class Mode( override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { - if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) { + // TODO: SPARK-49358: Mode expression for map type with collated fields + if (UnsafeRowUtils.isBinaryStable(child.dataType) || + !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] && + !UnsafeRowUtils.isBinaryStable(f))) { /* * The Mode class uses collation awareness logic to handle string data. - * Complex types with collated fields are not yet supported. + * All complex types except MapType with collated fields are supported. */ - // TODO: SPARK-48700: Mode expression for complex types (all collations) super.checkInputDataTypes() } else { - TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" + - " a type of binary-unstable type that is " + - s"not currently supported by ${prettyName}.") + TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE", + messageParameters = + Map("child" -> toSQLType(child.dataType), + "mode" -> toSQLId(prettyName))) } } @@ -86,6 +92,54 @@ case class Mode( buffer } + private def getCollationAwareBuffer( + childDataType: DataType, + buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = { + def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, Long)] = { + buffer.groupMapReduce(t => + groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values + } + def determineBufferingFunction( + childDataType: DataType): Option[AnyRef => _] = { + childDataType match { + case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None + case _ => Some(collationAwareTransform(_, childDataType)) + } + } + determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer) + } + + protected[sql] def collationAwareTransform(data: AnyRef, dataType: DataType): AnyRef = { + dataType match { + case _ if UnsafeRowUtils.isBinaryStable(dataType) => data + case st: StructType => + processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields)) + case at: ArrayType => processArrayTypeWithBuffer(at, data.asInstanceOf[ArrayData]) + case st: StringType => + CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], st.collationId) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS", + messageParameters = Map( + "expression" -> toSQLExpr(this), + "functionName" -> toSQLType(prettyName), + "dataType" -> toSQLType(child.dataType)) + ) + } + } + + private def processStructTypeWithBuffer( + tuples: Seq[(Any, StructField)]): Seq[Any] = { + tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef], t._2.dataType)) + } + + private def processArrayTypeWithBuffer( + a: ArrayType, + data: ArrayData): Seq[Any] = { + (0 until data.numElements()).map(i => + collationAwareTransform(data.get(i, a.elementType), a.elementType)) + } + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { if (buffer.isEmpty) { return null @@ -102,17 +156,12 @@ case class Mode( * to a single value (the sum of the counts), and finally reduces the groups to a single map. * * The new map is then used in the rest of the Mode evaluation logic. + * + * It is expected to work for all simple and complex types with + * collated fields, except for MapType (temporarily). */ - val collationAwareBuffer = child.dataType match { - case c: StringType if - !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality => - val collationId = c.collationId - val modeMap = buffer.toSeq.groupMapReduce { - case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId) - }(x => x)((x, y) => (x._1, x._2 + y._2)).values - modeMap - case _ => buffer - } + val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer) + reverseOpt.map { reverse => val defaultKeyOrdering = if (reverse) { PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 941d5cd31db40..9930709cd8bf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import scala.collection.immutable.Seq -import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} -import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable} +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Mode import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} @@ -1752,7 +1753,7 @@ class CollationSQLExpressionsSuite UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"), UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a")) - testCasesUTF8String.foreach(t => { + testCasesUTF8String.foreach ( t => { val buffer = new OpenHashMap[AnyRef, Long](5) val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId))) t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } @@ -1760,6 +1761,40 @@ class CollationSQLExpressionsSuite }) } + test("Support Mode.eval(buffer) with complex types") { + case class UTF8StringModeTestCase[R]( + collationId: String, + bufferValues: Map[InternalRow, Long], + result: R) + + val bufferValuesUTF8String: Map[Any, Long] = Map( + UTF8String.fromString("a") -> 5L, + UTF8String.fromString("b") -> 4L, + UTF8String.fromString("B") -> 3L, + UTF8String.fromString("d") -> 2L, + UTF8String.fromString("e") -> 1L) + + val bufferValuesComplex = bufferValuesUTF8String.map{ + case (k, v) => (InternalRow.fromSeq(Seq(k, k, k)), v) + } + val testCasesUTF8String = Seq( + UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"), + UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]")) + + testCasesUTF8String.foreach { t => + val buffer = new OpenHashMap[AnyRef, Long](5) + val myMode = Mode(child = Literal.create(null, StructType(Seq( + StructField("f1", StringType(t.collationId), true), + StructField("f2", StringType(t.collationId), true), + StructField("f3", StringType(t.collationId), true) + )))) + t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } + assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase()) + } + } + test("Support mode for string expression with collated strings in struct") { case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( @@ -1780,33 +1815,7 @@ class CollationSQLExpressionsSuite t.collationId + ", f2: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || - t.collationId == "unicode") { - // Cannot resolve "mode(i)" due to data type mismatch: - // Input to function mode was a complex type with strings collated on non-binary - // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode'" + - " was a type of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 13, - stopIndex = 19, - fragment = "mode(i)") - ) - ) - } else { - checkAnswer(sql(query), Row(t.result)) - } + checkAnswer(sql(query), Row(t.result)) } }) } @@ -1819,47 +1828,21 @@ class CollationSQLExpressionsSuite ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) - testCases.foreach(t => { + testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => (0L to numRepeats).map(_ => s"named_struct('f1', " + s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 1)").mkString(",") }.mkString(",") - val tableName = s"t_${t.collationId}_mode_nested_struct" + val tableName = s"t_${t.collationId}_mode_nested_struct1" withTable(tableName) { sql(s"CREATE TABLE ${tableName}(i STRUCT, f3: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || - t.collationId == "unicode") { - // Cannot resolve "mode(i)" due to data type mismatch: - // Input to function mode was a complex type with strings collated on non-binary - // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode' " + - "was a type of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 13, - stopIndex = 19, - fragment = "mode(i)") - ) - ) - } else { - checkAnswer(sql(query), Row(t.result)) - } + checkAnswer(sql(query), Row(t.result)) } - }) + } } test("Support mode for string expression with collated strings in array complex type") { @@ -1870,44 +1853,150 @@ class CollationSQLExpressionsSuite ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) - testCases.foreach(t => { + testCases.foreach { t => + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => s"array(named_struct('f2', " + + s"collate('$elt', '${t.collationId}'), 'f3', 1))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_struct2" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(" + + s"i ARRAY< STRUCT>)" + + s" USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(element_at(mode(i).f2, 1)) FROM ${tableName}" + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode for string expression with collated strings in 3D array type") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach { t => + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => + s"array(array(array(collate('$elt', '${t.collationId}'))))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_3d_array" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(i ARRAY>>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(" + + s"element_at(element_at(element_at(mode(i),1),1),1)) FROM ${tableName}" + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode for string expression with collated complex type - Highly nested") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => (0L to numRepeats).map(_ => s"array(named_struct('s1', named_struct('a2', " + s"array(collate('$elt', '${t.collationId}'))), 'f3', 1))").mkString(",") }.mkString(",") - val tableName = s"t_${t.collationId}_mode_nested_struct" + val tableName = s"t_${t.collationId}_mode_highly_nested_struct" withTable(tableName) { sql(s"CREATE TABLE ${tableName}(" + s"i ARRAY>, f3: INT>>)" + s" USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || t.collationId == "unicode") { - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode' was a type" + - " of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 35, - stopIndex = 41, - fragment = "mode(i)") - ) - ) - } else { + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode expression with collated in recursively nested struct with map with keys") { + case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) + Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}") + ).foreach { t1 => + def checkThisError(t: ModeTestCase, query: String): Any = { + val c = s"STRUCT>" + val c1 = s"\"${c}\"" + checkError( + exception = intercept[SparkThrowable] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE", + parameters = Map( + ("sqlExpr", "\"mode(i)\""), + ("child", c1), + ("mode", "`mode`")), + queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray + ) + } + + def getValuesToAdd(t: ModeTestCase): String = { + val valuesToAdd = t.bufferValues.map { + case (elt, numRepeats) => + (0L to numRepeats).map(i => + s"named_struct('m1', map(collate('$elt', '${t.collationId}'), 1))" + ).mkString(",") + }.mkString(",") + valuesToAdd + } + val tableName = s"t_${t1.collationId}_mode_nested_map_struct1" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(" + + s"i STRUCT>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}") + val query = "SELECT lower(cast(mode(i).m1 as string))" + + s" FROM ${tableName}" + if (t1.collationId == "utf8_binary") { + checkAnswer(sql(query), Row(t1.result)) + } else { + checkThisError(t1, query) } } - }) + } + } + + test("UDT with collation - Mode (throw exception)") { + case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) + Seq( + ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ).foreach { t1 => + checkError( + exception = intercept[SparkIllegalArgumentException] { + Mode( + child = Literal.create(null, + MapType(StringType(t1.collationId), IntegerType)) + ).collationAwareTransform( + data = Map.empty[String, Any], + dataType = MapType(StringType(t1.collationId), IntegerType) + ) + }, + condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS", + parameters = Map( + "expression" -> "\"mode(NULL)\"", + "functionName" -> "\"MODE\"", + "dataType" -> s"\"MAP\"") + ) + } } test("SPARK-48430: Map value extraction with collations") { From 3093ad68d2a3c6bab9c1605381d27e700766be22 Mon Sep 17 00:00:00 2001 From: exmy Date: Tue, 1 Oct 2024 15:22:29 +0200 Subject: [PATCH 132/250] [MINOR] Fix a typo in First aggregate expression ### What changes were proposed in this pull request? Find a typo for the comment on code `mergeExpressions` of `First` aggregate expression, fix from `first.right` to `first.left`. ### Why are the changes needed? Fix typo, it's confused. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? N.A ### Was this patch authored or co-authored using generative AI tooling? No Closes #48298 from exmy/fix-comment. Authored-by: exmy Signed-off-by: Max Gekk --- .../apache/spark/sql/catalyst/expressions/aggregate/First.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 4fe00099ddc91..9a39a6fe98796 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -104,7 +104,7 @@ case class First(child: Expression, ignoreNulls: Boolean) override lazy val mergeExpressions: Seq[Expression] = { // For first, we can just check if valueSet.left is set to true. If it is set - // to true, we use first.right. If not, we use first.right (even if valueSet.right is + // to true, we use first.left. If not, we use first.right (even if valueSet.right is // false, we are safe to do so because first.right will be null in this case). Seq( /* first = */ If(valueSet.left, first.left, first.right), From 3551a9ee6d388f68f326cce1c0c9dad51e33ef58 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 1 Oct 2024 21:14:24 -0700 Subject: [PATCH 133/250] [SPARK-49845][CORE] Make `appArgs` and `environmentVariables` optional in REST API ### What changes were proposed in this pull request? This PR aims to make `appArgs` and `environmentVariables` fields optional in REST API. ### Why are the changes needed? `appArgs` and `environmentVariables` became mandatory due to the Apache Mesos limitation at Spark 2.2.2. Technically, this is a revert of SPARK-22574. - https://github.com/apache/spark/pull/19966 Since Apache Spark 4.0 removed Mesos support, we don't need these requirements. - https://github.com/apache/spark/pull/43135 ### Does this PR introduce _any_ user-facing change? No because this is a relaxation of enforcement. ### How was this patch tested? Pass the CIs with the revised test case. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48316 from dongjoon-hyun/SPARK-49845. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../org/apache/spark/deploy/rest/StandaloneRestServer.scala | 5 +++-- .../apache/spark/deploy/rest/SubmitRestProtocolRequest.scala | 2 -- .../apache/spark/deploy/rest/SubmitRestProtocolSuite.scala | 2 -- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 31673f666173a..c92e79381ca9b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -218,11 +218,12 @@ private[rest] class StandaloneSubmitRequestServlet( val (_, masterPort) = Utils.extractHostPortFromSparkUrl(masterUrl) val updatedMasters = masters.map( _.replace(s":$masterRestPort", s":$masterPort")).getOrElse(masterUrl) - val appArgs = request.appArgs + val appArgs = Option(request.appArgs).getOrElse(Array[String]()) // Filter SPARK_LOCAL_(IP|HOSTNAME) environment variables from being set on the remote system. // In addition, the placeholders are replaced into the values of environment variables. val environmentVariables = - request.environmentVariables.filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) + Option(request.environmentVariables).getOrElse(Map.empty[String, String]) + .filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) .map(x => (x._1, replacePlaceHolder(x._2))) // Construct driver description diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala index 7f462148c71a1..63882259adcb5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala @@ -47,8 +47,6 @@ private[rest] class CreateSubmissionRequest extends SubmitRestProtocolRequest { super.doValidate() assert(sparkProperties != null, "No Spark properties set!") assertFieldIsSet(appResource, "appResource") - assertFieldIsSet(appArgs, "appArgs") - assertFieldIsSet(environmentVariables, "environmentVariables") assertPropertyIsSet("spark.app.name") assertPropertyIsBoolean(config.DRIVER_SUPERVISE.key) assertPropertyIsNumeric(config.DRIVER_CORES.key) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 9eb5172583120..f2807f258f2d1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -87,8 +87,6 @@ class SubmitRestProtocolSuite extends SparkFunSuite { message.clientSparkVersion = "1.2.3" message.appResource = "honey-walnut-cherry.jar" message.mainClass = "org.apache.spark.examples.SparkPie" - message.appArgs = Array("two slices") - message.environmentVariables = Map("PATH" -> "/dev/null") val conf = new SparkConf(false) conf.set("spark.app.name", "SparkPie") message.sparkProperties = conf.getAll.toMap From 077a31989c99cb6302a325c953d2ee92ba573a8b Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 2 Oct 2024 14:46:46 +0200 Subject: [PATCH 134/250] [SPARK-49843][SQL] Fix change comment on char/varchar columns ### What changes were proposed in this pull request? Fix the issue in `AlterTableChangeColumnCommand` where changing the comment of a char/varchar column also tries to change the column type to string. ### Why are the changes needed? Because the newColumn will always be a `StringType` even when the metadata says that it was originally char/varchar. ### Does this PR introduce _any_ user-facing change? Yes, the query will no longer fail when using this code path. ### How was this patch tested? New query in golden files. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48315 from stefankandic/fixAlterVarcharCol. Authored-by: Stefan Kandic Signed-off-by: Max Gekk --- .../analysis/ResolveSessionCatalog.scala | 10 ++++-- .../analyzer-results/charvarchar.sql.out | 12 +++++++ .../sql-tests/inputs/charvarchar.sql | 2 ++ .../sql-tests/results/charvarchar.sql.out | 32 ++++++++++++++----- 4 files changed, 45 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index a9ad7523c8fbc..884c870e8eed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, ResolveDefaultColumns => DefaultCols} +import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, CharVarcharUtils, ResolveDefaultColumns => DefaultCols} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogManager, CatalogPlugin, CatalogV2Util, LookupCatalog, SupportsNamespaces, V1Table} import org.apache.spark.sql.connector.expressions.Transform @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1} import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.connector.V1Function -import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, StructType} import org.apache.spark.util.ArrayImplicits._ /** @@ -87,7 +87,11 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) val colName = a.column.name(0) val dataType = a.dataType.getOrElse { table.schema.findNestedField(Seq(colName), resolver = conf.resolver) - .map(_._2.dataType) + .map { + case (_, StructField(_, st: StringType, _, metadata)) => + CharVarcharUtils.getRawType(metadata).getOrElse(st) + case (_, field) => field.dataType + } .getOrElse { throw QueryCompilationErrors.unresolvedColumnError( toSQLId(a.column.name), table.schema.fieldNames) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out index 5c1417f7c0aae..524797015a2f6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out @@ -263,6 +263,18 @@ desc formatted char_part DescribeTableCommand `spark_catalog`.`default`.`char_part`, true, [col_name#x, data_type#x, comment#x] +-- !query +alter table char_part change column c1 comment 'char comment' +-- !query analysis +AlterTableChangeColumnCommand `spark_catalog`.`default`.`char_part`, c1, StructField(c1,CharType(5),true) + + +-- !query +alter table char_part change column v1 comment 'varchar comment' +-- !query analysis +AlterTableChangeColumnCommand `spark_catalog`.`default`.`char_part`, v1, StructField(v1,VarcharType(6),true) + + -- !query alter table char_part add partition (v2='ke', c2='nt') location 'loc1' -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql b/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql index 8117dec53f4ab..be038e1083cd8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/charvarchar.sql @@ -49,6 +49,8 @@ desc formatted char_tbl1; create table char_part(c1 char(5), c2 char(2), v1 varchar(6), v2 varchar(2)) using parquet partitioned by (v2, c2); desc formatted char_part; +alter table char_part change column c1 comment 'char comment'; +alter table char_part change column v1 comment 'varchar comment'; alter table char_part add partition (v2='ke', c2='nt') location 'loc1'; desc formatted char_part; diff --git a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out index 568c9f3b29e87..8aafa25c5caaf 100644 --- a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out @@ -556,6 +556,22 @@ Location [not included in comparison]/{warehouse_dir}/char_part Partition Provider Catalog +-- !query +alter table char_part change column c1 comment 'char comment' +-- !query schema +struct<> +-- !query output + + + +-- !query +alter table char_part change column v1 comment 'varchar comment' +-- !query schema +struct<> +-- !query output + + + -- !query alter table char_part add partition (v2='ke', c2='nt') location 'loc1' -- !query schema @@ -569,8 +585,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -612,8 +628,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -647,8 +663,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information @@ -682,8 +698,8 @@ desc formatted char_part -- !query schema struct -- !query output -c1 char(5) -v1 varchar(6) +c1 char(5) char comment +v1 varchar(6) varchar comment v2 varchar(2) c2 char(2) # Partition Information From 18dbaa5a070c74007137780e8529321b75b10b48 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 2 Oct 2024 13:19:03 -0700 Subject: [PATCH 135/250] [SPARK-49560][SQL] Add SQL pipe syntax for the TABLESAMPLE operator ### What changes were proposed in this pull request? WIP This PR adds SQL pipe syntax support for the TABLESAMPLE operator. For example: ``` CREATE TABLE t(x INT, y STRING) USING CSV; INSERT INTO t VALUES (0, 'abc'), (1, 'def'); TABLE t |> TABLESAMPLE (100 PERCENT) REPEATABLE (0) |> TABLESAMPLE (5 ROWS) REPEATABLE (0) |> TABLESAMPLE (BUCKET 1 OUT OF 1) REPEATABLE (0); 0 abc 1 def ``` ### Why are the changes needed? The SQL pipe operator syntax will let users compose queries in a more flexible fashion. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds a few unit test cases, but mostly relies on golden file test coverage. I did this to make sure the answers are correct as this feature is implemented and also so we can look at the analyzer output plans to ensure they look right as well. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48168 from dtenedor/pipe-tablesample. Authored-by: Daniel Tenedorio Signed-off-by: Gengliang Wang --- .../sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../sql/catalyst/parser/AstBuilder.scala | 4 +- .../analyzer-results/pipe-operators.sql.out | 184 ++++++++++++++++ .../sql-tests/inputs/pipe-operators.sql | 49 +++++ .../sql-tests/results/pipe-operators.sql.out | 198 ++++++++++++++++++ .../sql/execution/SparkSqlParserSuite.scala | 9 + 6 files changed, 444 insertions(+), 1 deletion(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 33ac3249eb663..e8e2e980135a2 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1504,6 +1504,7 @@ operatorPipeRightSide // messages in the event that both are present (this is not allowed). | pivotClause unpivotClause? | unpivotClause pivotClause? + | sample ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e2350474a8708..9ce96ae652fed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -5903,7 +5903,9 @@ class AstBuilder extends DataTypeAstBuilder throw QueryParsingErrors.unpivotWithPivotInFromClauseNotAllowedError(ctx) } withUnpivot(c, left) - }.get))) + }.getOrElse(Option(ctx.sample).map { c => + withSample(c, left) + }.get)))) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index 8cd062aeb01a3..aee8da46aafbe 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -921,6 +921,190 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +table t +|> tablesample (100 percent) repeatable (0) +-- !query analysis +Sample 0.0, 1.0, false, 0 ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample (2 rows) repeatable (0) +-- !query analysis +GlobalLimit 2 ++- LocalLimit 2 + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query analysis +Sample 0.0, 1.0, false, 0 ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample (100 percent) repeatable (0) +|> tablesample (5 rows) repeatable (0) +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query analysis +Sample 0.0, 1.0, false, 0 ++- GlobalLimit 5 + +- LocalLimit 5 + +- Sample 0.0, 1.0, false, 0 + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> tablesample () +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0014", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 25, + "fragment" : "tablesample ()" + } ] +} + + +-- !query +table t +|> tablesample (-100 percent) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (-1.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 37, + "fragment" : "tablesample (-100 percent)" + } ] +} + + +-- !query +table t +|> tablesample (-5 rows) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"-5\"", + "name" : "limit", + "v" : "-5" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 26, + "fragment" : "-5" + } ] +} + + +-- !query +table t +|> tablesample (x rows) +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"x\"", + "name" : "limit" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 25, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> tablesample (bucket 2 out of 1) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (2.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 42, + "fragment" : "tablesample (bucket 2 out of 1)" + } ] +} + + +-- !query +table t +|> tablesample (200b) repeatable (0) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0015", + "messageParameters" : { + "msg" : "byteLengthLiteral" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 44, + "fragment" : "tablesample (200b) repeatable (0)" + } ] +} + + +-- !query +table t +|> tablesample (200) repeatable (0) +-- !query analysis +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0016", + "messageParameters" : { + "bytesStr" : "200" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 43, + "fragment" : "tablesample (200) repeatable (0)" + } ] +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 3aa01d472e83f..31748fe1125ab 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -326,6 +326,55 @@ table courseSales for `year` in (2012, 2013) ); +-- Sampling operators: positive tests. +-------------------------------------- + +-- We will use the REPEATABLE clause and/or adjust the sampling options to either remove no rows or +-- all rows to help keep the tests deterministic. +table t +|> tablesample (100 percent) repeatable (0); + +table t +|> tablesample (2 rows) repeatable (0); + +table t +|> tablesample (bucket 1 out of 1) repeatable (0); + +table t +|> tablesample (100 percent) repeatable (0) +|> tablesample (5 rows) repeatable (0) +|> tablesample (bucket 1 out of 1) repeatable (0); + +-- Sampling operators: negative tests. +-------------------------------------- + +-- The sampling method is required. +table t +|> tablesample (); + +-- Negative sampling options are not supported. +table t +|> tablesample (-100 percent); + +table t +|> tablesample (-5 rows); + +-- The sampling method may not refer to attribute names from the input relation. +table t +|> tablesample (x rows); + +-- The bucket number is invalid. +table t +|> tablesample (bucket 2 out of 1); + +-- Byte literals are not supported. +table t +|> tablesample (200b) repeatable (0); + +-- Invalid byte literal syntax. +table t +|> tablesample (200) repeatable (0); + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index 2c6abe2a277ad..78b610b0d97c6 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -861,6 +861,204 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +table t +|> tablesample (100 percent) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample (2 rows) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample (100 percent) repeatable (0) +|> tablesample (5 rows) repeatable (0) +|> tablesample (bucket 1 out of 1) repeatable (0) +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> tablesample () +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0014", + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 25, + "fragment" : "tablesample ()" + } ] +} + + +-- !query +table t +|> tablesample (-100 percent) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (-1.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 37, + "fragment" : "tablesample (-100 percent)" + } ] +} + + +-- !query +table t +|> tablesample (-5 rows) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_NEGATIVE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"-5\"", + "name" : "limit", + "v" : "-5" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 26, + "fragment" : "-5" + } ] +} + + +-- !query +table t +|> tablesample (x rows) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INVALID_LIMIT_LIKE_EXPRESSION.IS_UNFOLDABLE", + "sqlState" : "42K0E", + "messageParameters" : { + "expr" : "\"x\"", + "name" : "limit" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 25, + "stopIndex" : 25, + "fragment" : "x" + } ] +} + + +-- !query +table t +|> tablesample (bucket 2 out of 1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0064", + "messageParameters" : { + "msg" : "Sampling fraction (2.0) must be on interval [0, 1]" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 42, + "fragment" : "tablesample (bucket 2 out of 1)" + } ] +} + + +-- !query +table t +|> tablesample (200b) repeatable (0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0015", + "messageParameters" : { + "msg" : "byteLengthLiteral" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 44, + "fragment" : "tablesample (200b) repeatable (0)" + } ] +} + + +-- !query +table t +|> tablesample (200) repeatable (0) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.parser.ParseException +{ + "errorClass" : "_LEGACY_ERROR_TEMP_0016", + "messageParameters" : { + "bytesStr" : "200" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 12, + "stopIndex" : 43, + "fragment" : "tablesample (200) repeatable (0)" + } ] +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 1111a65c6a526..c76d44a1b82cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -928,6 +928,15 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { | earningsYear FOR year IN (`2012`, `2013`, `2014`) |) |""".stripMargin) + // Sampling operations + def checkSample(query: String): Unit = { + val plan: LogicalPlan = parser.parsePlan(query) + assert(plan.collectFirst(_.isInstanceOf[Sample]).nonEmpty) + assert(plan.containsAnyPattern(UNRESOLVED_RELATION, LOCAL_RELATION)) + } + checkSample("TABLE t |> TABLESAMPLE (50 PERCENT)") + checkSample("TABLE t |> TABLESAMPLE (5 ROWS)") + checkSample("TABLE t |> TABLESAMPLE (BUCKET 4 OUT OF 10)") } } } From d97acc17dd0bce476a1f44e7cce14e8d13d95a51 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 2 Oct 2024 14:20:49 -0700 Subject: [PATCH 136/250] [SPARK-49853][SQL][TESTS] Increase test timeout of `PythonForeachWriterSuite` to `60s` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This PR aims to increase test timeout of `PythonForeachWriterSuite` to `60s`. ### Why are the changes needed? To stablize `PythonForeachWriterSuite` in GitHub Action MacOS 15 Runner. For the failed cases, the data is still under generation. - https://github.com/apache/spark/actions/runs/11132652698/job/30936988757 ``` - UnsafeRowBuffer: handles more data than memory *** FAILED *** The code passed to eventually never returned normally. Attempted 237 times over 20.075615666999997 seconds. Last failure message: ArraySeq(1, ..., 1815) did not equal Range$Inclusive(1, ..., 2000) ``` GitHub Runners have different spec and macOS has very limited resources among them. - https://docs.github.com/en/actions/using-github-hosted-runners/using-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories | Virtual Machine | Processor (CPU) | Memory (RAM) | Storage (SSD) | Workflow label | | -- | -- | -- | -- | -- | | Linux | 4 | 16 GB | 14 GB | ubuntu-latest,ubuntu-24.04,ubuntu-22.04,ubuntu-20.04 | | macOS | 3 (M1) | 7 GB | 14 GB | macos-latest,macos-14, macos-15 [Beta] | ### Does this PR introduce _any_ user-facing change? No, this is a test-only change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48319 from dongjoon-hyun/SPARK-49853. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../spark/sql/execution/python/PythonForeachWriterSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala index 3a8ce569d1ba9..a2d3318361837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonForeachWriterSuite.scala @@ -99,7 +99,7 @@ class PythonForeachWriterSuite extends SparkFunSuite with Eventually with Mockit } private val iterator = buffer.iterator private val outputBuffer = new ArrayBuffer[Int] - private val testTimeout = timeout(20.seconds) + private val testTimeout = timeout(60.seconds) private val intProj = UnsafeProjection.create(Array[DataType](IntegerType)) private val thread = new Thread() { override def run(): Unit = { From ce5762649435086f0eeacbfa721d5f4686135abc Mon Sep 17 00:00:00 2001 From: ivanjevtic-db Date: Thu, 3 Oct 2024 08:59:18 +0200 Subject: [PATCH 137/250] [SPARK-49837][SQL][TESTS] Add more tests for NULLIF function ### What changes were proposed in this pull request? In this pull request, the proposed changes include introducing tests for the **NULLIF** function. These tests will help prevent potential regressions by ensuring that future modifications do not unintentionally break the behavior of **NULLIF**. I have written several tests, along with queries that combine NULLIF with GROUP BY to cover more complex use cases. ### Why are the changes needed? Currently, there is a lack of tests for the NULLIF function. We should add tests to prevent regressions. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48302 from ivanjevtic-db/nullif-tests. Authored-by: ivanjevtic-db Signed-off-by: Max Gekk --- .../spark/sql/DataFrameFunctionsSuite.scala | 47 +++++++++++++++---- 1 file changed, 38 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 016803635ff60..47691e1ccd40f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -315,6 +315,44 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(isnotnull(col("a"))), Seq(Row(false))) } + test("nullif function") { + Seq(true, false).foreach { alwaysInlineCommonExpr => + withSQLConf(SQLConf.ALWAYS_INLINE_COMMON_EXPR.key -> alwaysInlineCommonExpr.toString) { + Seq( + "SELECT NULLIF(1, 1)" -> Seq(Row(null)), + "SELECT NULLIF(1, 2)" -> Seq(Row(1)), + "SELECT NULLIF(NULL, 1)" -> Seq(Row(null)), + "SELECT NULLIF(1, NULL)" -> Seq(Row(1)), + "SELECT NULLIF(NULL, NULL)" -> Seq(Row(null)), + "SELECT NULLIF('abc', 'abc')" -> Seq(Row(null)), + "SELECT NULLIF('abc', 'xyz')" -> Seq(Row("abc")), + "SELECT NULLIF(id, 1) " + + "FROM range(10) " + + "GROUP BY NULLIF(id, 1)" -> Seq(Row(null), Row(2), Row(3), Row(4), Row(5), Row(6), + Row(7), Row(8), Row(9), Row(0)), + "SELECT NULLIF(id, 1), COUNT(*)" + + "FROM range(10) " + + "GROUP BY NULLIF(id, 1) " + + "HAVING COUNT(*) > 1" -> Seq.empty[Row] + ).foreach { + case (sqlText, expected) => checkAnswer(sql(sqlText), expected) + } + + checkError( + exception = intercept[AnalysisException] { + sql("SELECT NULLIF(id, 1), COUNT(*) " + + "FROM range(10) " + + "GROUP BY NULLIF(id, 2)") + }, + condition = "MISSING_AGGREGATION", + parameters = Map( + "expression" -> "\"id\"", + "expressionAnyValue" -> "\"any_value(id)\"") + ) + } + } + } + test("equal_null function") { val df = Seq[(Integer, Integer)]((null, 8)).toDF("a", "b") checkAnswer(df.selectExpr("equal_null(a, b)"), Seq(Row(false))) @@ -324,15 +362,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(df.select(equal_null(col("a"), col("a"))), Seq(Row(true))) } - test("nullif function") { - val df = Seq((5, 8)).toDF("a", "b") - checkAnswer(df.selectExpr("nullif(5, 8)"), Seq(Row(5))) - checkAnswer(df.select(nullif(lit(5), lit(8))), Seq(Row(5))) - - checkAnswer(df.selectExpr("nullif(a, a)"), Seq(Row(null))) - checkAnswer(df.select(nullif(lit(5), lit(5))), Seq(Row(null))) - } - test("nullifzero function") { withTable("t") { // Here we exercise a non-nullable, non-foldable column. From 216f761bcb122a253d42793466a9fe97e7ba3336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 3 Oct 2024 09:12:23 +0200 Subject: [PATCH 138/250] [SPARK-48357][SQL] Support for LOOP statement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? In this PR, support for LOOP statement in SQL scripting is introduced. Changes summary: Grammar/parser changes: - `loopStatement` grammar rule - `visitLoopStatement` rule visitor - `LoopStatement` logical operator `LoopStatementExec` execution node Iterator implementation - repeatedly execute body (only way to stop the loop is with LEAVE, or if an exception occurs) `SqlScriptingInterpreter` - added logic to transform LoopStatement logical operator to LoopStatementExec execution node ### Why are the changes needed? This is a part of SQL Scripting introduced to Spark, LOOP statement is a basic control flow construct in the SQL language. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New tests are introduced to scripting test suites: `SqlScriptingParserSuite`, `SqlScriptingExecutionNodeSuite` and `SqlScriptingInterpreterSuite`. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48323 from dusantism-db/sql-scripting-loop. Authored-by: Dušan Tišma Signed-off-by: Max Gekk --- docs/sql-ref-ansi-compliance.md | 1 + .../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 + .../sql/catalyst/parser/SqlBaseParser.g4 | 7 + .../sql/catalyst/parser/AstBuilder.scala | 11 + .../parser/SqlScriptingLogicalOperators.scala | 18 +- .../parser/SqlScriptingParserSuite.scala | 205 ++++++++++++++++++ .../scripting/SqlScriptingExecutionNode.scala | 52 +++++ .../scripting/SqlScriptingInterpreter.scala | 6 +- .../sql-tests/results/ansi/keywords.sql.out | 1 + .../sql-tests/results/keywords.sql.out | 1 + .../SqlScriptingExecutionNodeSuite.scala | 17 ++ .../SqlScriptingInterpreterSuite.scala | 152 +++++++++++++ .../ThriftServerWithSparkContextSuite.scala | 2 +- 13 files changed, 469 insertions(+), 5 deletions(-) diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md index 12dff1e325c49..b4446b1538cd6 100644 --- a/docs/sql-ref-ansi-compliance.md +++ b/docs/sql-ref-ansi-compliance.md @@ -581,6 +581,7 @@ Below is a list of all the keywords in Spark SQL. |LOCKS|non-reserved|non-reserved|non-reserved| |LOGICAL|non-reserved|non-reserved|non-reserved| |LONG|non-reserved|non-reserved|non-reserved| +|LOOP|non-reserved|non-reserved|non-reserved| |MACRO|non-reserved|non-reserved|non-reserved| |MAP|non-reserved|non-reserved|non-reserved| |MATCHED|non-reserved|non-reserved|non-reserved| diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 index de28041acd41f..7391e8c353dee 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4 @@ -301,6 +301,7 @@ LOCK: 'LOCK'; LOCKS: 'LOCKS'; LOGICAL: 'LOGICAL'; LONG: 'LONG'; +LOOP: 'LOOP'; MACRO: 'MACRO'; MAP: 'MAP' {incComplexTypeLevelCounter();}; MATCHED: 'MATCHED'; diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index e8e2e980135a2..644c7e732fbf0 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -69,6 +69,7 @@ compoundStatement | repeatStatement | leaveStatement | iterateStatement + | loopStatement ; setStatementWithOptionalVarKeyword @@ -106,6 +107,10 @@ caseStatement (ELSE elseBody=compoundBody)? END CASE #simpleCaseStatement ; +loopStatement + : beginLabel? LOOP compoundBody END LOOP endLabel? + ; + singleStatement : (statement|setResetStatement) SEMICOLON* EOF ; @@ -1658,6 +1663,7 @@ ansiNonReserved | LOCKS | LOGICAL | LONG + | LOOP | MACRO | MAP | MATCHED @@ -2016,6 +2022,7 @@ nonReserved | LOCKS | LOGICAL | LONG + | LOOP | MACRO | MAP | MATCHED diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9ce96ae652fed..f1d211f517789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -337,6 +337,10 @@ class AstBuilder extends DataTypeAstBuilder if Option(c.beginLabel()).isDefined && c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) => true + case c: LoopStatementContext + if Option(c.beginLabel()).isDefined && + c.beginLabel().multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label) + => true case _ => false } } @@ -373,6 +377,13 @@ class AstBuilder extends DataTypeAstBuilder CurrentOrigin.get, labelText, "ITERATE") } + override def visitLoopStatement(ctx: LoopStatementContext): LoopStatement = { + val labelText = generateLabelText(Option(ctx.beginLabel()), Option(ctx.endLabel())) + val body = visitCompoundBody(ctx.compoundBody()) + + LoopStatement(body, Some(labelText)) + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { Option(ctx.statement().asInstanceOf[ParserRuleContext]) .orElse(Option(ctx.setResetStatement().asInstanceOf[ParserRuleContext])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala index ed40a5fd734b6..9fd87f51bd57e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingLogicalOperators.scala @@ -81,7 +81,7 @@ case class IfElseStatement( * Body is executed as long as the condition evaluates to true * @param body Compound body is a collection of statements that are executed if condition is true. * @param label An optional label for the loop which is unique amongst all labels for statements - * within which the LOOP statement is contained. + * within which the WHILE statement is contained. * If an end label is specified it must match the beginning label. * The label can be used to LEAVE or ITERATE the loop. */ @@ -97,7 +97,7 @@ case class WhileStatement( * @param body Compound body is a collection of statements that are executed once no matter what, * and then as long as condition is false. * @param label An optional label for the loop which is unique amongst all labels for statements - * within which the LOOP statement is contained. + * within which the REPEAT statement is contained. * If an end label is specified it must match the beginning label. * The label can be used to LEAVE or ITERATE the loop. */ @@ -106,7 +106,6 @@ case class RepeatStatement( body: CompoundBody, label: Option[String]) extends CompoundPlanStatement - /** * Logical operator for LEAVE statement. * The statement can be used both for compounds or any kind of loops. @@ -138,3 +137,16 @@ case class CaseStatement( elseBody: Option[CompoundBody]) extends CompoundPlanStatement { assert(conditions.length == conditionalBodies.length) } + +/** + * Logical operator for LOOP statement. + * @param body Compound body is a collection of statements that are executed until the + * LOOP statement is terminated by using the LEAVE statement. + * @param label An optional label for the loop which is unique amongst all labels for statements + * within which the LOOP statement is contained. + * If an end label is specified it must match the beginning label. + * The label can be used to LEAVE or ITERATE the loop. + */ +case class LoopStatement( + body: CompoundBody, + label: Option[String]) extends CompoundPlanStatement diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index ba634333e06fb..2972ba2db21de 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -1400,6 +1400,211 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { .getText == "SELECT 42") } + test("loop statement") { + val sqlScriptText = + """BEGIN + |lbl: LOOP + | SELECT 1; + | SELECT 2; + |END LOOP lbl; + |END + """.stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[LoopStatement]) + + val whileStmt = tree.collection.head.asInstanceOf[LoopStatement] + + assert(whileStmt.body.isInstanceOf[CompoundBody]) + assert(whileStmt.body.collection.length == 2) + assert(whileStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(whileStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(whileStmt.label.contains("lbl")) + } + + test("loop with if else block") { + val sqlScriptText = + """BEGIN + |lbl: LOOP + | IF 1 = 1 THEN + | SELECT 1; + | ELSE + | SELECT 2; + | END IF; + |END LOOP lbl; + |END + """.stripMargin + + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[LoopStatement]) + + val loopStmt = tree.collection.head.asInstanceOf[LoopStatement] + + assert(loopStmt.body.isInstanceOf[CompoundBody]) + assert(loopStmt.body.collection.length == 1) + assert(loopStmt.body.collection.head.isInstanceOf[IfElseStatement]) + val ifStmt = loopStmt.body.collection.head.asInstanceOf[IfElseStatement] + + assert(ifStmt.conditions.length == 1) + assert(ifStmt.conditionalBodies.length == 1) + assert(ifStmt.elseBody.isDefined) + + assert(ifStmt.conditions.head.isInstanceOf[SingleStatement]) + assert(ifStmt.conditions.head.getText == "1 = 1") + + assert(ifStmt.conditionalBodies.head.collection.head.isInstanceOf[SingleStatement]) + assert(ifStmt.conditionalBodies.head.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 1") + + assert(ifStmt.elseBody.get.collection.head.isInstanceOf[SingleStatement]) + assert(ifStmt.elseBody.get.collection.head.asInstanceOf[SingleStatement] + .getText == "SELECT 2") + + assert(loopStmt.label.contains("lbl")) + } + + test("nested loop") { + val sqlScriptText = + """BEGIN + |lbl: LOOP + | LOOP + | SELECT 42; + | END LOOP; + |END LOOP lbl; + |END + """.stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[LoopStatement]) + + val loopStmt = tree.collection.head.asInstanceOf[LoopStatement] + + assert(loopStmt.body.isInstanceOf[CompoundBody]) + assert(loopStmt.body.collection.length == 1) + assert(loopStmt.body.collection.head.isInstanceOf[LoopStatement]) + val nestedLoopStmt = loopStmt.body.collection.head.asInstanceOf[LoopStatement] + + assert(nestedLoopStmt.body.isInstanceOf[CompoundBody]) + assert(nestedLoopStmt.body.collection.length == 1) + assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(nestedLoopStmt.body.collection. + head.asInstanceOf[SingleStatement].getText == "SELECT 42") + + assert(loopStmt.label.contains("lbl")) + } + + test("leave loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | SELECT 1; + | LEAVE lbl; + | END LOOP; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[LoopStatement]) + + val loopStmt = tree.collection.head.asInstanceOf[LoopStatement] + + assert(loopStmt.body.isInstanceOf[CompoundBody]) + assert(loopStmt.body.collection.length == 2) + + assert(loopStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(loopStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(loopStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(loopStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test("iterate loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | SELECT 1; + | ITERATE lbl; + | END LOOP; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[LoopStatement]) + + val loopStmt = tree.collection.head.asInstanceOf[LoopStatement] + + assert(loopStmt.body.isInstanceOf[CompoundBody]) + assert(loopStmt.body.collection.length == 2) + + assert(loopStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert(loopStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(loopStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(loopStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + + test("leave outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | lbl2: LOOP + | SELECT 1; + | LEAVE lbl; + | END LOOP; + | END LOOP; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[LoopStatement]) + + val loopStmt = tree.collection.head.asInstanceOf[LoopStatement] + + assert(loopStmt.body.isInstanceOf[CompoundBody]) + assert(loopStmt.body.collection.length == 1) + + val nestedLoopStmt = loopStmt.body.collection.head.asInstanceOf[LoopStatement] + + assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert( + nestedLoopStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedLoopStmt.body.collection(1).isInstanceOf[LeaveStatement]) + assert(nestedLoopStmt.body.collection(1).asInstanceOf[LeaveStatement].label == "lbl") + } + + test("iterate outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | lbl2: LOOP + | SELECT 1; + | ITERATE lbl; + | END LOOP; + | END LOOP; + |END""".stripMargin + val tree = parseScript(sqlScriptText) + assert(tree.collection.length == 1) + assert(tree.collection.head.isInstanceOf[LoopStatement]) + + val loopStmt = tree.collection.head.asInstanceOf[LoopStatement] + + assert(loopStmt.body.isInstanceOf[CompoundBody]) + assert(loopStmt.body.collection.length == 1) + + val nestedLoopStmt = loopStmt.body.collection.head.asInstanceOf[LoopStatement] + + assert(nestedLoopStmt.body.collection.head.isInstanceOf[SingleStatement]) + assert( + nestedLoopStmt.body.collection.head.asInstanceOf[SingleStatement].getText == "SELECT 1") + + assert(nestedLoopStmt.body.collection(1).isInstanceOf[IterateStatement]) + assert(nestedLoopStmt.body.collection(1).asInstanceOf[IterateStatement].label == "lbl") + } + // Helper methods def cleanupStatementString(statementStr: String): String = { statementStr diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala index af9fd5464277c..9fdb9626556f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala @@ -592,3 +592,55 @@ class IterateStatementExec(val label: String) extends LeafStatementExec { var hasBeenMatched: Boolean = false override def reset(): Unit = hasBeenMatched = false } + +class LoopStatementExec( + body: CompoundBodyExec, + val label: Option[String]) extends NonLeafStatementExec { + + /** + * Loop can be interrupted by LeaveStatementExec + */ + private var interrupted: Boolean = false + + /** + * Loop can be iterated by IterateStatementExec + */ + private var iterated: Boolean = false + + private lazy val treeIterator = + new Iterator[CompoundStatementExec] { + override def hasNext: Boolean = !interrupted + + override def next(): CompoundStatementExec = { + if (!body.getTreeIterator.hasNext || iterated) { + reset() + } + + val retStmt = body.getTreeIterator.next() + + retStmt match { + case leaveStatementExec: LeaveStatementExec if !leaveStatementExec.hasBeenMatched => + if (label.contains(leaveStatementExec.label)) { + leaveStatementExec.hasBeenMatched = true + } + interrupted = true + case iterStatementExec: IterateStatementExec if !iterStatementExec.hasBeenMatched => + if (label.contains(iterStatementExec.label)) { + iterStatementExec.hasBeenMatched = true + } + iterated = true + case _ => + } + + retStmt + } + } + + override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator + + override def reset(): Unit = { + interrupted = false + iterated = false + body.reset() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala index 917b4d6f45ee0..78ef715e18982 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreter.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.scripting import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier -import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, RepeatStatement, SingleStatement, WhileStatement} +import org.apache.spark.sql.catalyst.parser.{CaseStatement, CompoundBody, CompoundPlanStatement, IfElseStatement, IterateStatement, LeaveStatement, LoopStatement, RepeatStatement, SingleStatement, WhileStatement} import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, DropVariable, LogicalPlan} import org.apache.spark.sql.catalyst.trees.Origin @@ -120,6 +120,10 @@ case class SqlScriptingInterpreter() { transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] new RepeatStatementExec(conditionExec, bodyExec, label, session) + case LoopStatement(body, label) => + val bodyExec = transformTreeIntoExecutable(body, session).asInstanceOf[CompoundBodyExec] + new LoopStatementExec(bodyExec, label) + case leaveStatement: LeaveStatement => new LeaveStatementExec(leaveStatement.label) diff --git a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out index 7c694503056ab..d9d266e8a674a 100644 --- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out @@ -187,6 +187,7 @@ LOCK false LOCKS false LOGICAL false LONG false +LOOP false MACRO false MAP false MATCHED false diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out index 2c16d961b1313..cd93a811d64f5 100644 --- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out @@ -187,6 +187,7 @@ LOCK false LOCKS false LOGICAL false LONG false +LOOP false MACRO false MAP false MATCHED false diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala index 83d8191d01ec1..baad5702f4f22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNodeSuite.scala @@ -97,6 +97,7 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi case TestLeafStatement(testVal) => testVal case TestIfElseCondition(_, description) => description case TestLoopCondition(_, _, description) => description + case loopStmt: LoopStatementExec => loopStmt.label.get case leaveStmt: LeaveStatementExec => leaveStmt.label case iterateStmt: IterateStatementExec => iterateStmt.label case _ => fail("Unexpected statement type") @@ -669,4 +670,20 @@ class SqlScriptingExecutionNodeSuite extends SparkFunSuite with SharedSparkSessi val statements = iter.map(extractStatementValue).toSeq assert(statements === Seq("con1", "con2")) } + + test("loop statement with leave") { + val iter = new CompoundBodyExec( + statements = Seq( + new LoopStatementExec( + body = new CompoundBodyExec(Seq( + TestLeafStatement("body1"), + new LeaveStatementExec("lbl")) + ), + label = Some("lbl") + ) + ) + ).getTreeIterator + val statements = iter.map(extractStatementValue).toSeq + assert(statements === Seq("body1", "lbl")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala index ac190eb48d1f9..3551608a1ee84 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/scripting/SqlScriptingInterpreterSuite.scala @@ -1383,4 +1383,156 @@ class SqlScriptingInterpreterSuite extends QueryTest with SharedSparkSession { ) verifySqlScriptResult(sqlScriptText, expected) } + + test("loop statement with leave") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | SELECT x; + | IF x > 2 + | THEN + | LEAVE lbl; + | END IF; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select x + Seq.empty[Row], // set x = 2 + Seq(Row(2)), // select x + Seq.empty[Row], // set x = 3 + Seq(Row(3)), // select x + Seq(Row(3)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("nested loop statement with leave") { + val commands = + """ + |BEGIN + | DECLARE x = 0; + | DECLARE y = 0; + | lbl1: LOOP + | SET VAR y = 0; + | lbl2: LOOP + | SELECT x, y; + | SET VAR y = y + 1; + | IF y >= 2 THEN + | LEAVE lbl2; + | END IF; + | END LOOP; + | SET VAR x = x + 1; + | IF x >= 2 THEN + | LEAVE lbl1; + | END IF; + | END LOOP; + |END + |""".stripMargin + + val expected = Seq( + Seq.empty[Row], // declare x + Seq.empty[Row], // declare y + Seq.empty[Row], // set y to 0 + Seq(Row(0, 0)), // select x, y + Seq.empty[Row], // increase y + Seq(Row(0, 1)), // select x, y + Seq.empty[Row], // increase y + Seq.empty[Row], // increase x + Seq.empty[Row], // set y to 0 + Seq(Row(1, 0)), // select x, y + Seq.empty[Row], // increase y + Seq(Row(1, 1)), // select x, y + Seq.empty[Row], // increase y + Seq.empty[Row], // increase x + Seq.empty[Row], // drop y + Seq.empty[Row] // drop x + ) + verifySqlScriptResult(commands, expected) + } + + test("iterate loop statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | IF x > 1 THEN + | LEAVE lbl; + | END IF; + | ITERATE lbl; + | SET x = x + 2; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq.empty[Row], // set x = 2 + Seq(Row(2)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("leave outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | lbl: LOOP + | lbl2: LOOP + | SELECT 1; + | LEAVE lbl; + | END LOOP; + | END LOOP; + |END""".stripMargin + val expected = Seq( + Seq(Row(1)) // select 1 + ) + verifySqlScriptResult(sqlScriptText, expected) + } + + test("iterate outer loop from nested loop statement") { + val sqlScriptText = + """ + |BEGIN + | DECLARE x INT; + | SET x = 0; + | lbl: LOOP + | SET x = x + 1; + | IF x > 2 THEN + | LEAVE lbl; + | END IF; + | lbl2: LOOP + | SELECT 1; + | ITERATE lbl; + | SET x = 10; + | END LOOP; + | END LOOP; + | SELECT x; + |END""".stripMargin + val expected = Seq( + Seq.empty[Row], // declare + Seq.empty[Row], // set x = 0 + Seq.empty[Row], // set x = 1 + Seq(Row(1)), // select 1 + Seq.empty[Row], // set x = 2 + Seq(Row(1)), // select 1 + Seq.empty[Row], // set x = 3 + Seq(Row(3)), // select x + Seq.empty[Row] // drop + ) + verifySqlScriptResult(sqlScriptText, expected) + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala index dcf3bd8c71731..60c49619552e7 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala @@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends SharedThriftServer { val sessionHandle = client.openSession(user, "") val infoValue = client.getInfo(sessionHandle, GetInfoType.CLI_ODBC_KEYWORDS) // scalastyle:off line.size.limit - assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") + assert(infoValue.getStringValue == "ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DATE,CURRENT_TIME,CURRENT_TIMESTAMP,CURRENT_USER,DATA,DATABASE,DATABASES,DATE,DATEADD,DATEDIFF,DATE_ADD,DATE_DIFF,DAY,DAYOFYEAR,DAYS,DBPROPERTIES,DEC,DECIMAL,DECLARE,DEFAULT,DEFINED,DEFINER,DELETE,DELIMITED,DESC,DESCRIBE,DETERMINISTIC,DFS,DIRECTORIES,DIRECTORY,DISTINCT,DISTRIBUTE,DIV,DO,DOUBLE,DROP,ELSE,END,ESCAPE,ESCAPED,EVOLUTION,EXCEPT,EXCHANGE,EXCLUDE,EXECUTE,EXISTS,EXPLAIN,EXPORT,EXTENDED,EXTERNAL,EXTRACT,FALSE,FETCH,FIELDS,FILEFORMAT,FILTER,FIRST,FLOAT,FOLLOWING,FOR,FOREIGN,FORMAT,FORMATTED,FROM,FULL,FUNCTION,FUNCTIONS,GENERATED,GLOBAL,GRANT,GROUP,GROUPING,HAVING,HOUR,HOURS,IDENTIFIER,IDENTITY,IF,IGNORE,ILIKE,IMMEDIATE,IMPORT,IN,INCLUDE,INCREMENT,INDEX,INDEXES,INNER,INPATH,INPUT,INPUTFORMAT,INSERT,INT,INTEGER,INTERSECT,INTERVAL,INTO,INVOKER,IS,ITEMS,ITERATE,JOIN,KEYS,LANGUAGE,LAST,LATERAL,LAZY,LEADING,LEAVE,LEFT,LIKE,LIMIT,LINES,LIST,LOAD,LOCAL,LOCATION,LOCK,LOCKS,LOGICAL,LONG,LOOP,MACRO,MAP,MATCHED,MERGE,MICROSECOND,MICROSECONDS,MILLISECOND,MILLISECONDS,MINUS,MINUTE,MINUTES,MODIFIES,MONTH,MONTHS,MSCK,NAME,NAMESPACE,NAMESPACES,NANOSECOND,NANOSECONDS,NATURAL,NO,NONE,NOT,NULL,NULLS,NUMERIC,OF,OFFSET,ON,ONLY,OPTION,OPTIONS,OR,ORDER,OUT,OUTER,OUTPUTFORMAT,OVER,OVERLAPS,OVERLAY,OVERWRITE,PARTITION,PARTITIONED,PARTITIONS,PERCENT,PIVOT,PLACING,POSITION,PRECEDING,PRIMARY,PRINCIPALS,PROPERTIES,PURGE,QUARTER,QUERY,RANGE,READS,REAL,RECORDREADER,RECORDWRITER,RECOVER,REDUCE,REFERENCES,REFRESH,RENAME,REPAIR,REPEAT,REPEATABLE,REPLACE,RESET,RESPECT,RESTRICT,RETURN,RETURNS,REVOKE,RIGHT,ROLE,ROLES,ROLLBACK,ROLLUP,ROW,ROWS,SCHEMA,SCHEMAS,SECOND,SECONDS,SECURITY,SELECT,SEMI,SEPARATED,SERDE,SERDEPROPERTIES,SESSION_USER,SET,SETS,SHORT,SHOW,SINGLE,SKEWED,SMALLINT,SOME,SORT,SORTED,SOURCE,SPECIFIC,SQL,START,STATISTICS,STORED,STRATIFY,STRING,STRUCT,SUBSTR,SUBSTRING,SYNC,SYSTEM_TIME,SYSTEM_VERSION,TABLE,TABLES,TABLESAMPLE,TARGET,TBLPROPERTIES,TERMINATED,THEN,TIME,TIMEDIFF,TIMESTAMP,TIMESTAMPADD,TIMESTAMPDIFF,TIMESTAMP_LTZ,TIMESTAMP_NTZ,TINYINT,TO,TOUCH,TRAILING,TRANSACTION,TRANSACTIONS,TRANSFORM,TRIM,TRUE,TRUNCATE,TRY_CAST,TYPE,UNARCHIVE,UNBOUNDED,UNCACHE,UNION,UNIQUE,UNKNOWN,UNLOCK,UNPIVOT,UNSET,UNTIL,UPDATE,USE,USER,USING,VALUES,VAR,VARCHAR,VARIABLE,VARIANT,VERSION,VIEW,VIEWS,VOID,WEEK,WEEKS,WHEN,WHERE,WHILE,WINDOW,WITH,WITHIN,X,YEAR,YEARS,ZONE") // scalastyle:on line.size.limit } } From c1ecab4d77b487c595028f8d33e3d1b41634b44e Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 3 Oct 2024 15:28:41 +0800 Subject: [PATCH 139/250] [SPARK-49541][BUILD] Upgrade log4j2 to 2.24.1 ### What changes were proposed in this pull request? The pr aims to upgrade log4j2 from `2.22.1` to `2.24.1`. ### Why are the changes needed? - The full release notes: https://github.com/apache/logging-log4j2/releases/tag/rel%2F2.24.1 https://github.com/apache/logging-log4j2/releases/tag/rel%2F2.24.0 https://github.com/apache/logging-log4j2/releases/tag/rel%2F2.23.1 https://github.com/apache/logging-log4j2/releases/tag/rel%2F2.23.0 - The new version contains some bug fixes: Fix regression in JdkMapAdapterStringMap performance. (https://github.com/apache/logging-log4j2/issues/2238) Fix NPE in PatternProcessor for a UNIX_MILLIS pattern (https://github.com/apache/logging-log4j2/issues/2346) Fix that parameterized message formatting throws an exception when there are insufficient number of parameters. It previously simply didn't replace the '{}' sequence. The behavior changed in 2.21.0 and should be restored for backward compatibility. (https://github.com/apache/logging-log4j2/issues/2380) Fix putAll() in the default thread context map implementation (https://github.com/apache/logging-log4j2/pull/2942) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48029 from panbingkun/SPARK-49541. Authored-by: panbingkun Signed-off-by: yangjie01 --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 10 +++++----- pom.xml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index f6ce3d25ebc8a..5cba1c687e5aa 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -189,11 +189,11 @@ leveldbjni-all/1.8//leveldbjni-all-1.8.jar libfb303/0.9.3//libfb303-0.9.3.jar libthrift/0.16.0//libthrift-0.16.0.jar listenablefuture/9999.0-empty-to-avoid-conflict-with-guava//listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar -log4j-1.2-api/2.22.1//log4j-1.2-api-2.22.1.jar -log4j-api/2.22.1//log4j-api-2.22.1.jar -log4j-core/2.22.1//log4j-core-2.22.1.jar -log4j-layout-template-json/2.22.1//log4j-layout-template-json-2.22.1.jar -log4j-slf4j2-impl/2.22.1//log4j-slf4j2-impl-2.22.1.jar +log4j-1.2-api/2.24.1//log4j-1.2-api-2.24.1.jar +log4j-api/2.24.1//log4j-api-2.24.1.jar +log4j-core/2.24.1//log4j-core-2.24.1.jar +log4j-layout-template-json/2.24.1//log4j-layout-template-json-2.24.1.jar +log4j-slf4j2-impl/2.24.1//log4j-slf4j2-impl-2.24.1.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar metrics-core/4.2.27//metrics-core-4.2.27.jar diff --git a/pom.xml b/pom.xml index 6a77da703dbd2..31046e5a85f82 100644 --- a/pom.xml +++ b/pom.xml @@ -120,7 +120,7 @@ spark 9.7 2.0.16 - 2.22.1 + 2.24.1 3.4.0 From 036db74b97f8f5b447bc1d689e9f8081af47604c Mon Sep 17 00:00:00 2001 From: Xi Lyu Date: Thu, 3 Oct 2024 10:24:06 +0200 Subject: [PATCH 140/250] [SPARK-47341][SQL] Fix inaccurate documentation of RuntimeConfig#get MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? The existing documentation of `RuntimeConfig.get()` is misleading: * `get(key: String)` method will not throw any exception if the key is not set as long as the config entry has a default value, instead, it will just return the `defaultValue` of the `ConfigEntry`. An `NoSuchElementException` will only be thrown if there is no default value for the config entry. * `get(key: String, default: String)` method will ignore the `defaultValue` of its `ConfigEntry`, and return the given param `default` if unset.  * `getOption(key: String)` method will return the `defaultValue` of its `ConfigEntry` if the config is not set, instead of `None`.   An example to demonstrate the behaviour: image The first line makes sure the config is not set. ``` scala> spark.conf.unset("spark.sql.session.timeZone")  ``` The following code returns `Etc/UTC`, which doesn't throw any exception. ``` scala> spark.conf.get("spark.sql.session.timeZone") res1: String = "Etc/UTC" ``` The following code returns `Some("Etc/UTC")`, instead of `None`. ``` scala> spark.conf.getOption("spark.sql.session.timeZone") res2: Option[String] = Some(value = "Etc/UTC") ``` The following code returns `Europe/Berlin`, ignoring the default value. However, the documentation only says it returns the value, without mentioning ignoring the default value of the entry when the config is not explicitly set. ``` scala> spark.conf.get("spark.sql.session.timeZone", "Europe/Berlin")  res3: String = "Europe/Berlin" ``` In this PR, the documentation is fixed and a new test case is added. ### Why are the changes needed? The incorrect documentation is likely to mislead users to weird behaviours if they rely on the documentation. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New test case in `RuntimeConfigSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48264 from xi-db/SPARK-49798-fix-runtimeconfig-doc. Lead-authored-by: Xi Lyu Co-authored-by: Xi Lyu <159039256+xi-db@users.noreply.github.com> Signed-off-by: Max Gekk --- .../sql/internal/ConnectRuntimeConfig.scala | 2 +- .../org/apache/spark/sql/RuntimeConfig.scala | 13 +++++++---- .../sql/internal/RuntimeConfigImpl.scala | 2 +- .../apache/spark/sql/RuntimeConfigSuite.scala | 22 ++++++++++++++++++- 4 files changed, 32 insertions(+), 7 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala index 7578e2424fb42..be1a13cb2fed2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/ConnectRuntimeConfig.scala @@ -38,7 +38,7 @@ class ConnectRuntimeConfig private[sql] (client: SparkConnectClient) } /** @inheritdoc */ - @throws[NoSuchElementException]("if the key is not set") + @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String = getOption(key).getOrElse { throw new NoSuchElementException(key) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala index 23a2774ebc3a5..9e6e0e97f0302 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala @@ -54,17 +54,21 @@ abstract class RuntimeConfig { } /** - * Returns the value of Spark runtime configuration property for the given key. + * Returns the value of Spark runtime configuration property for the given key. If the key is + * not set yet, return its default value if possible, otherwise `NoSuchElementException` will be + * thrown. * * @throws java.util.NoSuchElementException * if the key is not set and does not have a default value * @since 2.0.0 */ - @throws[NoSuchElementException]("if the key is not set") + @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String /** - * Returns the value of Spark runtime configuration property for the given key. + * Returns the value of Spark runtime configuration property for the given key. If the key is + * not set yet, return the user given `default`. This is useful when its default value defined + * by Apache Spark is not the desired one. * * @since 2.0.0 */ @@ -78,7 +82,8 @@ abstract class RuntimeConfig { def getAll: Map[String, String] /** - * Returns the value of Spark runtime configuration property for the given key. + * Returns the value of Spark runtime configuration property for the given key. If the key is + * not set yet, return its default value if possible, otherwise `None` will be returned. * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala index f25ca387db299..0ef879387727a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -42,7 +42,7 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends } /** @inheritdoc */ - @throws[NoSuchElementException]("if the key is not set") + @throws[NoSuchElementException]("if the key is not set and there is no default value") def get(key: String): String = { sqlConf.getConfString(key) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala index 352197f96acb6..009fe55664a2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeConfigSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite import org.apache.spark.internal.config -import org.apache.spark.sql.internal.RuntimeConfigImpl +import org.apache.spark.sql.internal.{RuntimeConfigImpl, SQLConf} import org.apache.spark.sql.internal.SQLConf.CHECKPOINT_LOCATION import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -81,4 +81,24 @@ class RuntimeConfigSuite extends SparkFunSuite { } assert(ex.getMessage.contains("Spark config")) } + + test("set and get a config with defaultValue") { + val conf = newConf() + val key = SQLConf.SESSION_LOCAL_TIMEZONE.key + // By default, the value when getting an unset config entry is its defaultValue. + assert(conf.get(key) == SQLConf.SESSION_LOCAL_TIMEZONE.defaultValue.get) + assert(conf.getOption(key).contains(SQLConf.SESSION_LOCAL_TIMEZONE.defaultValue.get)) + // Get the unset config entry with a different default value, which should return the given + // default parameter. + assert(conf.get(key, "Europe/Amsterdam") == "Europe/Amsterdam") + + // Set a config entry. + conf.set(key, "Europe/Berlin") + // Get the set config entry. + assert(conf.get(key) == "Europe/Berlin") + // Unset the config entry. + conf.unset(key) + // Get the unset config entry, which should return its defaultValue again. + assert(conf.get(key) == SQLConf.SESSION_LOCAL_TIMEZONE.defaultValue.get) + } } From 38f067dfcef9ae53330fdd73ea89ebba614c965b Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Thu, 3 Oct 2024 11:36:26 +0200 Subject: [PATCH 141/250] [SPARK-49358][SQL] Mode expression for map types with collated strings ### What changes were proposed in this pull request? Introduce support for collated string in map types for `mode` expression. ### Why are the changes needed? Complete complex type handling for `mode` expression. ### Does this PR introduce _any_ user-facing change? Yes, `mode` expression can now handle map types with collated strings. ### How was this patch tested? New tests in `CollationSQLExpressionsSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48326 from uros-db/mode-map. Authored-by: Uros Bojanic Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 5 -- .../catalyst/expressions/aggregate/Mode.scala | 38 ++++++------- .../sql/CollationSQLExpressionsSuite.scala | 54 ++----------------- 3 files changed, 19 insertions(+), 78 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 3786643125a9f..12666fe4ff629 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1016,11 +1016,6 @@ "The input of can't be type data." ] }, - "UNSUPPORTED_MODE_DATA_TYPE" : { - "message" : [ - "The does not support the data type, because there is a \"MAP\" type with keys and/or values that have collated sub-fields." - ] - }, "UNSUPPORTED_UDF_INPUT_TYPE" : { "message" : [ "UDFs do not support '' as an input data type." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 8998348f0571b..97add0b8e45bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup} import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, UnsafeRowUtils} -import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, MapData, UnsafeRowUtils} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -52,24 +52,6 @@ case class Mode( override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - override def checkInputDataTypes(): TypeCheckResult = { - // TODO: SPARK-49358: Mode expression for map type with collated fields - if (UnsafeRowUtils.isBinaryStable(child.dataType) || - !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] && - !UnsafeRowUtils.isBinaryStable(f))) { - /* - * The Mode class uses collation awareness logic to handle string data. - * All complex types except MapType with collated fields are supported. - */ - super.checkInputDataTypes() - } else { - TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE", - messageParameters = - Map("child" -> toSQLType(child.dataType), - "mode" -> toSQLId(prettyName))) - } - } - override def prettyName: String = "mode" override def update( @@ -115,6 +97,7 @@ case class Mode( case st: StructType => processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields)) case at: ArrayType => processArrayTypeWithBuffer(at, data.asInstanceOf[ArrayData]) + case mt: MapType => processMapTypeWithBuffer(mt, data.asInstanceOf[MapData]) case st: StringType => CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], st.collationId) case _ => @@ -140,6 +123,16 @@ case class Mode( collationAwareTransform(data.get(i, a.elementType), a.elementType)) } + private def processMapTypeWithBuffer(mt: MapType, data: MapData): Map[Any, Any] = { + val transformedKeys = (0 until data.numElements()).map { i => + collationAwareTransform(data.keyArray().get(i, mt.keyType), mt.keyType) + } + val transformedValues = (0 until data.numElements()).map { i => + collationAwareTransform(data.valueArray().get(i, mt.valueType), mt.valueType) + } + transformedKeys.zip(transformedValues).toMap + } + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { if (buffer.isEmpty) { return null @@ -157,8 +150,7 @@ case class Mode( * * The new map is then used in the rest of the Mode evaluation logic. * - * It is expected to work for all simple and complex types with - * collated fields, except for MapType (temporarily). + * It is expected to work for all simple and complex types with collated fields. */ val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 9930709cd8bf3..851160d2fbb94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.Locale import scala.collection.immutable.Seq -import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable} +import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Mode @@ -1924,7 +1923,7 @@ class CollationSQLExpressionsSuite } } - test("Support mode expression with collated in recursively nested struct with map with keys") { + test("Support mode for string expression with collated complex type - nested map") { case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) Seq( ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), @@ -1932,22 +1931,6 @@ class CollationSQLExpressionsSuite ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}") ).foreach { t1 => - def checkThisError(t: ModeTestCase, query: String): Any = { - val c = s"STRUCT>" - val c1 = s"\"${c}\"" - checkError( - exception = intercept[SparkThrowable] { - sql(query).collect() - }, - condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE", - parameters = Map( - ("sqlExpr", "\"mode(i)\""), - ("child", c1), - ("mode", "`mode`")), - queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray - ) - } - def getValuesToAdd(t: ModeTestCase): String = { val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => @@ -1964,41 +1947,12 @@ class CollationSQLExpressionsSuite sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}") val query = "SELECT lower(cast(mode(i).m1 as string))" + s" FROM ${tableName}" - if (t1.collationId == "utf8_binary") { - checkAnswer(sql(query), Row(t1.result)) - } else { - checkThisError(t1, query) - } + val queryResult = sql(query) + checkAnswer(queryResult, Row(t1.result)) } } } - test("UDT with collation - Mode (throw exception)") { - case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) - Seq( - ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), - ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), - ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") - ).foreach { t1 => - checkError( - exception = intercept[SparkIllegalArgumentException] { - Mode( - child = Literal.create(null, - MapType(StringType(t1.collationId), IntegerType)) - ).collationAwareTransform( - data = Map.empty[String, Any], - dataType = MapType(StringType(t1.collationId), IntegerType) - ) - }, - condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS", - parameters = Map( - "expression" -> "\"mode(NULL)\"", - "functionName" -> "\"MODE\"", - "dataType" -> s"\"MAP\"") - ) - } - } - test("SPARK-48430: Map value extraction with collations") { for { collateKey <- Seq(true, false) From 68fd17d09e83dc4a53c5cf1bf42c346e481098ca Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 Oct 2024 07:30:05 -0700 Subject: [PATCH 142/250] [SPARK-49861][INFRA] Add `Python 3.13` to Infra docker image ### What changes were proposed in this pull request? This PR aims to add `Python 3.13` to Infra docker image. Note that SPARK-49862 `Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13` will track the Python package readiness. ### Why are the changes needed? `Python 3.13` release is scheduled on next Monday, 2024-10-07. - https://peps.python.org/pep-0719/ This is a part of `Python 3.13 Support` preparation for Apache Spark 4.0.0 on 2025 February. ### Does this PR introduce _any_ user-facing change? No, this is an infra only change. ### How was this patch tested? Pass the CIs and verify manually like the following. ``` $ docker run -it --rm ghcr.io/dongjoon-hyun/apache-spark-ci-image:master-11152626644 python3.13 --version Python 3.13.0rc3 ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48330 from dongjoon-hyun/SPARK-49861. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- dev/infra/Dockerfile | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index a40e43bb659f8..4125002cab0bb 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" # Overwrite this label to avoid exposing the underlying Ubuntu OS version label LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20240903 +ENV FULL_REFRESH_DATE 20241002 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -142,6 +142,16 @@ RUN python3.12 -m pip install $BASIC_PIP_PKGS $CONNECT_PIP_PKGS lxml && \ python3.12 -m pip install torcheval && \ python3.12 -m pip cache purge +# Install Python 3.13 at the last stage to avoid breaking the existing Python installations +RUN apt-get update && apt-get install -y \ + python3.13 \ + && rm -rf /var/lib/apt/lists/* +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13 +# TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13 +RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this +RUN python3.13 -m pip install lxml && \ + python3.13 -m pip cache purge + # Remove unused installation packages to free up disk space RUN apt-get remove --purge -y 'gfortran-11' 'humanity-icon-theme' 'nodejs-doc' || true RUN apt-get autoremove --purge -y From 901bb33c23b3f7b0417e1f2955fd1a8f6a2564de Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 3 Oct 2024 09:38:24 -0700 Subject: [PATCH 143/250] [SPARK-49860][PYTHON][INFRA] Add `Python 3.13` Daily Python Github Action job ### What changes were proposed in this pull request? This PR aims to add `Python 3.13` Daily Python GitHub Action job in advance. ### Why are the changes needed? `Python 3.13` is scheduled on next Monday, 2024-10-07. - https://peps.python.org/pep-0719/ This will help us track the readiness of `Python 3.13` and eventually achieve `Python 3.13 Support` in Apache Spark 4.0.0 before 2025 February. ### Does this PR introduce _any_ user-facing change? No, this is an infra PR. ### How was this patch tested? Manual review. This should be tested after merging. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48328 from dongjoon-hyun/SPARK-49860. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .github/workflows/build_python_3.13.yml | 45 +++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 .github/workflows/build_python_3.13.yml diff --git a/.github/workflows/build_python_3.13.yml b/.github/workflows/build_python_3.13.yml new file mode 100644 index 0000000000000..6f67cf383584f --- /dev/null +++ b/.github/workflows/build_python_3.13.yml @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +name: "Build / Python-only (master, Python 3.13)" + +on: + schedule: + - cron: '0 20 * * *' + +jobs: + run-build: + permissions: + packages: write + name: Run + uses: ./.github/workflows/build_and_test.yml + if: github.repository == 'apache/spark' + with: + java: 17 + branch: master + hadoop: hadoop3 + envs: >- + { + "PYTHON_TO_TEST": "python3.13" + } + jobs: >- + { + "pyspark": "true", + "pyspark-pandas": "true" + } From b9a327479422a56dc76173be1d86a1a9698039a6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 4 Oct 2024 08:21:31 +0900 Subject: [PATCH 144/250] [SPARK-46647][INFRA] Add `unittest-xml-reporting` into Python 3.12 image ### What changes were proposed in this pull request? This PR aims to use `unittest-xml-reporting` into Python 3.12. ### Why are the changes needed? It seems that we can install it with the latest Python 3.12.6. ``` $ python3 --version Python 3.12.6 ``` ``` $ pip3 install unittest-xml-reporting Looking in indexes: https://pypi.python.org/simple, https://pypi.apple.com/simple Collecting unittest-xml-reporting Using cached https://pypi.apple.com/packages/packages/39/88/f6e9b87428584a3c62cac768185c438ca6d561367a5d267b293259d76075/unittest_xml_reporting-3.2.0-py2.py3-none-any.whl (20 kB) Requirement already satisfied: lxml in /Users/dongjoon/.pyenv/versions/3.12.6/lib/python3.12/site-packages (from unittest-xml-reporting) (5.3.0) Installing collected packages: unittest-xml-reporting Successfully installed unittest-xml-reporting-3.2.0 ``` ``` $ python/run-tests.py --python-executables python3 --modules pyspark-core Running PySpark tests. Output is in /Users/dongjoon/APACHE/spark-merge/python/unit-tests.log Will test against the following Python executables: ['python3'] Will test the following Python modules: ['pyspark-core'] python3 python_implementation is CPython python3 version is: Python 3.12.6 Starting test(python3): pyspark.tests.test_appsubmit (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/7073e04c-e2ce-4d4b-b0a1-6f2aff30b612/python3__pyspark.tests.test_appsubmit__7odeq3cw.log) Starting test(python3): pyspark.tests.test_conf (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/461b38b4-f5fb-4165-b80f-b14756eb29bb/python3__pyspark.tests.test_conf__vexpdrlq.log) Starting test(python3): pyspark.tests.test_context (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/9135fed0-31f6-4b10-87fa-c0742038ba53/python3__pyspark.tests.test_context__gk6e0wnr.log) Starting test(python3): pyspark.tests.test_broadcast (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/d62e53d4-0252-4156-b6ca-b3c76244a6db/python3__pyspark.tests.test_broadcast__xcxa09t_.log) Finished test(python3): pyspark.tests.test_conf (10s) Starting test(python3): pyspark.tests.test_daemon (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/84088bb9-dcb6-4d5d-ba0e-e25479ffd9d2/python3__pyspark.tests.test_daemon__z4icjtza.log) Finished test(python3): pyspark.tests.test_daemon (5s) Starting test(python3): pyspark.tests.test_install_spark (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/02babefb-606a-4e34-b8cc-c56584924156/python3__pyspark.tests.test_install_spark__2jd6ytn9.log) Finished test(python3): pyspark.tests.test_broadcast (28s) Starting test(python3): pyspark.tests.test_join (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/9eab61a5-033b-42b1-a9c3-22908a6401a0/python3__pyspark.tests.test_join__wdgw71cw.log) Finished test(python3): pyspark.tests.test_appsubmit (33s) Starting test(python3): pyspark.tests.test_memory_profiler (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/c7fc91c2-7e6a-46c9-8924-eab74f60e6c9/python3__pyspark.tests.test_memory_profiler__ke2pdufb.log) Finished test(python3): pyspark.tests.test_install_spark (17s) Starting test(python3): pyspark.tests.test_profiler (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/148a2214-6c50-4baa-aba5-2e9511a0071d/python3__pyspark.tests.test_profiler__d7w8fl7g.log) Finished test(python3): pyspark.tests.test_join (5s) Starting test(python3): pyspark.tests.test_rdd (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/341893e3-ae96-421c-8314-c942fa65ae92/python3__pyspark.tests.test_rdd__1oqi_hpo.log) Finished test(python3): pyspark.tests.test_context (35s) Starting test(python3): pyspark.tests.test_rddbarrier (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/9ff06ab8-751a-4905-8ad5-f3cade17a3d9/python3__pyspark.tests.test_rddbarrier__t7nne4vg.log) Finished test(python3): pyspark.tests.test_rddbarrier (4s) Starting test(python3): pyspark.tests.test_rddsampler (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/05f2c30f-c85d-439c-909d-02f0bfdf9344/python3__pyspark.tests.test_rddsampler__2kjo59wq.log) Finished test(python3): pyspark.tests.test_profiler (7s) ... 1 tests were skipped Starting test(python3): pyspark.tests.test_readwrite (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/e59723dc-d7fa-448d-8336-d10a75a5a3d6/python3__pyspark.tests.test_readwrite__6xyu1xjz.log) Finished test(python3): pyspark.tests.test_readwrite (2s) Starting test(python3): pyspark.tests.test_serializers (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/c75b82d3-b63e-4b33-95e5-22a07d200b72/python3__pyspark.tests.test_serializers__cspd98j2.log) Finished test(python3): pyspark.tests.test_rddsampler (5s) Starting test(python3): pyspark.tests.test_shuffle (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/d9348ebb-576c-4c90-8b29-8ff4d90ae89c/python3__pyspark.tests.test_shuffle__0ip92g_e.log) Finished test(python3): pyspark.tests.test_serializers (6s) Starting test(python3): pyspark.tests.test_stage_sched (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/f274b3a1-a579-4c90-8b7c-178eabd13e12/python3__pyspark.tests.test_stage_sched__tlkac1aj.log) Finished test(python3): pyspark.tests.test_shuffle (7s) Starting test(python3): pyspark.tests.test_statcounter (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/4bc8ceef-e1f2-4566-9653-7acf98440a8a/python3__pyspark.tests.test_statcounter__4fan6599.log) Finished test(python3): pyspark.tests.test_memory_profiler (20s) Starting test(python3): pyspark.tests.test_taskcontext (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/9ae70357-86d7-4115-b6fa-e50a52ce7588/python3__pyspark.tests.test_taskcontext__ve_ei2cb.log) Finished test(python3): pyspark.tests.test_statcounter (4s) Starting test(python3): pyspark.tests.test_util (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/8dcfade3-e1d4-4f03-ba69-10376ad229f6/python3__pyspark.tests.test_util__m87nldnp.log) Finished test(python3): pyspark.tests.test_util (6s) Starting test(python3): pyspark.tests.test_worker (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/0f82d2bd-a221-471b-8634-86c94a415962/python3__pyspark.tests.test_worker__q_4tfdgg.log) Finished test(python3): pyspark.tests.test_stage_sched (23s) Starting test(python3): pyspark.accumulators (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/4c9c01df-938d-4433-923e-c384d58dc5d7/python3__pyspark.accumulators__c6f1z1w2.log) Finished test(python3): pyspark.accumulators (3s) Starting test(python3): pyspark.conf (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/36f0848b-a831-47b1-a80f-fb215fcf54d7/python3__pyspark.conf__1_nmps1k.log) Finished test(python3): pyspark.conf (1s) Starting test(python3): pyspark.core.broadcast (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/38222bab-a54d-4141-8c9b-ec066cd9df81/python3__pyspark.core.broadcast__qz98z0c0.log) Finished test(python3): pyspark.tests.test_worker (19s) ... 3 tests were skipped Starting test(python3): pyspark.core.context (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/2cfe1fd3-cdd2-412f-98de-a1759ef99557/python3__pyspark.core.context__jrfebewg.log) Finished test(python3): pyspark.core.broadcast (4s) Starting test(python3): pyspark.core.files (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/7e99dae7-399c-48ce-af54-3bd897d5ba13/python3__pyspark.core.files__yn6t4ze1.log) Finished test(python3): pyspark.core.files (3s) Starting test(python3): pyspark.core.rdd (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/f879d50d-4e9e-420a-becb-a1c3738cf90d/python3__pyspark.core.rdd__e9bhllsb.log) Finished test(python3): pyspark.core.rdd (19s) Starting test(python3): pyspark.profiler (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/8228af67-5834-4493-bd71-c5fcbce877ec/python3__pyspark.profiler__yzut311n.log) Finished test(python3): pyspark.core.context (26s) Starting test(python3): pyspark.serializers (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/9248860c-aa6d-4608-95d2-96b50501696b/python3__pyspark.serializers__yjb7i9f4.log) Finished test(python3): pyspark.profiler (3s) Starting test(python3): pyspark.shuffle (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/927866d3-940a-471c-a736-d68ec9fe082d/python3__pyspark.shuffle__vdg9m5rm.log) Finished test(python3): pyspark.shuffle (0s) Starting test(python3): pyspark.taskcontext (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/1e4ee756-c80e-4705-be6b-2cb72d73ce03/python3__pyspark.taskcontext__3ueryqa7.log) Finished test(python3): pyspark.tests.test_rdd (80s) Starting test(python3): pyspark.util (temp output: /Users/dongjoon/APACHE/spark-merge/python/target/5b0024fd-c92b-4ac0-8507-6d2981f34a86/python3__pyspark.util___86sm79x.log) Finished test(python3): pyspark.serializers (5s) Finished test(python3): pyspark.util (2s) Finished test(python3): pyspark.taskcontext (35s) Finished test(python3): pyspark.tests.test_taskcontext (140s) Tests passed in 193 seconds Skipped tests in pyspark.tests.test_profiler with python3: test_no_memory_profile_installed (pyspark.tests.test_profiler.ProfilerTests2.test_no_memory_profile_installed) ... skip (0.000s) Skipped tests in pyspark.tests.test_worker with python3: test_memory_limit (pyspark.tests.test_worker.WorkerMemoryTest.test_memory_limit) ... skip (0.001s) test_python_segfault (pyspark.tests.test_worker.WorkerSegfaultNonDaemonTest.test_python_segfault) ... skip (0.000s) test_python_segfault (pyspark.tests.test_worker.WorkerSegfaultTest.test_python_segfault) ... skip (0.000s) ``` ### Does this PR introduce _any_ user-facing change? No. This is an infra change. ### How was this patch tested? Pass the CIs. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48329 from dongjoon-hyun/SPARK-46647. Authored-by: Dongjoon Hyun Signed-off-by: Hyukjin Kwon --- dev/infra/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 4125002cab0bb..24f858a234ddf 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -135,9 +135,8 @@ RUN apt-get update && apt-get install -y \ python3.12 \ && rm -rf /var/lib/apt/lists/* RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 -# TODO(SPARK-46647) Add unittest-xml-reporting into Python 3.12 image when it supports Python 3.12 RUN python3.12 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this -RUN python3.12 -m pip install $BASIC_PIP_PKGS $CONNECT_PIP_PKGS lxml && \ +RUN python3.12 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \ python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \ python3.12 -m pip install torcheval && \ python3.12 -m pip cache purge From 29312bc0c971c7729dc6dd73e641cd9fd369ed0f Mon Sep 17 00:00:00 2001 From: Wei Liu Date: Fri, 4 Oct 2024 08:55:33 +0900 Subject: [PATCH 145/250] [SPARK-49824][SS][CONNECT] Improve logging in SparkConnectStreamingQueryCache ### What changes were proposed in this pull request? The query key in the cache is but in the log only the id is logged. A query could have the same id but different runid, we need to log both id and runid to make it less confusing. ### Why are the changes needed? Debug improvement ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual check log ### Was this patch authored or co-authored using generative AI tooling? No Closes #48293 from WweiL/listener-cache-improvement. Authored-by: Wei Liu Signed-off-by: Hyukjin Kwon --- .../SparkConnectStreamingQueryCache.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 8241672d5107b..48492bac62344 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -29,7 +29,7 @@ import scala.concurrent.duration.{Duration, DurationInt, FiniteDuration} import scala.util.control.NonFatal import org.apache.spark.internal.{Logging, MDC} -import org.apache.spark.internal.LogKeys.{DURATION, NEW_VALUE, OLD_VALUE, QUERY_CACHE_VALUE, QUERY_ID, SESSION_ID} +import org.apache.spark.internal.LogKeys.{DURATION, NEW_VALUE, OLD_VALUE, QUERY_CACHE_VALUE, QUERY_ID, QUERY_RUN_ID, SESSION_ID} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.streaming.StreamingQuery import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} @@ -158,7 +158,8 @@ private[connect] class SparkConnectStreamingQueryCache( if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { logInfo( - log"Stopping the query with id ${MDC(QUERY_ID, k.queryId)} " + + log"Stopping the query with id: ${MDC(QUERY_ID, k.queryId)} " + + log"runId: ${MDC(QUERY_RUN_ID, k.runId)} " + log"since the session has timed out") try { if (blocking) { @@ -170,7 +171,8 @@ private[connect] class SparkConnectStreamingQueryCache( } catch { case NonFatal(ex) => logWarning( - log"Failed to stop the query ${MDC(QUERY_ID, k.queryId)}. " + + log"Failed to stop the with id: ${MDC(QUERY_ID, k.queryId)} " + + log"runId: ${MDC(QUERY_RUN_ID, k.runId)} " + log"Error is ignored.", ex) } @@ -238,17 +240,20 @@ private[connect] class SparkConnectStreamingQueryCache( for ((k, v) <- queryCache) { val id = k.queryId + val runId = k.runId v.expiresAtMs match { case Some(ts) if nowMs >= ts => // Expired. Drop references. logInfo( - log"Removing references for ${MDC(QUERY_ID, id)} in " + + log"Removing references for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") queryCache.remove(k) case Some(_) => // Inactive query waiting for expiration. Do nothing. logInfo( - log"Waiting for the expiration for ${MDC(QUERY_ID, id)} in " + + log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + log"session ${MDC(SESSION_ID, v.sessionId)}") case None => // Active query, check if it is stopped. Enable timeout if it is stopped. @@ -256,7 +261,8 @@ private[connect] class SparkConnectStreamingQueryCache( if (!isActive) { logInfo( - log"Marking query ${MDC(QUERY_ID, id)} in " + + log"Marking query id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) From 96666d49feb3d4a6b5a76d05e48e898c0962653c Mon Sep 17 00:00:00 2001 From: Nemanja Boric Date: Fri, 4 Oct 2024 09:18:28 +0900 Subject: [PATCH 146/250] [SPARK-49859][CONNECT] Replace multiprocessing.ThreadPool with ThreadPoolExecutor ### What changes were proposed in this pull request? We change the reattachexecutor module to use concurrent.futures.ThreadPoolExecutor instead of multiprocessing.ThreadPool. ### Why are the changes needed? multiprocessing.ThreadPool doesn't work in environments where /dev/shm is not writtable by the python process. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48327 from nemanja-boric-databricks/sparkly. Authored-by: Nemanja Boric Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/connect/client/reattach.py | 18 ++++++++---------- .../sql/tests/connect/client/test_client.py | 4 ++-- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index ea6788e858317..e0c7cc448933d 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -24,7 +24,7 @@ import uuid from collections.abc import Generator from typing import Optional, Any, Iterator, Iterable, Tuple, Callable, cast, Type, ClassVar -from multiprocessing.pool import ThreadPool +from concurrent.futures import ThreadPoolExecutor import os import grpc @@ -58,19 +58,18 @@ class ExecutePlanResponseReattachableIterator(Generator): # Lock to manage the pool _lock: ClassVar[RLock] = RLock() - _release_thread_pool_instance: Optional[ThreadPool] = None + _release_thread_pool_instance: Optional[ThreadPoolExecutor] = None @classmethod # type: ignore[misc] @property - def _release_thread_pool(cls) -> ThreadPool: + def _release_thread_pool(cls) -> ThreadPoolExecutor: # Perform a first check outside the critical path. if cls._release_thread_pool_instance is not None: return cls._release_thread_pool_instance with cls._lock: if cls._release_thread_pool_instance is None: - cls._release_thread_pool_instance = ThreadPool( - os.cpu_count() if os.cpu_count() else 8 - ) + max_workers = os.cpu_count() or 8 + cls._release_thread_pool_instance = ThreadPoolExecutor(max_workers=max_workers) return cls._release_thread_pool_instance @classmethod @@ -81,8 +80,7 @@ def shutdown(cls: Type["ExecutePlanResponseReattachableIterator"]) -> None: """ with cls._lock: if cls._release_thread_pool_instance is not None: - cls._release_thread_pool.close() # type: ignore[attr-defined] - cls._release_thread_pool.join() # type: ignore[attr-defined] + cls._release_thread_pool.shutdown() # type: ignore[attr-defined] cls._release_thread_pool_instance = None def __init__( @@ -212,7 +210,7 @@ def target() -> None: with self._lock: if self._release_thread_pool_instance is not None: - self._release_thread_pool.apply_async(target) + self._release_thread_pool.submit(target) def _release_all(self) -> None: """ @@ -237,7 +235,7 @@ def target() -> None: with self._lock: if self._release_thread_pool_instance is not None: - self._release_thread_pool.apply_async(target) + self._release_thread_pool.submit(target) self._result_complete = True def _call_iter(self, iter_fun: Callable) -> Any: diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 5deb73a0ccf90..741d6b9c1104e 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -408,8 +408,8 @@ def not_found(): def checks(): self.assertEqual(1, stub.execute_calls) self.assertEqual(1, stub.attach_calls) - self.assertEqual(0, stub.release_calls) - self.assertEqual(0, stub.release_until_calls) + self.assertEqual(1, stub.release_calls) + self.assertEqual(1, stub.release_until_calls) eventually(timeout=1, catch_assertions=True)(checks)() From a38505c9455f42f986f9d315eac30516d907135b Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 3 Oct 2024 17:39:56 -0700 Subject: [PATCH 147/250] [SPARK-49869][INFRA] Add NumPy in Python 3.13 image ### What changes were proposed in this pull request? This PR add NumPy in Python 3.13 image. Note that this is different from SPARK-49862 because NumPy is a required dependency for ML in Python. ### Why are the changes needed? To fix Python 3.13 (https://github.com/apache/spark/actions/runs/11168860784/job/31048343334). ### Does this PR introduce _any_ user-facing change? No, dev-only. ### How was this patch tested? Will monitor the build. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48342 from HyukjinKwon/SPARK-49869. Authored-by: Hyukjin Kwon Signed-off-by: Dongjoon Hyun --- dev/infra/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 24f858a234ddf..1619b009e9364 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -148,7 +148,7 @@ RUN apt-get update && apt-get install -y \ RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13 # TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13 RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this -RUN python3.13 -m pip install lxml && \ +RUN python3.13 -m pip install lxml numpy>=2.1 && \ python3.13 -m pip cache purge # Remove unused installation packages to free up disk space From 98da5e1ab4ac94d3e870007bd06dd84ed27e9080 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 3 Oct 2024 17:41:20 -0700 Subject: [PATCH 148/250] [MINOR][PYTHON][TESTS] Skip test_artifact if grpc isn't installed ### What changes were proposed in this pull request? This PR proposes to fix test_artifact to skip test if grpc isn't installed ### Why are the changes needed? To fix Python 3.13 build (https://github.com/apache/spark/actions/runs/11168860784/job/31048343683) ``` Running PySpark tests. Output is in /__w/spark/spark/python/unit-tests.log Will test against the following Python executables: ['python3.13'] Will test the following Python modules: ['pyspark-connect'] python3.13 python_implementation is CPython python3.13 version is: Python 3.13.0rc3 Starting test(python3.13): pyspark.sql.tests.connect.client.test_artifact (temp output: /__w/spark/spark/python/target/176e36e6-3f4f-4ab4-9861-fa131061be94/python3.13__pyspark.sql.tests.connect.client.test_artifact__u6l46s5d.log) Traceback (most recent call last): File "", line 198, in _run_module_as_main File "", line 88, in _run_code File "/__w/spark/spark/python/pyspark/sql/tests/connect/client/test_artifact.py", line 24, in from pyspark.errors.exceptions.connect import SparkConnectGrpcException File "/__w/spark/spark/python/pyspark/errors/exceptions/connect.py", line 17, in import pyspark.sql.connect.proto as pb2 File "/__w/spark/spark/python/pyspark/sql/connect/proto/__init__.py", line 18, in from pyspark.sql.connect.proto.base_pb2_grpc import * File "/__w/spark/spark/python/pyspark/sql/connect/proto/base_pb2_grpc.py", line 19, in import grpc ModuleNotFoundError: No module named 'grpc' ``` ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48341 from HyukjinKwon/minor-python313. Authored-by: Hyukjin Kwon Signed-off-by: Dongjoon Hyun --- python/pyspark/sql/tests/connect/client/test_artifact.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py b/python/pyspark/sql/tests/connect/client/test_artifact.py index c886ff36d776f..0857591c306ae 100644 --- a/python/pyspark/sql/tests/connect/client/test_artifact.py +++ b/python/pyspark/sql/tests/connect/client/test_artifact.py @@ -21,7 +21,6 @@ import os from pyspark.util import is_remote_only -from pyspark.errors.exceptions.connect import SparkConnectGrpcException from pyspark.sql import SparkSession from pyspark.testing.connectutils import ReusedConnectTestCase, should_test_connect from pyspark.testing.utils import SPARK_HOME @@ -30,6 +29,7 @@ if should_test_connect: from pyspark.sql.connect.client.artifact import ArtifactManager from pyspark.sql.connect.client import DefaultChannelBuilder + from pyspark.errors.exceptions.connect import SparkConnectGrpcException class ArtifactTestsMixin: From de9b9c85e287c39d6bd380b518f72d3a3690012d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 3 Oct 2024 17:45:30 -0700 Subject: [PATCH 149/250] [SPARK-49870][PYTHON] Add Python 3.13 support in Spark Classic ### What changes were proposed in this pull request? This PR adds the note for Python 3.13 support in `setup.py` for Spark Classic. ### Why are the changes needed? Basic tests pass with Python 3.13 for Spark Classic (https://github.com/apache/spark/actions/runs/11168860784) ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Via CI. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48343 from HyukjinKwon/SPARK-49870. Authored-by: Hyukjin Kwon Signed-off-by: Dongjoon Hyun --- python/packaging/classic/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py index 17cca326d0241..76fd638c4aa03 100755 --- a/python/packaging/classic/setup.py +++ b/python/packaging/classic/setup.py @@ -374,6 +374,7 @@ def run(self): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Typing :: Typed", From bd3e2eb02170c09efd0722410b17992cec311107 Mon Sep 17 00:00:00 2001 From: Cheng Pan Date: Fri, 4 Oct 2024 11:35:16 +0900 Subject: [PATCH 150/250] [SPARK-49751][CONNECT] Fix deserialization of SparkListenerConnectServiceStarted event ### What changes were proposed in this pull request? `SparkListenerConnectServiceStarted` is introduced in SPARK-47952, while the referenced field `SparkConf` is not serialized properly, then causes the SHS deserialization failure. According to the discussion, we can remove the `sparkConf` field. ### Why are the changes needed? Fix the event log deserialization and recover the SHS UI rendering. ### Does this PR introduce _any_ user-facing change? Fix an unreleased feature, recover the SHS UI rendering from event logs produced by the connect server. ### How was this patch tested? Start a connect server with event log enabled, and then open the UI in SHS. 4.0.0-preview2 image This PR. image ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48200 from pan3793/SPARK-49751. Authored-by: Cheng Pan Signed-off-by: Hyukjin Kwon --- .../sql/connect/service/SparkConnectService.scala | 15 ++++----------- .../SparkConnectServiceInternalServerSuite.scala | 7 ------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index d0c06e96047f2..0468a55e23027 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -31,7 +31,7 @@ import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils -import org.apache.spark.{SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService @@ -420,7 +420,7 @@ object SparkConnectService extends Logging { started = true stopped = false - postSparkConnectServiceStarted(sc) + postSparkConnectServiceStarted() } def stop(timeout: Option[Long] = None, unit: Option[TimeUnit] = None): Unit = synchronized { @@ -456,13 +456,9 @@ object SparkConnectService extends Logging { * Post the event that the Spark Connect service has started. This is expected to be called only * once after the service is ready. */ - private def postSparkConnectServiceStarted(sc: SparkContext): Unit = { + private def postSparkConnectServiceStarted(): Unit = { postServiceEvent(isa => - SparkListenerConnectServiceStarted( - hostAddress, - isa.getPort, - sc.conf, - System.currentTimeMillis())) + SparkListenerConnectServiceStarted(hostAddress, isa.getPort, System.currentTimeMillis())) } /** @@ -521,15 +517,12 @@ object SparkConnectService extends Logging { * The host address of the started Spark Connect service. * @param bindingPort: * The binding port of the started Spark Connect service. - * @param sparkConf: - * The SparkConf of the active SparkContext that associated with the service. * @param eventTime: * The time in ms when the event was generated. */ case class SparkListenerConnectServiceStarted( hostAddress: String, bindingPort: Int, - sparkConf: SparkConf, eventTime: Long) extends SparkListenerEvent diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceInternalServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceInternalServerSuite.scala index 3240b33f3f090..173dc5c672bc3 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceInternalServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceInternalServerSuite.scala @@ -189,13 +189,6 @@ class SparkConnectServiceInternalServerSuite extends SparkFunSuite with LocalSpa // In the meanwhile, no any end event should be posted assert(listenerInstance.serviceEndEvents.size() == 0) - // The listener is able to get the SparkConf from the event - val event = listenerInstance.serviceStartedEvents.get(0) - assert(event.sparkConf != null) - val sparkConf = event.sparkConf - assert(sparkConf.contains("spark.driver.host")) - assert(sparkConf.contains("spark.app.id")) - // Try to start an already started SparkConnectService SparkConnectService.start(sc) // The listener should still receive only one started event From 3dfedf69cc3966083ac4fce9245dedab53c560a1 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Fri, 4 Oct 2024 11:21:19 +0800 Subject: [PATCH 151/250] [SPARK-49842][BUILD] Add `byte-buddy` dependency for modules that depend on `mockito-core` to ensure `sbt test` uses the correct `byte-buddy` with Java 21 ### What changes were proposed in this pull request? This pr add `byte-buddy` dependency for modules that depend on `mockito-core` to ensure `sbt test` uses the correct `byte-buddy` ### Why are the changes needed? To ensure that `sbt test` can use the correct version of `byte-buddy`, I have only observed this issue when using `sbt test` with Java 21. This issue has not been observed when using `sbt test` with Java 17(Perhaps it just didn't print the Warning Message) or `maven test` with Java 17/21(Maven can confirm that there is no such issue). Java 21 sbt daily test: - https://github.com/apache/spark/actions/runs/11099131939/job/30881044324 ``` WARNING: A Java agent has been loaded dynamically (/home/runner/.cache/coursier/v1/https/maven-central.storage-download.googleapis.com/maven2/net/bytebuddy/byte-buddy-agent/1.14.15/byte-buddy-agent-1.14.15.jar) ``` - https://github.com/apache/spark/actions/runs/11099131939/job/30881045891 ``` WARNING: A Java agent has been loaded dynamically (/home/runner/.cache/coursier/v1/https/maven-central.storage-download.googleapis.com/maven2/net/bytebuddy/byte-buddy-agent/1.14.15/byte-buddy-agent-1.14.15.jar) ``` - https://github.com/apache/spark/actions/runs/11099131939/job/30881047740 ``` WARNING: A Java agent has been loaded dynamically (/home/runner/.cache/coursier/v1/https/maven-central.storage-download.googleapis.com/maven2/net/bytebuddy/byte-buddy-agent/1.14.15/byte-buddy-agent-1.14.15.jar) ``` We can see that `byte-buddy-agent-1.14.15.jar` is being used in the sbt tests with Java 21, but the version defined in the `dependencyManagement` of the parent `pom.xml` is `1.14.17`: https://github.com/apache/spark/blob/3093ad68d2a3c6bab9c1605381d27e700766be22/pom.xml#L1221-L1231 Inconsistent dependency versions pose potential risks, such as the possibility that maven tests may pass while sbt tests may fail. Therefore, we should correct it.Meanwhile, for consistency, I have added a test dependency on `byte-buddy` for all modules that depend on `mockito-core`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Manual inspection shows that `sbt test` with Java 21 is using version 1.14.17 of `byte-buddy-agent` - https://github.com/LuciferYang/spark/actions/runs/11072600528/job/30922654197 - https://github.com/LuciferYang/spark/actions/runs/11072600528/job/30922654969 - https://github.com/LuciferYang/spark/actions/runs/11072600528/job/30922655852 - https://github.com/LuciferYang/spark/actions/runs/11072600528/job/30922656840 ``` WARNING: A Java agent has been loaded dynamically (/home/runner/.cache/coursier/v1/https/maven-central.storage-download.googleapis.com/maven2/net/bytebuddy/byte-buddy-agent/1.14.17/byte-buddy-agent-1.14.17.jar) ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48281 from LuciferYang/byte-buddy. Lead-authored-by: yangjie01 Co-authored-by: YangJie Signed-off-by: yangjie01 --- common/network-common/pom.xml | 10 ++++++++++ common/network-shuffle/pom.xml | 10 ++++++++++ common/unsafe/pom.xml | 10 ++++++++++ connector/kafka-0-10-sql/pom.xml | 10 ++++++++++ connector/kafka-0-10-token-provider/pom.xml | 10 ++++++++++ connector/kafka-0-10/pom.xml | 10 ++++++++++ connector/kinesis-asl/pom.xml | 10 ++++++++++ core/pom.xml | 10 ++++++++++ launcher/pom.xml | 10 ++++++++++ mllib-local/pom.xml | 10 ++++++++++ mllib/pom.xml | 10 ++++++++++ repl/pom.xml | 10 ++++++++++ resource-managers/yarn/pom.xml | 10 ++++++++++ sql/catalyst/pom.xml | 10 ++++++++++ sql/connect/server/pom.xml | 10 ++++++++++ sql/core/pom.xml | 10 ++++++++++ sql/hive-thriftserver/pom.xml | 10 ++++++++++ streaming/pom.xml | 10 ++++++++++ 18 files changed, 180 insertions(+) diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index cdb5bd72158a1..cbe4836b58da5 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -194,6 +194,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 0f7036ef746cc..49e6e08476151 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -113,6 +113,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + commons-io commons-io diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index a5ef9847859a7..cf15301273303 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -104,6 +104,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index 35f58134f1a85..66e1c24e821c8 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -148,6 +148,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml index 2b2707b9da320..3cbfc34e7d806 100644 --- a/connector/kafka-0-10-token-provider/pom.xml +++ b/connector/kafka-0-10-token-provider/pom.xml @@ -64,6 +64,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.hadoop hadoop-client-runtime diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml index 1b26839a371ce..a42410e6ce885 100644 --- a/connector/kafka-0-10/pom.xml +++ b/connector/kafka-0-10/pom.xml @@ -119,6 +119,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index 9a7f40443bbc9..7eba26ffdff74 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -81,6 +81,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/core/pom.xml b/core/pom.xml index 19f58940ed942..7805a3f37ae53 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -393,6 +393,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.scalacheck scalacheck_${scala.binary.version} diff --git a/launcher/pom.xml b/launcher/pom.xml index c47244ff887a6..e8feb7b684555 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -57,6 +57,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.slf4j jul-to-slf4j diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index ecfe45f046f2b..3b35a481adb1b 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -52,6 +52,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/mllib/pom.xml b/mllib/pom.xml index 4f983a325a0c1..c342519ca428a 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -117,6 +117,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/repl/pom.xml b/repl/pom.xml index 831379467a29e..1a1c6b92c9222 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -82,6 +82,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 694d81b3c25e3..770a550030f51 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -156,6 +156,16 @@ mockito-core test + + net.bytebuddy + byte-buddy + test + + + net.bytebuddy + byte-buddy-agent + test + 10.16.1.1 - 1.14.2 + 1.14.3 2.0.2 shaded-protobuf 11.0.23 From 92e79e36c79ca56a637b78faa43f8e55263c2191 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 8 Oct 2024 09:42:21 -0600 Subject: [PATCH 183/250] [SPARK-49901][BUILD] Upgrade dropwizard metrics to 4.2.28 ### What changes were proposed in this pull request? This pr aims to upgrade `dropwizard metrics` from `4.2.27` to `4.2.28`. ### Why are the changes needed? v4.2.127 VS v.4.2.28 https://github.com/dropwizard/metrics/compare/v4.2.27...v4.2.28 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48377 from panbingkun/SPARK-49901. Authored-by: panbingkun Signed-off-by: Dongjoon Hyun --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 10 +++++----- pom.xml | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index f1f44b6fcd5c8..8ba2f6c414cb9 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -196,11 +196,11 @@ log4j-layout-template-json/2.24.1//log4j-layout-template-json-2.24.1.jar log4j-slf4j2-impl/2.24.1//log4j-slf4j2-impl-2.24.1.jar logging-interceptor/3.12.12//logging-interceptor-3.12.12.jar lz4-java/1.8.0//lz4-java-1.8.0.jar -metrics-core/4.2.27//metrics-core-4.2.27.jar -metrics-graphite/4.2.27//metrics-graphite-4.2.27.jar -metrics-jmx/4.2.27//metrics-jmx-4.2.27.jar -metrics-json/4.2.27//metrics-json-4.2.27.jar -metrics-jvm/4.2.27//metrics-jvm-4.2.27.jar +metrics-core/4.2.28//metrics-core-4.2.28.jar +metrics-graphite/4.2.28//metrics-graphite-4.2.28.jar +metrics-jmx/4.2.28//metrics-jmx-4.2.28.jar +metrics-json/4.2.28//metrics-json-4.2.28.jar +metrics-jvm/4.2.28//metrics-jvm-4.2.28.jar minlog/1.3.0//minlog-1.3.0.jar netty-all/4.1.110.Final//netty-all-4.1.110.Final.jar netty-buffer/4.1.110.Final//netty-buffer-4.1.110.Final.jar diff --git a/pom.xml b/pom.xml index 02fb53ea7e2eb..585b329e642f5 100644 --- a/pom.xml +++ b/pom.xml @@ -151,7 +151,7 @@ If you change codahale.metrics.version, you also need to change the link to metrics.dropwizard.io in docs/monitoring.md. --> - 4.2.27 + 4.2.28 1.12.0 1.12.0 From 345a2be4132040c03920670a1f90624f9ad0f88d Mon Sep 17 00:00:00 2001 From: Vladimir Golubev Date: Wed, 9 Oct 2024 10:12:40 +0900 Subject: [PATCH 184/250] [SPARK-49906][SQL] Introduce and use CONFLICTING_DIRECTORY_STRUCTURES error for PartitioningUtils ### What changes were proposed in this pull request? Improve Spark user experience by introducing a new error type: `CONFLICTING_DIRECTORY_STRUCTURES` for `PartitioningUtils`. ### Why are the changes needed? `PartitioningUtils.parsePartitions(...)` uses an assertion to if partitions are misconfigured. We should use a proper error type for this case. ### Does this PR introduce _any_ user-facing change? Yes, the error will be nicer. ### How was this patch tested? Updated the existing tests. ### Was this patch authored or co-authored using generative AI tooling? `copilot.vim`. Closes #48383 from vladimirg-db/vladimirg-db/introduce-conflicting-directory-structures-error. Authored-by: Vladimir Golubev Signed-off-by: Hyukjin Kwon --- .../src/main/resources/error/error-conditions.json | 10 ++++++++++ .../spark/sql/errors/QueryExecutionErrors.scala | 10 ++++++++++ .../sql/execution/datasources/PartitioningUtils.scala | 11 +++-------- .../sql/execution/datasources/FileIndexSuite.scala | 4 ++-- .../parquet/ParquetPartitionDiscoverySuite.scala | 8 ++++---- 5 files changed, 29 insertions(+), 14 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index e3bffea0b62eb..8100f0580b21f 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -669,6 +669,16 @@ ], "sqlState" : "40000" }, + "CONFLICTING_DIRECTORY_STRUCTURES" : { + "message" : [ + "Conflicting directory structures detected.", + "Suspicious paths:", + "", + "If provided paths are partition directories, please set \"basePath\" in the options of the data source to specify the root directory of the table.", + "If there are multiple root directories, please load them separately and then union them." + ], + "sqlState" : "KD009" + }, "CONFLICTING_PARTITION_COLUMN_NAMES" : { "message" : [ "Conflicting partition column names detected:", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index bc6c7681ea1a5..301880f1bfc61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2845,6 +2845,16 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE ) } + def conflictingDirectoryStructuresError( + discoveredBasePaths: Seq[String]): SparkRuntimeException = { + new SparkRuntimeException( + errorClass = "CONFLICTING_DIRECTORY_STRUCTURES", + messageParameters = Map( + "discoveredBasePaths" -> discoveredBasePaths.distinct.mkString("\n\t", "\n\t", "\n") + ) + ) + } + def conflictingPartitionColumnNamesError( distinctPartColLists: Seq[String], suspiciousPaths: Seq[Path]): SparkRuntimeException = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index ffdca65151052..402b70065d8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -173,14 +173,9 @@ object PartitioningUtils extends SQLConfHelper { // "hdfs://host:9000/path" // TODO: Selective case sensitivity. val discoveredBasePaths = optDiscoveredBasePaths.flatten.map(_.toString.toLowerCase()) - assert( - ignoreInvalidPartitionPaths || discoveredBasePaths.distinct.size == 1, - "Conflicting directory structures detected. Suspicious paths:\b" + - discoveredBasePaths.distinct.mkString("\n\t", "\n\t", "\n\n") + - "If provided paths are partition directories, please set " + - "\"basePath\" in the options of the data source to specify the " + - "root directory of the table. If there are multiple root directories, " + - "please load them separately and then union them.") + if (!ignoreInvalidPartitionPaths && discoveredBasePaths.distinct.size != 1) { + throw QueryExecutionErrors.conflictingDirectoryStructuresError(discoveredBasePaths) + } val resolvedPartitionValues = resolvePartitions(pathsWithPartitionValues, caseSensitive) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index 31b7380889158..e9f78f9f598e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -566,7 +566,7 @@ class FileIndexSuite extends SharedSparkSession { new File(directoryPath, "part_col=1").renameTo(new File(directoryPath, "undefined")) // By default, we expect the invalid path assertion to trigger. - val ex = intercept[AssertionError] { + val ex = intercept[SparkRuntimeException] { spark.read .format("parquet") .load(directoryPath.getCanonicalPath) @@ -585,7 +585,7 @@ class FileIndexSuite extends SharedSparkSession { // Data source option override takes precedence. withSQLConf(SQLConf.IGNORE_INVALID_PARTITION_PATHS.key -> "true") { - val ex = intercept[AssertionError] { + val ex = intercept[SparkRuntimeException] { spark.read .format("parquet") .option(FileIndexOptions.IGNORE_INVALID_PARTITION_PATHS, "false") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 52d67a0954325..eb4618834504c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -111,7 +111,7 @@ abstract class ParquetPartitionDiscoverySuite "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/a=10.5/b=hello") - var exception = intercept[AssertionError] { + var exception = intercept[SparkRuntimeException] { parsePartitions( paths.map(new Path(_)), true, Set.empty[Path], None, true, true, timeZoneId, false) } @@ -173,7 +173,7 @@ abstract class ParquetPartitionDiscoverySuite "hdfs://host:9000/path/a=10/b=20", "hdfs://host:9000/path/path1") - exception = intercept[AssertionError] { + exception = intercept[SparkRuntimeException] { parsePartitions( paths.map(new Path(_)), true, @@ -197,7 +197,7 @@ abstract class ParquetPartitionDiscoverySuite "hdfs://host:9000/tmp/tables/nonPartitionedTable1", "hdfs://host:9000/tmp/tables/nonPartitionedTable2") - exception = intercept[AssertionError] { + exception = intercept[SparkRuntimeException] { parsePartitions( paths.map(new Path(_)), true, @@ -878,7 +878,7 @@ abstract class ParquetPartitionDiscoverySuite checkAnswer(twoPartitionsDF, df.filter("b != 3")) - intercept[AssertionError] { + intercept[SparkRuntimeException] { spark .read .parquet( From 80ae411b178b2006d75b6850c4ba1dc2e0c057dd Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Tue, 8 Oct 2024 21:24:20 -0400 Subject: [PATCH 185/250] [SPARK-49569][CONNECT][SQL] Add shims to support SparkContext and RDD ### What changes were proposed in this pull request? This PR does two things: - It adds shims for SparkContext and RDD. These are in a separate module. This module is a compile time dependency for sql/api, and a regular dependency for connector/connect/client/jvm. We remove this dependency in catalyst and connect-server because those should use the actual implementation. - It adds RDD (and the one SparkContext) based method to the shared Scala API. For connect these methods throw an unsupported operation exception. ### Why are the changes needed? We are creating a shared Scala interface for Classic and Connect. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. I will add a couple on the connect side. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48065 from hvanhovell/SPARK-49569. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- connector/connect/client/jvm/pom.xml | 5 + .../apache/spark/sql/DataFrameReader.scala | 10 ++ .../scala/org/apache/spark/sql/Dataset.scala | 8 ++ .../org/apache/spark/sql/SparkSession.scala | 33 +++++- .../scala/org/apache/spark/sql/package.scala | 3 + pom.xml | 1 + project/SparkBuild.scala | 58 +++++++--- sql/api/pom.xml | 6 ++ .../spark/sql/api/DataFrameReader.scala | 34 ++++++ .../org/apache/spark/sql/api/Dataset.scala | 32 ++++++ .../apache/spark/sql/api/SQLImplicits.scala | 9 ++ .../apache/spark/sql/api/SparkSession.scala | 102 ++++++++++++++++++ sql/catalyst/pom.xml | 6 ++ sql/connect/server/pom.xml | 4 + sql/connect/shims/README.md | 1 + sql/connect/shims/pom.xml | 41 +++++++ .../org/apache/spark/api/java/shims.scala | 19 ++++ .../scala/org/apache/spark/rdd/shims.scala | 19 ++++ .../main/scala/org/apache/spark/shims.scala | 19 ++++ .../apache/spark/sql/DataFrameReader.scala | 23 +--- .../scala/org/apache/spark/sql/Dataset.scala | 20 +--- .../org/apache/spark/sql/SQLImplicits.scala | 12 --- .../org/apache/spark/sql/SparkSession.scala | 74 ++----------- 23 files changed, 406 insertions(+), 133 deletions(-) create mode 100644 sql/connect/shims/README.md create mode 100644 sql/connect/shims/pom.xml create mode 100644 sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala create mode 100644 sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala create mode 100644 sql/connect/shims/src/main/scala/org/apache/spark/shims.scala diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index e117a0a7451cb..2fdb2d4bafe01 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -45,6 +45,11 @@ spark-sql-api_${scala.binary.version} ${project.version} + + org.apache.spark + spark-connect-shims_${scala.binary.version} + ${project.version} + org.apache.spark spark-sketch_${scala.binary.version} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 60bacd4e18ede..051d382c49773 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -22,7 +22,9 @@ import java.util.Properties import scala.jdk.CollectionConverters._ import org.apache.spark.annotation.Stable +import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto.Parse.ParseFormat +import org.apache.spark.rdd.RDD import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.StructType @@ -140,6 +142,14 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.Data def json(jsonDataset: Dataset[String]): DataFrame = parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON) + /** @inheritdoc */ + override def json(jsonRDD: JavaRDD[String]): Dataset[Row] = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def json(jsonRDD: RDD[String]): Dataset[Row] = + throwRddNotSupportedException() + /** @inheritdoc */ override def csv(path: String): DataFrame = super.csv(path) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index a368da2aaee60..966b5acebca23 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,8 +26,10 @@ import scala.util.control.NonFatal import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ import org.apache.spark.connect.proto +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ @@ -1463,4 +1465,10 @@ class Dataset[T] private[sql] ( func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + + /** @inheritdoc */ + override def rdd: RDD[T] = throwRddNotSupportedException() + + /** @inheritdoc */ + override def toJavaRDD: JavaRDD[T] = throwRddNotSupportedException() } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 222b5ea79508e..ad10a22f833bf 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -29,10 +29,13 @@ import com.google.common.cache.{CacheBuilder, CacheLoader} import io.grpc.ClientInterceptor import org.apache.arrow.memory.RootAllocator +import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalog.Catalog import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder} @@ -84,10 +87,14 @@ class SparkSession private[sql] ( private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]() - private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = { + private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = { client.hijackServerSideSessionIdForTesting(suffix) } + /** @inheritdoc */ + override def sparkContext: SparkContext = + throw new UnsupportedOperationException("sparkContext is not supported in Spark Connect.") + /** @inheritdoc */ val conf: RuntimeConfig = new ConnectRuntimeConfig(client) @@ -144,6 +151,30 @@ class SparkSession private[sql] ( createDataset(data.asScala.toSeq) } + /** @inheritdoc */ + override def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = + throwRddNotSupportedException() + + /** @inheritdoc */ + override def createDataset[T: Encoder](data: RDD[T]): Dataset[T] = + throwRddNotSupportedException() + /** @inheritdoc */ @Experimental def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder => diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala index ada94b76fcbcd..5c61b9371f37c 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala @@ -19,4 +19,7 @@ package org.apache.spark package object sql { type DataFrame = Dataset[Row] + + private[sql] def throwRddNotSupportedException(): Nothing = + throw new UnsupportedOperationException("RDDs are not supported in Spark Connect.") } diff --git a/pom.xml b/pom.xml index 585b329e642f5..bfaee1be609c0 100644 --- a/pom.xml +++ b/pom.xml @@ -84,6 +84,7 @@ common/utils common/variant common/tags + sql/connect/shims core graphx mllib diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 6137984a53c0a..5882fcbf336b0 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,24 +45,24 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "protobuf" - ).map(ProjectRef(buildLocation, _)) + val sqlProjects@Seq(sqlApi, catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = + Seq("sql-api", "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", + "sql-kafka-0-10", "avro", "protobuf").map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) - val connectCommon = ProjectRef(buildLocation, "connect-common") - val connect = ProjectRef(buildLocation, "connect") - val connectClient = ProjectRef(buildLocation, "connect-client-jvm") + val connectProjects@Seq(connectCommon, connect, connectClient, connectShims) = + Seq("connect-common", "connect", "connect-client-jvm", "connect-shims") + .map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, - commonUtils, sqlApi, variant, _* + commonUtils, variant, _* ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", - "tags", "sketch", "kvstore", "common-utils", "sql-api", "variant" - ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connectCommon, connect, connectClient) + "tags", "sketch", "kvstore", "common-utils", "variant" + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ connectProjects val optionallyEnabledProjects@Seq(kubernetes, yarn, sparkGangliaLgpl, streamingKinesisAsl, @@ -360,7 +360,7 @@ object SparkBuild extends PomBuild { /* Enable shared settings on all projects */ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools)) .foreach(enable(sharedSettings ++ DependencyOverrides.settings ++ - ExcludedDependencies.settings ++ Checkstyle.settings)) + ExcludedDependencies.settings ++ Checkstyle.settings ++ ExcludeShims.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) @@ -369,7 +369,7 @@ object SparkBuild extends PomBuild { Seq( spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn, unsafe, tags, tokenProviderKafka010, sqlKafka010, connectCommon, connect, connectClient, - variant + variant, connectShims ).contains(x) } @@ -1087,6 +1087,36 @@ object ExcludedDependencies { ) } +/** + * This excludes the spark-connect-shims module from a module when it is not part of the connect + * client dependencies. + */ +object ExcludeShims { + val shimmedProjects = Set("spark-sql-api", "spark-connect-common", "spark-connect-client-jvm") + val classPathFilter = TaskKey[Classpath => Classpath]("filter for classpath") + lazy val settings = Seq( + classPathFilter := { + if (!shimmedProjects(moduleName.value)) { + cp => cp.filterNot(_.data.name.contains("spark-connect-shims")) + } else { + identity _ + } + }, + Compile / internalDependencyClasspath := + classPathFilter.value((Compile / internalDependencyClasspath).value), + Compile / internalDependencyAsJars := + classPathFilter.value((Compile / internalDependencyAsJars).value), + Runtime / internalDependencyClasspath := + classPathFilter.value((Runtime / internalDependencyClasspath).value), + Runtime / internalDependencyAsJars := + classPathFilter.value((Runtime / internalDependencyAsJars).value), + Test / internalDependencyClasspath := + classPathFilter.value((Test / internalDependencyClasspath).value), + Test / internalDependencyAsJars := + classPathFilter.value((Test / internalDependencyAsJars).value), + ) +} + /** * Project to pull previous artifacts of Spark for generating Mima excludes. */ @@ -1456,10 +1486,12 @@ object SparkUnidoc extends SharedUnidocSettings { lazy val settings = baseSettings ++ Seq( (ScalaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, protobuf), + yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, + connectShims, protobuf), (JavaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, protobuf), + yarn, tags, streamingKafka010, sqlKafka010, connectCommon, connect, connectClient, + connectShims, protobuf), ) } diff --git a/sql/api/pom.xml b/sql/api/pom.xml index 54cdc96fc40a2..9c50a2567c5fe 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -58,6 +58,12 @@ spark-sketch_${scala.binary.version} ${project.version} + + org.apache.spark + spark-connect-shims_${scala.binary.version} + ${project.version} + compile + org.json4s json4s-jackson_${scala.binary.version} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala index c101c52fd0662..8c88387714228 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/DataFrameReader.scala @@ -21,6 +21,8 @@ import scala.jdk.CollectionConverters._ import _root_.java.util import org.apache.spark.annotation.Stable +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils} @@ -309,6 +311,38 @@ abstract class DataFrameReader { */ def json(jsonDataset: DS[String]): Dataset[Row] + /** + * Loads a `JavaRDD[String]` storing JSON objects (JSON Lines + * text format or newline-delimited JSON) and returns the result as a `DataFrame`. + * + * Unless the schema is specified using `schema` function, this function goes through the input + * once to determine the input schema. + * + * @note + * this method is not supported in Spark Connect. + * @param jsonRDD + * input RDD with one JSON object per record + * @since 1.4.0 + */ + @deprecated("Use json(Dataset[String]) instead.", "2.2.0") + def json(jsonRDD: JavaRDD[String]): DS[Row] + + /** + * Loads an `RDD[String]` storing JSON objects (JSON Lines text + * format or newline-delimited JSON) and returns the result as a `DataFrame`. + * + * Unless the schema is specified using `schema` function, this function goes through the input + * once to determine the input schema. + * + * @note + * this method is not supported in Spark Connect. + * @param jsonRDD + * input RDD with one JSON object per record + * @since 1.4.0 + */ + @deprecated("Use json(Dataset[String]) instead.", "2.2.0") + def json(jsonRDD: RDD[String]): DS[Row] + /** * Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other * overloaded `csv()` method for more details. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 06a6148a7c188..c277b4cab85c1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -22,7 +22,9 @@ import scala.reflect.runtime.universe.TypeTag import _root_.java.util import org.apache.spark.annotation.{DeveloperApi, Stable} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row, TypedColumn} import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors} import org.apache.spark.sql.types.{Metadata, StructType} @@ -3098,4 +3100,34 @@ abstract class Dataset[T] extends Serializable { * @since 1.6.0 */ def write: DataFrameWriter[T] + + /** + * Represents the content of the Dataset as an `RDD` of `T`. + * + * @note + * this method is not supported in Spark Connect. + * @group basic + * @since 1.6.0 + */ + def rdd: RDD[T] + + /** + * Returns the content of the Dataset as a `JavaRDD` of `T`s. + * + * @note + * this method is not supported in Spark Connect. + * @group basic + * @since 1.6.0 + */ + def toJavaRDD: JavaRDD[T] + + /** + * Returns the content of the Dataset as a `JavaRDD` of `T`s. + * + * @note + * this method is not supported in Spark Connect. + * @group basic + * @since 1.6.0 + */ + def javaRDD: JavaRDD[T] = toJavaRDD } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala index f6b44e168390a..5e022570d3ca7 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag import _root_.java +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{ColumnName, DatasetHolder, Encoder, Encoders} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder @@ -278,6 +279,14 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable { new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]]) } + /** + * Creates a [[Dataset]] from an RDD. + * + * @since 1.6.0 + */ + implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T, DS] = + new DatasetHolder(session.createDataset(rdd).asInstanceOf[DS[T]]) + /** * An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]]. * @since 1.3.0 diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index 4dfeb87a11d92..b2e61df5937bd 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -25,7 +25,10 @@ import _root_.java.lang import _root_.java.net.URI import _root_.java.util +import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Encoder, Row, RuntimeConfig} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SparkClassUtils @@ -52,6 +55,14 @@ import org.apache.spark.util.SparkClassUtils */ abstract class SparkSession extends Serializable with Closeable { + /** + * The Spark context associated with this Spark session. + * + * @note + * this method is not supported in Spark Connect. + */ + def sparkContext: SparkContext + /** * The version of Spark on which this application is running. * @@ -155,6 +166,85 @@ abstract class SparkSession extends Serializable with Closeable { */ def createDataFrame(data: util.List[_], beanClass: Class[_]): Dataset[Row] + /** + * Creates a `DataFrame` from an RDD of Product (e.g. case classes, tuples). + * + * @note + * this method is not supported in Spark Connect. + * @since 2.0.0 + */ + def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): Dataset[Row] + + /** + * :: DeveloperApi :: Creates a `DataFrame` from an `RDD` containing + * [[org.apache.spark.sql.Row]]s using the given schema. It is important to make sure that the + * structure of every [[org.apache.spark.sql.Row]] of the provided RDD matches the provided + * schema. Otherwise, there will be runtime exception. Example: + * {{{ + * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ + * val sparkSession = new org.apache.spark.sql.SparkSession(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val dataFrame = sparkSession.createDataFrame(people, schema) + * dataFrame.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * dataFrame.createOrReplaceTempView("people") + * sparkSession.sql("select name from people").collect.foreach(println) + * }}} + * + * @note + * this method is not supported in Spark Connect. + * @since 2.0.0 + */ + @DeveloperApi + def createDataFrame(rowRDD: RDD[Row], schema: StructType): Dataset[Row] + + /** + * :: DeveloperApi :: Creates a `DataFrame` from a `JavaRDD` containing + * [[org.apache.spark.sql.Row]]s using the given schema. It is important to make sure that the + * structure of every [[org.apache.spark.sql.Row]] of the provided RDD matches the provided + * schema. Otherwise, there will be runtime exception. + * + * @note + * this method is not supported in Spark Connect. + * @since 2.0.0 + */ + @DeveloperApi + def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): Dataset[Row] + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, SELECT * queries + * will return the columns in an undefined order. + * + * @since 2.0.0 + */ + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): Dataset[Row] + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, SELECT * queries + * will return the columns in an undefined order. + * + * @note + * this method is not supported in Spark Connect. + * @since 2.0.0 + */ + def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): Dataset[Row] + /* ------------------------------- * | Methods for creating DataSets | * ------------------------------- */ @@ -212,6 +302,18 @@ abstract class SparkSession extends Serializable with Closeable { */ def createDataset[T: Encoder](data: util.List[T]): Dataset[T] + /** + * Creates a [[Dataset]] from an RDD of a given type. This method requires an encoder (to + * convert a JVM object of type `T` to and from the internal Spark SQL representation) that is + * generally created automatically through implicits from a `SparkSession`, or can be created + * explicitly by calling static methods on `Encoders`. + * + * @note + * this method is not supported in Spark Connect. + * @since 2.0.0 + */ + def createDataset[T: Encoder](data: RDD[T]): Dataset[T] + /** * Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a * range from 0 to `end` (exclusive) with step value 1. diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 7ce4609de51f7..aa1aa5f67a2a9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -44,6 +44,12 @@ org.apache.spark spark-sql-api_${scala.binary.version} ${project.version} + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + org.apache.spark diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml index d0d982934d2c7..f2a7f1b1da9d9 100644 --- a/sql/connect/server/pom.xml +++ b/sql/connect/server/pom.xml @@ -52,6 +52,10 @@ spark-connect-common_${scala.binary.version} ${project.version} + + org.apache.spark + spark-connect-shims_${scala.binary.version} + com.google.guava guava diff --git a/sql/connect/shims/README.md b/sql/connect/shims/README.md new file mode 100644 index 0000000000000..07b593dd04b4b --- /dev/null +++ b/sql/connect/shims/README.md @@ -0,0 +1 @@ +This module defines shims used by the interface defined in sql/api. diff --git a/sql/connect/shims/pom.xml b/sql/connect/shims/pom.xml new file mode 100644 index 0000000000000..6bb12a927738c --- /dev/null +++ b/sql/connect/shims/pom.xml @@ -0,0 +1,41 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.13 + 4.0.0-SNAPSHOT + ../../../pom.xml + + + spark-connect-shims_2.13 + jar + Spark Project Connect Shims + https://spark.apache.org/ + + connect-shims + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala new file mode 100644 index 0000000000000..45fae00247485 --- /dev/null +++ b/sql/connect/shims/src/main/scala/org/apache/spark/api/java/shims.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.api.java + +class JavaRDD[T] diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala new file mode 100644 index 0000000000000..b23f83fa9185c --- /dev/null +++ b/sql/connect/shims/src/main/scala/org/apache/spark/rdd/shims.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rdd + +class RDD[T] diff --git a/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala new file mode 100644 index 0000000000000..813b8e4859c28 --- /dev/null +++ b/sql/connect/shims/src/main/scala/org/apache/spark/shims.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark + +class SparkContext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 78cc65bb7a298..ab3e939cee171 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -177,30 +177,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) @scala.annotation.varargs override def json(paths: String*): DataFrame = super.json(paths: _*) - /** - * Loads a `JavaRDD[String]` storing JSON objects (JSON - * Lines text format or newline-delimited JSON) and returns the result as - * a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the - * input once to determine the input schema. - * - * @param jsonRDD input RDD with one JSON object per record - * @since 1.4.0 - */ + /** @inheritdoc */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: JavaRDD[String]): DataFrame = json(jsonRDD.rdd) - /** - * Loads an `RDD[String]` storing JSON objects (JSON Lines - * text format or newline-delimited JSON) and returns the result as a `DataFrame`. - * - * Unless the schema is specified using `schema` function, this function goes through the - * input once to determine the input schema. - * - * @param jsonRDD input RDD with one JSON object per record - * @since 1.4.0 - */ + /** @inheritdoc */ @deprecated("Use json(Dataset[String]) instead.", "2.2.0") def json(jsonRDD: RDD[String]): DataFrame = { json(sparkSession.createDataset(jsonRDD)(Encoders.STRING)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 58006837a3a6d..1c5df1163eb78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1524,12 +1524,7 @@ class Dataset[T] private[sql]( sparkSession.sessionState.executePlan(deserialized) } - /** - * Represents the content of the Dataset as an `RDD` of `T`. - * - * @group basic - * @since 1.6.0 - */ + /** @inheritdoc */ lazy val rdd: RDD[T] = { val objectType = exprEnc.deserializer.dataType rddQueryExecution.toRdd.mapPartitions { rows => @@ -1537,20 +1532,9 @@ class Dataset[T] private[sql]( } } - /** - * Returns the content of the Dataset as a `JavaRDD` of `T`s. - * @group basic - * @since 1.6.0 - */ + /** @inheritdoc */ def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD() - /** - * Returns the content of the Dataset as a `JavaRDD` of `T`s. - * @group basic - * @since 1.6.0 - */ - def javaRDD: JavaRDD[T] = toJavaRDD - protected def createTempView( viewName: String, replace: Boolean, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 1bc7e3ee98e76..b6ed50447109d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -17,21 +17,9 @@ package org.apache.spark.sql -import scala.language.implicitConversions - -import org.apache.spark.rdd.RDD - /** @inheritdoc */ abstract class SQLImplicits extends api.SQLImplicits { type DS[U] = Dataset[U] protected def session: SparkSession - - /** - * Creates a [[Dataset]] from an RDD. - * - * @since 1.6.0 - */ - implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T, Dataset] = - new DatasetHolder(session.createDataset(rdd)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index eeb46fbf145d7..2d485c4ef321d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -294,11 +294,7 @@ class SparkSession private( new Dataset(self, LocalRelation(encoder.schema), encoder) } - /** - * Creates a `DataFrame` from an RDD of Product (e.g. case classes, tuples). - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = withActive { val encoder = Encoders.product[A] Dataset.ofRows(self, ExternalRDD(rdd, self)(encoder)) @@ -311,37 +307,7 @@ class SparkSession private( Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) } - /** - * :: DeveloperApi :: - * Creates a `DataFrame` from an `RDD` containing [[Row]]s using the given schema. - * It is important to make sure that the structure of every [[Row]] of the provided RDD matches - * the provided schema. Otherwise, there will be runtime exception. - * Example: - * {{{ - * import org.apache.spark.sql._ - * import org.apache.spark.sql.types._ - * val sparkSession = new org.apache.spark.sql.SparkSession(sc) - * - * val schema = - * StructType( - * StructField("name", StringType, false) :: - * StructField("age", IntegerType, true) :: Nil) - * - * val people = - * sc.textFile("examples/src/main/resources/people.txt").map( - * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val dataFrame = sparkSession.createDataFrame(people, schema) - * dataFrame.printSchema - * // root - * // |-- name: string (nullable = false) - * // |-- age: integer (nullable = true) - * - * dataFrame.createOrReplaceTempView("people") - * sparkSession.sql("select name from people").collect.foreach(println) - * }}} - * - * @since 2.0.0 - */ + /** @inheritdoc */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = withActive { val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] @@ -353,14 +319,7 @@ class SparkSession private( internalCreateDataFrame(catalystRows.setName(rowRDD.name), schema) } - /** - * :: DeveloperApi :: - * Creates a `DataFrame` from a `JavaRDD` containing [[Row]]s using the given schema. - * It is important to make sure that the structure of every [[Row]] of the provided RDD matches - * the provided schema. Otherwise, there will be runtime exception. - * - * @since 2.0.0 - */ + /** @inheritdoc */ @DeveloperApi def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] @@ -374,14 +333,7 @@ class SparkSession private( Dataset.ofRows(self, LocalRelation.fromExternalRows(toAttributes(replaced), rows.asScala.toSeq)) } - /** - * Applies a schema to an RDD of Java Beans. - * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = withActive { val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) val className = beanClass.getName @@ -392,14 +344,7 @@ class SparkSession private( Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd.setName(rdd.name))(self)) } - /** - * Applies a schema to an RDD of Java Beans. - * - * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, - * SELECT * queries will return the columns in an undefined order. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { createDataFrame(rdd.rdd, beanClass) } @@ -434,14 +379,7 @@ class SparkSession private( Dataset[T](self, plan) } - /** - * Creates a [[Dataset]] from an RDD of a given type. This method requires an - * encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation) - * that is generally created automatically through implicits from a `SparkSession`, or can be - * created explicitly by calling static methods on [[Encoders]]. - * - * @since 2.0.0 - */ + /** @inheritdoc */ def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { Dataset[T](self, ExternalRDD(data, self)) } From 5f64e80843ae47746b2999b4b277ecc622516cd2 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Wed, 9 Oct 2024 10:30:55 +0900 Subject: [PATCH 186/250] [SPARK-49895][SQL] Improve error when encountering trailing comma in SELECT clause ### What changes were proposed in this pull request? Introduced a specific error message for cases where a trailing comma appears at the end of the SELECT clause. ### Why are the changes needed? The previous error message was unclear and often pointed to an incorrect location in the query, leading to confusion. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48370 from stefankandic/fixTrailingComma. Lead-authored-by: Stefan Kandic Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../resources/error/error-conditions.json | 6 ++ .../sql/catalyst/analysis/Analyzer.scala | 13 ++++- .../sql/catalyst/analysis/CheckAnalysis.scala | 37 ++++++++++++ .../sql/errors/QueryCompilationErrors.scala | 8 +++ .../errors/QueryCompilationErrorsSuite.scala | 56 +++++++++++++++++++ 5 files changed, 118 insertions(+), 2 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8100f0580b21f..1b7f42e105077 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -4497,6 +4497,12 @@ ], "sqlState" : "428EK" }, + "TRAILING_COMMA_IN_SELECT" : { + "message" : [ + "Trailing comma detected in SELECT clause. Remove the trailing comma before the FROM clause." + ], + "sqlState" : "42601" + }, "TRANSPOSE_EXCEED_ROW_LIMIT" : { "message" : [ "Number of rows exceeds the allowed limit of for TRANSPOSE. If this was intended, set to at least the current row count." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b2e9115dd512f..5d41c07b47842 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1591,7 +1591,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // If the projection list contains Stars, expand it. case p: Project if containsStar(p.projectList) => - p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) + val expanded = p.copy(projectList = buildExpandedProjectList(p.projectList, p.child)) + if (expanded.projectList.size < p.projectList.size) { + checkTrailingCommaInSelect(expanded, starRemoved = true) + } + expanded // If the filter list contains Stars, expand it. case p: Filter if containsStar(Seq(p.condition)) => p.copy(expandStarExpression(p.condition, p.child)) @@ -1600,7 +1604,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor if (a.groupingExpressions.exists(_.isInstanceOf[UnresolvedOrdinal])) { throw QueryCompilationErrors.starNotAllowedWhenGroupByOrdinalPositionUsedError() } else { - a.copy(aggregateExpressions = buildExpandedProjectList(a.aggregateExpressions, a.child)) + val expanded = a.copy(aggregateExpressions = + buildExpandedProjectList(a.aggregateExpressions, a.child)) + if (expanded.aggregateExpressions.size < a.aggregateExpressions.size) { + checkTrailingCommaInSelect(expanded, starRemoved = true) + } + expanded } case c: CollectMetrics if containsStar(c.metrics) => c.copy(metrics = buildExpandedProjectList(c.metrics, c.child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b600f455f16ac..a4f424ba4b421 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -173,6 +173,36 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB ) } + /** + * Checks for errors in a `SELECT` clause, such as a trailing comma or an empty select list. + * + * @param plan The logical plan of the query. + * @param starRemoved Whether a '*' (wildcard) was removed from the select list. + * @throws AnalysisException if the select list is empty or ends with a trailing comma. + */ + protected def checkTrailingCommaInSelect( + plan: LogicalPlan, + starRemoved: Boolean = false): Unit = { + val exprList = plan match { + case proj: Project if proj.projectList.nonEmpty => + proj.projectList + case agg: Aggregate if agg.aggregateExpressions.nonEmpty => + agg.aggregateExpressions + case _ => + Seq.empty + } + + exprList.lastOption match { + case Some(Alias(UnresolvedAttribute(Seq(name)), _)) => + if (name.equalsIgnoreCase("FROM") && plan.exists(_.isInstanceOf[OneRowRelation])) { + if (exprList.size > 1 || starRemoved) { + throw QueryCompilationErrors.trailingCommaInSelectError(exprList.last.origin) + } + } + case _ => + } + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We should inline all CTE relations to restore the original plan shape, as the analysis check // may need to match certain plan shapes. For dangling CTE relations, they will still be kept @@ -210,6 +240,13 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier write.table.tableNotFound(tblName) + // We should check for trailing comma errors first, since we would get less obvious + // unresolved column errors if we do it bottom up + case proj: Project => + checkTrailingCommaInSelect(proj) + case agg: Aggregate => + checkTrailingCommaInSelect(agg) + case _ => } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 22cc001c0c78e..1f43b3dfa4a16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -358,6 +358,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def trailingCommaInSelectError(origin: Origin): Throwable = { + new AnalysisException( + errorClass = "TRAILING_COMMA_IN_SELECT", + messageParameters = Map.empty, + origin = origin + ) + } + def unresolvedUsingColForJoinError( colName: String, suggestion: String, side: String): Throwable = { new AnalysisException( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index 61b3489083a06..b4fdf50447458 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -979,6 +979,62 @@ class QueryCompilationErrorsSuite ) } + test("SPARK-49895: trailing comma in select statement") { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 INT, c2 INT) USING PARQUET") + + val queries = Seq( + "SELECT *? FROM t1", + "SELECT c1? FROM t1", + "SELECT c1? FROM t1 WHERE c1 = 1", + "SELECT c1? FROM t1 GROUP BY c1", + "SELECT *, RANK() OVER (ORDER BY c1)? FROM t1", + "SELECT c1? FROM t1 ORDER BY c1", + "WITH cte AS (SELECT c1? FROM t1) SELECT * FROM cte", + "WITH cte AS (SELECT c1 FROM t1) SELECT *? FROM cte", + "SELECT * FROM (SELECT c1? FROM t1)") + + queries.foreach { query => + val queryWithoutTrailingComma = query.replaceAll("\\?", "") + val queryWithTrailingComma = query.replaceAll("\\?", ",") + + sql(queryWithoutTrailingComma) + print(queryWithTrailingComma) + val exception = intercept[AnalysisException] { + sql(queryWithTrailingComma) + } + assert(exception.getErrorClass === "TRAILING_COMMA_IN_SELECT") + } + + val unresolvedColumnErrors = Seq( + "SELECT c3 FROM t1", + "SELECT from FROM t1", + "SELECT from FROM (SELECT 'a' as c1)", + "SELECT from AS col FROM t1", + "SELECT from AS from FROM t1", + "SELECT from from FROM t1") + unresolvedColumnErrors.foreach { query => + val exception = intercept[AnalysisException] { + sql(query) + } + assert(exception.getErrorClass === "UNRESOLVED_COLUMN.WITH_SUGGESTION") + } + + // sanity checks + withTable("from") { + sql(s"CREATE TABLE from (from INT) USING PARQUET") + + sql(s"SELECT from FROM from") + sql(s"SELECT from as from FROM from") + sql(s"SELECT from from FROM from from") + sql(s"SELECT c1, from FROM VALUES(1, 2) AS T(c1, from)") + + intercept[ParseException] { + sql("SELECT 1,") + } + } + } + } } class MyCastToString extends SparkUserDefinedFunction( From c1f18a00bb889c0f5aa703525c48d6a9650c7b2d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Wed, 9 Oct 2024 11:35:10 +0900 Subject: [PATCH 187/250] [SPARK-49022][CONNECT][SQL][FOLLOW-UP] Parse unresolved identifier to keep the behavior same ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/47688 that keeps `Column.toString` as the same before. ### Why are the changes needed? To keep the same behaviour with Spark Classic and Connect. ### Does this PR introduce _any_ user-facing change? No, the main change has not been released out yet. ### How was this patch tested? Will be added separately. I manually tested: ```scala import org.apache.spark.sql.functions.col val name = "with`!#$%dot".replace("`", "``") col(s"`${name}`").toString.equals("with`!#$%dot") ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48376 from HyukjinKwon/SPARK-49022-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../sql/internal/columnNodeSupport.scala | 5 +-- .../apache/spark/sql/ColumnTestSuite.scala | 4 +-- .../spark/sql/internal/columnNodes.scala | 33 +++++++++++++++++-- .../sql/internal/columnNodeSupport.scala | 20 ++++++++--- 4 files changed, 50 insertions(+), 12 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 45fa449b58ed7..34a8a91a0ddf8 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -52,9 +52,10 @@ object ColumnNodeToProtoConverter extends (ColumnNode => proto.Expression) { case Literal(value, Some(dataType), _) => builder.setLiteral(toLiteralProtoBuilder(value, dataType)) - case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + case u @ UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => + val escapedName = u.sql val b = builder.getUnresolvedAttributeBuilder - .setUnparsedIdentifier(unparsedIdentifier) + .setUnparsedIdentifier(escapedName) if (isMetadataColumn) { // We only set this field when it is needed. If we would always set it, // too many of the verbatims we use for testing would have to be regenerated. diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala index c37100b729029..86c7a20136851 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ColumnTestSuite.scala @@ -173,8 +173,8 @@ class ColumnTestSuite extends ConnectFunSuite { assert(explain1 != explain2) assert(explain1.strip() == "+(a, b)") assert(explain2.contains("UnresolvedFunction(+")) - assert(explain2.contains("UnresolvedAttribute(a")) - assert(explain2.contains("UnresolvedAttribute(b")) + assert(explain2.contains("UnresolvedAttribute(List(a")) + assert(explain2.contains("UnresolvedAttribute(List(b")) } private def testColName(dataType: DataType, f: ColumnName => StructField): Unit = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala index 51b26a1fa2435..979baf12be614 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/columnNodes.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import ColumnNode._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} +import org.apache.spark.sql.catalyst.util.AttributeNameParser import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.types.{DataType, IntegerType, LongType, Metadata} import org.apache.spark.util.SparkClassUtils @@ -122,7 +123,7 @@ private[sql] case class Literal( /** * Reference to an attribute produced by one of the underlying DataFrames. * - * @param unparsedIdentifier + * @param nameParts * name of the attribute. * @param planId * id of the plan (Dataframe) that produces the attribute. @@ -130,14 +131,40 @@ private[sql] case class Literal( * whether this is a metadata column. */ private[sql] case class UnresolvedAttribute( - unparsedIdentifier: String, + nameParts: Seq[String], planId: Option[Long] = None, isMetadataColumn: Boolean = false, override val origin: Origin = CurrentOrigin.get) extends ColumnNode { + override private[internal] def normalize(): UnresolvedAttribute = copy(planId = None, origin = NO_ORIGIN) - override def sql: String = unparsedIdentifier + + override def sql: String = nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".") +} + +private[sql] object UnresolvedAttribute { + def apply( + unparsedIdentifier: String, + planId: Option[Long], + isMetadataColumn: Boolean, + origin: Origin): UnresolvedAttribute = UnresolvedAttribute( + AttributeNameParser.parseAttributeName(unparsedIdentifier), + planId = planId, + isMetadataColumn = isMetadataColumn, + origin = origin) + + def apply( + unparsedIdentifier: String, + planId: Option[Long], + isMetadataColumn: Boolean): UnresolvedAttribute = + apply(unparsedIdentifier, planId, isMetadataColumn, CurrentOrigin.get) + + def apply(unparsedIdentifier: String, planId: Option[Long]): UnresolvedAttribute = + apply(unparsedIdentifier, planId, false, CurrentOrigin.get) + + def apply(unparsedIdentifier: String): UnresolvedAttribute = + apply(unparsedIdentifier, None, false, CurrentOrigin.get) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala index 920c0371292c9..476956e58e8e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/columnNodeSupport.scala @@ -54,8 +54,8 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case Literal(value, None, _) => expressions.Literal(value) - case UnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn) + case UnresolvedAttribute(nameParts, planId, isMetadataColumn, _) => + convertUnresolvedAttribute(nameParts, planId, isMetadataColumn) case UnresolvedStar(unparsedTarget, None, _) => val target = unparsedTarget.map { t => @@ -74,7 +74,7 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres analysis.UnresolvedRegex(columnNameRegex, Some(nameParts), conf.caseSensitiveAnalysis) case UnresolvedRegex(unparsedIdentifier, planId, _) => - convertUnresolvedAttribute(unparsedIdentifier, planId, isMetadataColumn = false) + convertUnresolvedRegex(unparsedIdentifier, planId) case UnresolvedFunction(functionName, arguments, isDistinct, isUDF, isInternal, _) => val nameParts = if (isUDF) { @@ -223,10 +223,10 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } private def convertUnresolvedAttribute( - unparsedIdentifier: String, + nameParts: Seq[String], planId: Option[Long], isMetadataColumn: Boolean): analysis.UnresolvedAttribute = { - val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + val attribute = analysis.UnresolvedAttribute(nameParts) if (planId.isDefined) { attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) } @@ -235,6 +235,16 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres } attribute } + + private def convertUnresolvedRegex( + unparsedIdentifier: String, + planId: Option[Long]): analysis.UnresolvedAttribute = { + val attribute = analysis.UnresolvedAttribute.quotedString(unparsedIdentifier) + if (planId.isDefined) { + attribute.setTagValue(LogicalPlan.PLAN_ID_TAG, planId.get) + } + attribute + } } private[sql] object ColumnNodeToExpressionConverter extends ColumnNodeToExpressionConverter { From 5e27eec3847533c037c902f88dd3b64d5226710b Mon Sep 17 00:00:00 2001 From: Nikhil Sheoran <125331115+nikhilsheoran-db@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:28:34 +0800 Subject: [PATCH 188/250] [SPARK-49863][SQL] Fix NormalizeFloatingNumbers to preserve nullability of nested structs ### What changes were proposed in this pull request? - Fixes a bug in `NormalizeFloatingNumbers` to respect the `nullable` attribute of nested expressions when normalizing. ### Why are the changes needed? - Without the fix, there would be a degradation in the nullability of the expression post normalization. - For example, for an expression like: `namedStruct("struct", namedStruct("double", )) ` with the following data type: ``` StructType(StructField("struct", StructType(StructField("double", DoubleType, true, {})), false, {})) ``` after normalizing we would have ended up with the dataType: ``` StructType(StructField("struct", StructType(StructField("double", DoubleType, true, {})), true, {})) ``` Note, the change in the `nullable` attribute of the "double" StructField from `false` to `true`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Added unit test. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48331 from nikhilsheoran-db/SPARK-49863-fix. Authored-by: Nikhil Sheoran <125331115+nikhilsheoran-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../optimizer/NormalizeFloatingNumbers.scala | 12 +++++++++++- .../NormalizeFloatingPointNumbersSuite.scala | 8 ++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 2fcc689b9df2b..776efbed273e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -134,7 +134,17 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case (name, i) => Seq(Literal(name), normalize(GetStructField(expr, i))) } val struct = CreateNamedStruct(fields.flatten.toImmutableArraySeq) - KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct)) + // For nested structs (and other complex types), this branch is called again with either a + // `GetStructField` or a `NamedLambdaVariable` expression. Even if the field for which this + // has been recursively called might have `nullable = false`, directly creating an `If` + // predicate would end up creating an expression with `nullable = true` (as the trueBranch is + // nullable). Hence, use the `expr.nullable` to create an `If` predicate only when the column + // is nullable. + if (expr.nullable) { + KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct)) + } else { + KnownFloatingPointNormalized(struct) + } case _ if expr.dataType.isInstanceOf[ArrayType] => val ArrayType(et, containsNull) = expr.dataType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala index 454619a2133d9..21049ca3546dc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -124,5 +124,13 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { comparePlans(doubleOptimized, correctAnswer) } + + test("SPARK-49863: NormalizeFloatingNumbers preserves nullability for nested struct") { + val relation = LocalRelation($"a".double, $"b".string) + val nestedExpr = namedStruct("struct", namedStruct("double", relation.output.head)) + .as("nestedExpr").toAttribute + val normalizedExpr = NormalizeFloatingNumbers.normalize(nestedExpr) + assert(nestedExpr.dataType == normalizedExpr.dataType) + } } From 135cbc6e89d75ae5141d1d57979d962e42f713e0 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Wed, 9 Oct 2024 15:36:58 +0800 Subject: [PATCH 189/250] [SPARK-49564][SQL] Add SQL pipe syntax for set operations ### What changes were proposed in this pull request? This PR adds SQL pipe syntax support for the set operations: UNION, INTERSECT, EXCEPT, DISTINCT. For example: ``` CREATE TABLE t(x INT, y STRING) USING CSV; INSERT INTO t VALUES (0, 'abc'), (1, 'def'); TABLE t |> UNION ALL (SELECT * FROM t); 0 abc 0 abc 1 def 1 def 1 NULL ``` ### Why are the changes needed? The SQL pipe operator syntax will let users compose queries in a more flexible fashion. ### Does this PR introduce _any_ user-facing change? Yes, see above. ### How was this patch tested? This PR adds a few unit test cases, but mostly relies on golden file test coverage. I did this to make sure the answers are correct as this feature is implemented and also so we can look at the analyzer output plans to ensure they look right as well. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48359 from dtenedor/pipe-union. Authored-by: Daniel Tenedorio Signed-off-by: Wenchen Fan --- .../sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../sql/catalyst/parser/AstBuilder.scala | 14 +- .../analyzer-results/pipe-operators.sql.out | 202 ++++++++++++++++++ .../sql-tests/inputs/pipe-operators.sql | 67 ++++++ .../sql-tests/results/pipe-operators.sql.out | 189 ++++++++++++++++ .../sql/execution/SparkSqlParserSuite.scala | 12 ++ 6 files changed, 481 insertions(+), 4 deletions(-) diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index f0481a1a7073c..9d237f069132a 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -1511,6 +1511,7 @@ operatorPipeRightSide | unpivotClause pivotClause? | sample | joinRelation + | operator=(UNION | EXCEPT | SETMINUS | INTERSECT) setQuantifier? right=queryTerm ; // When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 2b0443c01f6d5..c9150b8a26100 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1407,10 +1407,13 @@ class AstBuilder extends DataTypeAstBuilder * - INTERSECT [DISTINCT | ALL] */ override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { - val left = plan(ctx.left) - val right = plan(ctx.right) val all = Option(ctx.setQuantifier()).exists(_.ALL != null) - ctx.operator.getType match { + visitSetOperationImpl(plan(ctx.left), plan(ctx.right), all, ctx.operator.getType) + } + + private def visitSetOperationImpl( + left: LogicalPlan, right: LogicalPlan, all: Boolean, operatorType: Int): LogicalPlan = { + operatorType match { case SqlBaseParser.UNION if all => Union(left, right) case SqlBaseParser.UNION => @@ -5918,7 +5921,10 @@ class AstBuilder extends DataTypeAstBuilder withSample(c, left) }.getOrElse(Option(ctx.joinRelation()).map { c => withJoinRelation(c, left) - }.get))))) + }.getOrElse(Option(ctx.operator).map { c => + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + visitSetOperationImpl(left, plan(ctx.right), all, c.getType) + }.get)))))) } /** diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out index 4479c93f6e84e..7fa4ec0514ff0 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/pipe-operators.sql.out @@ -1820,6 +1820,208 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +table t +|> union all table t +-- !query analysis +Union false, false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> union table t +-- !query analysis +Distinct ++- Union false, false + :- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[x#x,y#x] csv + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select * from t) +|> union all table t +-- !query analysis +Union false, false +:- Project [x#x, y#x] +: +- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select * from t) +|> union table t +-- !query analysis +Distinct ++- Union false, false + :- Project [x#x, y#x] + : +- SubqueryAlias spark_catalog.default.t + : +- Relation spark_catalog.default.t[x#x,y#x] csv + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +values (0, 'abc') tab(x, y) +|> union all table t +-- !query analysis +Union false, false +:- SubqueryAlias tab +: +- LocalRelation [x#x, y#x] ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +values (0, 1) tab(x, y) +|> union table t +-- !query analysis +Distinct ++- Union false, false + :- Project [x#x, cast(y#x as string) AS y#x] + : +- SubqueryAlias tab + : +- LocalRelation [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +(select * from t) +|> union all (select * from t) +-- !query analysis +Union false, false +:- Project [x#x, y#x] +: +- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- Project [x#x, y#x] + +- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> except all table t +-- !query analysis +Except All true +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> except table t +-- !query analysis +Except false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> intersect all table t +-- !query analysis +Intersect All true +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> intersect table t +-- !query analysis +Intersect false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> minus all table t +-- !query analysis +Except All true +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> minus table t +-- !query analysis +Except false +:- SubqueryAlias spark_catalog.default.t +: +- Relation spark_catalog.default.t[x#x,y#x] csv ++- SubqueryAlias spark_catalog.default.t + +- Relation spark_catalog.default.t[x#x,y#x] csv + + +-- !query +table t +|> select x +|> union all table t +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NUM_COLUMNS_MISMATCH", + "sqlState" : "42826", + "messageParameters" : { + "firstNumColumns" : "1", + "invalidNumColumns" : "2", + "invalidOrdinalNum" : "second", + "operator" : "UNION" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 40, + "fragment" : "table t\n|> select x\n|> union all table t" + } ] +} + + +-- !query +table t +|> union all table st +-- !query analysis +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INCOMPATIBLE_COLUMN_TYPE", + "sqlState" : "42825", + "messageParameters" : { + "columnOrdinalNumber" : "second", + "dataType1" : "\"STRUCT\"", + "dataType2" : "\"STRING\"", + "hint" : "", + "operator" : "UNION", + "tableOrdinalNumber" : "second" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 29, + "fragment" : "table t\n|> union all table st" + } ] +} + + -- !query drop table t -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql index 1f8450e3507cb..61890f5cb146d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pipe-operators.sql @@ -504,6 +504,73 @@ table join_test_t1 table join_test_t1 jt |> cross join (select * from jt); +-- Set operations: positive tests. +----------------------------------- + +-- Union all. +table t +|> union all table t; + +-- Union distinct. +table t +|> union table t; + +-- Union all with a table subquery. +(select * from t) +|> union all table t; + +-- Union distinct with a table subquery. +(select * from t) +|> union table t; + +-- Union all with a VALUES list. +values (0, 'abc') tab(x, y) +|> union all table t; + +-- Union distinct with a VALUES list. +values (0, 1) tab(x, y) +|> union table t; + +-- Union all with a table subquery on both the source and target sides. +(select * from t) +|> union all (select * from t); + +-- Except all. +table t +|> except all table t; + +-- Except distinct. +table t +|> except table t; + +-- Intersect all. +table t +|> intersect all table t; + +-- Intersect distinct. +table t +|> intersect table t; + +-- Minus all. +table t +|> minus all table t; + +-- Minus distinct. +table t +|> minus table t; + +-- Set operations: negative tests. +----------------------------------- + +-- The UNION operator requires the same number of columns in the input relations. +table t +|> select x +|> union all table t; + +-- The UNION operator requires the column types to be compatible. +table t +|> union all table st; + -- Cleanup. ----------- drop table t; diff --git a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out index a54e66e53f0f3..8cbc5357d78b6 100644 --- a/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pipe-operators.sql.out @@ -1484,6 +1484,195 @@ org.apache.spark.sql.catalyst.parser.ParseException } +-- !query +table t +|> union all table t +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def +1 def + + +-- !query +table t +|> union table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +(select * from t) +|> union all table t +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def +1 def + + +-- !query +(select * from t) +|> union table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +values (0, 'abc') tab(x, y) +|> union all table t +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def + + +-- !query +values (0, 1) tab(x, y) +|> union table t +-- !query schema +struct +-- !query output +0 1 +0 abc +1 def + + +-- !query +(select * from t) +|> union all (select * from t) +-- !query schema +struct +-- !query output +0 abc +0 abc +1 def +1 def + + +-- !query +table t +|> except all table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> except table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> intersect all table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> intersect table t +-- !query schema +struct +-- !query output +0 abc +1 def + + +-- !query +table t +|> minus all table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> minus table t +-- !query schema +struct +-- !query output + + + +-- !query +table t +|> select x +|> union all table t +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "NUM_COLUMNS_MISMATCH", + "sqlState" : "42826", + "messageParameters" : { + "firstNumColumns" : "1", + "invalidNumColumns" : "2", + "invalidOrdinalNum" : "second", + "operator" : "UNION" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 40, + "fragment" : "table t\n|> select x\n|> union all table t" + } ] +} + + +-- !query +table t +|> union all table st +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.catalyst.ExtendedAnalysisException +{ + "errorClass" : "INCOMPATIBLE_COLUMN_TYPE", + "sqlState" : "42825", + "messageParameters" : { + "columnOrdinalNumber" : "second", + "dataType1" : "\"STRUCT\"", + "dataType2" : "\"STRING\"", + "hint" : "", + "operator" : "UNION", + "tableOrdinalNumber" : "second" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 1, + "stopIndex" : 29, + "fragment" : "table t\n|> union all table st" + } ] +} + + -- !query drop table t -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 20b9c9caa7493..fc1c9c6755572 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -943,6 +943,18 @@ class SparkSqlParserSuite extends AnalysisTest with SharedSparkSession { "FULL OUTER", "ANTI", "LEFT ANTI", "CROSS").foreach { joinType => checkPipeJoin(s"TABLE t |> $joinType JOIN other ON (t.x = other.x)") } + // Set operations + def checkDistinct(query: String): Unit = check(query, Seq(DISTINCT_LIKE)) + def checkExcept(query: String): Unit = check(query, Seq(EXCEPT)) + def checkIntersect(query: String): Unit = check(query, Seq(INTERSECT)) + def checkUnion(query: String): Unit = check(query, Seq(UNION)) + checkDistinct("TABLE t |> UNION DISTINCT TABLE t") + checkExcept("TABLE t |> EXCEPT ALL TABLE t") + checkExcept("TABLE t |> EXCEPT DISTINCT TABLE t") + checkExcept("TABLE t |> MINUS ALL TABLE t") + checkExcept("TABLE t |> MINUS DISTINCT TABLE t") + checkIntersect("TABLE t |> INTERSECT ALL TABLE t") + checkUnion("TABLE t |> UNION ALL TABLE t") } } } From 52538f0d9bd1258dc2a0a2ab5bdb953f85d85da9 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 9 Oct 2024 10:08:06 +0200 Subject: [PATCH 190/250] [SPARK-49909][SQL] Fix the pretty name of some expressions ### What changes were proposed in this pull request? The pr aims to fix the `pretty name` of some `expressions`, includes: `random`, `to_varchar`, `current_database`, `curdate`, `dateadd` and `array_agg`. ### Why are the changes needed? The actual function name used does not match the displayed name, as shown below: - Before: image - After: image ### Does this PR introduce _any_ user-facing change? Yes, Make the header of the data seen by the end-user from `Spark SQL` consistent with the `actual function name` used. ### How was this patch tested? - Pass GA. - Update existed UT. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48385 from panbingkun/SPARK-49909. Authored-by: panbingkun Signed-off-by: Max Gekk --- python/pyspark/sql/functions/builtin.py | 80 +++++++++---------- .../expressions/aggregate/collect.scala | 5 +- .../expressions/datetimeExpressions.scala | 5 +- .../spark/sql/catalyst/expressions/misc.scala | 3 +- .../expressions/numberFormatExpressions.scala | 7 +- .../expressions/randomExpressions.scala | 8 +- .../function_array_agg.explain | 2 +- .../explain-results/function_curdate.explain | 2 +- .../function_current_database.explain | 2 +- .../explain-results/function_dateadd.explain | 2 +- .../function_random_with_seed.explain | 2 +- .../function_to_varchar.explain | 2 +- .../sql-functions/sql-expression-schema.md | 12 +-- .../analyzer-results/charvarchar.sql.out | 6 +- .../current_database_catalog.sql.out | 2 +- .../analyzer-results/group-by.sql.out | 4 +- .../sql-session-variables.sql.out | 2 +- .../sql-tests/results/charvarchar.sql.out | 6 +- .../results/current_database_catalog.sql.out | 2 +- .../sql-tests/results/group-by.sql.out | 4 +- .../results/subexp-elimination.sql.out | 6 +- 21 files changed, 87 insertions(+), 77 deletions(-) diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index beed832e36067..b75d1b2f59faf 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -4921,44 +4921,44 @@ def array_agg(col: "ColumnOrName") -> Column: >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],[1],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - | [1, 1, 2]| - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + | [1, 1, 2]| + +------------------------------+ Example 2: Using array_agg function on a string column >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([["apple"],["apple"],["banana"]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show(truncate=False) - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - |[apple, apple, banana] | - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + |[apple, apple, banana] | + +------------------------------+ Example 3: Using array_agg function on a column with null values >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],[None],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - | [1, 2]| - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + | [1, 2]| + +------------------------------+ Example 4: Using array_agg function on a column with different data types >>> from pyspark.sql import functions as sf >>> df = spark.createDataFrame([[1],["apple"],[2]], ["c"]) >>> df.agg(sf.sort_array(sf.array_agg('c'))).show() - +---------------------------------+ - |sort_array(collect_list(c), true)| - +---------------------------------+ - | [1, 2, apple]| - +---------------------------------+ + +------------------------------+ + |sort_array(array_agg(c), true)| + +------------------------------+ + | [1, 2, apple]| + +------------------------------+ """ return _invoke_function_over_columns("array_agg", col) @@ -8712,31 +8712,31 @@ def dateadd(start: "ColumnOrName", days: Union["ColumnOrName", int]) -> Column: >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", 1)).show() - +---------------+ - |date_add(dt, 1)| - +---------------+ - | 2015-04-09| - +---------------+ + +--------------+ + |dateadd(dt, 1)| + +--------------+ + | 2015-04-09| + +--------------+ >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", sf.lit(2))).show() - +---------------+ - |date_add(dt, 2)| - +---------------+ - | 2015-04-10| - +---------------+ + +--------------+ + |dateadd(dt, 2)| + +--------------+ + | 2015-04-10| + +--------------+ >>> import pyspark.sql.functions as sf >>> spark.createDataFrame( ... [('2015-04-08', 2,)], ['dt', 'add'] ... ).select(sf.dateadd("dt", -1)).show() - +----------------+ - |date_add(dt, -1)| - +----------------+ - | 2015-04-07| - +----------------+ + +---------------+ + |dateadd(dt, -1)| + +---------------+ + | 2015-04-07| + +---------------+ """ days = _enum_to_value(days) days = lit(days) if isinstance(days, int) else days @@ -10343,11 +10343,11 @@ def current_database() -> Column: Examples -------- >>> spark.range(1).select(current_database()).show() - +----------------+ - |current_schema()| - +----------------+ - | default| - +----------------+ + +------------------+ + |current_database()| + +------------------+ + | default| + +------------------+ """ return _invoke_function("current_database") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index c593c8bfb8341..0a4882bfada17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.collection.mutable.Growable import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike @@ -118,7 +118,8 @@ case class CollectList( override def createAggregationBuffer(): mutable.ArrayBuffer[Any] = mutable.ArrayBuffer.empty - override def prettyName: String = "collect_list" + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("collect_list") override def eval(buffer: mutable.ArrayBuffer[Any]): Any = { new GenericArrayData(buffer.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index b166d235557fc..764637b97a100 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -150,7 +150,8 @@ case class CurrentDate(timeZoneId: Option[String] = None) override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def prettyName: String = "current_date" + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_date") } // scalastyle:off line.size.limit @@ -329,7 +330,7 @@ case class DateAdd(startDate: Expression, days: Expression) }) } - override def prettyName: String = "date_add" + override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("date_add") override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): DateAdd = copy(startDate = newLeft, days = newRight) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index cb846f606632b..0315c12b9bb8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -202,7 +202,8 @@ object AssertTrue { case class CurrentDatabase() extends LeafExpression with Unevaluable { override def dataType: DataType = SQLConf.get.defaultStringType override def nullable: Boolean = false - override def prettyName: String = "current_schema" + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("current_database") final override val nodePatterns: Seq[TreePattern] = Seq(CURRENT_LIKE) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala index 5bd2ab6035e10..eefd21b236b7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/numberFormatExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} @@ -307,7 +307,10 @@ case class ToCharacter(left: Expression, right: Expression) inputTypeCheck } } - override def prettyName: String = "to_char" + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("to_char") + override def nullSafeEval(decimal: Any, format: Any): Any = { val input = decimal.asInstanceOf[Decimal] numberFormatter.format(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index ada0a73a67958..3cec83facd01d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedSeed} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} @@ -128,8 +128,12 @@ case class Rand(child: Expression, hideSeed: Boolean = false) extends Nondetermi } override def flatArguments: Iterator[Any] = Iterator(child) + + override def prettyName: String = + getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rand") + override def sql: String = { - s"rand(${if (hideSeed) "" else child.sql})" + s"$prettyName(${if (hideSeed) "" else child.sql})" } override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain index 102f736c62ef6..6668692f6cf1d 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_agg.explain @@ -1,2 +1,2 @@ -Aggregate [collect_list(a#0, 0, 0) AS collect_list(a)#0] +Aggregate [array_agg(a#0, 0, 0) AS array_agg(a)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain index 5305b346c4f2d..be039d62a5494 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_curdate.explain @@ -1,2 +1,2 @@ -Project [current_date(Some(America/Los_Angeles)) AS current_date()#0] +Project [curdate(Some(America/Los_Angeles)) AS curdate()#0] +- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain index 481c0a478c8df..93dfac524d9a1 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_current_database.explain @@ -1,2 +1,2 @@ -Project [current_schema() AS current_schema()#0] +Project [current_database() AS current_database()#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain index 66325085b9c14..319428541760d 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_dateadd.explain @@ -1,2 +1,2 @@ -Project [date_add(d#0, 2) AS date_add(d, 2)#0] +Project [dateadd(d#0, 2) AS dateadd(d, 2)#0] +- LocalRelation , [d#0, t#0, s#0, x#0L, wt#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain index 81c81e95c2bdd..5854d2c7fa6be 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_random_with_seed.explain @@ -1,2 +1,2 @@ -Project [random(1) AS rand(1)#0] +Project [random(1) AS random(1)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain index f0d9cacc61ac5..cc5149bfed863 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_to_varchar.explain @@ -1,2 +1,2 @@ -Project [to_char(cast(b#0 as decimal(30,15)), $99.99) AS to_char(b, $99.99)#0] +Project [to_varchar(cast(b#0 as decimal(30,15)), $99.99) AS to_varchar(b, $99.99)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 5ad1380e1fb82..79fd25aa3eb14 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -99,9 +99,9 @@ | org.apache.spark.sql.catalyst.expressions.Csc | csc | SELECT csc(1) | struct | | org.apache.spark.sql.catalyst.expressions.CsvToStructs | from_csv | SELECT from_csv('1, 0.8', 'a INT, b DOUBLE') | struct> | | org.apache.spark.sql.catalyst.expressions.CumeDist | cume_dist | SELECT a, b, cume_dist() OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | -| org.apache.spark.sql.catalyst.expressions.CurDateExpressionBuilder | curdate | SELECT curdate() | struct | +| org.apache.spark.sql.catalyst.expressions.CurDateExpressionBuilder | curdate | SELECT curdate() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentCatalog | current_catalog | SELECT current_catalog() | struct | -| org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | +| org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_database | SELECT current_database() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDatabase | current_schema | SELECT current_schema() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentDate | current_date | SELECT current_date() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentTimeZone | current_timezone | SELECT current_timezone() | struct | @@ -110,7 +110,7 @@ | org.apache.spark.sql.catalyst.expressions.CurrentUser | session_user | SELECT session_user() | struct | | org.apache.spark.sql.catalyst.expressions.CurrentUser | user | SELECT user() | struct | | org.apache.spark.sql.catalyst.expressions.DateAdd | date_add | SELECT date_add('2016-07-30', 1) | struct | -| org.apache.spark.sql.catalyst.expressions.DateAdd | dateadd | SELECT dateadd('2016-07-30', 1) | struct | +| org.apache.spark.sql.catalyst.expressions.DateAdd | dateadd | SELECT dateadd('2016-07-30', 1) | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | date_diff | SELECT date_diff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateDiff | datediff | SELECT datediff('2009-07-31', '2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.DateFormatClass | date_format | SELECT date_format('2016-04-08', 'y') | struct | @@ -264,7 +264,7 @@ | org.apache.spark.sql.catalyst.expressions.RPadExpressionBuilder | rpad | SELECT rpad('hi', 5, '??') | struct | | org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder | raise_error | SELECT raise_error('custom error message') | struct | | org.apache.spark.sql.catalyst.expressions.Rand | rand | SELECT rand() | struct | -| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | +| org.apache.spark.sql.catalyst.expressions.Rand | random | SELECT random() | struct | | org.apache.spark.sql.catalyst.expressions.RandStr | randstr | SELECT randstr(3, 0) AS result | struct | | org.apache.spark.sql.catalyst.expressions.Randn | randn | SELECT randn() | struct | | org.apache.spark.sql.catalyst.expressions.Rank | rank | SELECT a, b, rank(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct | @@ -340,7 +340,7 @@ | org.apache.spark.sql.catalyst.expressions.TimeWindow | window | SELECT a, window.start, window.end, count(*) as cnt FROM VALUES ('A1', '2021-01-01 00:00:00'), ('A1', '2021-01-01 00:04:30'), ('A1', '2021-01-01 00:06:00'), ('A2', '2021-01-01 00:01:00') AS tab(a, b) GROUP by a, window(b, '5 minutes') ORDER BY a, start | struct | | org.apache.spark.sql.catalyst.expressions.ToBinary | to_binary | SELECT to_binary('abc', 'utf-8') | struct | | org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_char | SELECT to_char(454, '999') | struct | -| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_varchar | SELECT to_varchar(454, '999') | struct | +| org.apache.spark.sql.catalyst.expressions.ToCharacterBuilder | to_varchar | SELECT to_varchar(454, '999') | struct | | org.apache.spark.sql.catalyst.expressions.ToDegrees | degrees | SELECT degrees(3.141592653589793) | struct | | org.apache.spark.sql.catalyst.expressions.ToNumber | to_number | SELECT to_number('454', '999') | struct | | org.apache.spark.sql.catalyst.expressions.ToRadians | radians | SELECT radians(180) | struct | @@ -402,7 +402,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | any | SELECT any(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | bool_or | SELECT bool_or(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | | org.apache.spark.sql.catalyst.expressions.aggregate.BoolOr | some | SELECT some(col) FROM VALUES (true), (false), (false) AS tab(col) | struct | -| org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | array_agg | SELECT array_agg(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | +| org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | array_agg | SELECT array_agg(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.CollectList | collect_list | SELECT collect_list(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet | collect_set | SELECT collect_set(col) FROM VALUES (1), (2), (1) AS tab(col) | struct> | | org.apache.spark.sql.catalyst.expressions.aggregate.Corr | corr | SELECT corr(c1, c2) FROM VALUES (3, 2), (3, 3), (6, 4) as tab(c1, c2) | struct | diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out index 524797015a2f6..d4bcb8f2ed042 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/charvarchar.sql.out @@ -722,19 +722,19 @@ Project [chr(cast(167 as bigint)) AS chr(167)#x, chr(cast(247 as bigint)) AS chr -- !query SELECT to_varchar(78.12, '$99.99') -- !query analysis -Project [to_char(78.12, $99.99) AS to_char(78.12, $99.99)#x] +Project [to_varchar(78.12, $99.99) AS to_varchar(78.12, $99.99)#x] +- OneRowRelation -- !query SELECT to_varchar(111.11, '99.9') -- !query analysis -Project [to_char(111.11, 99.9) AS to_char(111.11, 99.9)#x] +Project [to_varchar(111.11, 99.9) AS to_varchar(111.11, 99.9)#x] +- OneRowRelation -- !query SELECT to_varchar(12454.8, '99,999.9S') -- !query analysis -Project [to_char(12454.8, 99,999.9S) AS to_char(12454.8, 99,999.9S)#x] +Project [to_varchar(12454.8, 99,999.9S) AS to_varchar(12454.8, 99,999.9S)#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out index 1a71594f84932..2759f5e67507b 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/current_database_catalog.sql.out @@ -2,5 +2,5 @@ -- !query select current_database(), current_schema(), current_catalog() -- !query analysis -Project [current_schema() AS current_schema()#x, current_schema() AS current_schema()#x, current_catalog() AS current_catalog()#x] +Project [current_database() AS current_database()#x, current_schema() AS current_schema()#x, current_catalog() AS current_catalog()#x] +- OneRowRelation diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out index 8849aa4452252..6996eb913a21e 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/group-by.sql.out @@ -1133,7 +1133,7 @@ SELECT FROM VALUES (1), (2), (1) AS tab(col) -- !query analysis -Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x, collect_list(col#x, 0, 0) AS collect_list(col)#x] +Aggregate [collect_list(col#x, 0, 0) AS collect_list(col)#x, array_agg(col#x, 0, 0) AS array_agg(col)#x] +- SubqueryAlias tab +- LocalRelation [col#x] @@ -1147,7 +1147,7 @@ FROM VALUES (1,4),(2,3),(1,4),(2,4) AS v(a,b) GROUP BY a -- !query analysis -Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, collect_list(b#x, 0, 0) AS collect_list(b)#x] +Aggregate [a#x], [a#x, collect_list(b#x, 0, 0) AS collect_list(b)#x, array_agg(b#x, 0, 0) AS array_agg(b)#x] +- SubqueryAlias v +- LocalRelation [a#x, b#x] diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out index 02e7c39ae83fd..8c10d78405751 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out @@ -776,7 +776,7 @@ Project [NULL AS Expected#x, variablereference(system.session.var1=CAST(NULL AS -- !query DECLARE OR REPLACE VARIABLE var1 STRING DEFAULT CURRENT_DATABASE() -- !query analysis -CreateVariable defaultvalueexpression(cast(current_schema() as string), CURRENT_DATABASE()), true +CreateVariable defaultvalueexpression(cast(current_database() as string), CURRENT_DATABASE()), true +- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var1 diff --git a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out index 8aafa25c5caaf..2960c4ca4f4d4 100644 --- a/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/charvarchar.sql.out @@ -1235,7 +1235,7 @@ struct -- !query SELECT to_varchar(78.12, '$99.99') -- !query schema -struct +struct -- !query output $78.12 @@ -1243,7 +1243,7 @@ $78.12 -- !query SELECT to_varchar(111.11, '99.9') -- !query schema -struct +struct -- !query output ##.# @@ -1251,6 +1251,6 @@ struct -- !query SELECT to_varchar(12454.8, '99,999.9S') -- !query schema -struct +struct -- !query output 12,454.8+ diff --git a/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out index 67db0adee7f07..7fbe2dfff4db1 100644 --- a/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/current_database_catalog.sql.out @@ -2,6 +2,6 @@ -- !query select current_database(), current_schema(), current_catalog() -- !query schema -struct +struct -- !query output default default spark_catalog diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index d8a9f4c2e11f5..5d220fc12b78e 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1066,7 +1066,7 @@ SELECT FROM VALUES (1), (2), (1) AS tab(col) -- !query schema -struct,collect_list(col):array> +struct,array_agg(col):array> -- !query output [1,2,1] [1,2,1] @@ -1080,7 +1080,7 @@ FROM VALUES (1,4),(2,3),(1,4),(2,4) AS v(a,b) GROUP BY a -- !query schema -struct,collect_list(b):array> +struct,array_agg(b):array> -- !query output 1 [4,4] [4,4] 2 [3,4] [3,4] diff --git a/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out b/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out index 0f7ff3f107567..28457c0579e95 100644 --- a/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subexp-elimination.sql.out @@ -72,7 +72,7 @@ NULL -- !query SELECT from_json(a, 'struct').a + random() > 2, from_json(a, 'struct').b, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].b + + random() > 2 FROM testData -- !query schema -struct<((from_json(a).a + rand()) > 2):boolean,from_json(a).b:string,from_json(b)[0].a:int,((from_json(b)[0].b + (+ rand())) > 2):boolean> +struct<((from_json(a).a + random()) > 2):boolean,from_json(a).b:string,from_json(b)[0].a:int,((from_json(b)[0].b + (+ random())) > 2):boolean> -- !query output NULL NULL 1 true false 2 1 true @@ -84,7 +84,7 @@ true 6 6 true -- !query SELECT if(from_json(a, 'struct').a + random() > 5, from_json(b, 'array>')[0].a, from_json(b, 'array>')[0].a + 1) FROM testData -- !query schema -struct<(IF(((from_json(a).a + rand()) > 5), from_json(b)[0].a, (from_json(b)[0].a + 1))):int> +struct<(IF(((from_json(a).a + random()) > 5), from_json(b)[0].a, (from_json(b)[0].a + 1))):int> -- !query output 2 2 @@ -96,7 +96,7 @@ NULL -- !query SELECT case when from_json(a, 'struct').a > 5 then from_json(a, 'struct').b + random() > 5 when from_json(a, 'struct').a > 4 then from_json(a, 'struct').b + 1 + random() > 2 else from_json(a, 'struct').b + 2 + random() > 5 end FROM testData -- !query schema -struct 5) THEN ((from_json(a).b + rand()) > 5) WHEN (from_json(a).a > 4) THEN (((from_json(a).b + 1) + rand()) > 2) ELSE (((from_json(a).b + 2) + rand()) > 5) END:boolean> +struct 5) THEN ((from_json(a).b + random()) > 5) WHEN (from_json(a).a > 4) THEN (((from_json(a).b + 1) + random()) > 2) ELSE (((from_json(a).b + 2) + random()) > 5) END:boolean> -- !query output NULL false From 6cdcf5b001b23791504dcdb964474727527e563b Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Wed, 9 Oct 2024 11:06:42 +0200 Subject: [PATCH 191/250] fix scala style. --- .../sql/catalyst/util/CollationFactory.java | 30 +++++++++---------- .../unsafe/types/CollationFactorySuite.scala | 4 +-- .../sql/CollationSQLExpressionsSuite.scala | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 85db5b02a6992..b1d11e96d9bbd 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -551,12 +551,12 @@ protected Collation buildCollation() { comparator = UTF8String::binaryCompare; hashFunction = s -> (long) s.hashCode(); equalsFunction = UTF8String::equals; - }else { + } else { comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( - applyTrimmingPolicy(s2, spaceTrimming)); + applyTrimmingPolicy(s2, spaceTrimming)); hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( - applyTrimmingPolicy(s2, spaceTrimming)); + applyTrimmingPolicy(s2, spaceTrimming)); } return new Collation( @@ -575,16 +575,16 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if(spaceTrimming == SpaceTrimming.NONE ) { + if (spaceTrimming == SpaceTrimming.NONE ) { comparator = CollationAwareUTF8String::compareLowerCase; hashFunction = s -> - (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); - }else{ + (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + } else { comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( - applyTrimmingPolicy(s1, spaceTrimming), - applyTrimmingPolicy(s2, spaceTrimming)); + applyTrimmingPolicy(s1, spaceTrimming), + applyTrimmingPolicy(s2, spaceTrimming)); hashFunction = s -> (long) CollationAwareUTF8String. - lowerCaseCodePoints(applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + lowerCaseCodePoints(applyTrimmingPolicy(s, spaceTrimming)).hashCode(); } return new Collation( @@ -961,17 +961,17 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if(spaceTrimming == SpaceTrimming.NONE){ + if (spaceTrimming == SpaceTrimming.NONE){ hashFunction = s -> (long) collator.getCollationKey( - s.toValidString()).hashCode(); + s.toValidString()).hashCode(); comparator = (s1, s2) -> - collator.compare(s1.toValidString(), s2.toValidString()); + collator.compare(s1.toValidString(), s2.toValidString()); } else { comparator = (s1, s2) -> collator.compare( - applyTrimmingPolicy(s1, spaceTrimming).toValidString(), - applyTrimmingPolicy(s2, spaceTrimming).toValidString()); + applyTrimmingPolicy(s1, spaceTrimming).toValidString(), + applyTrimmingPolicy(s2, spaceTrimming).toValidString()); hashFunction = s -> (long) collator.getCollationKey( - applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); } return new Collation( diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 491abeab58e01..88ef9a3c2d83f 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -146,7 +146,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", true) ) checks.foreach(testCase => { @@ -190,7 +190,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), CollationTestCase("UNICODE_CI_RTRIM", "aaa", "aaa ", 0), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1) ) checks.foreach(testCase => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index fd83408da7f74..ac8ad69dd55d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -503,7 +503,7 @@ class CollationSQLExpressionsSuite BinTestCase("13", "UNICODE", "1101"), BinTestCase("13", "UNICODE_RTRIM", "1101"), BinTestCase("13", "UNICODE_CI", "1101"), - BinTestCase("13", "UNICODE_CI_RTRIM", "1101"), + BinTestCase("13", "UNICODE_CI_RTRIM", "1101") ) testCases.foreach(t => { val query = From b565a8df9fef4344f18ad103df69178586ea099d Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 9 Oct 2024 08:58:06 -0400 Subject: [PATCH 192/250] [SPARK-49418][CONNECT][SQL] Shared Session Thread Locals ### What changes were proposed in this pull request? This PR adds interfaces for SparkSession Thread Locals. ### Why are the changes needed? We are creating a unified Spark SQL Scala interface. This is part of that effort. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48374 from hvanhovell/SPARK-49418. Authored-by: Herman van Hovell Signed-off-by: Herman van Hovell --- .../org/apache/spark/sql/SparkSession.scala | 113 ++--------- .../apache/spark/sql/SparkSessionSuite.scala | 3 +- .../CheckConnectJvmClientCompatibility.scala | 2 + project/MimaExcludes.scala | 6 + .../apache/spark/sql/api/SparkSession.scala | 177 +++++++++++++++++- .../org/apache/spark/sql/SparkSession.scala | 137 +++----------- 6 files changed, 224 insertions(+), 214 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index ad10a22f833bf..c0590fbd1728f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import java.net.URI import java.nio.file.{Files, Paths} import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.{AtomicLong, AtomicReference} +import java.util.concurrent.atomic.AtomicLong import scala.jdk.CollectionConverters._ import scala.reflect.runtime.universe.TypeTag @@ -525,6 +525,8 @@ class SparkSession private[sql] ( } } + override private[sql] def isUsable: Boolean = client.isSessionValid + implicit class RichColumn(c: Column) { def expr: proto.Expression = toExpr(c) def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e) @@ -533,7 +535,9 @@ class SparkSession private[sql] ( // The minimal builder needed to create a spark session. // TODO: implements all methods mentioned in the scaladoc of [[SparkSession]] -object SparkSession extends api.SparkSessionCompanion with Logging { +object SparkSession extends api.BaseSparkSessionCompanion with Logging { + override private[sql] type Session = SparkSession + private val MAX_CACHED_SESSIONS = 100 private val planIdGenerator = new AtomicLong private var server: Option[Process] = None @@ -549,29 +553,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging { override def load(c: Configuration): SparkSession = create(c) }) - /** The active SparkSession for the current thread. */ - private val activeThreadSession = new InheritableThreadLocal[SparkSession] - - /** Reference to the root SparkSession. */ - private val defaultSession = new AtomicReference[SparkSession] - - /** - * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when - * they are not set yet or the associated [[SparkConnectClient]] is unusable. - */ - private def setDefaultAndActiveSession(session: SparkSession): Unit = { - val currentDefault = defaultSession.getAcquire - if (currentDefault == null || !currentDefault.client.isSessionValid) { - // Update `defaultSession` if it is null or the contained session is not valid. There is a - // chance that the following `compareAndSet` fails if a new default session has just been set, - // but that does not matter since that event has happened after this method was invoked. - defaultSession.compareAndSet(currentDefault, session) - } - if (getActiveSession.isEmpty) { - setActiveSession(session) - } - } - /** * Create a new Spark Connect server to connect locally. */ @@ -624,17 +605,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging { new SparkSession(configuration.toSparkConnectClient, planIdGenerator) } - /** - * Hook called when a session is closed. - */ - private[sql] def onSessionClose(session: SparkSession): Unit = { - sessions.invalidate(session.client.configuration) - defaultSession.compareAndSet(session, null) - if (getActiveSession.contains(session)) { - clearActiveSession() - } - } - /** * Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]]. * @@ -781,71 +751,12 @@ object SparkSession extends api.SparkSessionCompanion with Logging { } } - /** - * Returns the default SparkSession. If the previously set default SparkSession becomes - * unusable, returns None. - * - * @since 3.5.0 - */ - def getDefaultSession: Option[SparkSession] = - Option(defaultSession.get()).filter(_.client.isSessionValid) - - /** - * Sets the default SparkSession. - * - * @since 3.5.0 - */ - def setDefaultSession(session: SparkSession): Unit = { - defaultSession.set(session) - } - - /** - * Clears the default SparkSession. - * - * @since 3.5.0 - */ - def clearDefaultSession(): Unit = { - defaultSession.set(null) - } - - /** - * Returns the active SparkSession for the current thread. If the previously set active - * SparkSession becomes unusable, returns None. - * - * @since 3.5.0 - */ - def getActiveSession: Option[SparkSession] = - Option(activeThreadSession.get()).filter(_.client.isSessionValid) - - /** - * Changes the SparkSession that will be returned in this thread and its children when - * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives - * an isolated SparkSession. - * - * @since 3.5.0 - */ - def setActiveSession(session: SparkSession): Unit = { - activeThreadSession.set(session) - } + /** @inheritdoc */ + override def getActiveSession: Option[SparkSession] = super.getActiveSession - /** - * Clears the active SparkSession for current thread. - * - * @since 3.5.0 - */ - def clearActiveSession(): Unit = { - activeThreadSession.remove() - } + /** @inheritdoc */ + override def getDefaultSession: Option[SparkSession] = super.getDefaultSession - /** - * Returns the currently active SparkSession, otherwise the default one. If there is no default - * SparkSession, throws an exception. - * - * @since 3.5.0 - */ - def active: SparkSession = { - getActiveSession - .orElse(getDefaultSession) - .getOrElse(throw new IllegalStateException("No active or default Spark session found")) - } + /** @inheritdoc */ + override def active: SparkSession = super.active } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala index 8abc41639fdd2..dec56554d143e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/SparkSessionSuite.scala @@ -22,6 +22,7 @@ import scala.util.control.NonFatal import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, MethodDescriptor} +import org.apache.spark.SparkException import org.apache.spark.sql.test.ConnectFunSuite import org.apache.spark.util.SparkSerDeUtils @@ -113,7 +114,7 @@ class SparkSessionSuite extends ConnectFunSuite { SparkSession.clearActiveSession() assert(SparkSession.getDefaultSession.isEmpty) assert(SparkSession.getActiveSession.isEmpty) - intercept[IllegalStateException](SparkSession.active) + intercept[SparkException](SparkSession.active) // Create a session val session1 = SparkSession.builder().remote(connectionString1).getOrCreate() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index c8776af18a14a..693c807ec71ea 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -227,6 +227,8 @@ object CheckConnectJvmClientCompatibility { "org.apache.spark.sql.SparkSession.baseRelationToDataFrame"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.createDataset"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.SparkSession.executeCommand"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "org.apache.spark.sql.SparkSession.canUseSession"), // SparkSession#implicits ProblemFilters.exclude[DirectMissingMethodProblem]( diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2b3d76eb0c2c3..3ccb0bddfb0eb 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -189,6 +189,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.javalang.typed"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed$"), + + // SPARK-49418: Consolidate thread local handling in sql/api + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setActiveSession"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setDefaultSession"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearActiveSession"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearDefaultSession"), ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala index b2e61df5937bd..31ceecb9e4ca5 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala @@ -24,8 +24,9 @@ import _root_.java.io.Closeable import _root_.java.lang import _root_.java.net.URI import _root_.java.util +import _root_.java.util.concurrent.atomic.AtomicReference -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD @@ -663,9 +664,19 @@ abstract class SparkSession extends Serializable with Closeable { * @since 2.0.0 */ def stop(): Unit = close() + + /** + * Check to see if the session is still usable. + * + * In Classic this means that the underlying `SparkContext` is still active. In Connect this + * means the connection to the server is usable. + */ + private[sql] def isUsable: Boolean } object SparkSession extends SparkSessionCompanion { + type Session = SparkSession + private[this] val companion: SparkSessionCompanion = { val cls = SparkClassUtils.classForName("org.apache.spark.sql.SparkSession") val mirror = scala.reflect.runtime.currentMirror @@ -675,12 +686,97 @@ object SparkSession extends SparkSessionCompanion { /** @inheritdoc */ override def builder(): SparkSessionBuilder = companion.builder() + + /** @inheritdoc */ + override def setActiveSession(session: SparkSession): Unit = + companion.setActiveSession(session.asInstanceOf[companion.Session]) + + /** @inheritdoc */ + override def clearActiveSession(): Unit = companion.clearActiveSession() + + /** @inheritdoc */ + override def setDefaultSession(session: SparkSession): Unit = + companion.setDefaultSession(session.asInstanceOf[companion.Session]) + + /** @inheritdoc */ + override def clearDefaultSession(): Unit = companion.clearDefaultSession() + + /** @inheritdoc */ + override def getActiveSession: Option[SparkSession] = companion.getActiveSession + + /** @inheritdoc */ + override def getDefaultSession: Option[SparkSession] = companion.getDefaultSession } /** - * Companion of a [[SparkSession]]. + * Interface for a [[SparkSession]] Companion. The companion is responsible for building the + * session, and managing the active (thread local) and default (global) SparkSessions. */ private[sql] abstract class SparkSessionCompanion { + private[sql] type Session <: SparkSession + + /** + * Changes the SparkSession that will be returned in this thread and its children when + * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives + * a SparkSession with an isolated session, instead of the global (first created) context. + * + * @since 2.0.0 + */ + def setActiveSession(session: Session): Unit + + /** + * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will + * return the first created context instead of a thread-local override. + * + * @since 2.0.0 + */ + def clearActiveSession(): Unit + + /** + * Sets the default SparkSession that is returned by the builder. + * + * @since 2.0.0 + */ + def setDefaultSession(session: Session): Unit + + /** + * Clears the default SparkSession that is returned by the builder. + * + * @since 2.0.0 + */ + def clearDefaultSession(): Unit + + /** + * Returns the active SparkSession for the current thread, returned by the builder. + * + * @note + * Return None, when calling this function on executors + * + * @since 2.2.0 + */ + def getActiveSession: Option[Session] + + /** + * Returns the default SparkSession that is returned by the builder. + * + * @note + * Return None, when calling this function on executors + * + * @since 2.2.0 + */ + def getDefaultSession: Option[Session] + + /** + * Returns the currently active SparkSession, otherwise the default one. If there is no default + * SparkSession, throws an exception. + * + * @since 2.4.0 + */ + def active: Session = { + getActiveSession.getOrElse( + getDefaultSession.getOrElse( + throw SparkException.internalError("No active or default Spark session found"))) + } /** * Creates a [[SparkSessionBuilder]] for constructing a [[SparkSession]]. @@ -690,6 +786,83 @@ private[sql] abstract class SparkSessionCompanion { def builder(): SparkSessionBuilder } +/** + * Abstract class for [[SparkSession]] companions. This implements active and default session + * management. + */ +private[sql] abstract class BaseSparkSessionCompanion extends SparkSessionCompanion { + + /** The active SparkSession for the current thread. */ + private val activeThreadSession = new InheritableThreadLocal[Session] + + /** Reference to the root SparkSession. */ + private val defaultSession = new AtomicReference[Session] + + /** @inheritdoc */ + def setActiveSession(session: Session): Unit = { + activeThreadSession.set(session) + } + + /** @inheritdoc */ + def clearActiveSession(): Unit = { + activeThreadSession.remove() + } + + /** @inheritdoc */ + def setDefaultSession(session: Session): Unit = { + defaultSession.set(session) + } + + /** @inheritdoc */ + def clearDefaultSession(): Unit = { + defaultSession.set(null.asInstanceOf[Session]) + } + + /** @inheritdoc */ + def getActiveSession: Option[Session] = usableSession(activeThreadSession.get()) + + /** @inheritdoc */ + def getDefaultSession: Option[Session] = usableSession(defaultSession.get()) + + private def usableSession(session: Session): Option[Session] = { + if ((session ne null) && canUseSession(session)) { + Some(session) + } else { + None + } + } + + protected def canUseSession(session: Session): Boolean = session.isUsable + + /** + * Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when + * they are not set yet or they are not usable. + */ + protected def setDefaultAndActiveSession(session: Session): Unit = { + val currentDefault = defaultSession.getAcquire + if (currentDefault == null || !currentDefault.isUsable) { + // Update `defaultSession` if it is null or the contained session is not usable. There is a + // chance that the following `compareAndSet` fails if a new default session has just been set, + // but that does not matter since that event has happened after this method was invoked. + defaultSession.compareAndSet(currentDefault, session) + } + val active = getActiveSession + if (active.isEmpty || !active.get.isUsable) { + setActiveSession(session) + } + } + + /** + * When the session is closed remove it from active and default. + */ + private[sql] def onSessionClose(session: Session): Unit = { + defaultSession.compareAndSet(session, null.asInstanceOf[Session]) + if (getActiveSession.contains(session)) { + clearActiveSession() + } + } +} + /** * Builder for [[SparkSession]]. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2d485c4ef321d..55525380aee55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,7 +21,7 @@ import java.net.URI import java.nio.file.Paths import java.util.{ServiceLoader, UUID} import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import java.util.concurrent.atomic.AtomicBoolean import scala.concurrent.duration.DurationInt import scala.jdk.CollectionConverters._ @@ -743,7 +743,7 @@ class SparkSession private( // Use the active session thread local directly to make sure we get the session that is actually // set and not the default session. This to prevent that we promote the default session to the // active session once we are done. - val old = SparkSession.activeThreadSession.get() + val old = SparkSession.getActiveSession.orNull SparkSession.setActiveSession(this) try block finally { SparkSession.setActiveSession(old) @@ -774,11 +774,14 @@ class SparkSession private( } private[sql] lazy val observationManager = new ObservationManager(this) + + override private[sql] def isUsable: Boolean = !sparkContext.isStopped } @Stable -object SparkSession extends api.SparkSessionCompanion with Logging { +object SparkSession extends api.BaseSparkSessionCompanion with Logging { + override private[sql] type Session = SparkSession /** * Builder for [[SparkSession]]. @@ -862,28 +865,22 @@ object SparkSession extends api.SparkSessionCompanion with Logging { assertOnDriver() } - def clearSessionIfDead(session: SparkSession): SparkSession = { - if ((session ne null) && !session.sparkContext.isStopped) { - session - } else { - null - } - } - // Get the session from current thread's active session. - val active = clearSessionIfDead(activeThreadSession.get()) - if (!forceCreate && (active ne null)) { - applyModifiableSettings(active, new java.util.HashMap[String, String](options.asJava)) - return active + val active = getActiveSession + if (!forceCreate && active.isDefined) { + val session = active.get + applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) + return session } // Global synchronization so we will only set the default session once. SparkSession.synchronized { // If the current thread does not have an active session, get it from the global session. - val default = clearSessionIfDead(defaultSession.get()) - if (!forceCreate && (default ne null)) { - applyModifiableSettings(default, new java.util.HashMap[String, String](options.asJava)) - return default + val default = getDefaultSession + if (!forceCreate && default.isDefined) { + val session = default.get + applyModifiableSettings(session, new java.util.HashMap[String, String](options.asJava)) + return session } // No active nor global default session. Create a new one. @@ -906,12 +903,7 @@ object SparkSession extends api.SparkSessionCompanion with Logging { extensions, initialSessionOptions = options.toMap, parentManagedJobTags = Map.empty) - if (default eq null) { - setDefaultSession(session) - } - if (active eq null) { - setActiveSession(session) - } + setDefaultAndActiveSession(session) registerContextListener(sparkContext) session } @@ -931,87 +923,17 @@ object SparkSession extends api.SparkSessionCompanion with Logging { */ def builder(): Builder = new Builder - /** - * Changes the SparkSession that will be returned in this thread and its children when - * SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives - * a SparkSession with an isolated session, instead of the global (first created) context. - * - * @since 2.0.0 - */ - def setActiveSession(session: SparkSession): Unit = { - activeThreadSession.set(session) - } - - /** - * Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will - * return the first created context instead of a thread-local override. - * - * @since 2.0.0 - */ - def clearActiveSession(): Unit = { - activeThreadSession.remove() - } - - /** - * Sets the default SparkSession that is returned by the builder. - * - * @since 2.0.0 - */ - def setDefaultSession(session: SparkSession): Unit = { - defaultSession.set(session) - } - - /** - * Clears the default SparkSession that is returned by the builder. - * - * @since 2.0.0 - */ - def clearDefaultSession(): Unit = { - defaultSession.set(null) - } + /** @inheritdoc */ + override def getActiveSession: Option[SparkSession] = super.getActiveSession - /** - * Returns the active SparkSession for the current thread, returned by the builder. - * - * @note Return None, when calling this function on executors - * - * @since 2.2.0 - */ - def getActiveSession: Option[SparkSession] = { - if (Utils.isInRunningSparkTask) { - // Return None when running on executors. - None - } else { - Option(activeThreadSession.get) - } - } + /** @inheritdoc */ + override def getDefaultSession: Option[SparkSession] = super.getDefaultSession - /** - * Returns the default SparkSession that is returned by the builder. - * - * @note Return None, when calling this function on executors - * - * @since 2.2.0 - */ - def getDefaultSession: Option[SparkSession] = { - if (Utils.isInRunningSparkTask) { - // Return None when running on executors. - None - } else { - Option(defaultSession.get) - } - } + /** @inheritdoc */ + override def active: SparkSession = super.active - /** - * Returns the currently active SparkSession, otherwise the default one. If there is no default - * SparkSession, throws an exception. - * - * @since 2.4.0 - */ - def active: SparkSession = { - getActiveSession.getOrElse(getDefaultSession.getOrElse( - throw SparkException.internalError("No active or default Spark session found"))) - } + override protected def canUseSession(session: SparkSession): Boolean = + session.isUsable && !Utils.isInRunningSparkTask /** * Apply modifiable settings to an existing [[SparkSession]]. This method are used @@ -1082,7 +1004,8 @@ object SparkSession extends api.SparkSessionCompanion with Logging { if (!listenerRegistered.get()) { sparkContext.addSparkListener(new SparkListener { override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { - defaultSession.set(null) + clearDefaultSession() + clearActiveSession() listenerRegistered.set(false) } }) @@ -1090,12 +1013,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging { } } - /** The active SparkSession for the current thread. */ - private val activeThreadSession = new InheritableThreadLocal[SparkSession] - - /** Reference to the root SparkSession. */ - private val defaultSession = new AtomicReference[SparkSession] - private val HIVE_SESSION_STATE_BUILDER_CLASS_NAME = "org.apache.spark.sql.hive.HiveSessionStateBuilder" From b1ff7672cba12750d41d803f0faeb3487d934601 Mon Sep 17 00:00:00 2001 From: Julek Sompolski Date: Wed, 9 Oct 2024 21:34:16 +0800 Subject: [PATCH 193/250] [SPARK-49857][SQL] Add storageLevel to Dataset localCheckpoint API ### What changes were proposed in this pull request? Currently, when running `Dataset.localCheckpoint(eager = true)`, it is impossible to specify a non-default StorageLevel for the checkpoint. On the other hand it is possible with Dataset cache by using `Dataset.persist(newLevel: StorageLevel)`. If one wants to specify a non-default StorageLevel for localCheckpoint, it currently needs accessing the plan, changing the level, and then triggering an action to materialize checkpoint: ``` // start lazy val checkpointDf = df.localCheckpoint(eager = false) // fish out the RDD val checkpointPlan = checkpointedSourcePlanDF.queryExecution.analyzed val rdd = checkpointedPlan.asInstanceOf[LogicalRDD].rdd // change the StorageLevel rdd.persist(StorageLevel.DISK_ONLY) // force materialization checkpointDf .mapPartitions(_ => Iterator.empty.asInstanceOf[Iterator[Row]]) .foreach((_: Row) => ()) ``` There are several issues with this: 1. Won't work with Connect as we don't have access to RDD internals 2. Lazy checkpoint is not in fact lazy when AQE is involved. In order to get the RDD of a lazy checkpoint, AQE will actually trigger execution of all the query stages except the result stage in order to get the final plan. So the `start lazy` phase will already execute everything except the final stage, and then `force materialization` will only execute result stage. This is "unexpected" and makes it more difficult to debug, first showing a query with missing metrics for the final stage, and then another query that skipped everything and only ran final stage. Having an API to specify storageLevel for localCheckpoint will help avoid such hacks. As a precedent, it is already possible to specify StorageLevel for Dataset cache by using `Dataset.persist(newLevel: StorageLevel)`. In this PR, I implement this API for scala and python, and classic and connect. ### Why are the changes needed? https://github.com/delta-io/delta/blob/master/spark/src/main/scala/org/apache/spark/sql/delta/commands/merge/MergeIntoMaterializeSource.scala in `prepareMergeSource` has to do hacks as described above to use localCheckpoint with non-default StorageLevel. It is hacky, and confusing that it then records two separate executions as described above. ### Does this PR introduce _any_ user-facing change? Yes. Adds API to pass `storageLevel` to Dataset `localCheckpoint`. ### How was this patch tested? Tests added. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Github Copilot (trivial code completions) Closes #48324 from juliuszsompolski/SPARK-49857. Authored-by: Julek Sompolski Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/Dataset.scala | 15 ++++++- .../apache/spark/sql/CheckpointSuite.scala | 15 +++++-- python/pyspark/sql/classic/dataframe.py | 9 +++- python/pyspark/sql/connect/dataframe.py | 6 ++- python/pyspark/sql/connect/plan.py | 22 ++++++---- .../pyspark/sql/connect/proto/commands_pb2.py | 14 +++---- .../sql/connect/proto/commands_pb2.pyi | 29 ++++++++++++- python/pyspark/sql/dataframe.py | 9 +++- python/pyspark/sql/tests/test_dataframe.py | 8 +++- .../org/apache/spark/sql/api/Dataset.scala | 42 ++++++++++++++++--- .../protobuf/spark/connect/commands.proto | 3 ++ .../connect/planner/SparkConnectPlanner.scala | 15 +++++-- .../scala/org/apache/spark/sql/Dataset.scala | 11 ++++- .../org/apache/spark/sql/DatasetSuite.scala | 20 +++++++++ 14 files changed, 181 insertions(+), 37 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 966b5acebca23..adbfda9691508 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1115,13 +1115,20 @@ class Dataset[T] private[sql] ( } /** @inheritdoc */ - protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + protected def checkpoint( + eager: Boolean, + reliableCheckpoint: Boolean, + storageLevel: Option[StorageLevel]): Dataset[T] = { sparkSession.newDataset(agnosticEncoder) { builder => val command = sparkSession.newCommand { builder => - builder.getCheckpointCommandBuilder + val checkpointBuilder = builder.getCheckpointCommandBuilder .setLocal(!reliableCheckpoint) .setEager(eager) .setRelation(this.plan.getRoot) + storageLevel.foreach { storageLevel => + checkpointBuilder.setStorageLevel( + StorageLevelProtoConverter.toConnectProtoType(storageLevel)) + } } val responseIter = sparkSession.execute(command) try { @@ -1304,6 +1311,10 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ override def localCheckpoint(eager: Boolean): Dataset[T] = super.localCheckpoint(eager) + /** @inheritdoc */ + override def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] = + super.localCheckpoint(eager, storageLevel) + /** @inheritdoc */ override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = super.joinWith(other, condition) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala index e57b051890f56..0d9685d9c710f 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CheckpointSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper} +import org.apache.spark.storage.StorageLevel class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelper { @@ -50,12 +51,20 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe checkFragments(captureStdOut(block), fragmentsToCheck) } - test("checkpoint") { + test("localCheckpoint") { val df = spark.range(100).localCheckpoint() testCapturedStdOut(df.explain(), "ExistingRDD") } - test("checkpoint gc") { + test("localCheckpoint with StorageLevel") { + // We don't have a way to reach into the server and assert the storage level server side, but + // this test should cover for unexpected errors in the API. + val df = + spark.range(100).localCheckpoint(eager = true, storageLevel = StorageLevel.DISK_ONLY) + df.collect() + } + + test("localCheckpoint gc") { val df = spark.range(100).localCheckpoint(eager = true) val encoder = df.agnosticEncoder val dfId = df.plan.getRoot.getCachedRemoteRelation.getRelationId @@ -77,7 +86,7 @@ class CheckpointSuite extends ConnectFunSuite with RemoteSparkSession with SQLHe // This test is flaky because cannot guarantee GC // You can locally run this to verify the behavior. - ignore("checkpoint gc derived DataFrame") { + ignore("localCheckpoint gc derived DataFrame") { var df1 = spark.range(100).localCheckpoint(eager = true) var derived = df1.repartition(10) val encoder = df1.agnosticEncoder diff --git a/python/pyspark/sql/classic/dataframe.py b/python/pyspark/sql/classic/dataframe.py index e412b98c47de5..91dec609e522a 100644 --- a/python/pyspark/sql/classic/dataframe.py +++ b/python/pyspark/sql/classic/dataframe.py @@ -360,8 +360,13 @@ def checkpoint(self, eager: bool = True) -> ParentDataFrame: jdf = self._jdf.checkpoint(eager) return DataFrame(jdf, self.sparkSession) - def localCheckpoint(self, eager: bool = True) -> ParentDataFrame: - jdf = self._jdf.localCheckpoint(eager) + def localCheckpoint( + self, eager: bool = True, storageLevel: Optional[StorageLevel] = None + ) -> ParentDataFrame: + if storageLevel is None: + jdf = self._jdf.localCheckpoint(eager) + else: + jdf = self._jdf.localCheckpoint(eager, self._sc._getJavaStorageLevel(storageLevel)) return DataFrame(jdf, self.sparkSession) def withWatermark(self, eventTime: str, delayThreshold: str) -> ParentDataFrame: diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index bb4dcb38c9e58..3d5b845fcd24c 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -2134,8 +2134,10 @@ def checkpoint(self, eager: bool = True) -> ParentDataFrame: assert isinstance(checkpointed._plan, plan.CachedRemoteRelation) return checkpointed - def localCheckpoint(self, eager: bool = True) -> ParentDataFrame: - cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager) + def localCheckpoint( + self, eager: bool = True, storageLevel: Optional[StorageLevel] = None + ) -> ParentDataFrame: + cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager, storage_level=storageLevel) _, properties, self._execution_info = self._session.client.execute_command( cmd.command(self._session.client) ) diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index fbed0eabc684f..b74f863db1e83 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -1868,21 +1868,29 @@ def command(self, session: "SparkConnectClient") -> proto.Command: class Checkpoint(LogicalPlan): - def __init__(self, child: Optional["LogicalPlan"], local: bool, eager: bool) -> None: + def __init__( + self, + child: Optional["LogicalPlan"], + local: bool, + eager: bool, + storage_level: Optional[StorageLevel] = None, + ) -> None: super().__init__(child) self._local = local self._eager = eager + self._storage_level = storage_level def command(self, session: "SparkConnectClient") -> proto.Command: cmd = proto.Command() assert self._child is not None - cmd.checkpoint_command.CopyFrom( - proto.CheckpointCommand( - relation=self._child.plan(session), - local=self._local, - eager=self._eager, - ) + checkpoint_command = proto.CheckpointCommand( + relation=self._child.plan(session), + local=self._local, + eager=self._eager, ) + if self._storage_level is not None: + checkpoint_command.storage_level.CopyFrom(storage_level_to_proto(self._storage_level)) + cmd.checkpoint_command.CopyFrom(checkpoint_command) return cmd diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index 43390ffa36d33..562e9d817f5fe 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -35,7 +35,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x90\r\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x81\x01\n$streaming_query_listener_bus_command\x18\x0b \x01(\x0b\x32/.spark.connect.StreamingQueryListenerBusCommandH\x00R streamingQueryListenerBusCommand\x12\x64\n\x14register_data_source\x18\x0c \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R\x12registerDataSource\x12t\n\x1f\x63reate_resource_profile_command\x18\r \x01(\x0b\x32+.spark.connect.CreateResourceProfileCommandH\x00R\x1c\x63reateResourceProfileCommand\x12Q\n\x12\x63heckpoint_command\x18\x0e \x01(\x0b\x32 .spark.connect.CheckpointCommandH\x00R\x11\x63heckpointCommand\x12\x84\x01\n%remove_cached_remote_relation_command\x18\x0f \x01(\x0b\x32\x30.spark.connect.RemoveCachedRemoteRelationCommandH\x00R!removeCachedRemoteRelationCommand\x12_\n\x18merge_into_table_command\x18\x10 \x01(\x0b\x32$.spark.connect.MergeIntoTableCommandH\x00R\x15mergeIntoTableCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xaa\x04\n\nSqlCommand\x12\x14\n\x03sql\x18\x01 \x01(\tB\x02\x18\x01R\x03sql\x12;\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12Z\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32-.spark.connect.SqlCommand.NamedArgumentsEntryB\x02\x18\x01R\x0enamedArguments\x12\x42\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionB\x02\x18\x01R\x0cposArguments\x12-\n\x05input\x18\x06 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xca\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x12-\n\x12\x63lustering_columns\x18\n \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xdc\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x12-\n\x12\x63lustering_columns\x18\t \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xd8\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x12\x36\n\x17\x63lustering_column_names\x18\x0f \x03(\tR\x15\x63lusteringColumnNames\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"\xd4\x01\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12<\n\x18query_started_event_json\x18\x03 \x01(\tH\x00R\x15queryStartedEventJson\x88\x01\x01\x42\x1b\n\x19_query_started_event_json"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xbd\x06\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1a\xcd\x01\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x12U\n\x17python_listener_payload\x18\x02 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x15pythonListenerPayload\x88\x01\x01\x12\x0e\n\x02id\x18\x03 \x01(\tR\x02idB\x1a\n\x18_python_listener_payloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\xad\x01\n StreamingQueryListenerBusCommand\x12;\n\x19\x61\x64\x64_listener_bus_listener\x18\x01 \x01(\x08H\x00R\x16\x61\x64\x64ListenerBusListener\x12\x41\n\x1cremove_listener_bus_listener\x18\x02 \x01(\x08H\x00R\x19removeListenerBusListenerB\t\n\x07\x63ommand"\x83\x01\n\x1bStreamingQueryListenerEvent\x12\x1d\n\nevent_json\x18\x01 \x01(\tR\teventJson\x12\x45\n\nevent_type\x18\x02 \x01(\x0e\x32&.spark.connect.StreamingQueryEventTypeR\teventType"\xcc\x01\n"StreamingQueryListenerEventsResult\x12\x42\n\x06\x65vents\x18\x01 \x03(\x0b\x32*.spark.connect.StreamingQueryListenerEventR\x06\x65vents\x12\x42\n\x1blistener_bus_listener_added\x18\x02 \x01(\x08H\x00R\x18listenerBusListenerAdded\x88\x01\x01\x42\x1e\n\x1c_listener_bus_listener_added"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01"X\n\x1c\x43reateResourceProfileCommand\x12\x38\n\x07profile\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ResourceProfileR\x07profile"C\n"CreateResourceProfileCommandResult\x12\x1d\n\nprofile_id\x18\x01 \x01(\x05R\tprofileId"d\n!RemoveCachedRemoteRelationCommand\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"t\n\x11\x43heckpointCommand\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x14\n\x05local\x18\x02 \x01(\x08R\x05local\x12\x14\n\x05\x65\x61ger\x18\x03 \x01(\x08R\x05\x65\x61ger"\xe8\x03\n\x15MergeIntoTableCommand\x12*\n\x11target_table_name\x18\x01 \x01(\tR\x0ftargetTableName\x12\x43\n\x11source_table_plan\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x0fsourceTablePlan\x12\x42\n\x0fmerge_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0emergeCondition\x12>\n\rmatch_actions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cmatchActions\x12I\n\x13not_matched_actions\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11notMatchedActions\x12[\n\x1dnot_matched_by_source_actions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x19notMatchedBySourceActions\x12\x32\n\x15with_schema_evolution\x18\x07 \x01(\x08R\x13withSchemaEvolution*\x85\x01\n\x17StreamingQueryEventType\x12\x1e\n\x1aQUERY_PROGRESS_UNSPECIFIED\x10\x00\x12\x18\n\x14QUERY_PROGRESS_EVENT\x10\x01\x12\x1a\n\x16QUERY_TERMINATED_EVENT\x10\x02\x12\x14\n\x10QUERY_IDLE_EVENT\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto"\x90\r\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFrameViewCommandH\x00R\x13\x63reateDataframeView\x12O\n\x12write_operation_v2\x18\x04 \x01(\x0b\x32\x1f.spark.connect.WriteOperationV2H\x00R\x10writeOperationV2\x12<\n\x0bsql_command\x18\x05 \x01(\x0b\x32\x19.spark.connect.SqlCommandH\x00R\nsqlCommand\x12k\n\x1cwrite_stream_operation_start\x18\x06 \x01(\x0b\x32(.spark.connect.WriteStreamOperationStartH\x00R\x19writeStreamOperationStart\x12^\n\x17streaming_query_command\x18\x07 \x01(\x0b\x32$.spark.connect.StreamingQueryCommandH\x00R\x15streamingQueryCommand\x12X\n\x15get_resources_command\x18\x08 \x01(\x0b\x32".spark.connect.GetResourcesCommandH\x00R\x13getResourcesCommand\x12t\n\x1fstreaming_query_manager_command\x18\t \x01(\x0b\x32+.spark.connect.StreamingQueryManagerCommandH\x00R\x1cstreamingQueryManagerCommand\x12m\n\x17register_table_function\x18\n \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R\x15registerTableFunction\x12\x81\x01\n$streaming_query_listener_bus_command\x18\x0b \x01(\x0b\x32/.spark.connect.StreamingQueryListenerBusCommandH\x00R streamingQueryListenerBusCommand\x12\x64\n\x14register_data_source\x18\x0c \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R\x12registerDataSource\x12t\n\x1f\x63reate_resource_profile_command\x18\r \x01(\x0b\x32+.spark.connect.CreateResourceProfileCommandH\x00R\x1c\x63reateResourceProfileCommand\x12Q\n\x12\x63heckpoint_command\x18\x0e \x01(\x0b\x32 .spark.connect.CheckpointCommandH\x00R\x11\x63heckpointCommand\x12\x84\x01\n%remove_cached_remote_relation_command\x18\x0f \x01(\x0b\x32\x30.spark.connect.RemoveCachedRemoteRelationCommandH\x00R!removeCachedRemoteRelationCommand\x12_\n\x18merge_into_table_command\x18\x10 \x01(\x0b\x32$.spark.connect.MergeIntoTableCommandH\x00R\x15mergeIntoTableCommand\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x0e\n\x0c\x63ommand_type"\xaa\x04\n\nSqlCommand\x12\x14\n\x03sql\x18\x01 \x01(\tB\x02\x18\x01R\x03sql\x12;\n\x04\x61rgs\x18\x02 \x03(\x0b\x32#.spark.connect.SqlCommand.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12Z\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32-.spark.connect.SqlCommand.NamedArgumentsEntryB\x02\x18\x01R\x0enamedArguments\x12\x42\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionB\x02\x18\x01R\x0cposArguments\x12-\n\x05input\x18\x06 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"\x96\x01\n\x1a\x43reateDataFrameViewCommand\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x1b\n\tis_global\x18\x03 \x01(\x08R\x08isGlobal\x12\x18\n\x07replace\x18\x04 \x01(\x08R\x07replace"\xca\x08\n\x0eWriteOperation\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1b\n\x06source\x18\x02 \x01(\tH\x01R\x06source\x88\x01\x01\x12\x14\n\x04path\x18\x03 \x01(\tH\x00R\x04path\x12?\n\x05table\x18\x04 \x01(\x0b\x32\'.spark.connect.WriteOperation.SaveTableH\x00R\x05table\x12:\n\x04mode\x18\x05 \x01(\x0e\x32&.spark.connect.WriteOperation.SaveModeR\x04mode\x12*\n\x11sort_column_names\x18\x06 \x03(\tR\x0fsortColumnNames\x12\x31\n\x14partitioning_columns\x18\x07 \x03(\tR\x13partitioningColumns\x12\x43\n\tbucket_by\x18\x08 \x01(\x0b\x32&.spark.connect.WriteOperation.BucketByR\x08\x62ucketBy\x12\x44\n\x07options\x18\t \x03(\x0b\x32*.spark.connect.WriteOperation.OptionsEntryR\x07options\x12-\n\x12\x63lustering_columns\x18\n \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x82\x02\n\tSaveTable\x12\x1d\n\ntable_name\x18\x01 \x01(\tR\ttableName\x12X\n\x0bsave_method\x18\x02 \x01(\x0e\x32\x37.spark.connect.WriteOperation.SaveTable.TableSaveMethodR\nsaveMethod"|\n\x0fTableSaveMethod\x12!\n\x1dTABLE_SAVE_METHOD_UNSPECIFIED\x10\x00\x12#\n\x1fTABLE_SAVE_METHOD_SAVE_AS_TABLE\x10\x01\x12!\n\x1dTABLE_SAVE_METHOD_INSERT_INTO\x10\x02\x1a[\n\x08\x42ucketBy\x12.\n\x13\x62ucket_column_names\x18\x01 \x03(\tR\x11\x62ucketColumnNames\x12\x1f\n\x0bnum_buckets\x18\x02 \x01(\x05R\nnumBuckets"\x89\x01\n\x08SaveMode\x12\x19\n\x15SAVE_MODE_UNSPECIFIED\x10\x00\x12\x14\n\x10SAVE_MODE_APPEND\x10\x01\x12\x17\n\x13SAVE_MODE_OVERWRITE\x10\x02\x12\x1d\n\x19SAVE_MODE_ERROR_IF_EXISTS\x10\x03\x12\x14\n\x10SAVE_MODE_IGNORE\x10\x04\x42\x0b\n\tsave_typeB\t\n\x07_source"\xdc\x06\n\x10WriteOperationV2\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\ntable_name\x18\x02 \x01(\tR\ttableName\x12\x1f\n\x08provider\x18\x03 \x01(\tH\x00R\x08provider\x88\x01\x01\x12L\n\x14partitioning_columns\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13partitioningColumns\x12\x46\n\x07options\x18\x05 \x03(\x0b\x32,.spark.connect.WriteOperationV2.OptionsEntryR\x07options\x12_\n\x10table_properties\x18\x06 \x03(\x0b\x32\x34.spark.connect.WriteOperationV2.TablePropertiesEntryR\x0ftableProperties\x12\x38\n\x04mode\x18\x07 \x01(\x0e\x32$.spark.connect.WriteOperationV2.ModeR\x04mode\x12J\n\x13overwrite_condition\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x12overwriteCondition\x12-\n\x12\x63lustering_columns\x18\t \x03(\tR\x11\x63lusteringColumns\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"\x9f\x01\n\x04Mode\x12\x14\n\x10MODE_UNSPECIFIED\x10\x00\x12\x0f\n\x0bMODE_CREATE\x10\x01\x12\x12\n\x0eMODE_OVERWRITE\x10\x02\x12\x1d\n\x19MODE_OVERWRITE_PARTITIONS\x10\x03\x12\x0f\n\x0bMODE_APPEND\x10\x04\x12\x10\n\x0cMODE_REPLACE\x10\x05\x12\x1a\n\x16MODE_CREATE_OR_REPLACE\x10\x06\x42\x0b\n\t_provider"\xd8\x06\n\x19WriteStreamOperationStart\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06\x66ormat\x18\x02 \x01(\tR\x06\x66ormat\x12O\n\x07options\x18\x03 \x03(\x0b\x32\x35.spark.connect.WriteStreamOperationStart.OptionsEntryR\x07options\x12:\n\x19partitioning_column_names\x18\x04 \x03(\tR\x17partitioningColumnNames\x12:\n\x18processing_time_interval\x18\x05 \x01(\tH\x00R\x16processingTimeInterval\x12%\n\ravailable_now\x18\x06 \x01(\x08H\x00R\x0c\x61vailableNow\x12\x14\n\x04once\x18\x07 \x01(\x08H\x00R\x04once\x12\x46\n\x1e\x63ontinuous_checkpoint_interval\x18\x08 \x01(\tH\x00R\x1c\x63ontinuousCheckpointInterval\x12\x1f\n\x0boutput_mode\x18\t \x01(\tR\noutputMode\x12\x1d\n\nquery_name\x18\n \x01(\tR\tqueryName\x12\x14\n\x04path\x18\x0b \x01(\tH\x01R\x04path\x12\x1f\n\ntable_name\x18\x0c \x01(\tH\x01R\ttableName\x12N\n\x0e\x66oreach_writer\x18\r \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\rforeachWriter\x12L\n\rforeach_batch\x18\x0e \x01(\x0b\x32\'.spark.connect.StreamingForeachFunctionR\x0c\x66oreachBatch\x12\x36\n\x17\x63lustering_column_names\x18\x0f \x03(\tR\x15\x63lusteringColumnNames\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07triggerB\x12\n\x10sink_destination"\xb3\x01\n\x18StreamingForeachFunction\x12\x43\n\x0fpython_function\x18\x01 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x0epythonFunction\x12\x46\n\x0escala_function\x18\x02 \x01(\x0b\x32\x1d.spark.connect.ScalarScalaUDFH\x00R\rscalaFunctionB\n\n\x08\x66unction"\xd4\x01\n\x1fWriteStreamOperationStartResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12<\n\x18query_started_event_json\x18\x03 \x01(\tH\x00R\x15queryStartedEventJson\x88\x01\x01\x42\x1b\n\x19_query_started_event_json"A\n\x18StreamingQueryInstanceId\x12\x0e\n\x02id\x18\x01 \x01(\tR\x02id\x12\x15\n\x06run_id\x18\x02 \x01(\tR\x05runId"\xf8\x04\n\x15StreamingQueryCommand\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12\x18\n\x06status\x18\x02 \x01(\x08H\x00R\x06status\x12%\n\rlast_progress\x18\x03 \x01(\x08H\x00R\x0clastProgress\x12)\n\x0frecent_progress\x18\x04 \x01(\x08H\x00R\x0erecentProgress\x12\x14\n\x04stop\x18\x05 \x01(\x08H\x00R\x04stop\x12\x34\n\x15process_all_available\x18\x06 \x01(\x08H\x00R\x13processAllAvailable\x12O\n\x07\x65xplain\x18\x07 \x01(\x0b\x32\x33.spark.connect.StreamingQueryCommand.ExplainCommandH\x00R\x07\x65xplain\x12\x1e\n\texception\x18\x08 \x01(\x08H\x00R\texception\x12k\n\x11\x61wait_termination\x18\t \x01(\x0b\x32<.spark.connect.StreamingQueryCommand.AwaitTerminationCommandH\x00R\x10\x61waitTermination\x1a,\n\x0e\x45xplainCommand\x12\x1a\n\x08\x65xtended\x18\x01 \x01(\x08R\x08\x65xtended\x1aL\n\x17\x41waitTerminationCommand\x12"\n\ntimeout_ms\x18\x02 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_msB\t\n\x07\x63ommand"\xf5\x08\n\x1bStreamingQueryCommandResult\x12\x42\n\x08query_id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x07queryId\x12Q\n\x06status\x18\x02 \x01(\x0b\x32\x37.spark.connect.StreamingQueryCommandResult.StatusResultH\x00R\x06status\x12j\n\x0frecent_progress\x18\x03 \x01(\x0b\x32?.spark.connect.StreamingQueryCommandResult.RecentProgressResultH\x00R\x0erecentProgress\x12T\n\x07\x65xplain\x18\x04 \x01(\x0b\x32\x38.spark.connect.StreamingQueryCommandResult.ExplainResultH\x00R\x07\x65xplain\x12Z\n\texception\x18\x05 \x01(\x0b\x32:.spark.connect.StreamingQueryCommandResult.ExceptionResultH\x00R\texception\x12p\n\x11\x61wait_termination\x18\x06 \x01(\x0b\x32\x41.spark.connect.StreamingQueryCommandResult.AwaitTerminationResultH\x00R\x10\x61waitTermination\x1a\xaa\x01\n\x0cStatusResult\x12%\n\x0estatus_message\x18\x01 \x01(\tR\rstatusMessage\x12*\n\x11is_data_available\x18\x02 \x01(\x08R\x0fisDataAvailable\x12*\n\x11is_trigger_active\x18\x03 \x01(\x08R\x0fisTriggerActive\x12\x1b\n\tis_active\x18\x04 \x01(\x08R\x08isActive\x1aH\n\x14RecentProgressResult\x12\x30\n\x14recent_progress_json\x18\x05 \x03(\tR\x12recentProgressJson\x1a\'\n\rExplainResult\x12\x16\n\x06result\x18\x01 \x01(\tR\x06result\x1a\xc5\x01\n\x0f\x45xceptionResult\x12\x30\n\x11\x65xception_message\x18\x01 \x01(\tH\x00R\x10\x65xceptionMessage\x88\x01\x01\x12$\n\x0b\x65rror_class\x18\x02 \x01(\tH\x01R\nerrorClass\x88\x01\x01\x12$\n\x0bstack_trace\x18\x03 \x01(\tH\x02R\nstackTrace\x88\x01\x01\x42\x14\n\x12_exception_messageB\x0e\n\x0c_error_classB\x0e\n\x0c_stack_trace\x1a\x38\n\x16\x41waitTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminatedB\r\n\x0bresult_type"\xbd\x06\n\x1cStreamingQueryManagerCommand\x12\x18\n\x06\x61\x63tive\x18\x01 \x01(\x08H\x00R\x06\x61\x63tive\x12\x1d\n\tget_query\x18\x02 \x01(\tH\x00R\x08getQuery\x12|\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32\x46.spark.connect.StreamingQueryManagerCommand.AwaitAnyTerminationCommandH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12n\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0b\x61\x64\x64Listener\x12t\n\x0fremove_listener\x18\x06 \x01(\x0b\x32I.spark.connect.StreamingQueryManagerCommand.StreamingQueryListenerCommandH\x00R\x0eremoveListener\x12\'\n\x0elist_listeners\x18\x07 \x01(\x08H\x00R\rlistListeners\x1aO\n\x1a\x41waitAnyTerminationCommand\x12"\n\ntimeout_ms\x18\x01 \x01(\x03H\x00R\ttimeoutMs\x88\x01\x01\x42\r\n\x0b_timeout_ms\x1a\xcd\x01\n\x1dStreamingQueryListenerCommand\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x12U\n\x17python_listener_payload\x18\x02 \x01(\x0b\x32\x18.spark.connect.PythonUDFH\x00R\x15pythonListenerPayload\x88\x01\x01\x12\x0e\n\x02id\x18\x03 \x01(\tR\x02idB\x1a\n\x18_python_listener_payloadB\t\n\x07\x63ommand"\xb4\x08\n"StreamingQueryManagerCommandResult\x12X\n\x06\x61\x63tive\x18\x01 \x01(\x0b\x32>.spark.connect.StreamingQueryManagerCommandResult.ActiveResultH\x00R\x06\x61\x63tive\x12`\n\x05query\x18\x02 \x01(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceH\x00R\x05query\x12\x81\x01\n\x15\x61wait_any_termination\x18\x03 \x01(\x0b\x32K.spark.connect.StreamingQueryManagerCommandResult.AwaitAnyTerminationResultH\x00R\x13\x61waitAnyTermination\x12+\n\x10reset_terminated\x18\x04 \x01(\x08H\x00R\x0fresetTerminated\x12#\n\x0c\x61\x64\x64_listener\x18\x05 \x01(\x08H\x00R\x0b\x61\x64\x64Listener\x12)\n\x0fremove_listener\x18\x06 \x01(\x08H\x00R\x0eremoveListener\x12{\n\x0elist_listeners\x18\x07 \x01(\x0b\x32R.spark.connect.StreamingQueryManagerCommandResult.ListStreamingQueryListenerResultH\x00R\rlistListeners\x1a\x7f\n\x0c\x41\x63tiveResult\x12o\n\x0e\x61\x63tive_queries\x18\x01 \x03(\x0b\x32H.spark.connect.StreamingQueryManagerCommandResult.StreamingQueryInstanceR\ractiveQueries\x1as\n\x16StreamingQueryInstance\x12\x37\n\x02id\x18\x01 \x01(\x0b\x32\'.spark.connect.StreamingQueryInstanceIdR\x02id\x12\x17\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x88\x01\x01\x42\x07\n\x05_name\x1a;\n\x19\x41waitAnyTerminationResult\x12\x1e\n\nterminated\x18\x01 \x01(\x08R\nterminated\x1aK\n\x1eStreamingQueryListenerInstance\x12)\n\x10listener_payload\x18\x01 \x01(\x0cR\x0flistenerPayload\x1a\x45\n ListStreamingQueryListenerResult\x12!\n\x0clistener_ids\x18\x01 \x03(\tR\x0blistenerIdsB\r\n\x0bresult_type"\xad\x01\n StreamingQueryListenerBusCommand\x12;\n\x19\x61\x64\x64_listener_bus_listener\x18\x01 \x01(\x08H\x00R\x16\x61\x64\x64ListenerBusListener\x12\x41\n\x1cremove_listener_bus_listener\x18\x02 \x01(\x08H\x00R\x19removeListenerBusListenerB\t\n\x07\x63ommand"\x83\x01\n\x1bStreamingQueryListenerEvent\x12\x1d\n\nevent_json\x18\x01 \x01(\tR\teventJson\x12\x45\n\nevent_type\x18\x02 \x01(\x0e\x32&.spark.connect.StreamingQueryEventTypeR\teventType"\xcc\x01\n"StreamingQueryListenerEventsResult\x12\x42\n\x06\x65vents\x18\x01 \x03(\x0b\x32*.spark.connect.StreamingQueryListenerEventR\x06\x65vents\x12\x42\n\x1blistener_bus_listener_added\x18\x02 \x01(\x08H\x00R\x18listenerBusListenerAdded\x88\x01\x01\x42\x1e\n\x1c_listener_bus_listener_added"\x15\n\x13GetResourcesCommand"\xd4\x01\n\x19GetResourcesCommandResult\x12U\n\tresources\x18\x01 \x03(\x0b\x32\x37.spark.connect.GetResourcesCommandResult.ResourcesEntryR\tresources\x1a`\n\x0eResourcesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x38\n\x05value\x18\x02 \x01(\x0b\x32".spark.connect.ResourceInformationR\x05value:\x02\x38\x01"X\n\x1c\x43reateResourceProfileCommand\x12\x38\n\x07profile\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ResourceProfileR\x07profile"C\n"CreateResourceProfileCommandResult\x12\x1d\n\nprofile_id\x18\x01 \x01(\x05R\tprofileId"d\n!RemoveCachedRemoteRelationCommand\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"\xcd\x01\n\x11\x43heckpointCommand\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x14\n\x05local\x18\x02 \x01(\x08R\x05local\x12\x14\n\x05\x65\x61ger\x18\x03 \x01(\x08R\x05\x65\x61ger\x12\x45\n\rstorage_level\x18\x04 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level"\xe8\x03\n\x15MergeIntoTableCommand\x12*\n\x11target_table_name\x18\x01 \x01(\tR\x0ftargetTableName\x12\x43\n\x11source_table_plan\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x0fsourceTablePlan\x12\x42\n\x0fmerge_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x0emergeCondition\x12>\n\rmatch_actions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cmatchActions\x12I\n\x13not_matched_actions\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11notMatchedActions\x12[\n\x1dnot_matched_by_source_actions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x19notMatchedBySourceActions\x12\x32\n\x15with_schema_evolution\x18\x07 \x01(\x08R\x13withSchemaEvolution*\x85\x01\n\x17StreamingQueryEventType\x12\x1e\n\x1aQUERY_PROGRESS_UNSPECIFIED\x10\x00\x12\x18\n\x14QUERY_PROGRESS_EVENT\x10\x01\x12\x1a\n\x16QUERY_TERMINATED_EVENT\x10\x02\x12\x14\n\x10QUERY_IDLE_EVENT\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -71,8 +71,8 @@ _WRITESTREAMOPERATIONSTART_OPTIONSENTRY._serialized_options = b"8\001" _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._options = None _GETRESOURCESCOMMANDRESULT_RESOURCESENTRY._serialized_options = b"8\001" - _STREAMINGQUERYEVENTTYPE._serialized_start = 11162 - _STREAMINGQUERYEVENTTYPE._serialized_end = 11295 + _STREAMINGQUERYEVENTTYPE._serialized_start = 11252 + _STREAMINGQUERYEVENTTYPE._serialized_end = 11385 _COMMAND._serialized_start = 167 _COMMAND._serialized_end = 1847 _SQLCOMMAND._serialized_start = 1850 @@ -167,8 +167,8 @@ _CREATERESOURCEPROFILECOMMANDRESULT._serialized_end = 10448 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_start = 10450 _REMOVECACHEDREMOTERELATIONCOMMAND._serialized_end = 10550 - _CHECKPOINTCOMMAND._serialized_start = 10552 - _CHECKPOINTCOMMAND._serialized_end = 10668 - _MERGEINTOTABLECOMMAND._serialized_start = 10671 - _MERGEINTOTABLECOMMAND._serialized_end = 11159 + _CHECKPOINTCOMMAND._serialized_start = 10553 + _CHECKPOINTCOMMAND._serialized_end = 10758 + _MERGEINTOTABLECOMMAND._serialized_start = 10761 + _MERGEINTOTABLECOMMAND._serialized_end = 11249 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 2dedcdfc8e3e4..6192a29607cbf 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -2188,6 +2188,7 @@ class CheckpointCommand(google.protobuf.message.Message): RELATION_FIELD_NUMBER: builtins.int LOCAL_FIELD_NUMBER: builtins.int EAGER_FIELD_NUMBER: builtins.int + STORAGE_LEVEL_FIELD_NUMBER: builtins.int @property def relation(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: """(Required) The logical plan to checkpoint.""" @@ -2197,22 +2198,46 @@ class CheckpointCommand(google.protobuf.message.Message): """ eager: builtins.bool """(Required) Whether to checkpoint this dataframe immediately.""" + @property + def storage_level(self) -> pyspark.sql.connect.proto.common_pb2.StorageLevel: + """(Optional) For local checkpoint, the storage level to use.""" def __init__( self, *, relation: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., local: builtins.bool = ..., eager: builtins.bool = ..., + storage_level: pyspark.sql.connect.proto.common_pb2.StorageLevel | None = ..., ) -> None: ... def HasField( - self, field_name: typing_extensions.Literal["relation", b"relation"] + self, + field_name: typing_extensions.Literal[ + "_storage_level", + b"_storage_level", + "relation", + b"relation", + "storage_level", + b"storage_level", + ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "eager", b"eager", "local", b"local", "relation", b"relation" + "_storage_level", + b"_storage_level", + "eager", + b"eager", + "local", + b"local", + "relation", + b"relation", + "storage_level", + b"storage_level", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_storage_level", b"_storage_level"] + ) -> typing_extensions.Literal["storage_level"] | None: ... global___CheckpointCommand = CheckpointCommand diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c21e2271a64ac..62f2129e5be62 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1017,7 +1017,9 @@ def checkpoint(self, eager: bool = True) -> "DataFrame": """ ... - def localCheckpoint(self, eager: bool = True) -> "DataFrame": + def localCheckpoint( + self, eager: bool = True, storageLevel: Optional[StorageLevel] = None + ) -> "DataFrame": """Returns a locally checkpointed version of this :class:`DataFrame`. Checkpointing can be used to truncate the logical plan of this :class:`DataFrame`, which is especially useful in iterative algorithms where the plan may grow exponentially. Local checkpoints @@ -1028,12 +1030,17 @@ def localCheckpoint(self, eager: bool = True) -> "DataFrame": .. versionchanged:: 4.0.0 Supports Spark Connect. + Added storageLevel parameter. Parameters ---------- eager : bool, optional, default True Whether to checkpoint this :class:`DataFrame` immediately. + storageLevel : :class:`StorageLevel`, optional, default None + The StorageLevel with which the checkpoint will be stored. + If not specified, default for RDD local checkpoints. + Returns ------- :class:`DataFrame` diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index b5af00a4e7b78..cc43804949e84 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -950,11 +950,17 @@ def test_union_classmethod_usage(self): def test_isinstance_dataframe(self): self.assertIsInstance(self.spark.range(1), DataFrame) - def test_checkpoint_dataframe(self): + def test_local_checkpoint_dataframe(self): with io.StringIO() as buf, redirect_stdout(buf): self.spark.range(1).localCheckpoint().explain() self.assertIn("ExistingRDD", buf.getvalue()) + def test_local_checkpoint_dataframe_with_storage_level(self): + # We don't have a way to reach into the server and assert the storage level server side, but + # this test should cover for unexpected errors in the API. + df = self.spark.range(10).localCheckpoint(eager=True, storageLevel=StorageLevel.DISK_ONLY) + df.collect() + def test_transpose(self): df = self.spark.createDataFrame([{"a": "x", "b": "y", "c": "z"}]) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index c277b4cab85c1..d6442930d1c5c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -314,7 +314,8 @@ abstract class Dataset[T] extends Serializable { * @group basic * @since 2.1.0 */ - def checkpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = true) + def checkpoint(): Dataset[T] = + checkpoint(eager = true, reliableCheckpoint = true, storageLevel = None) /** * Returns a checkpointed version of this Dataset. Checkpointing can be used to truncate the @@ -334,7 +335,7 @@ abstract class Dataset[T] extends Serializable { * @since 2.1.0 */ def checkpoint(eager: Boolean): Dataset[T] = - checkpoint(eager = eager, reliableCheckpoint = true) + checkpoint(eager = eager, reliableCheckpoint = true, storageLevel = None) /** * Eagerly locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used @@ -345,7 +346,8 @@ abstract class Dataset[T] extends Serializable { * @group basic * @since 2.3.0 */ - def localCheckpoint(): Dataset[T] = checkpoint(eager = true, reliableCheckpoint = false) + def localCheckpoint(): Dataset[T] = + checkpoint(eager = true, reliableCheckpoint = false, storageLevel = None) /** * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to @@ -365,7 +367,29 @@ abstract class Dataset[T] extends Serializable { * @since 2.3.0 */ def localCheckpoint(eager: Boolean): Dataset[T] = - checkpoint(eager = eager, reliableCheckpoint = false) + checkpoint(eager = eager, reliableCheckpoint = false, storageLevel = None) + + /** + * Locally checkpoints a Dataset and return the new Dataset. Checkpointing can be used to + * truncate the logical plan of this Dataset, which is especially useful in iterative algorithms + * where the plan may grow exponentially. Local checkpoints are written to executor storage and + * despite potentially faster they are unreliable and may compromise job completion. + * + * @param eager + * Whether to checkpoint this dataframe immediately + * @param storageLevel + * StorageLevel with which to checkpoint the data. + * @note + * When checkpoint is used with eager = false, the final data that is checkpointed after the + * first action may be different from the data that was used during the job due to + * non-determinism of the underlying operation and retries. If checkpoint is used to achieve + * saving a deterministic snapshot of the data, eager = true should be used. Otherwise, it is + * only deterministic after the first execution, after the checkpoint was finalized. + * @group basic + * @since 4.0.0 + */ + def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] = + checkpoint(eager = eager, reliableCheckpoint = false, storageLevel = Some(storageLevel)) /** * Returns a checkpointed version of this Dataset. @@ -375,8 +399,14 @@ abstract class Dataset[T] extends Serializable { * @param reliableCheckpoint * Whether to create a reliable checkpoint saved to files inside the checkpoint directory. If * false creates a local checkpoint using the caching subsystem - */ - protected def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] + * @param storageLevel + * Option. If defined, StorageLevel with which to checkpoint the data. Only with + * reliableCheckpoint = false. + */ + protected def checkpoint( + eager: Boolean, + reliableCheckpoint: Boolean, + storageLevel: Option[StorageLevel]): Dataset[T] /** * Defines an event time watermark for this [[Dataset]]. A watermark tracks a point in time diff --git a/sql/connect/common/src/main/protobuf/spark/connect/commands.proto b/sql/connect/common/src/main/protobuf/spark/connect/commands.proto index 71189a3c43a19..a01d4369a7aed 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -507,6 +507,9 @@ message CheckpointCommand { // (Required) Whether to checkpoint this dataframe immediately. bool eager = 3; + + // (Optional) For local checkpoint, the storage level to use. + optional StorageLevel storage_level = 4; } message MergeIntoTableCommand { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 231e54ff77d29..25fd7d13b7d48 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3354,9 +3354,18 @@ class SparkConnectPlanner( responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = { val target = Dataset .ofRows(session, transformRelation(checkpointCommand.getRelation)) - val checkpointed = target.checkpoint( - eager = checkpointCommand.getEager, - reliableCheckpoint = !checkpointCommand.getLocal) + val checkpointed = if (checkpointCommand.getLocal) { + if (checkpointCommand.hasStorageLevel) { + target.localCheckpoint( + eager = checkpointCommand.getEager, + storageLevel = + StorageLevelProtoConverter.toStorageLevel(checkpointCommand.getStorageLevel)) + } else { + target.localCheckpoint(eager = checkpointCommand.getEager) + } + } else { + target.checkpoint(eager = checkpointCommand.getEager) + } val dfId = UUID.randomUUID().toString logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1c5df1163eb78..b7b96f0c98274 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -540,13 +540,18 @@ class Dataset[T] private[sql]( def isStreaming: Boolean = logicalPlan.isStreaming /** @inheritdoc */ - protected[sql] def checkpoint(eager: Boolean, reliableCheckpoint: Boolean): Dataset[T] = { + protected[sql] def checkpoint( + eager: Boolean, + reliableCheckpoint: Boolean, + storageLevel: Option[StorageLevel]): Dataset[T] = { val actionName = if (reliableCheckpoint) "checkpoint" else "localCheckpoint" withAction(actionName, queryExecution) { physicalPlan => val internalRdd = physicalPlan.execute().map(_.copy()) if (reliableCheckpoint) { + assert(storageLevel.isEmpty, "StorageLevel should not be defined for reliableCheckpoint") internalRdd.checkpoint() } else { + storageLevel.foreach(storageLevel => internalRdd.persist(storageLevel)) internalRdd.localCheckpoint() } @@ -1794,6 +1799,10 @@ class Dataset[T] private[sql]( /** @inheritdoc */ override def localCheckpoint(eager: Boolean): Dataset[T] = super.localCheckpoint(eager) + /** @inheritdoc */ + override def localCheckpoint(eager: Boolean, storageLevel: StorageLevel): Dataset[T] = + super.localCheckpoint(eager, storageLevel) + /** @inheritdoc */ override def joinWith[U](other: Dataset[U], condition: Column): Dataset[(T, U)] = super.joinWith(other, condition) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 089ce79201dd8..85f296665b6e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1849,6 +1849,26 @@ class DatasetSuite extends QueryTest } } + test("Dataset().localCheckpoint() lazy with StorageLevel") { + val df = spark.range(10).repartition($"id" % 2) + val checkpointedDf = df.localCheckpoint(eager = false, StorageLevel.DISK_ONLY) + val checkpointedPlan = checkpointedDf.queryExecution.analyzed + val rdd = checkpointedPlan.asInstanceOf[LogicalRDD].rdd + assert(rdd.getStorageLevel == StorageLevel.DISK_ONLY) + assert(!rdd.isCheckpointed) + checkpointedDf.collect() + assert(rdd.isCheckpointed) + } + + test("Dataset().localCheckpoint() eager with StorageLevel") { + val df = spark.range(10).repartition($"id" % 2) + val checkpointedDf = df.localCheckpoint(eager = true, StorageLevel.DISK_ONLY) + val checkpointedPlan = checkpointedDf.queryExecution.analyzed + val rdd = checkpointedPlan.asInstanceOf[LogicalRDD].rdd + assert(rdd.isCheckpointed) + assert(rdd.getStorageLevel == StorageLevel.DISK_ONLY) + } + test("identity map for primitive arrays") { val arrayByte = Array(1.toByte, 2.toByte, 3.toByte) val arrayInt = Array(1, 2, 3) From 97a5aa6ef7cab46564325dafed9f8266ac7f8999 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Wed, 9 Oct 2024 15:48:38 +0200 Subject: [PATCH 194/250] [SPARK-49873][SQL] Assign proper error class for _LEGACY_ERROR_TEMP_1325 ### What changes were proposed in this pull request? This PR proposes to assign proper error class for _LEGACY_ERROR_TEMP_1325 ### Why are the changes needed? To improve user facing error message by providing proper error condition and sql state ### Does this PR introduce _any_ user-facing change? Improve user-facing error message ### How was this patch tested? Updated the existing UT ### Was this patch authored or co-authored using generative AI tooling? No Closes #48346 from itholic/legacy_1325. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../main/resources/error/error-conditions.json | 5 ----- .../sql/errors/QueryCompilationErrors.scala | 5 +++-- .../spark/sql/internal/SQLConfSuite.scala | 18 ++++++++++++------ .../thriftserver/HiveThriftServer2Suites.scala | 2 +- .../sql/hive/execution/SQLQuerySuite.scala | 11 ++++++++--- 5 files changed, 24 insertions(+), 17 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 1b7f42e105077..f6317d731c77b 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6696,11 +6696,6 @@ "The pivot column has more than distinct values, this could indicate an error. If this was intended, set to at least the number of distinct values of the pivot column." ] }, - "_LEGACY_ERROR_TEMP_1325" : { - "message" : [ - "Cannot modify the value of a static config: ." - ] - }, "_LEGACY_ERROR_TEMP_1327" : { "message" : [ "Command execution is not supported in runner ." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 1f43b3dfa4a16..0e02e4249addd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3388,8 +3388,9 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def cannotModifyValueOfStaticConfigError(key: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1325", - messageParameters = Map("key" -> key)) + errorClass = "CANNOT_MODIFY_CONFIG", + messageParameters = Map("key" -> toSQLConf(key), "docroot" -> SPARK_DOC_ROOT) + ) } def cannotModifyValueOfSparkConfigError(key: String, docroot: String): Throwable = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 6611ecce0ad8e..1a6cdd1258cc3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -233,8 +233,8 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { // static sql configs checkError( exception = intercept[AnalysisException](sql(s"RESET ${StaticSQLConf.WAREHOUSE_PATH.key}")), - condition = "_LEGACY_ERROR_TEMP_1325", - parameters = Map("key" -> "spark.sql.warehouse.dir")) + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map("key" -> "\"spark.sql.warehouse.dir\"", "docroot" -> SPARK_DOC_ROOT)) } @@ -315,10 +315,16 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { } test("cannot set/unset static SQL conf") { - val e1 = intercept[AnalysisException](sql(s"SET ${GLOBAL_TEMP_DATABASE.key}=10")) - assert(e1.message.contains("Cannot modify the value of a static config")) - val e2 = intercept[AnalysisException](spark.conf.unset(GLOBAL_TEMP_DATABASE.key)) - assert(e2.message.contains("Cannot modify the value of a static config")) + checkError( + exception = intercept[AnalysisException](sql(s"SET ${GLOBAL_TEMP_DATABASE.key}=10")), + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map("key" -> "\"spark.sql.globalTempDatabase\"", "docroot" -> SPARK_DOC_ROOT) + ) + checkError( + exception = intercept[AnalysisException](spark.conf.unset(GLOBAL_TEMP_DATABASE.key)), + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map("key" -> "\"spark.sql.globalTempDatabase\"", "docroot" -> SPARK_DOC_ROOT) + ) } test("SPARK-36643: Show migration guide when attempting SparkConf") { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index f1f0befcb0d30..43030f68e5dac 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -1062,7 +1062,7 @@ class SingleSessionSuite extends HiveThriftServer2TestBase { statement.executeQuery("SET spark.sql.hive.thriftServer.singleSession=false") }.getMessage assert(e.contains( - "Cannot modify the value of a static config: spark.sql.hive.thriftServer.singleSession")) + "CANNOT_MODIFY_CONFIG")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 1c45b02375b30..83d70b2e19109 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -26,7 +26,7 @@ import java.util.{Locale, Set} import com.google.common.io.{Files, FileWriteMode} import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkException, TestUtils} +import org.apache.spark.{SPARK_DOC_ROOT, SparkException, TestUtils} import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.catalog.{CatalogTableType, CatalogUtils, HiveTableRelation} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLConf import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.execution.{SparkPlanInfo, TestUncaughtExceptionHandler} import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} @@ -2461,8 +2462,12 @@ abstract class SQLQuerySuiteBase extends QueryTest with SQLTestUtils with TestHi "spark.sql.hive.metastore.jars", "spark.sql.hive.metastore.sharedPrefixes", "spark.sql.hive.metastore.barrierPrefixes").foreach { key => - val e = intercept[AnalysisException](sql(s"set $key=abc")) - assert(e.getMessage.contains("Cannot modify the value of a static config")) + checkError( + exception = intercept[AnalysisException](sql(s"set $key=abc")), + condition = "CANNOT_MODIFY_CONFIG", + parameters = Map( + "key" -> toSQLConf(key), "docroot" -> SPARK_DOC_ROOT) + ) } } From fef3a7167aff4ff9b9f561378a152503663ad00f Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Wed, 9 Oct 2024 19:37:26 +0200 Subject: [PATCH 195/250] fix bug. --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index b1d11e96d9bbd..0868fbf6da4b8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -1191,7 +1191,7 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsSpaceTrimming) { + if (collation.supportsBinaryEquality) { return input; } else if (collation.supportsLowercaseEquality) { return CollationAwareUTF8String.lowerCaseCodePoints(input); From fed9a8da3d4187794161e0be325aa96be8487783 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Thu, 10 Oct 2024 01:49:24 +0800 Subject: [PATCH 196/250] [SPARK-49569][BUILD][FOLLOWUP] Adds `scala-library` maven dependency to the `spark-connect-shims` module to fix Maven build errors ### What changes were proposed in this pull request? This PR adds `scala-library` maven dependency to the `spark-connect-shims` module to fix Maven build errors. ### Why are the changes needed? Maven daily test pipeline build failed: - https://github.com/apache/spark/actions/runs/11255598249 - https://github.com/apache/spark/actions/runs/11256610976 ``` scaladoc error: fatal error: object scala in compiler mirror not found. Error: Failed to execute goal net.alchim31.maven:scala-maven-plugin:4.9.1:doc-jar (attach-scaladocs) on project spark-connect-shims_2.13: MavenReportException: Error while creating archive: wrap: Process exited with an error: 1 (Exit value: 1) -> [Help 1] Error: Error: To see the full stack trace of the errors, re-run Maven with the -e switch. Error: Re-run Maven using the -X switch to enable full debug logging. Error: Error: For more information about the errors and possible solutions, please read the following articles: Error: [Help 1] http://cwiki.apache.org/confluence/display/MAVEN/MojoExecutionException Error: Error: After correcting the problems, you can resume the build with the command Error: mvn -rf :spark-connect-shims_2.13 Error: Process completed with exit code 1. ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - locally test: ``` build/mvn clean install -DskipTests -Phive ``` **Before** ``` [INFO] --- scala:4.9.1:doc-jar (attach-scaladocs) spark-connect-shims_2.13 --- scaladoc error: fatal error: object scala in compiler mirror not found. [INFO] ------------------------------------------------------------------------ [INFO] Reactor Summary for Spark Project Parent POM 4.0.0-SNAPSHOT: [INFO] [INFO] Spark Project Parent POM ........................... SUCCESS [ 2.833 s] [INFO] Spark Project Tags ................................. SUCCESS [ 5.292 s] [INFO] Spark Project Sketch ............................... SUCCESS [ 5.675 s] [INFO] Spark Project Common Utils ......................... SUCCESS [ 16.762 s] [INFO] Spark Project Local DB ............................. SUCCESS [ 7.735 s] [INFO] Spark Project Networking ........................... SUCCESS [ 11.389 s] [INFO] Spark Project Shuffle Streaming Service ............ SUCCESS [ 9.159 s] [INFO] Spark Project Variant .............................. SUCCESS [ 3.618 s] [INFO] Spark Project Unsafe ............................... SUCCESS [ 9.692 s] [INFO] Spark Project Connect Shims ........................ FAILURE [ 2.478 s] [INFO] Spark Project Launcher ............................. SKIPPED [INFO] Spark Project Core ................................. SKIPPED [INFO] Spark Project ML Local Library ..................... SKIPPED [INFO] Spark Project GraphX ............................... SKIPPED [INFO] Spark Project Streaming ............................ SKIPPED [INFO] Spark Project SQL API .............................. SKIPPED [INFO] Spark Project Catalyst ............................. SKIPPED [INFO] Spark Project SQL .................................. SKIPPED [INFO] Spark Project ML Library ........................... SKIPPED [INFO] Spark Project Tools ................................ SKIPPED [INFO] Spark Project Hive ................................. SKIPPED [INFO] Spark Project Connect Common ....................... SKIPPED [INFO] Spark Avro ......................................... SKIPPED [INFO] Spark Protobuf ..................................... SKIPPED [INFO] Spark Project REPL ................................. SKIPPED [INFO] Spark Project Connect Server ....................... SKIPPED [INFO] Spark Project Connect Client ....................... SKIPPED [INFO] Spark Project Assembly ............................. SKIPPED [INFO] Kafka 0.10+ Token Provider for Streaming ........... SKIPPED [INFO] Spark Integration for Kafka 0.10 ................... SKIPPED [INFO] Kafka 0.10+ Source for Structured Streaming ........ SKIPPED [INFO] Spark Project Examples ............................. SKIPPED [INFO] Spark Integration for Kafka 0.10 Assembly .......... SKIPPED [INFO] ------------------------------------------------------------------------ [INFO] BUILD FAILURE [INFO] ------------------------------------------------------------------------ [INFO] Total time: 01:15 min [INFO] Finished at: 2024-10-09T23:43:58+08:00 [INFO] ------------------------------------------------------------------------ [ERROR] Failed to execute goal net.alchim31.maven:scala-maven-plugin:4.9.1:doc-jar (attach-scaladocs) on project spark-connect-shims_2.13: MavenReportException: Error while creating archive: wrap: Process exited with an error: 1 (Exit value: 1) -> [Help 1] [ERROR] [ERROR] To see the full stack trace of the errors, re-run Maven with the -e switch. [ERROR] Re-run Maven using the -X switch to enable full debug logging. [ERROR] [ERROR] For more information about the errors and possible solutions, please read the following articles: [ERROR] [Help 1] http://cwiki.apache.org/confluence/display/MAVEN/MojoExecutionException [ERROR] [ERROR] After correcting the problems, you can resume the build with the command [ERROR] mvn -rf :spark-connect-shims_2.13 ``` **After** ``` [INFO] ------------------------------------------------------------------------ [INFO] Reactor Summary for Spark Project Parent POM 4.0.0-SNAPSHOT: [INFO] [INFO] Spark Project Parent POM ........................... SUCCESS [ 2.766 s] [INFO] Spark Project Tags ................................. SUCCESS [ 5.398 s] [INFO] Spark Project Sketch ............................... SUCCESS [ 6.361 s] [INFO] Spark Project Common Utils ......................... SUCCESS [ 16.919 s] [INFO] Spark Project Local DB ............................. SUCCESS [ 8.083 s] [INFO] Spark Project Networking ........................... SUCCESS [ 11.240 s] [INFO] Spark Project Shuffle Streaming Service ............ SUCCESS [ 9.438 s] [INFO] Spark Project Variant .............................. SUCCESS [ 3.697 s] [INFO] Spark Project Unsafe ............................... SUCCESS [ 9.939 s] [INFO] Spark Project Connect Shims ........................ SUCCESS [ 2.938 s] [INFO] Spark Project Launcher ............................. SUCCESS [ 6.502 s] [INFO] Spark Project Core ................................. SUCCESS [01:33 min] [INFO] Spark Project ML Local Library ..................... SUCCESS [ 18.220 s] [INFO] Spark Project GraphX ............................... SUCCESS [ 20.923 s] [INFO] Spark Project Streaming ............................ SUCCESS [ 29.949 s] [INFO] Spark Project SQL API .............................. SUCCESS [ 25.842 s] [INFO] Spark Project Catalyst ............................. SUCCESS [02:02 min] [INFO] Spark Project SQL .................................. SUCCESS [02:18 min] [INFO] Spark Project ML Library ........................... SUCCESS [01:38 min] [INFO] Spark Project Tools ................................ SUCCESS [ 3.365 s] [INFO] Spark Project Hive ................................. SUCCESS [ 45.357 s] [INFO] Spark Project Connect Common ....................... SUCCESS [ 33.636 s] [INFO] Spark Avro ......................................... SUCCESS [ 22.040 s] [INFO] Spark Protobuf ..................................... SUCCESS [ 24.557 s] [INFO] Spark Project REPL ................................. SUCCESS [ 13.843 s] [INFO] Spark Project Connect Server ....................... SUCCESS [ 35.587 s] [INFO] Spark Project Connect Client ....................... SUCCESS [ 33.929 s] [INFO] Spark Project Assembly ............................. SUCCESS [ 5.121 s] [INFO] Kafka 0.10+ Token Provider for Streaming ........... SUCCESS [ 12.623 s] [INFO] Spark Integration for Kafka 0.10 ................... SUCCESS [ 16.908 s] [INFO] Kafka 0.10+ Source for Structured Streaming ........ SUCCESS [ 23.664 s] [INFO] Spark Project Examples ............................. SUCCESS [ 30.777 s] [INFO] Spark Integration for Kafka 0.10 Assembly .......... SUCCESS [ 6.997 s] [INFO] ------------------------------------------------------------------------ [INFO] BUILD SUCCESS [INFO] ------------------------------------------------------------------------ [INFO] Total time: 15:40 min [INFO] Finished at: 2024-10-09T23:27:20+08:00 [INFO] ------------------------------------------------------------------------ ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48399 from LuciferYang/SPARK-49569-FOLLOWUP. Authored-by: yangjie01 Signed-off-by: yangjie01 --- sql/connect/shims/pom.xml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/connect/shims/pom.xml b/sql/connect/shims/pom.xml index 6bb12a927738c..d177b4a9971f5 100644 --- a/sql/connect/shims/pom.xml +++ b/sql/connect/shims/pom.xml @@ -34,6 +34,13 @@ connect-shims + + + org.scala-lang + scala-library + + + target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes From f69d03e0cf45ae15fde770cba0340a9bd2e05c28 Mon Sep 17 00:00:00 2001 From: Mihailo Timotic Date: Thu, 10 Oct 2024 08:17:39 +0900 Subject: [PATCH 197/250] [SPARK-43838][SQL][FOLLOWUP] Replace `HashSet` with `HashMap` to improve performance of `DeduplicateRelations` ### What changes were proposed in this pull request? This PR replaces `HashSet` that is currently used with a `HashMap` to improve `DeduplicateRelations` performance. Additionally, this PR reverts #48053 as that change is no longer needed ### Why are the changes needed? Current implementation doesn't utilize `HashSet` properly, but instead performs multiple linear searches on the set creating a O(n^2) complexity ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Existing tests ### Was this patch authored or co-authored using generative AI tooling? Closes #48392 from mihailotim-db/mihailotim-db/master. Authored-by: Mihailo Timotic Signed-off-by: Hyukjin Kwon --- .../analysis/DeduplicateRelations.scala | 51 ++- .../q22a.sf100/explain.txt | 154 ++++----- .../q22a.sf100/simplified.txt | 2 +- .../approved-plans-v2_7/q22a/explain.txt | 154 ++++----- .../approved-plans-v2_7/q22a/simplified.txt | 2 +- .../q67a.sf100/explain.txt | 326 +++++++++--------- .../q67a.sf100/simplified.txt | 2 +- .../approved-plans-v2_7/q67a/explain.txt | 326 +++++++++--------- .../approved-plans-v2_7/q67a/simplified.txt | 2 +- 9 files changed, 504 insertions(+), 515 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index e22a4b941b30c..8181078c519fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -24,20 +24,12 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern._ -/** - * A helper class used to detect duplicate relations fast in `DeduplicateRelations`. Two relations - * are duplicated if: - * 1. they are the same class. - * 2. they have the same output attribute IDs. - * - * The first condition is necessary because the CTE relation definition node and reference node have - * the same output attribute IDs but they are not duplicated. - */ -case class RelationWrapper(cls: Class[_], outputAttrIds: Seq[Long]) - object DeduplicateRelations extends Rule[LogicalPlan] { + + type ExprIdMap = mutable.HashMap[Class[_], mutable.HashSet[Long]] + override def apply(plan: LogicalPlan): LogicalPlan = { - val newPlan = renewDuplicatedRelations(mutable.HashSet.empty, plan)._1 + val newPlan = renewDuplicatedRelations(mutable.HashMap.empty, plan)._1 // Wait for `ResolveMissingReferences` to resolve missing attributes first def noMissingInput(p: LogicalPlan) = !p.exists(_.missingInput.nonEmpty) @@ -86,10 +78,10 @@ object DeduplicateRelations extends Rule[LogicalPlan] { } private def existDuplicatedExprId( - existingRelations: mutable.HashSet[RelationWrapper], - plan: RelationWrapper): Boolean = { - existingRelations.filter(_.cls == plan.cls) - .exists(_.outputAttrIds.intersect(plan.outputAttrIds).nonEmpty) + existingRelations: ExprIdMap, + planClass: Class[_], exprIds: Seq[Long]): Boolean = { + val attrSet = existingRelations.getOrElse(planClass, mutable.HashSet.empty) + exprIds.exists(attrSet.contains) } /** @@ -100,20 +92,16 @@ object DeduplicateRelations extends Rule[LogicalPlan] { * whether the plan is changed or not) */ private def renewDuplicatedRelations( - existingRelations: mutable.HashSet[RelationWrapper], + existingRelations: ExprIdMap, plan: LogicalPlan): (LogicalPlan, Boolean) = plan match { case p: LogicalPlan if p.isStreaming => (plan, false) case m: MultiInstanceRelation => - val planWrapper = RelationWrapper(m.getClass, m.output.map(_.exprId.id)) - if (existingRelations.contains(planWrapper)) { - val newNode = m.newInstance() - newNode.copyTagsFrom(m) - (newNode, true) - } else { - existingRelations.add(planWrapper) - (m, false) - } + deduplicateAndRenew[LogicalPlan with MultiInstanceRelation]( + existingRelations, + m, + _.output.map(_.exprId.id), + node => node.newInstance().asInstanceOf[LogicalPlan with MultiInstanceRelation]) case p: Project => deduplicateAndRenew[Project]( @@ -207,7 +195,7 @@ object DeduplicateRelations extends Rule[LogicalPlan] { } private def deduplicate( - existingRelations: mutable.HashSet[RelationWrapper], + existingRelations: ExprIdMap, plan: LogicalPlan): (LogicalPlan, Boolean) = { var planChanged = false val newPlan = if (plan.children.nonEmpty) { @@ -291,20 +279,21 @@ object DeduplicateRelations extends Rule[LogicalPlan] { } private def deduplicateAndRenew[T <: LogicalPlan]( - existingRelations: mutable.HashSet[RelationWrapper], plan: T, + existingRelations: ExprIdMap, plan: T, getExprIds: T => Seq[Long], copyNewPlan: T => T): (LogicalPlan, Boolean) = { var (newPlan, planChanged) = deduplicate(existingRelations, plan) if (newPlan.resolved) { val exprIds = getExprIds(newPlan.asInstanceOf[T]) if (exprIds.nonEmpty) { - val planWrapper = RelationWrapper(newPlan.getClass, exprIds) - if (existDuplicatedExprId(existingRelations, planWrapper)) { + if (existDuplicatedExprId(existingRelations, newPlan.getClass, exprIds)) { newPlan = copyNewPlan(newPlan.asInstanceOf[T]) newPlan.copyTagsFrom(plan) (newPlan, true) } else { - existingRelations.add(planWrapper) + val attrSet = existingRelations.getOrElseUpdate(newPlan.getClass, mutable.HashSet.empty) + exprIds.foreach(attrSet.add) + existingRelations.put(newPlan.getClass, attrSet) (newPlan, planChanged) } } else { diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt index 96bed479d2e06..4bf7de791b279 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/explain.txt @@ -175,125 +175,125 @@ Input [6]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, sum#21, cou Keys [4]: [i_product_name#12, i_brand#9, i_class#10, i_category#11] Functions [1]: [avg(qoh#18)] Aggregate Attributes [1]: [avg(qoh#18)#23] -Results [5]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, avg(qoh#18)#23 AS qoh#24] +Results [5]: [i_product_name#12 AS i_product_name#24, i_brand#9 AS i_brand#25, i_class#10 AS i_class#26, i_category#11 AS i_category#27, avg(qoh#18)#23 AS qoh#28] (27) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] +Output [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] (28) HashAggregate [codegen id : 16] -Input [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] -Keys [4]: [i_product_name#25, i_brand#26, i_class#27, i_category#28] -Functions [1]: [avg(inv_quantity_on_hand#31)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#31)#17] -Results [4]: [i_product_name#25, i_brand#26, i_class#27, avg(inv_quantity_on_hand#31)#17 AS qoh#32] +Input [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] +Keys [4]: [i_product_name#29, i_brand#30, i_class#31, i_category#32] +Functions [1]: [avg(inv_quantity_on_hand#35)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#35)#17] +Results [4]: [i_product_name#29, i_brand#30, i_class#31, avg(inv_quantity_on_hand#35)#17 AS qoh#36] (29) HashAggregate [codegen id : 16] -Input [4]: [i_product_name#25, i_brand#26, i_class#27, qoh#32] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [partial_avg(qoh#32)] -Aggregate Attributes [2]: [sum#33, count#34] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Input [4]: [i_product_name#29, i_brand#30, i_class#31, qoh#36] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [partial_avg(qoh#36)] +Aggregate Attributes [2]: [sum#37, count#38] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] (30) Exchange -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Arguments: hashpartitioning(i_product_name#25, i_brand#26, i_class#27, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Arguments: hashpartitioning(i_product_name#29, i_brand#30, i_class#31, 5), ENSURE_REQUIREMENTS, [plan_id=5] (31) HashAggregate [codegen id : 17] -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [avg(qoh#32)] -Aggregate Attributes [1]: [avg(qoh#32)#37] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, null AS i_category#38, avg(qoh#32)#37 AS qoh#39] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [avg(qoh#36)] +Aggregate Attributes [1]: [avg(qoh#36)#41] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, null AS i_category#42, avg(qoh#36)#41 AS qoh#43] (32) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] +Output [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] (33) HashAggregate [codegen id : 25] -Input [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] -Keys [4]: [i_product_name#40, i_brand#41, i_class#42, i_category#43] -Functions [1]: [avg(inv_quantity_on_hand#46)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#46)#17] -Results [3]: [i_product_name#40, i_brand#41, avg(inv_quantity_on_hand#46)#17 AS qoh#47] +Input [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] +Keys [4]: [i_product_name#44, i_brand#45, i_class#46, i_category#47] +Functions [1]: [avg(inv_quantity_on_hand#50)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#50)#17] +Results [3]: [i_product_name#44, i_brand#45, avg(inv_quantity_on_hand#50)#17 AS qoh#51] (34) HashAggregate [codegen id : 25] -Input [3]: [i_product_name#40, i_brand#41, qoh#47] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [partial_avg(qoh#47)] -Aggregate Attributes [2]: [sum#48, count#49] -Results [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Input [3]: [i_product_name#44, i_brand#45, qoh#51] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [partial_avg(qoh#51)] +Aggregate Attributes [2]: [sum#52, count#53] +Results [4]: [i_product_name#44, i_brand#45, sum#54, count#55] (35) Exchange -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Arguments: hashpartitioning(i_product_name#40, i_brand#41, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Arguments: hashpartitioning(i_product_name#44, i_brand#45, 5), ENSURE_REQUIREMENTS, [plan_id=6] (36) HashAggregate [codegen id : 26] -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [avg(qoh#47)] -Aggregate Attributes [1]: [avg(qoh#47)#52] -Results [5]: [i_product_name#40, i_brand#41, null AS i_class#53, null AS i_category#54, avg(qoh#47)#52 AS qoh#55] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [avg(qoh#51)] +Aggregate Attributes [1]: [avg(qoh#51)#56] +Results [5]: [i_product_name#44, i_brand#45, null AS i_class#57, null AS i_category#58, avg(qoh#51)#56 AS qoh#59] (37) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] +Output [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] (38) HashAggregate [codegen id : 34] -Input [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] -Keys [4]: [i_product_name#56, i_brand#57, i_class#58, i_category#59] -Functions [1]: [avg(inv_quantity_on_hand#62)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#62)#17] -Results [2]: [i_product_name#56, avg(inv_quantity_on_hand#62)#17 AS qoh#63] +Input [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] +Keys [4]: [i_product_name#60, i_brand#61, i_class#62, i_category#63] +Functions [1]: [avg(inv_quantity_on_hand#66)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#66)#17] +Results [2]: [i_product_name#60, avg(inv_quantity_on_hand#66)#17 AS qoh#67] (39) HashAggregate [codegen id : 34] -Input [2]: [i_product_name#56, qoh#63] -Keys [1]: [i_product_name#56] -Functions [1]: [partial_avg(qoh#63)] -Aggregate Attributes [2]: [sum#64, count#65] -Results [3]: [i_product_name#56, sum#66, count#67] +Input [2]: [i_product_name#60, qoh#67] +Keys [1]: [i_product_name#60] +Functions [1]: [partial_avg(qoh#67)] +Aggregate Attributes [2]: [sum#68, count#69] +Results [3]: [i_product_name#60, sum#70, count#71] (40) Exchange -Input [3]: [i_product_name#56, sum#66, count#67] -Arguments: hashpartitioning(i_product_name#56, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [3]: [i_product_name#60, sum#70, count#71] +Arguments: hashpartitioning(i_product_name#60, 5), ENSURE_REQUIREMENTS, [plan_id=7] (41) HashAggregate [codegen id : 35] -Input [3]: [i_product_name#56, sum#66, count#67] -Keys [1]: [i_product_name#56] -Functions [1]: [avg(qoh#63)] -Aggregate Attributes [1]: [avg(qoh#63)#68] -Results [5]: [i_product_name#56, null AS i_brand#69, null AS i_class#70, null AS i_category#71, avg(qoh#63)#68 AS qoh#72] +Input [3]: [i_product_name#60, sum#70, count#71] +Keys [1]: [i_product_name#60] +Functions [1]: [avg(qoh#67)] +Aggregate Attributes [1]: [avg(qoh#67)#72] +Results [5]: [i_product_name#60, null AS i_brand#73, null AS i_class#74, null AS i_category#75, avg(qoh#67)#72 AS qoh#76] (42) ReusedExchange [Reuses operator id: 23] -Output [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] +Output [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] (43) HashAggregate [codegen id : 43] -Input [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] -Keys [4]: [i_product_name#73, i_brand#74, i_class#75, i_category#76] -Functions [1]: [avg(inv_quantity_on_hand#79)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#79)#17] -Results [1]: [avg(inv_quantity_on_hand#79)#17 AS qoh#80] +Input [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] +Keys [4]: [i_product_name#77, i_brand#78, i_class#79, i_category#80] +Functions [1]: [avg(inv_quantity_on_hand#83)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#83)#17] +Results [1]: [avg(inv_quantity_on_hand#83)#17 AS qoh#84] (44) HashAggregate [codegen id : 43] -Input [1]: [qoh#80] +Input [1]: [qoh#84] Keys: [] -Functions [1]: [partial_avg(qoh#80)] -Aggregate Attributes [2]: [sum#81, count#82] -Results [2]: [sum#83, count#84] +Functions [1]: [partial_avg(qoh#84)] +Aggregate Attributes [2]: [sum#85, count#86] +Results [2]: [sum#87, count#88] (45) Exchange -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=8] (46) HashAggregate [codegen id : 44] -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Keys: [] -Functions [1]: [avg(qoh#80)] -Aggregate Attributes [1]: [avg(qoh#80)#85] -Results [5]: [null AS i_product_name#86, null AS i_brand#87, null AS i_class#88, null AS i_category#89, avg(qoh#80)#85 AS qoh#90] +Functions [1]: [avg(qoh#84)] +Aggregate Attributes [1]: [avg(qoh#84)#89] +Results [5]: [null AS i_product_name#90, null AS i_brand#91, null AS i_class#92, null AS i_category#93, avg(qoh#84)#89 AS qoh#94] (47) Union (48) TakeOrderedAndProject -Input [5]: [i_product_name#12, i_brand#9, i_class#10, i_category#11, qoh#24] -Arguments: 100, [qoh#24 ASC NULLS FIRST, i_product_name#12 ASC NULLS FIRST, i_brand#9 ASC NULLS FIRST, i_class#10 ASC NULLS FIRST, i_category#11 ASC NULLS FIRST], [i_product_name#12, i_brand#9, i_class#10, i_category#11, qoh#24] +Input [5]: [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] +Arguments: 100, [qoh#28 ASC NULLS FIRST, i_product_name#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_class#26 ASC NULLS FIRST, i_category#27 ASC NULLS FIRST], [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] ===== Subqueries ===== @@ -306,22 +306,22 @@ BroadcastExchange (53) (49) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#7, d_month_seq#91] +Output [2]: [d_date_sk#7, d_month_seq#95] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (50) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#7, d_month_seq#91] +Input [2]: [d_date_sk#7, d_month_seq#95] (51) Filter [codegen id : 1] -Input [2]: [d_date_sk#7, d_month_seq#91] -Condition : (((isnotnull(d_month_seq#91) AND (d_month_seq#91 >= 1212)) AND (d_month_seq#91 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [2]: [d_date_sk#7, d_month_seq#95] +Condition : (((isnotnull(d_month_seq#95) AND (d_month_seq#95 >= 1212)) AND (d_month_seq#95 <= 1223)) AND isnotnull(d_date_sk#7)) (52) Project [codegen id : 1] Output [1]: [d_date_sk#7] -Input [2]: [d_date_sk#7, d_month_seq#91] +Input [2]: [d_date_sk#7, d_month_seq#95] (53) BroadcastExchange Input [1]: [d_date_sk#7] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt index 0c4267b3ca513..042f946b8fca4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a.sf100/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [qoh,i_product_name,i_brand,i_class,i_category] Union WholeStageCodegen (8) - HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),qoh,sum,count] + HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),i_product_name,i_brand,i_class,i_category,qoh,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,qoh] [sum,count,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(inv_quantity_on_hand),qoh,sum,count] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt index 4b8993f370f4d..8aab8e91acfc8 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/explain.txt @@ -160,125 +160,125 @@ Input [6]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, sum#21, coun Keys [4]: [i_product_name#11, i_brand#8, i_class#9, i_category#10] Functions [1]: [avg(qoh#18)] Aggregate Attributes [1]: [avg(qoh#18)#23] -Results [5]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, avg(qoh#18)#23 AS qoh#24] +Results [5]: [i_product_name#11 AS i_product_name#24, i_brand#8 AS i_brand#25, i_class#9 AS i_class#26, i_category#10 AS i_category#27, avg(qoh#18)#23 AS qoh#28] (24) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] +Output [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] (25) HashAggregate [codegen id : 10] -Input [6]: [i_product_name#25, i_brand#26, i_class#27, i_category#28, sum#29, count#30] -Keys [4]: [i_product_name#25, i_brand#26, i_class#27, i_category#28] -Functions [1]: [avg(inv_quantity_on_hand#31)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#31)#17] -Results [4]: [i_product_name#25, i_brand#26, i_class#27, avg(inv_quantity_on_hand#31)#17 AS qoh#32] +Input [6]: [i_product_name#29, i_brand#30, i_class#31, i_category#32, sum#33, count#34] +Keys [4]: [i_product_name#29, i_brand#30, i_class#31, i_category#32] +Functions [1]: [avg(inv_quantity_on_hand#35)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#35)#17] +Results [4]: [i_product_name#29, i_brand#30, i_class#31, avg(inv_quantity_on_hand#35)#17 AS qoh#36] (26) HashAggregate [codegen id : 10] -Input [4]: [i_product_name#25, i_brand#26, i_class#27, qoh#32] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [partial_avg(qoh#32)] -Aggregate Attributes [2]: [sum#33, count#34] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] +Input [4]: [i_product_name#29, i_brand#30, i_class#31, qoh#36] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [partial_avg(qoh#36)] +Aggregate Attributes [2]: [sum#37, count#38] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] (27) Exchange -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Arguments: hashpartitioning(i_product_name#25, i_brand#26, i_class#27, 5), ENSURE_REQUIREMENTS, [plan_id=4] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Arguments: hashpartitioning(i_product_name#29, i_brand#30, i_class#31, 5), ENSURE_REQUIREMENTS, [plan_id=4] (28) HashAggregate [codegen id : 11] -Input [5]: [i_product_name#25, i_brand#26, i_class#27, sum#35, count#36] -Keys [3]: [i_product_name#25, i_brand#26, i_class#27] -Functions [1]: [avg(qoh#32)] -Aggregate Attributes [1]: [avg(qoh#32)#37] -Results [5]: [i_product_name#25, i_brand#26, i_class#27, null AS i_category#38, avg(qoh#32)#37 AS qoh#39] +Input [5]: [i_product_name#29, i_brand#30, i_class#31, sum#39, count#40] +Keys [3]: [i_product_name#29, i_brand#30, i_class#31] +Functions [1]: [avg(qoh#36)] +Aggregate Attributes [1]: [avg(qoh#36)#41] +Results [5]: [i_product_name#29, i_brand#30, i_class#31, null AS i_category#42, avg(qoh#36)#41 AS qoh#43] (29) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] +Output [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] (30) HashAggregate [codegen id : 16] -Input [6]: [i_product_name#40, i_brand#41, i_class#42, i_category#43, sum#44, count#45] -Keys [4]: [i_product_name#40, i_brand#41, i_class#42, i_category#43] -Functions [1]: [avg(inv_quantity_on_hand#46)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#46)#17] -Results [3]: [i_product_name#40, i_brand#41, avg(inv_quantity_on_hand#46)#17 AS qoh#47] +Input [6]: [i_product_name#44, i_brand#45, i_class#46, i_category#47, sum#48, count#49] +Keys [4]: [i_product_name#44, i_brand#45, i_class#46, i_category#47] +Functions [1]: [avg(inv_quantity_on_hand#50)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#50)#17] +Results [3]: [i_product_name#44, i_brand#45, avg(inv_quantity_on_hand#50)#17 AS qoh#51] (31) HashAggregate [codegen id : 16] -Input [3]: [i_product_name#40, i_brand#41, qoh#47] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [partial_avg(qoh#47)] -Aggregate Attributes [2]: [sum#48, count#49] -Results [4]: [i_product_name#40, i_brand#41, sum#50, count#51] +Input [3]: [i_product_name#44, i_brand#45, qoh#51] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [partial_avg(qoh#51)] +Aggregate Attributes [2]: [sum#52, count#53] +Results [4]: [i_product_name#44, i_brand#45, sum#54, count#55] (32) Exchange -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Arguments: hashpartitioning(i_product_name#40, i_brand#41, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Arguments: hashpartitioning(i_product_name#44, i_brand#45, 5), ENSURE_REQUIREMENTS, [plan_id=5] (33) HashAggregate [codegen id : 17] -Input [4]: [i_product_name#40, i_brand#41, sum#50, count#51] -Keys [2]: [i_product_name#40, i_brand#41] -Functions [1]: [avg(qoh#47)] -Aggregate Attributes [1]: [avg(qoh#47)#52] -Results [5]: [i_product_name#40, i_brand#41, null AS i_class#53, null AS i_category#54, avg(qoh#47)#52 AS qoh#55] +Input [4]: [i_product_name#44, i_brand#45, sum#54, count#55] +Keys [2]: [i_product_name#44, i_brand#45] +Functions [1]: [avg(qoh#51)] +Aggregate Attributes [1]: [avg(qoh#51)#56] +Results [5]: [i_product_name#44, i_brand#45, null AS i_class#57, null AS i_category#58, avg(qoh#51)#56 AS qoh#59] (34) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] +Output [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] (35) HashAggregate [codegen id : 22] -Input [6]: [i_product_name#56, i_brand#57, i_class#58, i_category#59, sum#60, count#61] -Keys [4]: [i_product_name#56, i_brand#57, i_class#58, i_category#59] -Functions [1]: [avg(inv_quantity_on_hand#62)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#62)#17] -Results [2]: [i_product_name#56, avg(inv_quantity_on_hand#62)#17 AS qoh#63] +Input [6]: [i_product_name#60, i_brand#61, i_class#62, i_category#63, sum#64, count#65] +Keys [4]: [i_product_name#60, i_brand#61, i_class#62, i_category#63] +Functions [1]: [avg(inv_quantity_on_hand#66)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#66)#17] +Results [2]: [i_product_name#60, avg(inv_quantity_on_hand#66)#17 AS qoh#67] (36) HashAggregate [codegen id : 22] -Input [2]: [i_product_name#56, qoh#63] -Keys [1]: [i_product_name#56] -Functions [1]: [partial_avg(qoh#63)] -Aggregate Attributes [2]: [sum#64, count#65] -Results [3]: [i_product_name#56, sum#66, count#67] +Input [2]: [i_product_name#60, qoh#67] +Keys [1]: [i_product_name#60] +Functions [1]: [partial_avg(qoh#67)] +Aggregate Attributes [2]: [sum#68, count#69] +Results [3]: [i_product_name#60, sum#70, count#71] (37) Exchange -Input [3]: [i_product_name#56, sum#66, count#67] -Arguments: hashpartitioning(i_product_name#56, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [3]: [i_product_name#60, sum#70, count#71] +Arguments: hashpartitioning(i_product_name#60, 5), ENSURE_REQUIREMENTS, [plan_id=6] (38) HashAggregate [codegen id : 23] -Input [3]: [i_product_name#56, sum#66, count#67] -Keys [1]: [i_product_name#56] -Functions [1]: [avg(qoh#63)] -Aggregate Attributes [1]: [avg(qoh#63)#68] -Results [5]: [i_product_name#56, null AS i_brand#69, null AS i_class#70, null AS i_category#71, avg(qoh#63)#68 AS qoh#72] +Input [3]: [i_product_name#60, sum#70, count#71] +Keys [1]: [i_product_name#60] +Functions [1]: [avg(qoh#67)] +Aggregate Attributes [1]: [avg(qoh#67)#72] +Results [5]: [i_product_name#60, null AS i_brand#73, null AS i_class#74, null AS i_category#75, avg(qoh#67)#72 AS qoh#76] (39) ReusedExchange [Reuses operator id: 20] -Output [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] +Output [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] (40) HashAggregate [codegen id : 28] -Input [6]: [i_product_name#73, i_brand#74, i_class#75, i_category#76, sum#77, count#78] -Keys [4]: [i_product_name#73, i_brand#74, i_class#75, i_category#76] -Functions [1]: [avg(inv_quantity_on_hand#79)] -Aggregate Attributes [1]: [avg(inv_quantity_on_hand#79)#17] -Results [1]: [avg(inv_quantity_on_hand#79)#17 AS qoh#80] +Input [6]: [i_product_name#77, i_brand#78, i_class#79, i_category#80, sum#81, count#82] +Keys [4]: [i_product_name#77, i_brand#78, i_class#79, i_category#80] +Functions [1]: [avg(inv_quantity_on_hand#83)] +Aggregate Attributes [1]: [avg(inv_quantity_on_hand#83)#17] +Results [1]: [avg(inv_quantity_on_hand#83)#17 AS qoh#84] (41) HashAggregate [codegen id : 28] -Input [1]: [qoh#80] +Input [1]: [qoh#84] Keys: [] -Functions [1]: [partial_avg(qoh#80)] -Aggregate Attributes [2]: [sum#81, count#82] -Results [2]: [sum#83, count#84] +Functions [1]: [partial_avg(qoh#84)] +Aggregate Attributes [2]: [sum#85, count#86] +Results [2]: [sum#87, count#88] (42) Exchange -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=7] (43) HashAggregate [codegen id : 29] -Input [2]: [sum#83, count#84] +Input [2]: [sum#87, count#88] Keys: [] -Functions [1]: [avg(qoh#80)] -Aggregate Attributes [1]: [avg(qoh#80)#85] -Results [5]: [null AS i_product_name#86, null AS i_brand#87, null AS i_class#88, null AS i_category#89, avg(qoh#80)#85 AS qoh#90] +Functions [1]: [avg(qoh#84)] +Aggregate Attributes [1]: [avg(qoh#84)#89] +Results [5]: [null AS i_product_name#90, null AS i_brand#91, null AS i_class#92, null AS i_category#93, avg(qoh#84)#89 AS qoh#94] (44) Union (45) TakeOrderedAndProject -Input [5]: [i_product_name#11, i_brand#8, i_class#9, i_category#10, qoh#24] -Arguments: 100, [qoh#24 ASC NULLS FIRST, i_product_name#11 ASC NULLS FIRST, i_brand#8 ASC NULLS FIRST, i_class#9 ASC NULLS FIRST, i_category#10 ASC NULLS FIRST], [i_product_name#11, i_brand#8, i_class#9, i_category#10, qoh#24] +Input [5]: [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] +Arguments: 100, [qoh#28 ASC NULLS FIRST, i_product_name#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_class#26 ASC NULLS FIRST, i_category#27 ASC NULLS FIRST], [i_product_name#24, i_brand#25, i_class#26, i_category#27, qoh#28] ===== Subqueries ===== @@ -291,22 +291,22 @@ BroadcastExchange (50) (46) Scan parquet spark_catalog.default.date_dim -Output [2]: [d_date_sk#6, d_month_seq#91] +Output [2]: [d_date_sk#6, d_month_seq#95] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (47) ColumnarToRow [codegen id : 1] -Input [2]: [d_date_sk#6, d_month_seq#91] +Input [2]: [d_date_sk#6, d_month_seq#95] (48) Filter [codegen id : 1] -Input [2]: [d_date_sk#6, d_month_seq#91] -Condition : (((isnotnull(d_month_seq#91) AND (d_month_seq#91 >= 1212)) AND (d_month_seq#91 <= 1223)) AND isnotnull(d_date_sk#6)) +Input [2]: [d_date_sk#6, d_month_seq#95] +Condition : (((isnotnull(d_month_seq#95) AND (d_month_seq#95 >= 1212)) AND (d_month_seq#95 <= 1223)) AND isnotnull(d_date_sk#6)) (49) Project [codegen id : 1] Output [1]: [d_date_sk#6] -Input [2]: [d_date_sk#6, d_month_seq#91] +Input [2]: [d_date_sk#6, d_month_seq#95] (50) BroadcastExchange Input [1]: [d_date_sk#6] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt index 22f73cc9b9db5..d747066f5945b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q22a/simplified.txt @@ -1,7 +1,7 @@ TakeOrderedAndProject [qoh,i_product_name,i_brand,i_class,i_category] Union WholeStageCodegen (5) - HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),qoh,sum,count] + HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(qoh),i_product_name,i_brand,i_class,i_category,qoh,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,qoh] [sum,count,sum,count] HashAggregate [i_product_name,i_brand,i_class,i_category,sum,count] [avg(inv_quantity_on_hand),qoh,sum,count] InputAdapter diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt index 9c28ff9f351d8..a4c009f8219b4 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/explain.txt @@ -186,265 +186,265 @@ Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, Keys [8]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] Functions [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))] Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22] -Results [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#23] +Results [9]: [i_category#16 AS i_category#23, i_class#15 AS i_class#24, i_brand#14 AS i_brand#25, i_product_name#17 AS i_product_name#26, d_year#8 AS d_year#27, d_qoy#10 AS d_qoy#28, d_moy#9 AS d_moy#29, s_store_id#12 AS s_store_id#30, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#31] (25) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] +Output [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] (26) HashAggregate [codegen id : 16] -Input [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] -Keys [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31] -Functions [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22] -Results [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22 AS sumsales#36] +Input [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] +Keys [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39] +Functions [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22] +Results [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22 AS sumsales#44] (27) HashAggregate [codegen id : 16] -Input [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sumsales#36] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [partial_sum(sumsales#36)] -Aggregate Attributes [2]: [sum#37, isEmpty#38] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Input [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sumsales#44] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [partial_sum(sumsales#44)] +Aggregate Attributes [2]: [sum#45, isEmpty#46] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] (28) Exchange -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Arguments: hashpartitioning(i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Arguments: hashpartitioning(i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, 5), ENSURE_REQUIREMENTS, [plan_id=5] (29) HashAggregate [codegen id : 17] -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [sum(sumsales#36)] -Aggregate Attributes [1]: [sum(sumsales#36)#41] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, null AS s_store_id#42, sum(sumsales#36)#41 AS sumsales#43] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [sum(sumsales#44)] +Aggregate Attributes [1]: [sum(sumsales#44)#49] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, null AS s_store_id#50, sum(sumsales#44)#49 AS sumsales#51] (30) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] +Output [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] (31) HashAggregate [codegen id : 25] -Input [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] -Keys [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51] -Functions [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22] -Results [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22 AS sumsales#56] +Input [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] +Keys [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59] +Functions [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22] +Results [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22 AS sumsales#64] (32) HashAggregate [codegen id : 25] -Input [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sumsales#56] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [partial_sum(sumsales#56)] -Aggregate Attributes [2]: [sum#57, isEmpty#58] -Results [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Input [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sumsales#64] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [partial_sum(sumsales#64)] +Aggregate Attributes [2]: [sum#65, isEmpty#66] +Results [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] (33) Exchange -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Arguments: hashpartitioning(i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Arguments: hashpartitioning(i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, 5), ENSURE_REQUIREMENTS, [plan_id=6] (34) HashAggregate [codegen id : 26] -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [sum(sumsales#56)] -Aggregate Attributes [1]: [sum(sumsales#56)#61] -Results [9]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, null AS d_moy#62, null AS s_store_id#63, sum(sumsales#56)#61 AS sumsales#64] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [sum(sumsales#64)] +Aggregate Attributes [1]: [sum(sumsales#64)#69] +Results [9]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, null AS d_moy#70, null AS s_store_id#71, sum(sumsales#64)#69 AS sumsales#72] (35) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] +Output [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] (36) HashAggregate [codegen id : 34] -Input [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] -Keys [8]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72] -Functions [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22] -Results [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22 AS sumsales#77] +Input [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] +Keys [8]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80] +Functions [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22] +Results [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22 AS sumsales#85] (37) HashAggregate [codegen id : 34] -Input [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sumsales#77] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [partial_sum(sumsales#77)] -Aggregate Attributes [2]: [sum#78, isEmpty#79] -Results [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Input [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sumsales#85] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [partial_sum(sumsales#85)] +Aggregate Attributes [2]: [sum#86, isEmpty#87] +Results [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] (38) Exchange -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Arguments: hashpartitioning(i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Arguments: hashpartitioning(i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, 5), ENSURE_REQUIREMENTS, [plan_id=7] (39) HashAggregate [codegen id : 35] -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [sum(sumsales#77)] -Aggregate Attributes [1]: [sum(sumsales#77)#82] -Results [9]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, null AS d_qoy#83, null AS d_moy#84, null AS s_store_id#85, sum(sumsales#77)#82 AS sumsales#86] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [sum(sumsales#85)] +Aggregate Attributes [1]: [sum(sumsales#85)#90] +Results [9]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, null AS d_qoy#91, null AS d_moy#92, null AS s_store_id#93, sum(sumsales#85)#90 AS sumsales#94] (40) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] +Output [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] (41) HashAggregate [codegen id : 43] -Input [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] -Keys [8]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94] -Functions [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22] -Results [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22 AS sumsales#99] +Input [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] +Keys [8]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102] +Functions [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22] +Results [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22 AS sumsales#107] (42) HashAggregate [codegen id : 43] -Input [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sumsales#99] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [partial_sum(sumsales#99)] -Aggregate Attributes [2]: [sum#100, isEmpty#101] -Results [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Input [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sumsales#107] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [partial_sum(sumsales#107)] +Aggregate Attributes [2]: [sum#108, isEmpty#109] +Results [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] (43) Exchange -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Arguments: hashpartitioning(i_category#87, i_class#88, i_brand#89, i_product_name#90, 5), ENSURE_REQUIREMENTS, [plan_id=8] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Arguments: hashpartitioning(i_category#95, i_class#96, i_brand#97, i_product_name#98, 5), ENSURE_REQUIREMENTS, [plan_id=8] (44) HashAggregate [codegen id : 44] -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [sum(sumsales#99)] -Aggregate Attributes [1]: [sum(sumsales#99)#104] -Results [9]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, null AS d_year#105, null AS d_qoy#106, null AS d_moy#107, null AS s_store_id#108, sum(sumsales#99)#104 AS sumsales#109] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [sum(sumsales#107)] +Aggregate Attributes [1]: [sum(sumsales#107)#112] +Results [9]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, null AS d_year#113, null AS d_qoy#114, null AS d_moy#115, null AS s_store_id#116, sum(sumsales#107)#112 AS sumsales#117] (45) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] +Output [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] (46) HashAggregate [codegen id : 52] -Input [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] -Keys [8]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117] -Functions [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22] -Results [4]: [i_category#110, i_class#111, i_brand#112, sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22 AS sumsales#122] +Input [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] +Keys [8]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125] +Functions [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22] +Results [4]: [i_category#118, i_class#119, i_brand#120, sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22 AS sumsales#130] (47) HashAggregate [codegen id : 52] -Input [4]: [i_category#110, i_class#111, i_brand#112, sumsales#122] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [partial_sum(sumsales#122)] -Aggregate Attributes [2]: [sum#123, isEmpty#124] -Results [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Input [4]: [i_category#118, i_class#119, i_brand#120, sumsales#130] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [partial_sum(sumsales#130)] +Aggregate Attributes [2]: [sum#131, isEmpty#132] +Results [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] (48) Exchange -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Arguments: hashpartitioning(i_category#110, i_class#111, i_brand#112, 5), ENSURE_REQUIREMENTS, [plan_id=9] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Arguments: hashpartitioning(i_category#118, i_class#119, i_brand#120, 5), ENSURE_REQUIREMENTS, [plan_id=9] (49) HashAggregate [codegen id : 53] -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [sum(sumsales#122)] -Aggregate Attributes [1]: [sum(sumsales#122)#127] -Results [9]: [i_category#110, i_class#111, i_brand#112, null AS i_product_name#128, null AS d_year#129, null AS d_qoy#130, null AS d_moy#131, null AS s_store_id#132, sum(sumsales#122)#127 AS sumsales#133] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [sum(sumsales#130)] +Aggregate Attributes [1]: [sum(sumsales#130)#135] +Results [9]: [i_category#118, i_class#119, i_brand#120, null AS i_product_name#136, null AS d_year#137, null AS d_qoy#138, null AS d_moy#139, null AS s_store_id#140, sum(sumsales#130)#135 AS sumsales#141] (50) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] +Output [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] (51) HashAggregate [codegen id : 61] -Input [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] -Keys [8]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141] -Functions [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22] -Results [3]: [i_category#134, i_class#135, sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22 AS sumsales#146] +Input [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] +Keys [8]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149] +Functions [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22] +Results [3]: [i_category#142, i_class#143, sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22 AS sumsales#154] (52) HashAggregate [codegen id : 61] -Input [3]: [i_category#134, i_class#135, sumsales#146] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [partial_sum(sumsales#146)] -Aggregate Attributes [2]: [sum#147, isEmpty#148] -Results [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Input [3]: [i_category#142, i_class#143, sumsales#154] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [partial_sum(sumsales#154)] +Aggregate Attributes [2]: [sum#155, isEmpty#156] +Results [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] (53) Exchange -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Arguments: hashpartitioning(i_category#134, i_class#135, 5), ENSURE_REQUIREMENTS, [plan_id=10] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Arguments: hashpartitioning(i_category#142, i_class#143, 5), ENSURE_REQUIREMENTS, [plan_id=10] (54) HashAggregate [codegen id : 62] -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [sum(sumsales#146)] -Aggregate Attributes [1]: [sum(sumsales#146)#151] -Results [9]: [i_category#134, i_class#135, null AS i_brand#152, null AS i_product_name#153, null AS d_year#154, null AS d_qoy#155, null AS d_moy#156, null AS s_store_id#157, sum(sumsales#146)#151 AS sumsales#158] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [sum(sumsales#154)] +Aggregate Attributes [1]: [sum(sumsales#154)#159] +Results [9]: [i_category#142, i_class#143, null AS i_brand#160, null AS i_product_name#161, null AS d_year#162, null AS d_qoy#163, null AS d_moy#164, null AS s_store_id#165, sum(sumsales#154)#159 AS sumsales#166] (55) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] +Output [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] (56) HashAggregate [codegen id : 70] -Input [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] -Keys [8]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166] -Functions [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22] -Results [2]: [i_category#159, sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22 AS sumsales#171] +Input [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] +Keys [8]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174] +Functions [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22] +Results [2]: [i_category#167, sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22 AS sumsales#179] (57) HashAggregate [codegen id : 70] -Input [2]: [i_category#159, sumsales#171] -Keys [1]: [i_category#159] -Functions [1]: [partial_sum(sumsales#171)] -Aggregate Attributes [2]: [sum#172, isEmpty#173] -Results [3]: [i_category#159, sum#174, isEmpty#175] +Input [2]: [i_category#167, sumsales#179] +Keys [1]: [i_category#167] +Functions [1]: [partial_sum(sumsales#179)] +Aggregate Attributes [2]: [sum#180, isEmpty#181] +Results [3]: [i_category#167, sum#182, isEmpty#183] (58) Exchange -Input [3]: [i_category#159, sum#174, isEmpty#175] -Arguments: hashpartitioning(i_category#159, 5), ENSURE_REQUIREMENTS, [plan_id=11] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Arguments: hashpartitioning(i_category#167, 5), ENSURE_REQUIREMENTS, [plan_id=11] (59) HashAggregate [codegen id : 71] -Input [3]: [i_category#159, sum#174, isEmpty#175] -Keys [1]: [i_category#159] -Functions [1]: [sum(sumsales#171)] -Aggregate Attributes [1]: [sum(sumsales#171)#176] -Results [9]: [i_category#159, null AS i_class#177, null AS i_brand#178, null AS i_product_name#179, null AS d_year#180, null AS d_qoy#181, null AS d_moy#182, null AS s_store_id#183, sum(sumsales#171)#176 AS sumsales#184] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Keys [1]: [i_category#167] +Functions [1]: [sum(sumsales#179)] +Aggregate Attributes [1]: [sum(sumsales#179)#184] +Results [9]: [i_category#167, null AS i_class#185, null AS i_brand#186, null AS i_product_name#187, null AS d_year#188, null AS d_qoy#189, null AS d_moy#190, null AS s_store_id#191, sum(sumsales#179)#184 AS sumsales#192] (60) ReusedExchange [Reuses operator id: 23] -Output [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] +Output [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] (61) HashAggregate [codegen id : 79] -Input [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] -Keys [8]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192] -Functions [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22] -Results [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22 AS sumsales#197] +Input [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] +Keys [8]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200] +Functions [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22] +Results [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22 AS sumsales#205] (62) HashAggregate [codegen id : 79] -Input [1]: [sumsales#197] +Input [1]: [sumsales#205] Keys: [] -Functions [1]: [partial_sum(sumsales#197)] -Aggregate Attributes [2]: [sum#198, isEmpty#199] -Results [2]: [sum#200, isEmpty#201] +Functions [1]: [partial_sum(sumsales#205)] +Aggregate Attributes [2]: [sum#206, isEmpty#207] +Results [2]: [sum#208, isEmpty#209] (63) Exchange -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=12] (64) HashAggregate [codegen id : 80] -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Keys: [] -Functions [1]: [sum(sumsales#197)] -Aggregate Attributes [1]: [sum(sumsales#197)#202] -Results [9]: [null AS i_category#203, null AS i_class#204, null AS i_brand#205, null AS i_product_name#206, null AS d_year#207, null AS d_qoy#208, null AS d_moy#209, null AS s_store_id#210, sum(sumsales#197)#202 AS sumsales#211] +Functions [1]: [sum(sumsales#205)] +Aggregate Attributes [1]: [sum(sumsales#205)#210] +Results [9]: [null AS i_category#211, null AS i_class#212, null AS i_brand#213, null AS i_product_name#214, null AS d_year#215, null AS d_qoy#216, null AS d_moy#217, null AS s_store_id#218, sum(sumsales#205)#210 AS sumsales#219] (65) Union (66) Sort [codegen id : 81] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (67) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Partial +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Partial (68) Exchange -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: hashpartitioning(i_category#16, 5), ENSURE_REQUIREMENTS, [plan_id=13] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: hashpartitioning(i_category#23, 5), ENSURE_REQUIREMENTS, [plan_id=13] (69) Sort [codegen id : 82] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (70) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Final +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Final (71) Window -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [rank(sumsales#23) windowspecdefinition(i_category#16, sumsales#23 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#212], [i_category#16], [sumsales#23 DESC NULLS LAST] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [rank(sumsales#31) windowspecdefinition(i_category#23, sumsales#31 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#220], [i_category#23], [sumsales#31 DESC NULLS LAST] (72) Filter [codegen id : 83] -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Condition : (rk#212 <= 100) +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Condition : (rk#220 <= 100) (73) TakeOrderedAndProject -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Arguments: 100, [i_category#16 ASC NULLS FIRST, i_class#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, i_product_name#17 ASC NULLS FIRST, d_year#8 ASC NULLS FIRST, d_qoy#10 ASC NULLS FIRST, d_moy#9 ASC NULLS FIRST, s_store_id#12 ASC NULLS FIRST, sumsales#23 ASC NULLS FIRST, rk#212 ASC NULLS FIRST], [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Arguments: 100, [i_category#23 ASC NULLS FIRST, i_class#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_product_name#26 ASC NULLS FIRST, d_year#27 ASC NULLS FIRST, d_qoy#28 ASC NULLS FIRST, d_moy#29 ASC NULLS FIRST, s_store_id#30 ASC NULLS FIRST, sumsales#31 ASC NULLS FIRST, rk#220 ASC NULLS FIRST], [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] ===== Subqueries ===== @@ -457,22 +457,22 @@ BroadcastExchange (78) (74) Scan parquet spark_catalog.default.date_dim -Output [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Output [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (75) ColumnarToRow [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (76) Filter [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] -Condition : (((isnotnull(d_month_seq#213) AND (d_month_seq#213 >= 1212)) AND (d_month_seq#213 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Condition : (((isnotnull(d_month_seq#221) AND (d_month_seq#221 >= 1212)) AND (d_month_seq#221 <= 1223)) AND isnotnull(d_date_sk#7)) (77) Project [codegen id : 1] Output [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (78) BroadcastExchange Input [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt index 795fa297b9bad..b6a4358c4d43b 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a.sf100/simplified.txt @@ -14,7 +14,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Union WholeStageCodegen (8) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (7) diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt index 75d526da4ba71..417af4fe924ee 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/explain.txt @@ -171,265 +171,265 @@ Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, Keys [8]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12] Functions [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))] Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22] -Results [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#23] +Results [9]: [i_category#16 AS i_category#23, i_class#15 AS i_class#24, i_brand#14 AS i_brand#25, i_product_name#17 AS i_product_name#26, d_year#8 AS d_year#27, d_qoy#10 AS d_qoy#28, d_moy#9 AS d_moy#29, s_store_id#12 AS s_store_id#30, cast(sum(coalesce((ss_sales_price#4 * cast(ss_quantity#3 as decimal(10,0))), 0.00))#22 as decimal(38,2)) AS sumsales#31] (22) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] +Output [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] (23) HashAggregate [codegen id : 10] -Input [10]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31, sum#32, isEmpty#33] -Keys [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, s_store_id#31] -Functions [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22] -Results [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum(coalesce((ss_sales_price#34 * cast(ss_quantity#35 as decimal(10,0))), 0.00))#22 AS sumsales#36] +Input [10]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39, sum#40, isEmpty#41] +Keys [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, s_store_id#39] +Functions [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22] +Results [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum(coalesce((ss_sales_price#42 * cast(ss_quantity#43 as decimal(10,0))), 0.00))#22 AS sumsales#44] (24) HashAggregate [codegen id : 10] -Input [8]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sumsales#36] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [partial_sum(sumsales#36)] -Aggregate Attributes [2]: [sum#37, isEmpty#38] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] +Input [8]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sumsales#44] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [partial_sum(sumsales#44)] +Aggregate Attributes [2]: [sum#45, isEmpty#46] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] (25) Exchange -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Arguments: hashpartitioning(i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, 5), ENSURE_REQUIREMENTS, [plan_id=4] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Arguments: hashpartitioning(i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, 5), ENSURE_REQUIREMENTS, [plan_id=4] (26) HashAggregate [codegen id : 11] -Input [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, sum#39, isEmpty#40] -Keys [7]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30] -Functions [1]: [sum(sumsales#36)] -Aggregate Attributes [1]: [sum(sumsales#36)#41] -Results [9]: [i_category#24, i_class#25, i_brand#26, i_product_name#27, d_year#28, d_qoy#29, d_moy#30, null AS s_store_id#42, sum(sumsales#36)#41 AS sumsales#43] +Input [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, sum#47, isEmpty#48] +Keys [7]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38] +Functions [1]: [sum(sumsales#44)] +Aggregate Attributes [1]: [sum(sumsales#44)#49] +Results [9]: [i_category#32, i_class#33, i_brand#34, i_product_name#35, d_year#36, d_qoy#37, d_moy#38, null AS s_store_id#50, sum(sumsales#44)#49 AS sumsales#51] (27) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] +Output [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] (28) HashAggregate [codegen id : 16] -Input [10]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51, sum#52, isEmpty#53] -Keys [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, d_moy#50, s_store_id#51] -Functions [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22] -Results [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum(coalesce((ss_sales_price#54 * cast(ss_quantity#55 as decimal(10,0))), 0.00))#22 AS sumsales#56] +Input [10]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59, sum#60, isEmpty#61] +Keys [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, d_moy#58, s_store_id#59] +Functions [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22] +Results [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum(coalesce((ss_sales_price#62 * cast(ss_quantity#63 as decimal(10,0))), 0.00))#22 AS sumsales#64] (29) HashAggregate [codegen id : 16] -Input [7]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sumsales#56] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [partial_sum(sumsales#56)] -Aggregate Attributes [2]: [sum#57, isEmpty#58] -Results [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] +Input [7]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sumsales#64] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [partial_sum(sumsales#64)] +Aggregate Attributes [2]: [sum#65, isEmpty#66] +Results [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] (30) Exchange -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Arguments: hashpartitioning(i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, 5), ENSURE_REQUIREMENTS, [plan_id=5] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Arguments: hashpartitioning(i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, 5), ENSURE_REQUIREMENTS, [plan_id=5] (31) HashAggregate [codegen id : 17] -Input [8]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, sum#59, isEmpty#60] -Keys [6]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49] -Functions [1]: [sum(sumsales#56)] -Aggregate Attributes [1]: [sum(sumsales#56)#61] -Results [9]: [i_category#44, i_class#45, i_brand#46, i_product_name#47, d_year#48, d_qoy#49, null AS d_moy#62, null AS s_store_id#63, sum(sumsales#56)#61 AS sumsales#64] +Input [8]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, sum#67, isEmpty#68] +Keys [6]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57] +Functions [1]: [sum(sumsales#64)] +Aggregate Attributes [1]: [sum(sumsales#64)#69] +Results [9]: [i_category#52, i_class#53, i_brand#54, i_product_name#55, d_year#56, d_qoy#57, null AS d_moy#70, null AS s_store_id#71, sum(sumsales#64)#69 AS sumsales#72] (32) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] +Output [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] (33) HashAggregate [codegen id : 22] -Input [10]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72, sum#73, isEmpty#74] -Keys [8]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, d_qoy#70, d_moy#71, s_store_id#72] -Functions [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22] -Results [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum(coalesce((ss_sales_price#75 * cast(ss_quantity#76 as decimal(10,0))), 0.00))#22 AS sumsales#77] +Input [10]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80, sum#81, isEmpty#82] +Keys [8]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, d_qoy#78, d_moy#79, s_store_id#80] +Functions [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22] +Results [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum(coalesce((ss_sales_price#83 * cast(ss_quantity#84 as decimal(10,0))), 0.00))#22 AS sumsales#85] (34) HashAggregate [codegen id : 22] -Input [6]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sumsales#77] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [partial_sum(sumsales#77)] -Aggregate Attributes [2]: [sum#78, isEmpty#79] -Results [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] +Input [6]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sumsales#85] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [partial_sum(sumsales#85)] +Aggregate Attributes [2]: [sum#86, isEmpty#87] +Results [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] (35) Exchange -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Arguments: hashpartitioning(i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, 5), ENSURE_REQUIREMENTS, [plan_id=6] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Arguments: hashpartitioning(i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, 5), ENSURE_REQUIREMENTS, [plan_id=6] (36) HashAggregate [codegen id : 23] -Input [7]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, sum#80, isEmpty#81] -Keys [5]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69] -Functions [1]: [sum(sumsales#77)] -Aggregate Attributes [1]: [sum(sumsales#77)#82] -Results [9]: [i_category#65, i_class#66, i_brand#67, i_product_name#68, d_year#69, null AS d_qoy#83, null AS d_moy#84, null AS s_store_id#85, sum(sumsales#77)#82 AS sumsales#86] +Input [7]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, sum#88, isEmpty#89] +Keys [5]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77] +Functions [1]: [sum(sumsales#85)] +Aggregate Attributes [1]: [sum(sumsales#85)#90] +Results [9]: [i_category#73, i_class#74, i_brand#75, i_product_name#76, d_year#77, null AS d_qoy#91, null AS d_moy#92, null AS s_store_id#93, sum(sumsales#85)#90 AS sumsales#94] (37) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] +Output [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] (38) HashAggregate [codegen id : 28] -Input [10]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94, sum#95, isEmpty#96] -Keys [8]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, d_year#91, d_qoy#92, d_moy#93, s_store_id#94] -Functions [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22] -Results [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum(coalesce((ss_sales_price#97 * cast(ss_quantity#98 as decimal(10,0))), 0.00))#22 AS sumsales#99] +Input [10]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102, sum#103, isEmpty#104] +Keys [8]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, d_year#99, d_qoy#100, d_moy#101, s_store_id#102] +Functions [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22] +Results [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum(coalesce((ss_sales_price#105 * cast(ss_quantity#106 as decimal(10,0))), 0.00))#22 AS sumsales#107] (39) HashAggregate [codegen id : 28] -Input [5]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sumsales#99] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [partial_sum(sumsales#99)] -Aggregate Attributes [2]: [sum#100, isEmpty#101] -Results [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] +Input [5]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sumsales#107] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [partial_sum(sumsales#107)] +Aggregate Attributes [2]: [sum#108, isEmpty#109] +Results [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] (40) Exchange -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Arguments: hashpartitioning(i_category#87, i_class#88, i_brand#89, i_product_name#90, 5), ENSURE_REQUIREMENTS, [plan_id=7] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Arguments: hashpartitioning(i_category#95, i_class#96, i_brand#97, i_product_name#98, 5), ENSURE_REQUIREMENTS, [plan_id=7] (41) HashAggregate [codegen id : 29] -Input [6]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, sum#102, isEmpty#103] -Keys [4]: [i_category#87, i_class#88, i_brand#89, i_product_name#90] -Functions [1]: [sum(sumsales#99)] -Aggregate Attributes [1]: [sum(sumsales#99)#104] -Results [9]: [i_category#87, i_class#88, i_brand#89, i_product_name#90, null AS d_year#105, null AS d_qoy#106, null AS d_moy#107, null AS s_store_id#108, sum(sumsales#99)#104 AS sumsales#109] +Input [6]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, sum#110, isEmpty#111] +Keys [4]: [i_category#95, i_class#96, i_brand#97, i_product_name#98] +Functions [1]: [sum(sumsales#107)] +Aggregate Attributes [1]: [sum(sumsales#107)#112] +Results [9]: [i_category#95, i_class#96, i_brand#97, i_product_name#98, null AS d_year#113, null AS d_qoy#114, null AS d_moy#115, null AS s_store_id#116, sum(sumsales#107)#112 AS sumsales#117] (42) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] +Output [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] (43) HashAggregate [codegen id : 34] -Input [10]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117, sum#118, isEmpty#119] -Keys [8]: [i_category#110, i_class#111, i_brand#112, i_product_name#113, d_year#114, d_qoy#115, d_moy#116, s_store_id#117] -Functions [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22] -Results [4]: [i_category#110, i_class#111, i_brand#112, sum(coalesce((ss_sales_price#120 * cast(ss_quantity#121 as decimal(10,0))), 0.00))#22 AS sumsales#122] +Input [10]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125, sum#126, isEmpty#127] +Keys [8]: [i_category#118, i_class#119, i_brand#120, i_product_name#121, d_year#122, d_qoy#123, d_moy#124, s_store_id#125] +Functions [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22] +Results [4]: [i_category#118, i_class#119, i_brand#120, sum(coalesce((ss_sales_price#128 * cast(ss_quantity#129 as decimal(10,0))), 0.00))#22 AS sumsales#130] (44) HashAggregate [codegen id : 34] -Input [4]: [i_category#110, i_class#111, i_brand#112, sumsales#122] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [partial_sum(sumsales#122)] -Aggregate Attributes [2]: [sum#123, isEmpty#124] -Results [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] +Input [4]: [i_category#118, i_class#119, i_brand#120, sumsales#130] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [partial_sum(sumsales#130)] +Aggregate Attributes [2]: [sum#131, isEmpty#132] +Results [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] (45) Exchange -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Arguments: hashpartitioning(i_category#110, i_class#111, i_brand#112, 5), ENSURE_REQUIREMENTS, [plan_id=8] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Arguments: hashpartitioning(i_category#118, i_class#119, i_brand#120, 5), ENSURE_REQUIREMENTS, [plan_id=8] (46) HashAggregate [codegen id : 35] -Input [5]: [i_category#110, i_class#111, i_brand#112, sum#125, isEmpty#126] -Keys [3]: [i_category#110, i_class#111, i_brand#112] -Functions [1]: [sum(sumsales#122)] -Aggregate Attributes [1]: [sum(sumsales#122)#127] -Results [9]: [i_category#110, i_class#111, i_brand#112, null AS i_product_name#128, null AS d_year#129, null AS d_qoy#130, null AS d_moy#131, null AS s_store_id#132, sum(sumsales#122)#127 AS sumsales#133] +Input [5]: [i_category#118, i_class#119, i_brand#120, sum#133, isEmpty#134] +Keys [3]: [i_category#118, i_class#119, i_brand#120] +Functions [1]: [sum(sumsales#130)] +Aggregate Attributes [1]: [sum(sumsales#130)#135] +Results [9]: [i_category#118, i_class#119, i_brand#120, null AS i_product_name#136, null AS d_year#137, null AS d_qoy#138, null AS d_moy#139, null AS s_store_id#140, sum(sumsales#130)#135 AS sumsales#141] (47) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] +Output [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] (48) HashAggregate [codegen id : 40] -Input [10]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141, sum#142, isEmpty#143] -Keys [8]: [i_category#134, i_class#135, i_brand#136, i_product_name#137, d_year#138, d_qoy#139, d_moy#140, s_store_id#141] -Functions [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22] -Results [3]: [i_category#134, i_class#135, sum(coalesce((ss_sales_price#144 * cast(ss_quantity#145 as decimal(10,0))), 0.00))#22 AS sumsales#146] +Input [10]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149, sum#150, isEmpty#151] +Keys [8]: [i_category#142, i_class#143, i_brand#144, i_product_name#145, d_year#146, d_qoy#147, d_moy#148, s_store_id#149] +Functions [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22] +Results [3]: [i_category#142, i_class#143, sum(coalesce((ss_sales_price#152 * cast(ss_quantity#153 as decimal(10,0))), 0.00))#22 AS sumsales#154] (49) HashAggregate [codegen id : 40] -Input [3]: [i_category#134, i_class#135, sumsales#146] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [partial_sum(sumsales#146)] -Aggregate Attributes [2]: [sum#147, isEmpty#148] -Results [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] +Input [3]: [i_category#142, i_class#143, sumsales#154] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [partial_sum(sumsales#154)] +Aggregate Attributes [2]: [sum#155, isEmpty#156] +Results [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] (50) Exchange -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Arguments: hashpartitioning(i_category#134, i_class#135, 5), ENSURE_REQUIREMENTS, [plan_id=9] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Arguments: hashpartitioning(i_category#142, i_class#143, 5), ENSURE_REQUIREMENTS, [plan_id=9] (51) HashAggregate [codegen id : 41] -Input [4]: [i_category#134, i_class#135, sum#149, isEmpty#150] -Keys [2]: [i_category#134, i_class#135] -Functions [1]: [sum(sumsales#146)] -Aggregate Attributes [1]: [sum(sumsales#146)#151] -Results [9]: [i_category#134, i_class#135, null AS i_brand#152, null AS i_product_name#153, null AS d_year#154, null AS d_qoy#155, null AS d_moy#156, null AS s_store_id#157, sum(sumsales#146)#151 AS sumsales#158] +Input [4]: [i_category#142, i_class#143, sum#157, isEmpty#158] +Keys [2]: [i_category#142, i_class#143] +Functions [1]: [sum(sumsales#154)] +Aggregate Attributes [1]: [sum(sumsales#154)#159] +Results [9]: [i_category#142, i_class#143, null AS i_brand#160, null AS i_product_name#161, null AS d_year#162, null AS d_qoy#163, null AS d_moy#164, null AS s_store_id#165, sum(sumsales#154)#159 AS sumsales#166] (52) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] +Output [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] (53) HashAggregate [codegen id : 46] -Input [10]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166, sum#167, isEmpty#168] -Keys [8]: [i_category#159, i_class#160, i_brand#161, i_product_name#162, d_year#163, d_qoy#164, d_moy#165, s_store_id#166] -Functions [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22] -Results [2]: [i_category#159, sum(coalesce((ss_sales_price#169 * cast(ss_quantity#170 as decimal(10,0))), 0.00))#22 AS sumsales#171] +Input [10]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174, sum#175, isEmpty#176] +Keys [8]: [i_category#167, i_class#168, i_brand#169, i_product_name#170, d_year#171, d_qoy#172, d_moy#173, s_store_id#174] +Functions [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22] +Results [2]: [i_category#167, sum(coalesce((ss_sales_price#177 * cast(ss_quantity#178 as decimal(10,0))), 0.00))#22 AS sumsales#179] (54) HashAggregate [codegen id : 46] -Input [2]: [i_category#159, sumsales#171] -Keys [1]: [i_category#159] -Functions [1]: [partial_sum(sumsales#171)] -Aggregate Attributes [2]: [sum#172, isEmpty#173] -Results [3]: [i_category#159, sum#174, isEmpty#175] +Input [2]: [i_category#167, sumsales#179] +Keys [1]: [i_category#167] +Functions [1]: [partial_sum(sumsales#179)] +Aggregate Attributes [2]: [sum#180, isEmpty#181] +Results [3]: [i_category#167, sum#182, isEmpty#183] (55) Exchange -Input [3]: [i_category#159, sum#174, isEmpty#175] -Arguments: hashpartitioning(i_category#159, 5), ENSURE_REQUIREMENTS, [plan_id=10] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Arguments: hashpartitioning(i_category#167, 5), ENSURE_REQUIREMENTS, [plan_id=10] (56) HashAggregate [codegen id : 47] -Input [3]: [i_category#159, sum#174, isEmpty#175] -Keys [1]: [i_category#159] -Functions [1]: [sum(sumsales#171)] -Aggregate Attributes [1]: [sum(sumsales#171)#176] -Results [9]: [i_category#159, null AS i_class#177, null AS i_brand#178, null AS i_product_name#179, null AS d_year#180, null AS d_qoy#181, null AS d_moy#182, null AS s_store_id#183, sum(sumsales#171)#176 AS sumsales#184] +Input [3]: [i_category#167, sum#182, isEmpty#183] +Keys [1]: [i_category#167] +Functions [1]: [sum(sumsales#179)] +Aggregate Attributes [1]: [sum(sumsales#179)#184] +Results [9]: [i_category#167, null AS i_class#185, null AS i_brand#186, null AS i_product_name#187, null AS d_year#188, null AS d_qoy#189, null AS d_moy#190, null AS s_store_id#191, sum(sumsales#179)#184 AS sumsales#192] (57) ReusedExchange [Reuses operator id: 20] -Output [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] +Output [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] (58) HashAggregate [codegen id : 52] -Input [10]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192, sum#193, isEmpty#194] -Keys [8]: [i_category#185, i_class#186, i_brand#187, i_product_name#188, d_year#189, d_qoy#190, d_moy#191, s_store_id#192] -Functions [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))] -Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22] -Results [1]: [sum(coalesce((ss_sales_price#195 * cast(ss_quantity#196 as decimal(10,0))), 0.00))#22 AS sumsales#197] +Input [10]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200, sum#201, isEmpty#202] +Keys [8]: [i_category#193, i_class#194, i_brand#195, i_product_name#196, d_year#197, d_qoy#198, d_moy#199, s_store_id#200] +Functions [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))] +Aggregate Attributes [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22] +Results [1]: [sum(coalesce((ss_sales_price#203 * cast(ss_quantity#204 as decimal(10,0))), 0.00))#22 AS sumsales#205] (59) HashAggregate [codegen id : 52] -Input [1]: [sumsales#197] +Input [1]: [sumsales#205] Keys: [] -Functions [1]: [partial_sum(sumsales#197)] -Aggregate Attributes [2]: [sum#198, isEmpty#199] -Results [2]: [sum#200, isEmpty#201] +Functions [1]: [partial_sum(sumsales#205)] +Aggregate Attributes [2]: [sum#206, isEmpty#207] +Results [2]: [sum#208, isEmpty#209] (60) Exchange -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Arguments: SinglePartition, ENSURE_REQUIREMENTS, [plan_id=11] (61) HashAggregate [codegen id : 53] -Input [2]: [sum#200, isEmpty#201] +Input [2]: [sum#208, isEmpty#209] Keys: [] -Functions [1]: [sum(sumsales#197)] -Aggregate Attributes [1]: [sum(sumsales#197)#202] -Results [9]: [null AS i_category#203, null AS i_class#204, null AS i_brand#205, null AS i_product_name#206, null AS d_year#207, null AS d_qoy#208, null AS d_moy#209, null AS s_store_id#210, sum(sumsales#197)#202 AS sumsales#211] +Functions [1]: [sum(sumsales#205)] +Aggregate Attributes [1]: [sum(sumsales#205)#210] +Results [9]: [null AS i_category#211, null AS i_class#212, null AS i_brand#213, null AS i_product_name#214, null AS d_year#215, null AS d_qoy#216, null AS d_moy#217, null AS s_store_id#218, sum(sumsales#205)#210 AS sumsales#219] (62) Union (63) Sort [codegen id : 54] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (64) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Partial +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Partial (65) Exchange -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: hashpartitioning(i_category#16, 5), ENSURE_REQUIREMENTS, [plan_id=12] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: hashpartitioning(i_category#23, 5), ENSURE_REQUIREMENTS, [plan_id=12] (66) Sort [codegen id : 55] -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16 ASC NULLS FIRST, sumsales#23 DESC NULLS LAST], false, 0 +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23 ASC NULLS FIRST, sumsales#31 DESC NULLS LAST], false, 0 (67) WindowGroupLimit -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [i_category#16], [sumsales#23 DESC NULLS LAST], rank(sumsales#23), 100, Final +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [i_category#23], [sumsales#31 DESC NULLS LAST], rank(sumsales#31), 100, Final (68) Window -Input [9]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23] -Arguments: [rank(sumsales#23) windowspecdefinition(i_category#16, sumsales#23 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#212], [i_category#16], [sumsales#23 DESC NULLS LAST] +Input [9]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31] +Arguments: [rank(sumsales#31) windowspecdefinition(i_category#23, sumsales#31 DESC NULLS LAST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rk#220], [i_category#23], [sumsales#31 DESC NULLS LAST] (69) Filter [codegen id : 56] -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Condition : (rk#212 <= 100) +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Condition : (rk#220 <= 100) (70) TakeOrderedAndProject -Input [10]: [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] -Arguments: 100, [i_category#16 ASC NULLS FIRST, i_class#15 ASC NULLS FIRST, i_brand#14 ASC NULLS FIRST, i_product_name#17 ASC NULLS FIRST, d_year#8 ASC NULLS FIRST, d_qoy#10 ASC NULLS FIRST, d_moy#9 ASC NULLS FIRST, s_store_id#12 ASC NULLS FIRST, sumsales#23 ASC NULLS FIRST, rk#212 ASC NULLS FIRST], [i_category#16, i_class#15, i_brand#14, i_product_name#17, d_year#8, d_qoy#10, d_moy#9, s_store_id#12, sumsales#23, rk#212] +Input [10]: [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] +Arguments: 100, [i_category#23 ASC NULLS FIRST, i_class#24 ASC NULLS FIRST, i_brand#25 ASC NULLS FIRST, i_product_name#26 ASC NULLS FIRST, d_year#27 ASC NULLS FIRST, d_qoy#28 ASC NULLS FIRST, d_moy#29 ASC NULLS FIRST, s_store_id#30 ASC NULLS FIRST, sumsales#31 ASC NULLS FIRST, rk#220 ASC NULLS FIRST], [i_category#23, i_class#24, i_brand#25, i_product_name#26, d_year#27, d_qoy#28, d_moy#29, s_store_id#30, sumsales#31, rk#220] ===== Subqueries ===== @@ -442,22 +442,22 @@ BroadcastExchange (75) (71) Scan parquet spark_catalog.default.date_dim -Output [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Output [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] Batched: true Location [not included in comparison]/{warehouse_dir}/date_dim] PushedFilters: [IsNotNull(d_month_seq), GreaterThanOrEqual(d_month_seq,1212), LessThanOrEqual(d_month_seq,1223), IsNotNull(d_date_sk)] ReadSchema: struct (72) ColumnarToRow [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (73) Filter [codegen id : 1] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] -Condition : (((isnotnull(d_month_seq#213) AND (d_month_seq#213 >= 1212)) AND (d_month_seq#213 <= 1223)) AND isnotnull(d_date_sk#7)) +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] +Condition : (((isnotnull(d_month_seq#221) AND (d_month_seq#221 >= 1212)) AND (d_month_seq#221 <= 1223)) AND isnotnull(d_date_sk#7)) (74) Project [codegen id : 1] Output [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] -Input [5]: [d_date_sk#7, d_month_seq#213, d_year#8, d_moy#9, d_qoy#10] +Input [5]: [d_date_sk#7, d_month_seq#221, d_year#8, d_moy#9, d_qoy#10] (75) BroadcastExchange Input [4]: [d_date_sk#7, d_year#8, d_moy#9, d_qoy#10] diff --git a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt index 89393f265a49f..5a43dced056bd 100644 --- a/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt +++ b/sql/core/src/test/resources/tpcds-plan-stability/approved-plans-v2_7/q67a/simplified.txt @@ -14,7 +14,7 @@ TakeOrderedAndProject [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_ InputAdapter Union WholeStageCodegen (5) - HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),sumsales,sum,isEmpty] + HashAggregate [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sum,isEmpty] [sum(coalesce((ss_sales_price * cast(ss_quantity as decimal(10,0))), 0.00)),i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id,sumsales,sum,isEmpty] InputAdapter Exchange [i_category,i_class,i_brand,i_product_name,d_year,d_qoy,d_moy,s_store_id] #2 WholeStageCodegen (4) From 6ed4bdf7fdbaa4699ce75de95dc94afe61483b77 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Thu, 10 Oct 2024 08:23:55 +0900 Subject: [PATCH 198/250] [MINOR][PYTHON][TESTS] Reduce the python worker error log of `test_toDF_with_schema_string` ### What changes were proposed in this pull request? Reduce the python worker error log of `test_toDF_with_schema_string` ### Why are the changes needed? When I run the test locally ```python python/run-tests -k --python-executables python3 --testnames 'pyspark.sql.tests.test_dataframe' ``` Two assertions in `test_toDF_with_schema_string` generate too many python worker error logs (~1k lines), which easily exceed the limitation of terminal and make it hard to debug. So I want to reduce the number of python workers in the two assertions. ### Does this PR introduce _any_ user-facing change? no, test only ### How was this patch tested? manually test, the logs will be reduced to ~200 lines ### Was this patch authored or co-authored using generative AI tooling? no Closes #48388 from zhengruifeng/test_to_df_error. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_dataframe.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index cc43804949e84..4fb3e7a9192c2 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -506,14 +506,16 @@ def test_toDF_with_schema_string(self): # number of fields must match. self.assertRaisesRegex( - Exception, "FIELD_STRUCT_LENGTH_MISMATCH", lambda: rdd.toDF("key: int").collect() + Exception, + "FIELD_STRUCT_LENGTH_MISMATCH", + lambda: rdd.coalesce(1).toDF("key: int").collect(), ) # field types mismatch will cause exception at runtime. self.assertRaisesRegex( Exception, "FIELD_DATA_TYPE_UNACCEPTABLE", - lambda: rdd.toDF("key: float, value: string").collect(), + lambda: rdd.coalesce(1).toDF("key: float, value: string").collect(), ) # flat schema values will be wrapped into row. From 7e82e290fb051afd47a542eb6c3509d9bb654ab9 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Thu, 10 Oct 2024 09:52:02 +0900 Subject: [PATCH 199/250] [SPARK-49905][SQL][SS] Use different ShuffleOrigin for the shuffle required from stateful operators ### What changes were proposed in this pull request? This PR proposes to use different ShuffleOrigin for the shuffle required from stateful operators. Spark has been using ENSURE_REQUIREMENTS as ShuffleOrigin which is open for optimization e.g. AQE can adjust the shuffle spec. Quoting the code of ENSURE_REQUIREMENTS: ``` // Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule. It // means that the shuffle operator is used to ensure internal data partitioning requirements and // Spark is free to optimize it as long as the requirements are still ensured. case object ENSURE_REQUIREMENTS extends ShuffleOrigin ``` But the distribution requirement for stateful operators is lot more strict - it has to use the all expressions to calculate the hash (for partitioning) and the number of shuffle partitions must be the same with the spec. This is because stateful operator assumes that there is 1:1 mapping between the partition for the operator and the "physical" partition for checkpointed state. That said, it is fragile if we allow any optimization to be made against shuffle for stateful operator. To prevent this, this PR introduces a new ShuffleOrigin with note that the shuffle is not expected to be "modified". ### Why are the changes needed? This exposes a possibility of broken state based on the contract. We introduced StatefulOpClusteredDistribution in similar reason. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New UT added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48382 from HeartSaVioR/SPARK-49905. Authored-by: Jungtaek Lim Signed-off-by: Jungtaek Lim --- .../exchange/EnsureRequirements.scala | 11 ++++++++- .../exchange/ShuffleExchangeExec.scala | 5 ++++ .../sql/streaming/StreamingQuerySuite.scala | 24 ++++++++++++++++++- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index e669165f4f2f8..8ec903f8e61da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -70,7 +70,16 @@ case class EnsureRequirements( case (child, distribution) => val numPartitions = distribution.requiredNumPartitions .getOrElse(conf.numShufflePartitions) - ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child, shuffleOrigin) + distribution match { + case _: StatefulOpClusteredDistribution => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, + REQUIRED_BY_STATEFUL_OPERATOR) + + case _ => + ShuffleExchangeExec( + distribution.createPartitioning(numPartitions), child, shuffleOrigin) + } } // Get the indexes of children which have specified distribution requirements and need to be diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index ae11229cd516e..31a3f53eb7191 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -177,6 +177,11 @@ case object REBALANCE_PARTITIONS_BY_NONE extends ShuffleOrigin // the output needs to be partitioned by the given columns. case object REBALANCE_PARTITIONS_BY_COL extends ShuffleOrigin +// Indicates that the shuffle operator was added by the internal `EnsureRequirements` rule, but +// was required by a stateful operator. The physical partitioning is static and Spark shouldn't +// change it. +case object REQUIRED_BY_STATEFUL_OPERATOR extends ShuffleOrigin + /** * Performs a shuffle that will result in the desired partitioning. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 8471995cb1e50..c12846d7512d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2, ReadLimit} -import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.spark.sql.execution.exchange.{REQUIRED_BY_STATEFUL_OPERATOR, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter} import org.apache.spark.sql.functions._ @@ -1448,6 +1448,28 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-49905 shuffle added by stateful operator should use the shuffle origin " + + "`REQUIRED_BY_STATEFUL_OPERATOR`") { + val inputData = MemoryStream[Int] + + // Use the streaming aggregation as an example - all stateful operators are using the same + // distribution, named `StatefulOpClusteredDistribution`. + val df = inputData.toDF().groupBy("value").count() + + testStream(df, OutputMode.Update())( + AddData(inputData, 1, 2, 3, 1, 2, 3), + CheckAnswer((1, 2), (2, 2), (3, 2)), + Execute { qe => + val shuffleOpt = qe.lastExecution.executedPlan.collect { + case s: ShuffleExchangeExec => s + } + + assert(shuffleOpt.nonEmpty, "No shuffle exchange found in the query plan") + assert(shuffleOpt.head.shuffleOrigin === REQUIRED_BY_STATEFUL_OPERATOR) + } + ) + } + private def checkAppendOutputModeException(df: DataFrame): Unit = { withTempDir { outputDir => withTempDir { checkpointDir => From ea60e935fb11cb17acd3c1883ed57c52da0a1f7c Mon Sep 17 00:00:00 2001 From: Prashanth Menon Date: Thu, 10 Oct 2024 09:28:20 +0800 Subject: [PATCH 200/250] [SPARK-49918][CORE] Use read-only access to conf in `SparkContext` where appropriate ### What changes were proposed in this pull request? This PR switches all calls to `SparkContext.getConf` that are read-only to use `SparkContext.conf` instead. The former method clones the conf, which is unnecessary when the caller only reads the conf. `SparkContext.conf` provides read-only access to the conf. ### Why are the changes needed? Cloning the entire conf adds unnecessary CPU overhead due to copying, and GC overhead due to cleanup, and both affect tail latencies on certain workloads. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #48402 from pmenon/getconf-optimizations. Authored-by: Prashanth Menon Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/Dependency.scala | 2 +- .../spark/input/PortableDataStream.scala | 4 +- .../apache/spark/scheduler/DAGScheduler.scala | 43 +++++++++---------- .../apache/spark/ui/ConsoleProgressBar.scala | 2 +- .../connect/service/SparkConnectService.scala | 2 +- .../org/apache/spark/sql/SparkSession.scala | 2 +- .../execution/datasources/v2/FileScan.scala | 2 +- 7 files changed, 28 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 3b7c7778e26ce..573608c4327e0 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -173,7 +173,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( } private def canShuffleMergeBeEnabled(): Boolean = { - val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf, + val isPushShuffleEnabled = Utils.isPushBasedShuffleEnabled(rdd.sparkContext.conf, // invoked at driver isDriver = true) if (isPushShuffleEnabled && rdd.isBarrier()) { diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index f0d6cba6ae734..3c3017a9a64c1 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -45,8 +45,8 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(sc: SparkContext, context: JobContext, minPartitions: Int): Unit = { - val defaultMaxSplitBytes = sc.getConf.get(config.FILES_MAX_PARTITION_BYTES) - val openCostInBytes = sc.getConf.get(config.FILES_OPEN_COST_IN_BYTES) + val defaultMaxSplitBytes = sc.conf.get(config.FILES_MAX_PARTITION_BYTES) + val openCostInBytes = sc.conf.get(config.FILES_OPEN_COST_IN_BYTES) val defaultParallelism = Math.max(sc.defaultParallelism, minPartitions) val files = listStatus(context).asScala val totalBytes = files.filterNot(_.isDirectory).map(_.getLen + openCostInBytes).sum diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2c89fe7885d08..f41888320a29b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -174,7 +174,7 @@ private[spark] class DAGScheduler( // `NUM_CANCELLED_JOB_GROUPS_TO_TRACK` stored. On a new job submission, if its job group is in // this set, the job will be immediately cancelled. private[scheduler] val cancelledJobGroups = - new LimitedSizeFIFOSet[String](sc.getConf.get(config.NUM_CANCELLED_JOB_GROUPS_TO_TRACK)) + new LimitedSizeFIFOSet[String](sc.conf.get(config.NUM_CANCELLED_JOB_GROUPS_TO_TRACK)) /** * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids @@ -224,9 +224,9 @@ private[spark] class DAGScheduler( private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ - private val disallowStageRetryForTest = sc.getConf.get(TEST_NO_STAGE_RETRY) + private val disallowStageRetryForTest = sc.conf.get(TEST_NO_STAGE_RETRY) - private val shouldMergeResourceProfiles = sc.getConf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) + private val shouldMergeResourceProfiles = sc.conf.get(config.RESOURCE_PROFILE_MERGE_CONFLICTS) /** * Whether to unregister all the outputs on the host in condition that we receive a FetchFailure, @@ -234,19 +234,19 @@ private[spark] class DAGScheduler( * executor(instead of the host) on a FetchFailure. */ private[scheduler] val unRegisterOutputOnHostOnFetchFailure = - sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) + sc.conf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE) /** * Number of consecutive stage attempts allowed before a stage is aborted. */ private[scheduler] val maxConsecutiveStageAttempts = - sc.getConf.get(config.STAGE_MAX_CONSECUTIVE_ATTEMPTS) + sc.conf.get(config.STAGE_MAX_CONSECUTIVE_ATTEMPTS) /** * Max stage attempts allowed before a stage is aborted. */ private[scheduler] val maxStageAttempts: Int = { - Math.max(maxConsecutiveStageAttempts, sc.getConf.get(config.STAGE_MAX_ATTEMPTS)) + Math.max(maxConsecutiveStageAttempts, sc.conf.get(config.STAGE_MAX_ATTEMPTS)) } /** @@ -254,7 +254,7 @@ private[spark] class DAGScheduler( * count spark.stage.maxConsecutiveAttempts */ private[scheduler] val ignoreDecommissionFetchFailure = - sc.getConf.get(config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE) + sc.conf.get(config.STAGE_IGNORE_DECOMMISSION_FETCH_FAILURE) /** * Number of max concurrent tasks check failures for each barrier job. @@ -264,14 +264,14 @@ private[spark] class DAGScheduler( /** * Time in seconds to wait between a max concurrent tasks check failure and the next check. */ - private val timeIntervalNumTasksCheck = sc.getConf + private val timeIntervalNumTasksCheck = sc.conf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_INTERVAL) /** * Max number of max concurrent tasks check failures allowed for a job before fail the job * submission. */ - private val maxFailureNumTasksCheck = sc.getConf + private val maxFailureNumTasksCheck = sc.conf .get(config.BARRIER_MAX_CONCURRENT_TASKS_CHECK_MAX_FAILURES) private val messageScheduler = @@ -286,26 +286,26 @@ private[spark] class DAGScheduler( taskScheduler.setDAGScheduler(this) - private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf, isDriver = true) + private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.conf, isDriver = true) private val blockManagerMasterDriverHeartbeatTimeout = - sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis + sc.conf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis private val shuffleMergeResultsTimeoutSec = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) private val shuffleMergeFinalizeWaitSec = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) private val shuffleMergeWaitMinSizeThreshold = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT) + sc.conf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT) - private val shufflePushMinRatio = sc.getConf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO) + private val shufflePushMinRatio = sc.conf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO) private val shuffleMergeFinalizeNumThreads = - sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS) + sc.conf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS) - private val shuffleFinalizeRpcThreads = sc.getConf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) + private val shuffleFinalizeRpcThreads = sc.conf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be // initialized lazily @@ -328,11 +328,10 @@ private[spark] class DAGScheduler( ThreadUtils.newDaemonFixedThreadPool(shuffleFinalizeRpcThreads, "shuffle-merge-finalize-rpc") /** Whether rdd cache visibility tracking is enabled. */ - private val trackingCacheVisibility: Boolean = - sc.getConf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED) + private val trackingCacheVisibility: Boolean = sc.conf.get(RDD_CACHE_VISIBILITY_TRACKING_ENABLED) /** Whether to abort a stage after canceling all of its tasks. */ - private val legacyAbortStageAfterKillTasks = sc.getConf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS) + private val legacyAbortStageAfterKillTasks = sc.conf.get(LEGACY_ABORT_STAGE_AFTER_KILL_TASKS) /** * Called by the TaskSetManager to report task's starting. @@ -557,7 +556,7 @@ private[spark] class DAGScheduler( * TODO SPARK-24942 Improve cluster resource management with jobs containing barrier stage */ private def checkBarrierStageWithDynamicAllocation(rdd: RDD[_]): Unit = { - if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.getConf)) { + if (rdd.isBarrier() && Utils.isDynamicAllocationEnabled(sc.conf)) { throw SparkCoreErrors.barrierStageWithDynamicAllocationError() } } @@ -2163,7 +2162,7 @@ private[spark] class DAGScheduler( case mapStage: ShuffleMapStage => val numMissingPartitions = mapStage.findMissingPartitions().length if (numMissingPartitions < mapStage.numTasks) { - if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + if (sc.conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { val reason = "A shuffle map stage with indeterminate output was failed " + "and retried. However, Spark can only do this while using the new " + "shuffle block fetching protocol. Please check the config " + diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala index 7a2b7d9caec42..fc7a4675429aa 100644 --- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -35,7 +35,7 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { // Carriage return private val CR = '\r' // Update period of progress bar, in milliseconds - private val updatePeriodMSec = sc.getConf.get(UI_CONSOLE_PROGRESS_UPDATE_INTERVAL) + private val updatePeriodMSec = sc.conf.get(UI_CONSOLE_PROGRESS_UPDATE_INTERVAL) // Delay to show up a progress bar, in milliseconds private val firstDelayMSec = 500L diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 0468a55e23027..e62c19b66c8e5 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -345,7 +345,7 @@ object SparkConnectService extends Logging { val kvStore = sc.statusStore.store.asInstanceOf[ElementTrackingStore] listener = new SparkConnectServerListener(kvStore, sc.conf) sc.listenerBus.addToStatusQueue(listener) - uiTab = if (sc.getConf.get(UI_ENABLED)) { + uiTab = if (sc.conf.get(UI_ENABLED)) { Some( new SparkConnectServerTab( new SparkConnectServerAppStatusStore(kvStore), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 55525380aee55..99ab3ca69fb20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -1100,7 +1100,7 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { private def applyExtensions( sparkContext: SparkContext, extensions: SparkSessionExtensions): SparkSessionExtensions = { - val extensionConfClassNames = sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + val extensionConfClassNames = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) .getOrElse(Seq.empty) extensionConfClassNames.foreach { extensionConfClassName => try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index d890107277d6c..5c0f8c0a4afd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -164,7 +164,7 @@ trait FileScan extends Scan if (splitFiles.length == 1) { val path = splitFiles(0).toPath if (!isSplitable(path) && splitFiles(0).length > - sparkSession.sparkContext.getConf.get(IO_WARNING_LARGEFILETHRESHOLD)) { + sparkSession.sparkContext.conf.get(IO_WARNING_LARGEFILETHRESHOLD)) { logWarning(log"Loading one large unsplittable file ${MDC(PATH, path.toString)} with only " + log"one partition, the reason is: ${MDC(REASON, getFileUnSplittableReason(path))}") } From 0912217b38f474e758f5d0161a24d732a4d58db5 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 10 Oct 2024 11:28:21 +0900 Subject: [PATCH 201/250] [SPARK-48714][PYTHON][FOLLOW-UP] Skip tests if test class is not available ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/47086 that makes PySpark tests skipped if tests classes are unavailable. ### Why are the changes needed? `./build/sbt package` should be able to run PySpark tests according to https://spark.apache.org/developer-tools.html and https://spark.apache.org/docs/latest/api/python/development/testing.html ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually tested: ```bash build/sbt -Phive clean package python/run-tests -k --python-executables python3 --testnames 'pyspark.sql.tests.test_dataframe' ``` ``` ... Starting test(python3): pyspark.sql.tests.test_dataframe (temp output: /.../spark/python/target/33eca5b9-23e8-4e95-9eca-9f09ce333336/python3__pyspark.sql.tests.test_dataframe__3saz2ymf.log) Finished test(python3): pyspark.sql.tests.test_dataframe (21s) ... 1 tests were skipped Tests passed in 21 seconds Skipped tests in pyspark.sql.tests.test_dataframe with python3: test_df_merge_into (pyspark.sql.tests.test_dataframe.DataFrameTests.test_df_merge_into) ... skip (0.001s) ``` ```bash build/sbt -Phive clean test:package python/run-tests -k --python-executables python3 --testnames 'pyspark.sql.tests.test_dataframe' ``` ``` Starting test(python3): pyspark.sql.tests.test_dataframe (temp output: /.../spark/python/target/710cf488-f39d-49b4-8b04-70044318ea02/python3__pyspark.sql.tests.test_dataframe___95hp4wt.log) Finished test(python3): pyspark.sql.tests.test_dataframe (23s) Tests passed in 23 seconds ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48404 from HyukjinKwon/SPARK-48714-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/test_dataframe.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 4fb3e7a9192c2..2f53ca38743c1 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -15,6 +15,8 @@ # limitations under the License. # +import glob +import os import pydoc import shutil import tempfile @@ -47,6 +49,7 @@ pandas_requirement_message, pyarrow_requirement_message, ) +from pyspark.testing.utils import SPARK_HOME class DataFrameTestsMixin: @@ -779,6 +782,16 @@ def test_df_show(self): ) def test_df_merge_into(self): + filename_pattern = ( + "sql/catalyst/target/scala-*/test-classes/org/apache/spark/sql/connector/catalog/" + "InMemoryRowLevelOperationTableCatalog.class" + ) + if not bool(glob.glob(os.path.join(SPARK_HOME, filename_pattern))): + raise unittest.SkipTest( + "org.apache.spark.sql.connector.catalog.InMemoryRowLevelOperationTableCatalog' " + "is not available. Will skip the related tests" + ) + try: # InMemoryRowLevelOperationTableCatalog is a test catalog that is included in the # catalyst-test package. If Spark complains that it can't find this class, make sure From d7772f27b8a851c5007d7d9c891bb862233bf7fa Mon Sep 17 00:00:00 2001 From: LantaoJin Date: Thu, 10 Oct 2024 12:20:13 +0800 Subject: [PATCH 202/250] [SPARK-49782][SQL] ResolveDataFrameDropColumns rule resolves UnresolvedAttribute with child output ### What changes were proposed in this pull request? When the drop list of `DataFrameDropColumns` contains an UnresolvedAttribute. Current rule mistakenly resolve the column with its grand-children's output attributes. In dataframe/dataset API application, issue cannot be encountered since the `dropList` are all AttributeReferences. But when we use Spark LogicalPlan, the bug will be encountered, the UnresolvedAttribute in dropList cannot work. ### Why are the changes needed? In `ResolveDataFrameDropColumns` ```scala val dropped = d.dropList.map { case u: UnresolvedAttribute => resolveExpressionByPlanChildren(u, d.child) //mistakenly resolve the column with its grand-children's output attributes case e => e } ``` To fix it, change to `resolveExpressionByPlanChildren(u, d)` or `resolveExpressionByPlanOutput(u, d.child)` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT added. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48240 from LantaoJin/SPARK-49782. Authored-by: LantaoJin Signed-off-by: Wenchen Fan --- .../analysis/ResolveDataFrameDropColumns.scala | 2 +- .../spark/sql/catalyst/analysis/AnalysisSuite.scala | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala index 2642b4a1c5daa..0f9b93cc2986d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala @@ -36,7 +36,7 @@ class ResolveDataFrameDropColumns(val catalogManager: CatalogManager) // df.drop(col("non-existing-column")) val dropped = d.dropList.map { case u: UnresolvedAttribute => - resolveExpressionByPlanChildren(u, d.child) + resolveExpressionByPlanChildren(u, d) case e => e } val remaining = d.child.output.filterNot(attr => dropped.exists(_.semanticEquals(attr))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e23a753dafe8c..8409f454bfb88 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1832,4 +1832,14 @@ class AnalysisSuite extends AnalysisTest with Matchers { preemptedError.clear() assert(preemptedError.getErrorOpt().isEmpty) } + + test("SPARK-49782: ResolveDataFrameDropColumns rule resolves complex UnresolvedAttribute") { + val function = UnresolvedFunction("trim", Seq(UnresolvedAttribute("i")), isDistinct = false) + val addColumnF = Project(Seq(UnresolvedAttribute("i"), Alias(function, "f")()), testRelation5) + // Drop column "f" via ResolveDataFrameDropColumns rule. + val inputPlan = DataFrameDropColumns(Seq(UnresolvedAttribute("f")), addColumnF) + // The expected Project (root node) should only have column "i". + val expectedPlan = Project(Seq(UnresolvedAttribute("i")), addColumnF).analyze + checkAnalysis(inputPlan, expectedPlan) + } } From c9c33d9f9c65ec731fe26946c3bdac14a53fe965 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Thu, 10 Oct 2024 08:22:55 +0200 Subject: [PATCH 203/250] fix style. --- .../org/apache/spark/sql/catalyst/util/CollationFactory.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 0868fbf6da4b8..6c6594c0b94af 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -583,8 +583,8 @@ protected Collation buildCollation() { comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( applyTrimmingPolicy(s1, spaceTrimming), applyTrimmingPolicy(s2, spaceTrimming)); - hashFunction = s -> (long) CollationAwareUTF8String. - lowerCaseCodePoints(applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).hashCode(); } return new Collation( From f0498f056ddfcf82012ccc27b2b86e4ca9af79ba Mon Sep 17 00:00:00 2001 From: Mark Andreev Date: Thu, 10 Oct 2024 08:59:24 +0200 Subject: [PATCH 204/250] [SPARK-49549][SQL] Assign a name to the error conditions _LEGACY_ERROR_TEMP_3055, 3146 ### What changes were proposed in this pull request? Choose a proper name for the error conditions _LEGACY_ERROR_TEMP_3055 and _LEGACY_ERROR_TEMP_3146 defined in core/src/main/resources/error/error-conditions.json. The name should be short but complete (look at the example in error-conditions.json). Add a test which triggers the error from user code if such test still doesn't exist. Check exception fields by using checkError(). The last function checks valuable error fields only, and avoids dependencies from error text message. In this way, tech editors can modify error format in error-conditions.json, and don't worry of Spark's internal tests. Migrate other tests that might trigger the error onto checkError(). ### Why are the changes needed? This changes needed because spark 4 introduce new approach with user friendly error messages. ### Does this PR introduce _any_ user-facing change? Yes, if user's code depends on error condition names. ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48288 from mrk-andreev/SPARK-49549. Authored-by: Mark Andreev Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 22 ++++++++-------- .../main/resources/error/error-states.json | 12 +++++++++ .../catalog/functions/ScalarFunction.java | 8 +++++- .../expressions/V2ExpressionUtils.scala | 5 ++-- .../connector/DataSourceV2FunctionSuite.scala | 25 ++++++++++++++++--- 5 files changed, 55 insertions(+), 17 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index f6317d731c77b..99403e12e62c6 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3996,6 +3996,18 @@ ], "sqlState" : "22023" }, + "SCALAR_FUNCTION_NOT_COMPATIBLE" : { + "message" : [ + "ScalarFunction not overrides method 'produceResult(InternalRow)' with custom implementation." + ], + "sqlState" : "42K0O" + }, + "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED" : { + "message" : [ + "ScalarFunction not implements or overrides method 'produceResult(InternalRow)'." + ], + "sqlState" : "42K0P" + }, "SCALAR_SUBQUERY_IS_IN_GROUP_BY_OR_AGGREGATE_FUNCTION" : { "message" : [ "The correlated scalar subquery '' is neither present in GROUP BY, nor in an aggregate function.", @@ -7946,11 +7958,6 @@ " is not currently supported" ] }, - "_LEGACY_ERROR_TEMP_3055" : { - "message" : [ - "ScalarFunction neither implement magic method nor override 'produceResult'" - ] - }, "_LEGACY_ERROR_TEMP_3056" : { "message" : [ "Unexpected row-level read relations (allow multiple = ): " @@ -8309,11 +8316,6 @@ "Partitions truncate is not supported" ] }, - "_LEGACY_ERROR_TEMP_3146" : { - "message" : [ - "Cannot find a compatible ScalarFunction#produceResult" - ] - }, "_LEGACY_ERROR_TEMP_3147" : { "message" : [ ": Batch scan are not supported" diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index 87811fef9836e..9be97556c1076 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4631,6 +4631,18 @@ "standard": "N", "usedBy": ["Spark"] }, + "42K0O": { + "description": "ScalarFunction not overrides method 'produceResult(InternalRow)' with custom implementation.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, + "42K0P": { + "description": "ScalarFunction not implements or overrides method 'produceResult(InternalRow)'.", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "42KD0": { "description": "Ambiguous name reference.", "origin": "Databricks", diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java index ca4ea5114c26b..c0078872bd843 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java @@ -20,8 +20,11 @@ import org.apache.spark.SparkUnsupportedOperationException; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.QuotingUtils; import org.apache.spark.sql.types.DataType; +import java.util.Map; + /** * Interface for a function that produces a result value for each input row. *

    @@ -149,7 +152,10 @@ public interface ScalarFunction extends BoundFunction { * @return a result value */ default R produceResult(InternalRow input) { - throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3146"); + throw new SparkUnsupportedOperationException( + "SCALAR_FUNCTION_NOT_COMPATIBLE", + Map.of("scalarFunc", QuotingUtils.quoteIdentifier(name())) + ); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 220920a5a3198..d14c8cb675387 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.connector.catalog.{FunctionCatalog, Identifier} import org.apache.spark.sql.connector.catalog.functions._ import org.apache.spark.sql.connector.catalog.functions.ScalarFunction.MAGIC_METHOD_NAME import org.apache.spark.sql.connector.expressions.{BucketTransform, Expression => V2Expression, FieldReference, IdentityTransform, Literal => V2Literal, NamedReference, NamedTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortOrder => V2SortOrder, SortValue, Transform} +import org.apache.spark.sql.errors.DataTypeErrors.toSQLId import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -182,8 +183,8 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { ApplyFunctionExpression(scalarFunc, arguments) case _ => throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3055", - messageParameters = Map("scalarFunc" -> scalarFunc.name())) + errorClass = "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED", + messageParameters = Map("scalarFunc" -> toSQLId(scalarFunc.name()))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index d6599debd3b11..6b0fd6084099c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -414,8 +414,8 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { new JavaStrLen(new JavaStrLenNoImpl)) checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.strlen('abc')").collect()), - condition = "_LEGACY_ERROR_TEMP_3055", - parameters = Map("scalarFunc" -> "strlen"), + condition = "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED", + parameters = Map("scalarFunc" -> "`strlen`"), context = ExpectedContext( fragment = "testcat.ns.strlen('abc')", start = 7, @@ -448,8 +448,8 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddMismatchMagic)) checkError( exception = intercept[AnalysisException](sql("SELECT testcat.ns.add(1L, 2L)").collect()), - condition = "_LEGACY_ERROR_TEMP_3055", - parameters = Map("scalarFunc" -> "long_add_mismatch_magic"), + condition = "SCALAR_FUNCTION_NOT_FULLY_IMPLEMENTED", + parameters = Map("scalarFunc" -> "`long_add_mismatch_magic`"), context = ExpectedContext( fragment = "testcat.ns.add(1L, 2L)", start = 7, @@ -458,6 +458,23 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { ) } + test("SPARK-49549: scalar function w/ mismatch a compatible ScalarFunction#produceResult") { + case object CharLength extends ScalarFunction[Int] { + override def inputTypes(): Array[DataType] = Array(StringType) + override def resultType(): DataType = IntegerType + override def name(): String = "CHAR_LENGTH" + } + + catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) + addFunction(Identifier.of(Array("ns"), "my_strlen"), StrLen(CharLength)) + checkError( + exception = intercept[SparkUnsupportedOperationException] + (sql("SELECT testcat.ns.my_strlen('abc')").collect()), + condition = "SCALAR_FUNCTION_NOT_COMPATIBLE", + parameters = Map("scalarFunc" -> "`CHAR_LENGTH`") + ) + } + test("SPARK-35390: scalar function w/ type coercion") { catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps) addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new JavaLongAddDefault(false))) From f003638f10157191b2501a547af62ab4e7d8aa43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Du=C5=A1an=20Ti=C5=A1ma?= Date: Thu, 10 Oct 2024 09:11:46 +0200 Subject: [PATCH 205/250] [SPARK-49542][SQL] Partition transform exception evaluate error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? Add user facing error for improper use of partition transform expressions. ### Why are the changes needed? Replace internal error with user facing one. ### Does this PR introduce _any_ user-facing change? Yes, new error condition. ### How was this patch tested? Added tests to QueryExecutionErrorsSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #48387 from dusantism-db/partition-transform-exception-evaluate-error. Authored-by: Dušan Tišma Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 6 ++++ .../main/resources/error/error-states.json | 6 ++++ .../tests/connect/test_parity_readwriter.py | 1 + python/pyspark/sql/tests/test_readwriter.py | 30 +++++++++++++++++++ .../sql/catalyst/expressions/Expression.scala | 4 +-- .../expressions/PartitionTransforms.scala | 19 +++++++++++- .../errors/QueryExecutionErrorsSuite.scala | 14 ++++++++- 7 files changed, 76 insertions(+), 4 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 99403e12e62c6..14b228daf3c1a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -3824,6 +3824,12 @@ ], "sqlState" : "42000" }, + "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" : { + "message" : [ + "The expression must be inside 'partitionedBy'." + ], + "sqlState" : "42S23" + }, "PATH_ALREADY_EXISTS" : { "message" : [ "Path already exists. Set mode as \"overwrite\" to overwrite the existing path." diff --git a/common/utils/src/main/resources/error/error-states.json b/common/utils/src/main/resources/error/error-states.json index 9be97556c1076..fb899e4eb207e 100644 --- a/common/utils/src/main/resources/error/error-states.json +++ b/common/utils/src/main/resources/error/error-states.json @@ -4913,6 +4913,12 @@ "standard": "N", "usedBy": ["SQL Server"] }, + "42S23": { + "description": "Partition transform expression not in 'partitionedBy'", + "origin": "Spark", + "standard": "N", + "usedBy": ["Spark"] + }, "44000": { "description": "with check option violation", "origin": "SQL/Foundation", diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py b/python/pyspark/sql/tests/connect/test_parity_readwriter.py index 46333b555c351..f83f3edbfa787 100644 --- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py @@ -33,6 +33,7 @@ def test_api(self): def test_partitioning_functions(self): self.check_partitioning_functions(DataFrameWriterV2) + self.partitioning_functions_user_error() if __name__ == "__main__": diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index f4f32dea9060a..2fca6b57decf9 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -255,6 +255,7 @@ def check_api(self, tpe): def test_partitioning_functions(self): self.check_partitioning_functions(DataFrameWriterV2) + self.partitioning_functions_user_error() def check_partitioning_functions(self, tpe): import datetime @@ -274,6 +275,35 @@ def check_partitioning_functions(self, tpe): self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), tpe) self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), tpe) + def partitioning_functions_user_error(self): + import datetime + from pyspark.sql.functions.partitioning import years, months, days, hours, bucket + + df = self.spark.createDataFrame( + [(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value") + ) + + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(years("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(months("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(days("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(hours("ts")).collect() + with self.assertRaisesRegex( + Exception, "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY" + ): + df.select(bucket(2, "ts")).collect() + def test_create(self): df = self.df with self.table("test_table"): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index de15ec43c4f31..6a57ba2aaa569 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -383,10 +383,10 @@ abstract class Expression extends TreeNode[Expression] { trait FoldableUnevaluable extends Expression { override def foldable: Boolean = true - final override def eval(input: InternalRow = null): Any = + override def eval(input: InternalRow = null): Any = throw QueryExecutionErrors.cannotEvaluateExpressionError(this) - final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala index 433f8500fab1f..04d31b5797819 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/PartitionTransforms.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.UnaryLike +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, IntegerType} @@ -37,8 +41,21 @@ import org.apache.spark.sql.types.{DataType, IntegerType} abstract class PartitionTransformExpression extends Expression with Unevaluable with UnaryLike[Expression] { override def nullable: Boolean = true -} + override def eval(input: InternalRow): Any = + throw new SparkException( + errorClass = "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY", + messageParameters = Map("expression" -> toSQLExpr(this)), + cause = null + ) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new SparkException( + errorClass = "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY", + messageParameters = Map("expression" -> toSQLExpr(this)), + cause = null + ) +} /** * Expression for the v2 partition transform years. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 9d1448d0ac09d..86c1f17b4dbb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -35,11 +35,12 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, Encoder, Kry import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{NamedParameter, UnresolvedGenerator} import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Concat, CreateArray, EmptyRow, Expression, Flatten, Grouping, Literal, RowNumber, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Concat, CreateArray, EmptyRow, Expression, Flatten, Grouping, Literal, RowNumber, UnaryExpression, Years} import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.expressions.objects.InitializeJavaBean import org.apache.spark.sql.catalyst.rules.RuleIdCollection +import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLExpr import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions} import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.execution.datasources.orc.OrcTest @@ -1006,6 +1007,17 @@ class QueryExecutionErrorsSuite sqlState = "XX000") } + test("PartitionTransformExpression error on eval") { + val expr = Years(Literal("foo")) + val e = intercept[SparkException] { + expr.eval() + } + checkError( + exception = e, + condition = "PARTITION_TRANSFORM_EXPRESSION_NOT_IN_PARTITIONED_BY", + parameters = Map("expression" -> toSQLExpr(expr))) + } + test("INTERNAL_ERROR: Calling doGenCode on unresolved") { val e = intercept[SparkException] { val ctx = new CodegenContext From 38d66fde2eec7e0d7bc9dda4d87cdee03b1ee9e8 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Thu, 10 Oct 2024 09:15:50 +0200 Subject: [PATCH 206/250] [SPARK-49908][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_0044 ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for _LEGACY_ERROR_TEMP_0044 ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48384 from itholic/legacy_0044. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../resources/error/error-conditions.json | 10 +++---- .../spark/sql/errors/QueryParsingErrors.scala | 7 +++-- .../spark/sql/execution/SparkSqlParser.scala | 27 ++++++++++++++++--- .../analyzer-results/timezone.sql.out | 24 ++++++++++++++--- .../sql-tests/results/timezone.sql.out | 24 ++++++++++++++--- .../spark/sql/internal/SQLConfSuite.scala | 4 +-- 6 files changed, 75 insertions(+), 21 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 14b228daf3c1a..4ceef4b2d8b92 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2537,6 +2537,11 @@ "Interval string does not match second-nano format of ss.nnnnnnnnn." ] }, + "TIMEZONE_INTERVAL_OUT_OF_RANGE" : { + "message" : [ + "The interval value must be in the range of [-18, +18] hours with second precision." + ] + }, "UNKNOWN_PARSING_ERROR" : { "message" : [ "Unknown error when parsing ." @@ -5703,11 +5708,6 @@ "Expected format is 'RESET' or 'RESET key'. If you want to include special characters in key, please use quotes, e.g., RESET `key`." ] }, - "_LEGACY_ERROR_TEMP_0044" : { - "message" : [ - "The interval value must be in the range of [-18, +18] hours with second precision." - ] - }, "_LEGACY_ERROR_TEMP_0045" : { "message" : [ "Invalid time zone displacement value." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala index b0743d6de4772..53cbf086c96e3 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala @@ -516,8 +516,11 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase { new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0043", ctx) } - def intervalValueOutOfRangeError(ctx: IntervalContext): Throwable = { - new ParseException(errorClass = "_LEGACY_ERROR_TEMP_0044", ctx) + def intervalValueOutOfRangeError(input: String, ctx: IntervalContext): Throwable = { + new ParseException( + errorClass = "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + messageParameters = Map("input" -> input), + ctx) } def invalidTimeZoneDisplacementValueError(ctx: SetTimeZoneContext): Throwable = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8fc860c503c96..9fbe400a555fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -189,10 +189,29 @@ class SparkSqlAstBuilder extends AstBuilder { val key = SQLConf.SESSION_LOCAL_TIMEZONE.key if (ctx.interval != null) { val interval = parseIntervalLiteral(ctx.interval) - if (interval.months != 0 || interval.days != 0 || - math.abs(interval.microseconds) > 18 * DateTimeConstants.MICROS_PER_HOUR || - interval.microseconds % DateTimeConstants.MICROS_PER_SECOND != 0) { - throw QueryParsingErrors.intervalValueOutOfRangeError(ctx.interval()) + if (interval.months != 0) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue(interval.months), + ctx.interval() + ) + } + else if (interval.days != 0) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue(interval.days), + ctx.interval() + ) + } + else if (math.abs(interval.microseconds) > 18 * DateTimeConstants.MICROS_PER_HOUR) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue((math.abs(interval.microseconds) / DateTimeConstants.MICROS_PER_HOUR).toInt), + ctx.interval() + ) + } + else if (interval.microseconds % DateTimeConstants.MICROS_PER_SECOND != 0) { + throw QueryParsingErrors.intervalValueOutOfRangeError( + toSQLValue((interval.microseconds / DateTimeConstants.MICROS_PER_SECOND).toInt), + ctx.interval() + ) } else { val seconds = (interval.microseconds / DateTimeConstants.MICROS_PER_SECOND).toInt SetCommand(Some(key -> Some(ZoneOffset.ofTotalSeconds(seconds).toString))) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out index 9059f37f3607b..5b55a0c218934 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/timezone.sql.out @@ -64,7 +64,11 @@ SET TIME ZONE INTERVAL 3 DAYS -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "3" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -80,7 +84,11 @@ SET TIME ZONE INTERVAL 24 HOURS -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "24" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -96,7 +104,11 @@ SET TIME ZONE INTERVAL '19:40:32' HOUR TO SECOND -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "19" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -128,7 +140,11 @@ SET TIME ZONE INTERVAL 10 HOURS 1 MILLISECOND -- !query analysis org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "36000" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", diff --git a/sql/core/src/test/resources/sql-tests/results/timezone.sql.out b/sql/core/src/test/resources/sql-tests/results/timezone.sql.out index d34599a49c5ff..5f0fdef50e3db 100644 --- a/sql/core/src/test/resources/sql-tests/results/timezone.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/timezone.sql.out @@ -80,7 +80,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "3" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -98,7 +102,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "24" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -116,7 +124,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "19" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", @@ -152,7 +164,11 @@ struct<> -- !query output org.apache.spark.sql.catalyst.parser.ParseException { - "errorClass" : "_LEGACY_ERROR_TEMP_0044", + "errorClass" : "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + "sqlState" : "22006", + "messageParameters" : { + "input" : "36000" + }, "queryContext" : [ { "objectType" : "", "objectName" : "", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 1a6cdd1258cc3..2b58440baf852 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -492,8 +492,8 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { val sqlText = "set time zone interval 19 hours" checkError( exception = intercept[ParseException](sql(sqlText)), - condition = "_LEGACY_ERROR_TEMP_0044", - parameters = Map.empty, + condition = "INVALID_INTERVAL_FORMAT.TIMEZONE_INTERVAL_OUT_OF_RANGE", + parameters = Map("input" -> "19"), context = ExpectedContext(sqlText, 0, 30)) } From e693af0a0adafe845c90579c02810b4c7d0c2b7d Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Thu, 10 Oct 2024 10:55:35 +0200 Subject: [PATCH 207/250] [SPARK-49748][CORE] Add `getCondition` and deprecate `getErrorClass` in `SparkThrowable` ### What changes were proposed in this pull request? 1. Deprecate the `getErrorClass` method in `SparkThrowable` 2. Add new method `getCondition` as the replacement of `getErrorClass` to the `SparkThrowable` interface 3. Use `getCondition` instead of `getErrorClass` in implementations of `SparkThrowable` to avoid warnings. ### Why are the changes needed? To follow new naming convention proposed by SPARK-46810. ### Does this PR introduce _any_ user-facing change? Yes, it extends existing APIs. ### How was this patch tested? By running the existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48196 from MaxGekk/deprecate-getErrorClass. Authored-by: Max Gekk Signed-off-by: Max Gekk --- .../unsafe/types/CollationFactorySuite.scala | 2 +- .../java/org/apache/spark/SparkThrowable.java | 20 ++++++++--- .../org/apache/spark/SparkException.scala | 34 +++++++++---------- .../apache/spark/SparkThrowableHelper.scala | 4 +-- .../streaming/StreamingQueryException.scala | 2 +- .../spark/sql/avro/AvroLogicalTypeSuite.scala | 2 +- .../org/apache/spark/sql/avro/AvroSuite.scala | 14 ++++---- .../org/apache/spark/sql/CatalogSuite.scala | 2 +- .../spark/sql/ClientDataFrameStatSuite.scala | 6 ++-- .../apache/spark/sql/ClientE2ETestSuite.scala | 20 +++++------ .../client/SparkConnectClientSuite.scala | 2 +- .../client/arrow/ArrowEncoderSuite.scala | 2 +- .../streaming/ClientStreamingQuerySuite.scala | 4 +-- .../spark/sql/kafka010/KafkaExceptions.scala | 2 +- .../spark/memory/SparkOutOfMemoryError.java | 2 +- .../SparkFileAlreadyExistsException.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 +-- .../apache/spark/JobCancellationSuite.scala | 2 +- .../org/apache/spark/SparkFunSuite.scala | 2 +- .../apache/spark/SparkThrowableSuite.scala | 12 +++---- .../spark/broadcast/BroadcastSuite.scala | 8 ++--- .../apache/spark/ml/feature/Binarizer.scala | 2 +- .../spark/ml/feature/StringIndexer.scala | 2 +- project/MimaExcludes.scala | 3 ++ .../apache/spark/sql/AnalysisException.scala | 2 +- .../spark/sql/catalyst/parser/parsers.scala | 6 ++-- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../catalyst/encoders/ExpressionEncoder.scala | 4 +-- .../sql/catalyst/parser/AstBuilder.scala | 4 +-- .../sql/catalyst/util/GeneratedColumn.scala | 4 +-- .../exceptions/SqlScriptingException.scala | 2 +- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../catalog/CatalogLoadingSuite.java | 2 +- .../sql/catalyst/analysis/AnalysisTest.scala | 2 +- .../analysis/UnsupportedOperationsSuite.scala | 2 +- .../sql/catalyst/csv/CSVExprUtilsSuite.scala | 2 +- .../encoders/EncoderResolutionSuite.scala | 2 +- .../catalyst/encoders/RowEncoderSuite.scala | 2 +- .../BufferHolderSparkSubmitSuite.scala | 4 +-- .../sql/catalyst/parser/DDLParserSuite.scala | 2 +- .../parser/SqlScriptingParserSuite.scala | 8 ++--- .../connect/planner/SparkConnectPlanner.scala | 2 +- .../spark/sql/connect/utils/ErrorUtils.scala | 6 ++-- .../SparkConnectSessionManagerSuite.scala | 10 +++--- .../spark/sql/DataSourceRegistration.scala | 2 +- .../sql/execution/datasources/rules.scala | 8 ++--- .../datasources/v2/FileDataSourceV2.scala | 2 +- .../execution/streaming/StreamExecution.scala | 2 +- .../state/HDFSBackedStateStoreProvider.scala | 2 +- .../state/RocksDBStateStoreProvider.scala | 4 +-- .../spark/sql/internal/CatalogImpl.scala | 8 ++--- .../spark/sql/JavaColumnExpressionSuite.java | 2 +- .../sql/CollationExpressionWalkerSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 8 ++--- .../spark/sql/LateralColumnAliasSuite.scala | 8 ++--- .../spark/sql/RuntimeNullChecksV2Writes.scala | 4 +-- .../apache/spark/sql/SQLQueryTestHelper.scala | 4 +-- .../org/apache/spark/sql/SubquerySuite.scala | 8 ++--- .../scala/org/apache/spark/sql/UDFSuite.scala | 10 +++--- .../sql/connector/DataSourceV2SQLSuite.scala | 2 +- .../connector/MergeIntoTableSuiteBase.scala | 6 ++-- .../errors/QueryExecutionErrorsSuite.scala | 10 +++--- .../adaptive/AdaptiveQueryExecSuite.scala | 2 +- .../binaryfile/BinaryFileFormatSuite.scala | 2 +- .../execution/datasources/csv/CSVSuite.scala | 2 +- .../datasources/orc/OrcFilterSuite.scala | 2 +- .../datasources/orc/OrcQuerySuite.scala | 6 ++-- .../datasources/orc/OrcSourceSuite.scala | 4 +-- .../parquet/ParquetFilterSuite.scala | 2 +- .../datasources/parquet/ParquetIOSuite.scala | 6 ++-- .../ParquetPartitionDiscoverySuite.scala | 2 +- .../parquet/ParquetQuerySuite.scala | 2 +- .../parquet/ParquetRebaseDatetimeSuite.scala | 4 +-- .../parquet/ParquetRowIndexSuite.scala | 2 +- .../StateDataSourceChangeDataReadSuite.scala | 8 ++--- .../v2/state/StateDataSourceReadSuite.scala | 2 +- .../python/PythonDataSourceSuite.scala | 8 ++--- .../PythonStreamingDataSourceSuite.scala | 2 +- ...StateSchemaCompatibilityCheckerSuite.scala | 8 ++--- .../streaming/state/StateStoreSuite.scala | 2 +- .../spark/sql/sources/InsertSuite.scala | 2 +- .../StreamingQueryListenerSuite.scala | 2 +- .../thriftserver/HiveThriftServerErrors.scala | 2 +- 83 files changed, 201 insertions(+), 188 deletions(-) diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index ff40f16e5a052..66ff551193101 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -456,7 +456,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig val e = intercept[SparkException] { fetchCollation(collationName) } - assert(e.getErrorClass === "COLLATION_INVALID_NAME") + assert(e.getCondition === "COLLATION_INVALID_NAME") assert(e.getMessageParameters.asScala === Map( "collationName" -> collationName, "proposals" -> proposals)) } diff --git a/common/utils/src/main/java/org/apache/spark/SparkThrowable.java b/common/utils/src/main/java/org/apache/spark/SparkThrowable.java index e1235b2982ba0..39808f58b08ae 100644 --- a/common/utils/src/main/java/org/apache/spark/SparkThrowable.java +++ b/common/utils/src/main/java/org/apache/spark/SparkThrowable.java @@ -35,19 +35,29 @@ */ @Evolving public interface SparkThrowable { - // Succinct, human-readable, unique, and consistent representation of the error category - // If null, error class is not set - String getErrorClass(); + /** + * Succinct, human-readable, unique, and consistent representation of the error condition. + * If null, error condition is not set. + */ + String getCondition(); + + /** + * Succinct, human-readable, unique, and consistent representation of the error category. + * If null, error class is not set. + * @deprecated Use {@link #getCondition()} instead. + */ + @Deprecated + default String getErrorClass() { return getCondition(); } // Portable error identifier across SQL engines // If null, error class or SQLSTATE is not set default String getSqlState() { - return SparkThrowableHelper.getSqlState(this.getErrorClass()); + return SparkThrowableHelper.getSqlState(this.getCondition()); } // True if this error is an internal error. default boolean isInternalError() { - return SparkThrowableHelper.isInternalError(this.getErrorClass()); + return SparkThrowableHelper.isInternalError(this.getCondition()); } default Map getMessageParameters() { diff --git a/common/utils/src/main/scala/org/apache/spark/SparkException.scala b/common/utils/src/main/scala/org/apache/spark/SparkException.scala index fcaee787fd8d3..0c0a1902ee2a1 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkException.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkException.scala @@ -69,7 +69,7 @@ class SparkException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -179,7 +179,7 @@ private[spark] class SparkUpgradeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull } /** @@ -212,7 +212,7 @@ private[spark] class SparkArithmeticException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -250,7 +250,7 @@ private[spark] class SparkUnsupportedOperationException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull } private[spark] object SparkUnsupportedOperationException { @@ -280,7 +280,7 @@ private[spark] class SparkClassNotFoundException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -296,7 +296,7 @@ private[spark] class SparkConcurrentModificationException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -346,7 +346,7 @@ private[spark] class SparkDateTimeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -362,7 +362,7 @@ private[spark] class SparkFileNotFoundException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -396,7 +396,7 @@ private[spark] class SparkNumberFormatException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -448,7 +448,7 @@ private[spark] class SparkIllegalArgumentException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -477,7 +477,7 @@ private[spark] class SparkRuntimeException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -506,7 +506,7 @@ private[spark] class SparkPythonException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -524,7 +524,7 @@ private[spark] class SparkNoSuchElementException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass override def getQueryContext: Array[QueryContext] = context } @@ -541,7 +541,7 @@ private[spark] class SparkSecurityException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -575,7 +575,7 @@ private[spark] class SparkArrayIndexOutOfBoundsException private( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context } @@ -591,7 +591,7 @@ private[spark] class SparkSQLException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } /** @@ -606,5 +606,5 @@ private[spark] class SparkSQLFeatureNotSupportedException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala index 428c9d2a49351..b6c2b176de62b 100644 --- a/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala +++ b/common/utils/src/main/scala/org/apache/spark/SparkThrowableHelper.scala @@ -81,7 +81,7 @@ private[spark] object SparkThrowableHelper { import ErrorMessageFormat._ format match { case PRETTY => e.getMessage - case MINIMAL | STANDARD if e.getErrorClass == null => + case MINIMAL | STANDARD if e.getCondition == null => toJsonString { generator => val g = generator.useDefaultPrettyPrinter() g.writeStartObject() @@ -92,7 +92,7 @@ private[spark] object SparkThrowableHelper { g.writeEndObject() } case MINIMAL | STANDARD => - val errorClass = e.getErrorClass + val errorClass = e.getCondition toJsonString { generator => val g = generator.useDefaultPrettyPrinter() g.writeStartObject() diff --git a/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala b/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala index 259f4330224c9..1972ef05d8759 100644 --- a/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala +++ b/common/utils/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryException.scala @@ -84,7 +84,7 @@ class StreamingQueryException private[sql]( s"""${classOf[StreamingQueryException].getName}: ${cause.getMessage} |$queryDebugString""".stripMargin - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala index 751ac275e048a..bb0858decdf8f 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroLogicalTypeSuite.scala @@ -436,7 +436,7 @@ abstract class AvroLogicalTypeSuite extends QueryTest with SharedSparkSession { val ex = intercept[SparkException] { spark.read.format("avro").load(s"$dir.avro").collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkArithmeticException], condition = "NUMERIC_VALUE_OUT_OF_RANGE.WITH_SUGGESTION", diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index be887bd5237b0..e9d6c2458df81 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -891,7 +891,7 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema("a DECIMAL(4, 3)").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], condition = "AVRO_INCOMPATIBLE_READ_TYPE", @@ -969,7 +969,7 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], condition = "AVRO_INCOMPATIBLE_READ_TYPE", @@ -1006,7 +1006,7 @@ abstract class AvroSuite val ex = intercept[SparkException] { spark.read.schema(s"a $sqlType").format("avro").load(path.toString).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[AnalysisException], condition = "AVRO_INCOMPATIBLE_READ_TYPE", @@ -1515,7 +1515,7 @@ abstract class AvroSuite .write.format("avro").option("avroSchema", avroSchema) .save(s"$tempDir/${UUID.randomUUID()}") } - assert(ex.getErrorClass == "TASK_WRITE_FAILED") + assert(ex.getCondition == "TASK_WRITE_FAILED") assert(ex.getCause.isInstanceOf[java.lang.NullPointerException]) assert(ex.getCause.getMessage.contains( "null value for (non-nullable) string at test_schema.Name")) @@ -2629,7 +2629,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").option("avroSchema", avroSchema).save(path3_x) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") assert(e.getCause.isInstanceOf[SparkUpgradeException]) } checkDefaultLegacyRead(oldPath) @@ -2884,7 +2884,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").option("avroSchema", avroSchema).save(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } @@ -2895,7 +2895,7 @@ abstract class AvroSuite val e = intercept[SparkException] { df.write.format("avro").save(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala index 0e3a683d2701d..ce552bdd4f0f0 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/CatalogSuite.scala @@ -69,7 +69,7 @@ class CatalogSuite extends ConnectFunSuite with RemoteSparkSession with SQLHelpe val exception = intercept[SparkException] { spark.catalog.setCurrentCatalog("notExists") } - assert(exception.getErrorClass == "CATALOG_NOT_FOUND") + assert(exception.getCondition == "CATALOG_NOT_FOUND") spark.catalog.setCurrentCatalog("testcat") assert(spark.catalog.currentCatalog().equals("testcat")) val catalogsAfterChange = spark.catalog.listCatalogs().collect() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala index 88281352f2479..84ed624a95214 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala @@ -251,16 +251,16 @@ class ClientDataFrameStatSuite extends ConnectFunSuite with RemoteSparkSession { val error1 = intercept[AnalysisException] { df.stat.bloomFilter("id", -1000, 100) } - assert(error1.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error1.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") val error2 = intercept[AnalysisException] { df.stat.bloomFilter("id", 1000, -100) } - assert(error2.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error2.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") val error3 = intercept[AnalysisException] { df.stat.bloomFilter("id", 1000, -1.0) } - assert(error3.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") + assert(error3.getCondition === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE") } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index b47231948dc98..0371981b728d1 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -95,7 +95,7 @@ class ClientE2ETestSuite .collect() } assert( - ex.getErrorClass === + ex.getCondition === "INCONSISTENT_BEHAVIOR_CROSS_VERSION.PARSE_DATETIME_BY_NEW_PARSER") assert( ex.getMessageParameters.asScala == Map( @@ -122,12 +122,12 @@ class ClientE2ETestSuite Seq("1").toDS().withColumn("udf_val", throwException($"value")).collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.getMessageParameters.isEmpty) assert(ex.getCause.isInstanceOf[SparkException]) val cause = ex.getCause.asInstanceOf[SparkException] - assert(cause.getErrorClass == null) + assert(cause.getCondition == null) assert(cause.getMessageParameters.isEmpty) assert(cause.getMessage.contains("test" * 10000)) } @@ -141,7 +141,7 @@ class ClientE2ETestSuite val ex = intercept[AnalysisException] { spark.sql("select x").collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.messageParameters.isEmpty) assert(ex.getSqlState != null) assert(!ex.isInternalError) @@ -169,14 +169,14 @@ class ClientE2ETestSuite val ex = intercept[NoSuchNamespaceException] { spark.sql("use database123") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } test("table not found for spark.catalog.getTable") { val ex = intercept[AnalysisException] { spark.catalog.getTable("test_table") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } test("throw NamespaceAlreadyExistsException") { @@ -185,7 +185,7 @@ class ClientE2ETestSuite val ex = intercept[NamespaceAlreadyExistsException] { spark.sql("create database test_db") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } finally { spark.sql("drop database test_db") } @@ -197,7 +197,7 @@ class ClientE2ETestSuite val ex = intercept[TempTableAlreadyExistsException] { spark.sql("create temporary view test_view as select 1") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } finally { spark.sql("drop view test_view") } @@ -209,7 +209,7 @@ class ClientE2ETestSuite val ex = intercept[TableAlreadyExistsException] { spark.sql(s"create table testcat.test_table (id int)") } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) } } @@ -217,7 +217,7 @@ class ClientE2ETestSuite val ex = intercept[ParseException] { spark.sql("selet 1").collect() } - assert(ex.getErrorClass != null) + assert(ex.getCondition != null) assert(!ex.messageParameters.isEmpty) assert(ex.getSqlState != null) assert(!ex.isInternalError) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 46aeaeff43d2f..ac56600392aa3 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -224,7 +224,7 @@ class SparkConnectClientSuite extends ConnectFunSuite with BeforeAndAfterEach { val error = constructor(testParams).asInstanceOf[Throwable with SparkThrowable] assert(error.getMessage.contains(testParams.message)) assert(error.getCause == null) - assert(error.getErrorClass == testParams.errorClass.get) + assert(error.getCondition == testParams.errorClass.get) assert(error.getMessageParameters.asScala == testParams.messageParameters) assert(error.getQueryContext.isEmpty) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index 7176c582d0bbc..10e4c11c406fe 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -783,7 +783,7 @@ class ArrowEncoderSuite extends ConnectFunSuite with BeforeAndAfterAll { Iterator.tabulate(10)(i => (i, "itr_" + i)) } } - assert(e.getErrorClass == "CANNOT_USE_KRYO") + assert(e.getCondition == "CANNOT_USE_KRYO") } test("transforming encoder") { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index 27b1ee014a719..b1a7d81916e92 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -331,7 +331,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L query.awaitTermination() } - assert(exception.getErrorClass != null) + assert(exception.getCondition != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) assert(exception.getCause.isInstanceOf[SparkException]) @@ -369,7 +369,7 @@ class ClientStreamingQuerySuite extends QueryTest with RemoteSparkSession with L spark.streams.awaitAnyTermination() } - assert(exception.getErrorClass != null) + assert(exception.getCondition != null) assert(exception.getMessageParameters().get("id") == query.id.toString) assert(exception.getMessageParameters().get("runId") == query.runId.toString) assert(exception.getCause.isInstanceOf[SparkException]) diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala index 13a68e72269f0..c4adb6b3f26e1 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaExceptions.scala @@ -184,5 +184,5 @@ private[kafka010] class KafkaIllegalStateException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java index 8ec5c2221b6e9..fa71eb066ff89 100644 --- a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -52,7 +52,7 @@ public Map getMessageParameters() { } @Override - public String getErrorClass() { + public String getCondition() { return errorClass; } } diff --git a/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala b/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala index 0e578f045452e..82a0261f32ae7 100644 --- a/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala +++ b/core/src/main/scala/org/apache/spark/SparkFileAlreadyExistsException.scala @@ -33,5 +33,5 @@ private[spark] class SparkFileAlreadyExistsException( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f41888320a29b..4f7338f74e298 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2892,8 +2892,8 @@ private[spark] class DAGScheduler( val finalException = exception.collect { // If the error is user-facing (defines error class and is not internal error), we don't // wrap it with "Job aborted" and expose this error to the end users directly. - case st: Exception with SparkThrowable if st.getErrorClass != null && - !SparkThrowableHelper.isInternalError(st.getErrorClass) => + case st: Exception with SparkThrowable if st.getCondition != null && + !SparkThrowableHelper.isInternalError(st.getCondition) => st }.getOrElse { new SparkException(s"Job aborted due to stage failure: $reason", cause = exception.orNull) diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 380231ce97c0b..ca51e61f5ed44 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -288,7 +288,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sem.acquire(1) sc.cancelJobGroupAndFutureJobs(s"job-group-$idx") ThreadUtils.awaitReady(job, Duration.Inf).failed.foreach { case e: SparkException => - assert(e.getErrorClass == "SPARK_JOB_CANCELLED") + assert(e.getCondition == "SPARK_JOB_CANCELLED") } } // submit a job with the 0 job group that was evicted from cancelledJobGroups set, it should run diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 9f310c06ac5ae..e38efc27b78f9 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -343,7 +343,7 @@ abstract class SparkFunSuite parameters: Map[String, String] = Map.empty, matchPVals: Boolean = false, queryContext: Array[ExpectedContext] = Array.empty): Unit = { - assert(exception.getErrorClass === condition) + assert(exception.getCondition === condition) sqlState.foreach(state => assert(exception.getSqlState === state)) val expectedParameters = exception.getMessageParameters.asScala if (matchPVals) { diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 946ea75686e32..9f005e5757193 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -199,7 +199,7 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { getMessage("UNRESOLVED_COLUMN.WITHOUT_SUGGESTION", Map.empty[String, String]) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessageParameters().get("message").contains("Undefined error message parameter")) } @@ -245,7 +245,7 @@ class SparkThrowableSuite extends SparkFunSuite { throw new SparkException("Arbitrary legacy message") } catch { case e: SparkThrowable => - assert(e.getErrorClass == null) + assert(e.getCondition == null) assert(!e.isInternalError) assert(e.getSqlState == null) case _: Throwable => @@ -262,7 +262,7 @@ class SparkThrowableSuite extends SparkFunSuite { cause = null) } catch { case e: SparkThrowable => - assert(e.getErrorClass == "CANNOT_PARSE_DECIMAL") + assert(e.getCondition == "CANNOT_PARSE_DECIMAL") assert(!e.isInternalError) assert(e.getSqlState == "22018") case _: Throwable => @@ -357,7 +357,7 @@ class SparkThrowableSuite extends SparkFunSuite { |}""".stripMargin) // Legacy mode when an exception does not have any error class class LegacyException extends Throwable with SparkThrowable { - override def getErrorClass: String = null + override def getCondition: String = null override def getMessage: String = "Test message" } val e3 = new LegacyException @@ -452,7 +452,7 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL)) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessage.contains("DIVIDE.BY_ZERO")) } @@ -478,7 +478,7 @@ class SparkThrowableSuite extends SparkFunSuite { val e = intercept[SparkException] { new ErrorClassesJsonReader(Seq(errorJsonFilePath.toUri.toURL, json.toURI.toURL)) } - assert(e.getErrorClass === "INTERNAL_ERROR") + assert(e.getCondition === "INTERNAL_ERROR") assert(e.getMessage.contains("BY.ZERO")) } } diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index 1efef3383b821..b0f36b9744fa8 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -317,13 +317,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext with Encryptio // Instead, crash the driver by directly accessing the broadcast value. val e1 = intercept[SparkException] { broadcast.value } assert(e1.isInternalError) - assert(e1.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e1.getCondition == "INTERNAL_ERROR_BROADCAST") val e2 = intercept[SparkException] { broadcast.unpersist(blocking = true) } assert(e2.isInternalError) - assert(e2.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e2.getCondition == "INTERNAL_ERROR_BROADCAST") val e3 = intercept[SparkException] { broadcast.destroy(blocking = true) } assert(e3.isInternalError) - assert(e3.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(e3.getCondition == "INTERNAL_ERROR_BROADCAST") } else { val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) @@ -339,7 +339,7 @@ package object testPackage extends Assertions { val thrown = intercept[SparkException] { broadcast.value } assert(thrown.getMessage.contains("BroadcastSuite.scala")) assert(thrown.isInternalError) - assert(thrown.getErrorClass == "INTERNAL_ERROR_BROADCAST") + assert(thrown.getCondition == "INTERNAL_ERROR_BROADCAST") } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 30f3e4c4af021..5486c39034fd3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -204,7 +204,7 @@ final class Binarizer @Since("1.4.0") (@Since("1.4.0") override val uid: String) val inputType = try { SchemaUtils.getSchemaFieldType(schema, inputColName) } catch { - case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" => + case e: SparkIllegalArgumentException if e.getCondition == "FIELD_NOT_FOUND" => throw new SparkException(s"Input column $inputColName does not exist.") case e: Exception => throw e diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 8e64f60427d90..20b03edf23c4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -127,7 +127,7 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi validateAndTransformField(schema, inputColName, dtype, outputColName) ) } catch { - case e: SparkIllegalArgumentException if e.getErrorClass == "FIELD_NOT_FOUND" => + case e: SparkIllegalArgumentException if e.getCondition == "FIELD_NOT_FOUND" => if (skipNonExistsCol) { None } else { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3ccb0bddfb0eb..f31a29788aafe 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -195,6 +195,9 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setDefaultSession"), ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearActiveSession"), ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearDefaultSession"), + + // SPARK-49748: Add getCondition and deprecate getErrorClass in SparkThrowable + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkThrowable.getCondition"), ) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++ loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++ loggingExcludes("org.apache.spark.sql.SparkSession#Builder") diff --git a/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala index a2c1f2cc41f8f..51825ee1a5bed 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/AnalysisException.scala @@ -139,7 +139,7 @@ class AnalysisException protected ( override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava - override def getErrorClass: String = errorClass.orNull + override def getCondition: String = errorClass.orNull override def getQueryContext: Array[QueryContext] = context diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala index 10da24567545b..f2c7dd533af3a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala @@ -100,7 +100,7 @@ abstract class AbstractParser extends DataTypeParserInterface with Logging { command = Option(command), start = e.origin, stop = e.origin, - errorClass = e.getErrorClass, + errorClass = e.getCondition, messageParameters = e.getMessageParameters.asScala.toMap, queryContext = e.getQueryContext) } @@ -275,7 +275,7 @@ class ParseException private ( } def withCommand(cmd: String): ParseException = { - val cl = getErrorClass + val cl = getCondition val (newCl, params) = if (cl == "PARSE_SYNTAX_ERROR" && cmd.trim().isEmpty) { // PARSE_EMPTY_STATEMENT error class overrides the PARSE_SYNTAX_ERROR when cmd is empty ("PARSE_EMPTY_STATEMENT", Map.empty[String, String]) @@ -287,7 +287,7 @@ class ParseException private ( override def getQueryContext: Array[QueryContext] = queryContext - override def getErrorClass: String = errorClass.getOrElse { + override def getCondition: String = errorClass.getOrElse { throw SparkException.internalError("ParseException shall have an error class.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a4f424ba4b421..4720b9dcdfa13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -1621,7 +1621,7 @@ class PreemptedError() { // errors have the lowest priority. def set(error: Exception with SparkThrowable, priority: Option[Int] = None): Unit = { val calculatedPriority = priority.getOrElse { - error.getErrorClass match { + error.getCondition match { case c if c.startsWith("INTERNAL_ERROR") => 1 case _ => 2 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index d7d53230470d9..f2f86a90d5172 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -87,7 +87,7 @@ object ExpressionEncoder { } constructProjection(row).get(0, anyObjectType).asInstanceOf[T] } catch { - case e: SparkRuntimeException if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" => + case e: SparkRuntimeException if e.getCondition == "NOT_NULL_ASSERT_VIOLATION" => throw e case e: Exception => throw QueryExecutionErrors.expressionDecodingError(e, expressions) @@ -115,7 +115,7 @@ object ExpressionEncoder { inputRow(0) = t extractProjection(inputRow) } catch { - case e: SparkRuntimeException if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" => + case e: SparkRuntimeException if e.getCondition == "NOT_NULL_ASSERT_VIOLATION" => throw e case e: Exception => throw QueryExecutionErrors.expressionEncodingError(e, expressions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c9150b8a26100..3ecb680cf6427 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -3256,7 +3256,7 @@ class AstBuilder extends DataTypeAstBuilder } catch { case e: SparkArithmeticException => throw new ParseException( - errorClass = e.getErrorClass, + errorClass = e.getCondition, messageParameters = e.getMessageParameters.asScala.toMap, ctx) } @@ -3552,7 +3552,7 @@ class AstBuilder extends DataTypeAstBuilder // Keep error class of SparkIllegalArgumentExceptions and enrich it with query context case se: SparkIllegalArgumentException => val pe = new ParseException( - errorClass = se.getErrorClass, + errorClass = se.getCondition, messageParameters = se.getMessageParameters.asScala.toMap, ctx) pe.setStackTrace(se.getStackTrace) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala index 46f14876be363..8d88b05546ed2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala @@ -127,7 +127,7 @@ object GeneratedColumn { } catch { case ex: AnalysisException => // Improve error message if possible - if (ex.getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") { + if (ex.getCondition == "UNRESOLVED_COLUMN.WITH_SUGGESTION") { ex.messageParameters.get("objectName").foreach { unresolvedCol => val resolver = SQLConf.get.resolver // Whether `col` = `unresolvedCol` taking into account case-sensitivity @@ -144,7 +144,7 @@ object GeneratedColumn { } } } - if (ex.getErrorClass == "UNRESOLVED_ROUTINE") { + if (ex.getCondition == "UNRESOLVED_ROUTINE") { // Cannot resolve function using built-in catalog ex.messageParameters.get("routineName").foreach { fnName => throw unsupportedExpressionError(s"failed to resolve $fnName to a built-in function") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala index f0c28c95046eb..7602366c71a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/exceptions/SqlScriptingException.scala @@ -33,7 +33,7 @@ class SqlScriptingException ( cause) with SparkThrowable { - override def getErrorClass: String = errorClass + override def getCondition: String = errorClass override def getMessageParameters: java.util.Map[String, String] = messageParameters.asJava } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 969eee4d912e4..08002887135ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -782,7 +782,7 @@ object SQLConf { CollationFactory.fetchCollation(collationName) true } catch { - case e: SparkException if e.getErrorClass == "COLLATION_INVALID_NAME" => false + case e: SparkException if e.getCondition == "COLLATION_INVALID_NAME" => false } }, "DEFAULT_COLLATION", diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java index 0db155e88aea5..339f16407ae60 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java @@ -80,7 +80,7 @@ public void testLoadWithoutConfig() { SparkException exc = Assertions.assertThrows(CatalogNotFoundException.class, () -> Catalogs.load("missing", conf)); - Assertions.assertEquals(exc.getErrorClass(), "CATALOG_NOT_FOUND"); + Assertions.assertEquals(exc.getCondition(), "CATALOG_NOT_FOUND"); Assertions.assertEquals(exc.getMessageParameters().get("catalogName"), "`missing`"); } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 33b9fb488c94f..71744f4d15105 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -205,7 +205,7 @@ trait AnalysisTest extends PlanTest { assert(e.message.contains(message)) } if (condition.isDefined) { - assert(e.getErrorClass == condition.get) + assert(e.getCondition == condition.get) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 3e9a93dc743df..6ee19bab5180a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -1133,7 +1133,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper { } } if (!condition.isEmpty) { - assert(e.getErrorClass == condition) + assert(e.getCondition == condition) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala index e8239c7523948..f3817e4dd1a8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/csv/CSVExprUtilsSuite.scala @@ -106,7 +106,7 @@ class CSVExprUtilsSuite extends SparkFunSuite { } catch { case e: SparkIllegalArgumentException => assert(separatorStr.isEmpty) - assert(e.getErrorClass === expectedErrorClass.get) + assert(e.getCondition === expectedErrorClass.get) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 35a27f41da80a..6bd5b457ea24e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -173,7 +173,7 @@ class EncoderResolutionSuite extends PlanTest { val exception = intercept[SparkRuntimeException] { fromRow(InternalRow(new GenericArrayData(Array(1, null)))) } - assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("the real number of fields doesn't match encoder schema: tuple encoder") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index f73911d344d96..79c6d07d6d218 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -279,7 +279,7 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest { // Check the error class only since the parameters may change depending on how we are running // this test case. val exception = intercept[SparkRuntimeException](toRow(encoder, null)) - assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("RowEncoder should validate external type") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala index 3aeb0c882ac3c..891e2d048b7a8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -64,7 +64,7 @@ object BufferHolderSparkSubmitSuite extends Assertions { val e1 = intercept[SparkIllegalArgumentException] { holder.grow(-1) } - assert(e1.getErrorClass === "_LEGACY_ERROR_TEMP_3198") + assert(e1.getCondition === "_LEGACY_ERROR_TEMP_3198") // while to reuse a buffer may happen, this test checks whether the buffer can be grown holder.grow(ARRAY_MAX / 2) @@ -82,6 +82,6 @@ object BufferHolderSparkSubmitSuite extends Assertions { val e2 = intercept[SparkIllegalArgumentException] { holder.grow(ARRAY_MAX + 1 - holder.totalSize()) } - assert(e2.getErrorClass === "_LEGACY_ERROR_TEMP_3199") + assert(e2.getCondition === "_LEGACY_ERROR_TEMP_3199") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index b7e2490b552cc..926beacc592a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -3065,7 +3065,7 @@ class DDLParserSuite extends AnalysisTest { s"(id BIGINT GENERATED ALWAYS AS IDENTITY $identitySpecStr, val INT) USING foo" ) } - assert(exception.getErrorClass === "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION") + assert(exception.getCondition === "IDENTITY_COLUMNS_DUPLICATED_SEQUENCE_GENERATOR_OPTION") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala index 2972ba2db21de..2e702e5642a92 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala @@ -50,7 +50,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val e = intercept[ParseException] { parseScript(sqlScriptText) } - assert(e.getErrorClass === "PARSE_SYNTAX_ERROR") + assert(e.getCondition === "PARSE_SYNTAX_ERROR") assert(e.getMessage.contains("Syntax error")) assert(e.getMessage.contains("SELECT")) } @@ -90,7 +90,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val e = intercept[ParseException] { parseScript(sqlScriptText) } - assert(e.getErrorClass === "PARSE_SYNTAX_ERROR") + assert(e.getCondition === "PARSE_SYNTAX_ERROR") assert(e.getMessage.contains("Syntax error")) assert(e.getMessage.contains("at or near ';'")) } @@ -105,7 +105,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val e = intercept[ParseException] { parseScript(sqlScriptText) } - assert(e.getErrorClass === "PARSE_SYNTAX_ERROR") + assert(e.getCondition === "PARSE_SYNTAX_ERROR") assert(e.getMessage.contains("Syntax error")) assert(e.getMessage.contains("at or near end of input")) } @@ -367,7 +367,7 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper { val e = intercept[ParseException] { parseScript(sqlScriptText) } - assert(e.getErrorClass === "PARSE_SYNTAX_ERROR") + assert(e.getCondition === "PARSE_SYNTAX_ERROR") assert(e.getMessage.contains("Syntax error")) } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 25fd7d13b7d48..4e6994f9c2f8b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -3118,7 +3118,7 @@ class SparkConnectPlanner( .newBuilder() exception_builder .setExceptionMessage(e.toString()) - .setErrorClass(e.getErrorClass) + .setErrorClass(e.getCondition) val stackTrace = Option(ExceptionUtils.getStackTrace(e)) stackTrace.foreach { s => diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index f1636ed1ef092..837d4a4d3ee78 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -114,8 +114,8 @@ private[connect] object ErrorUtils extends Logging { case sparkThrowable: SparkThrowable => val sparkThrowableBuilder = FetchErrorDetailsResponse.SparkThrowable .newBuilder() - if (sparkThrowable.getErrorClass != null) { - sparkThrowableBuilder.setErrorClass(sparkThrowable.getErrorClass) + if (sparkThrowable.getCondition != null) { + sparkThrowableBuilder.setErrorClass(sparkThrowable.getCondition) } for (queryCtx <- sparkThrowable.getQueryContext) { val builder = FetchErrorDetailsResponse.QueryContext @@ -193,7 +193,7 @@ private[connect] object ErrorUtils extends Logging { if (state != null && state.nonEmpty) { errorInfo.putMetadata("sqlState", state) } - val errorClass = e.getErrorClass + val errorClass = e.getCondition if (errorClass != null && errorClass.nonEmpty) { val messageParameters = JsonMethods.compact( JsonMethods.render(map2jvalue(e.getMessageParameters.asScala.toMap))) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala index 42bb93de05e26..1f522ea28b761 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala @@ -37,7 +37,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA val exGetOrCreate = intercept[SparkSQLException] { SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None) } - assert(exGetOrCreate.getErrorClass == "INVALID_HANDLE.FORMAT") + assert(exGetOrCreate.getCondition == "INVALID_HANDLE.FORMAT") } test( @@ -72,7 +72,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA key, Some(sessionHolder.session.sessionUUID + "invalid")) } - assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_CHANGED") + assert(exGet.getCondition == "INVALID_HANDLE.SESSION_CHANGED") } test( @@ -85,12 +85,12 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA val exGetOrCreate = intercept[SparkSQLException] { SparkConnectService.sessionManager.getOrCreateIsolatedSession(key, None) } - assert(exGetOrCreate.getErrorClass == "INVALID_HANDLE.SESSION_CLOSED") + assert(exGetOrCreate.getCondition == "INVALID_HANDLE.SESSION_CLOSED") val exGet = intercept[SparkSQLException] { SparkConnectService.sessionManager.getIsolatedSession(key, None) } - assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_CLOSED") + assert(exGet.getCondition == "INVALID_HANDLE.SESSION_CLOSED") val sessionGetIfPresent = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key) assert(sessionGetIfPresent.isEmpty) @@ -102,7 +102,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA val exGet = intercept[SparkSQLException] { SparkConnectService.sessionManager.getIsolatedSession(key, None) } - assert(exGet.getErrorClass == "INVALID_HANDLE.SESSION_NOT_FOUND") + assert(exGet.getCondition == "INVALID_HANDLE.SESSION_NOT_FOUND") val sessionGetIfPresent = SparkConnectService.sessionManager.getIsolatedSessionIfPresent(key) assert(sessionGetIfPresent.isEmpty) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala index 8ffdbb952b082..3b64cb97e10b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala @@ -68,7 +68,7 @@ class DataSourceRegistration private[sql] (dataSourceManager: DataSourceManager) DataSource.lookupDataSource(name, SQLConf.get) throw QueryCompilationErrors.dataSourceAlreadyExists(name) } catch { - case e: SparkClassNotFoundException if e.getErrorClass == "DATA_SOURCE_NOT_FOUND" => // OK + case e: SparkClassNotFoundException if e.getCondition == "DATA_SOURCE_NOT_FOUND" => // OK case _: Throwable => // If there are other errors when resolving the data source, it's unclear whether // it's safe to proceed. To prevent potential lookup errors, treat it as an existing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 29385904a7525..cbbf9f88f89d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -89,9 +89,9 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] { LogicalRelation(ds.resolveRelation()) } catch { case _: ClassNotFoundException => u - case e: SparkIllegalArgumentException if e.getErrorClass != null => + case e: SparkIllegalArgumentException if e.getCondition != null => u.failAnalysis( - errorClass = e.getErrorClass, + errorClass = e.getCondition, messageParameters = e.getMessageParameters.asScala.toMap, cause = e) case e: Exception if !e.isInstanceOf[AnalysisException] => @@ -469,8 +469,8 @@ object PreprocessTableInsertion extends ResolveInsertionBase { supportColDefaultValue = true) } catch { case e: AnalysisException if staticPartCols.nonEmpty && - (e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS" || - e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS") => + (e.getCondition == "INSERT_COLUMN_ARITY_MISMATCH.NOT_ENOUGH_DATA_COLUMNS" || + e.getCondition == "INSERT_COLUMN_ARITY_MISMATCH.TOO_MANY_DATA_COLUMNS") => val newException = e.copy( errorClass = Some("INSERT_PARTITION_COLUMN_ARITY_MISMATCH"), messageParameters = e.messageParameters ++ Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index 168aea5b041f8..4242fc5d8510a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -131,7 +131,7 @@ object FileDataSourceV2 { // The error is already FAILED_READ_FILE, throw it directly. To be consistent, schema // inference code path throws `FAILED_READ_FILE`, but the file reading code path can reach // that code path as well and we should not double-wrap the error. - case e: SparkException if e.getErrorClass == "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER" => + case e: SparkException if e.getCondition == "FAILED_READ_FILE.CANNOT_READ_FILE_FOOTER" => throw e case e: SchemaColumnConvertNotSupportedException => throw QueryExecutionErrors.parquetColumnDataTypeMismatchError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 8f030884ad33b..14adf951f07e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -374,7 +374,7 @@ abstract class StreamExecution( "message" -> message)) errorClassOpt = e match { - case t: SparkThrowable => Option(t.getErrorClass) + case t: SparkThrowable => Option(t.getCondition) case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 884b8aa3853cb..3df63c41dbf97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -282,7 +282,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with newMap } catch { - case e: SparkException if e.getErrorClass.contains("CANNOT_LOAD_STATE_STORE") => + case e: SparkException if e.getCondition.contains("CANNOT_LOAD_STATE_STORE") => throw e case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 6ab634668bc2a..870ed79ec1747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -389,7 +389,7 @@ private[sql] class RocksDBStateStoreProvider new RocksDBStateStore(version) } catch { - case e: SparkException if e.getErrorClass.contains("CANNOT_LOAD_STATE_STORE") => + case e: SparkException if e.getCondition.contains("CANNOT_LOAD_STATE_STORE") => throw e case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( @@ -409,7 +409,7 @@ private[sql] class RocksDBStateStoreProvider new RocksDBStateStore(version) } catch { - case e: SparkException if e.getErrorClass.contains("CANNOT_LOAD_STATE_STORE") => + case e: SparkException if e.getCondition.contains("CANNOT_LOAD_STATE_STORE") => throw e case e: OutOfMemoryError => throw QueryExecutionErrors.notEnoughMemoryToLoadStore( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 52b8d35e2fbf8..64689e75e2e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -177,7 +177,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { try { Some(makeTable(catalogName +: ns :+ tableName)) } catch { - case e: AnalysisException if e.getErrorClass == "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE" => + case e: AnalysisException if e.getCondition == "UNSUPPORTED_FEATURE.HIVE_TABLE_TYPE" => Some(new Table( name = tableName, catalog = catalogName, @@ -189,7 +189,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { } } } catch { - case e: AnalysisException if e.getErrorClass == "TABLE_OR_VIEW_NOT_FOUND" => None + case e: AnalysisException if e.getCondition == "TABLE_OR_VIEW_NOT_FOUND" => None } } @@ -203,7 +203,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { case _ => false } } catch { - case e: AnalysisException if e.getErrorClass == "TABLE_OR_VIEW_NOT_FOUND" => false + case e: AnalysisException if e.getCondition == "TABLE_OR_VIEW_NOT_FOUND" => false } } @@ -323,7 +323,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { case _ => false } } catch { - case e: AnalysisException if e.getErrorClass == "UNRESOLVED_ROUTINE" => false + case e: AnalysisException if e.getCondition == "UNRESOLVED_ROUTINE" => false } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java index 9fbd1919a2668..9988d04220f0f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java @@ -85,7 +85,7 @@ public void isInCollectionCheckExceptionMessage() { Dataset df = spark.createDataFrame(rows, schema); AnalysisException e = Assertions.assertThrows(AnalysisException.class, () -> df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b"))))); - Assertions.assertTrue(e.getErrorClass().equals("DATATYPE_MISMATCH.DATA_DIFF_TYPES")); + Assertions.assertTrue(e.getCondition().equals("DATATYPE_MISMATCH.DATA_DIFF_TYPES")); Map messageParameters = new HashMap<>(); messageParameters.put("functionName", "`in`"); messageParameters.put("dataType", "[\"INT\", \"ARRAY\"]"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala index 879c0c480943d..8600ec4f8787f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationExpressionWalkerSuite.scala @@ -741,7 +741,7 @@ class CollationExpressionWalkerSuite extends SparkFunSuite with SharedSparkSessi assert(resultUTF8.collect() === resultUTF8Lcase.collect()) } } catch { - case e: SparkRuntimeException => assert(e.getErrorClass == "USER_RAISED_EXCEPTION") + case e: SparkRuntimeException => assert(e.getCondition == "USER_RAISED_EXCEPTION") case other: Throwable => throw other } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 85f296665b6e0..45c34d9c73367 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1273,7 +1273,7 @@ class DatasetSuite extends QueryTest // Just check the error class here to avoid flakiness due to different parameters. assert(intercept[SparkRuntimeException] { buildDataset(Row(Row("hello", null))).collect() - }.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + }.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("SPARK-12478: top level null field") { @@ -1416,7 +1416,7 @@ class DatasetSuite extends QueryTest val ex = intercept[SparkRuntimeException] { spark.createDataFrame(rdd, schema).collect() } - assert(ex.getErrorClass == "EXPRESSION_ENCODING_FAILED") + assert(ex.getCondition == "EXPRESSION_ENCODING_FAILED") assert(ex.getCause.getMessage.contains("The 1th field 'b' of input row cannot be null")) } @@ -1612,7 +1612,7 @@ class DatasetSuite extends QueryTest test("Dataset should throw RuntimeException if top-level product input object is null") { val e = intercept[SparkRuntimeException](Seq(ClassData("a", 1), null).toDS()) - assert(e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("dropDuplicates") { @@ -2121,7 +2121,7 @@ class DatasetSuite extends QueryTest test("SPARK-23835: null primitive data type should throw NullPointerException") { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() val exception = intercept[SparkRuntimeException](ds.as[(Int, Int)].collect()) - assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION") } test("SPARK-24569: Option of primitive types are mistakenly mapped to struct type") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala index a892cd4db02b0..3f921618297d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/LateralColumnAliasSuite.scala @@ -205,7 +205,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { withLCAOn { checkAnswer(sql(query), expectedAnswerLCAOn) } withLCAOff { assert(intercept[AnalysisException]{ sql(query) } - .getErrorClass == "UNRESOLVED_COLUMN.WITH_SUGGESTION") + .getCondition == "UNRESOLVED_COLUMN.WITH_SUGGESTION") } } @@ -216,8 +216,8 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { errorParams: Map[String, String]): Unit = { val e1 = intercept[AnalysisException] { sql(q1) } val e2 = intercept[AnalysisException] { sql(q2) } - assert(e1.getErrorClass == condition) - assert(e2.getErrorClass == condition) + assert(e1.getCondition == condition) + assert(e2.getCondition == condition) errorParams.foreach { case (k, v) => assert(e1.messageParameters.get(k).exists(_ == v)) assert(e2.messageParameters.get(k).exists(_ == v)) @@ -1187,7 +1187,7 @@ class LateralColumnAliasSuite extends LateralColumnAliasSuiteBase { "sum_avg * 1.0 as sum_avg1, sum_avg1 + dept " + s"from $testTable group by dept, properties.joinYear $havingSuffix" ).foreach { query => - assert(intercept[AnalysisException](sql(query)).getErrorClass == + assert(intercept[AnalysisException](sql(query)).getCondition == "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_AGGREGATE_WITH_WINDOW_AND_HAVING") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala index 754c46cc5cd3e..b48ff7121c767 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RuntimeNullChecksV2Writes.scala @@ -64,7 +64,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS sql("INSERT INTO t VALUES ('txt', null)") } } - assert(e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } @@ -404,7 +404,7 @@ class RuntimeNullChecksV2Writes extends QueryTest with SQLTestUtils with SharedS private def assertNotNullException(e: SparkRuntimeException, colPath: Seq[String]): Unit = { e.getCause match { - case _ if e.getErrorClass == "NOT_NULL_ASSERT_VIOLATION" => + case _ if e.getCondition == "NOT_NULL_ASSERT_VIOLATION" => case other => fail(s"Unexpected exception cause: $other") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala index 38e004e0b7209..4bd20bc245613 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala @@ -148,7 +148,7 @@ trait SQLQueryTestHelper extends Logging { try { result } catch { - case e: SparkThrowable with Throwable if e.getErrorClass != null => + case e: SparkThrowable with Throwable if e.getCondition != null => (emptySchema, Seq(e.getClass.getName, getMessage(e, format))) case a: AnalysisException => // Do not output the logical plan tree which contains expression IDs. @@ -160,7 +160,7 @@ trait SQLQueryTestHelper extends Logging { // information of stage, task ID, etc. // To make result matching simpler, here we match the cause of the exception if it exists. s.getCause match { - case e: SparkThrowable with Throwable if e.getErrorClass != null => + case e: SparkThrowable with Throwable if e.getCondition != null => (emptySchema, Seq(e.getClass.getName, getMessage(e, format))) case cause => (emptySchema, Seq(cause.getClass.getName, cause.getMessage)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index f17cf25565145..f8f7fd246832f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -925,12 +925,12 @@ class SubquerySuite extends QueryTest withSQLConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } withSQLConf(SQLConf.DECORRELATE_SET_OPS_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } @@ -1004,12 +1004,12 @@ class SubquerySuite extends QueryTest withSQLConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } withSQLConf(SQLConf.DECORRELATE_SET_OPS_ENABLED.key -> "false") { val error = intercept[AnalysisException] { sql(query) } - assert(error.getErrorClass == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + + assert(error.getCondition == "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY." + "ACCESSING_OUTER_QUERY_COLUMN_IS_NOT_ALLOWED") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 2e072e5afc926..d550d0f94f236 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -821,14 +821,14 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e1 = intercept[SparkException] { Seq("20").toDF("col").select(udf(f1).apply(Column("col"))).collect() } - assert(e1.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e1.getCondition == "FAILED_EXECUTE_UDF") assert(e1.getCause.getStackTrace.head.toString.contains( "UDFSuite$MalformedClassObject$MalformedNonPrimitiveFunction")) val e2 = intercept[SparkException] { Seq(20).toDF("col").select(udf(f2).apply(Column("col"))).collect() } - assert(e2.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e2.getCondition == "FAILED_EXECUTE_UDF") assert(e2.getCause.getStackTrace.head.toString.contains( "UDFSuite$MalformedClassObject$MalformedPrimitiveFunction")) } @@ -938,7 +938,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { input.select(overflowFunc($"dateTime")).collect() } - assert(e.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e.getCondition == "FAILED_EXECUTE_UDF") assert(e.getCause.isInstanceOf[java.lang.ArithmeticException]) } @@ -1053,7 +1053,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { input.select(overflowFunc($"d")).collect() } - assert(e.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e.getCondition == "FAILED_EXECUTE_UDF") assert(e.getCause.isInstanceOf[java.lang.ArithmeticException]) } @@ -1101,7 +1101,7 @@ class UDFSuite extends QueryTest with SharedSparkSession { val e = intercept[SparkException] { input.select(overflowFunc($"p")).collect() } - assert(e.getErrorClass == "FAILED_EXECUTE_UDF") + assert(e.getCondition == "FAILED_EXECUTE_UDF") assert(e.getCause.isInstanceOf[java.lang.ArithmeticException]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 6b58d23e92603..52ae1bf5d9d3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -840,7 +840,7 @@ class DataSourceV2SQLSuiteV1Filter val exception = intercept[SparkRuntimeException] { insertNullValueAndCheck() } - assert(exception.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(exception.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala index 9d4e4fc016722..053616c88d638 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoTableSuiteBase.scala @@ -1326,7 +1326,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) |""".stripMargin) } - assert(e1.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e1.getCondition == "NOT_NULL_ASSERT_VIOLATION") val e2 = intercept[SparkRuntimeException] { sql( @@ -1337,7 +1337,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | UPDATE SET s = named_struct('n_i', null, 'n_l', -1L) |""".stripMargin) } - assert(e2.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e2.getCondition == "NOT_NULL_ASSERT_VIOLATION") val e3 = intercept[SparkRuntimeException] { sql( @@ -1348,7 +1348,7 @@ abstract class MergeIntoTableSuiteBase extends RowLevelOperationSuiteBase { | INSERT (pk, s, dep) VALUES (s.pk, named_struct('n_i', null, 'n_l', -1L), 'invalid') |""".stripMargin) } - assert(e3.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(e3.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala index 86c1f17b4dbb9..1adb1fdf05032 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryExecutionErrorsSuite.scala @@ -293,7 +293,7 @@ class QueryExecutionErrorsSuite val e = intercept[SparkException] { df.write.parquet(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val format = "Parquet" val config = "\"" + SQLConf.PARQUET_REBASE_MODE_IN_WRITE.key + "\"" @@ -312,7 +312,7 @@ class QueryExecutionErrorsSuite val ex = intercept[SparkException] { spark.read.schema("time timestamp_ntz").orc(file.getCanonicalPath).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkUnsupportedOperationException], condition = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", @@ -334,7 +334,7 @@ class QueryExecutionErrorsSuite val ex = intercept[SparkException] { spark.read.schema("time timestamp_ltz").orc(file.getCanonicalPath).collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) checkError( exception = ex.getCause.asInstanceOf[SparkUnsupportedOperationException], condition = "UNSUPPORTED_FEATURE.ORC_TYPE_CAST", @@ -382,7 +382,7 @@ class QueryExecutionErrorsSuite } val e2 = e1.getCause.asInstanceOf[SparkException] - assert(e2.getErrorClass == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") + assert(e2.getCondition == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") checkError( exception = e2.getCause.asInstanceOf[SparkRuntimeException], @@ -921,7 +921,7 @@ class QueryExecutionErrorsSuite val e = intercept[StreamingQueryException] { query.awaitTermination() } - assert(e.getErrorClass === "STREAM_FAILED") + assert(e.getCondition === "STREAM_FAILED") assert(e.getCause.isInstanceOf[NullPointerException]) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 75f016d050de9..c5e64c96b2c8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -904,7 +904,7 @@ class AdaptiveQueryExecSuite val error = intercept[SparkException] { aggregated.count() } - assert(error.getErrorClass === "INVALID_BUCKET_FILE") + assert(error.getCondition === "INVALID_BUCKET_FILE") assert(error.getMessage contains "Invalid bucket file") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala index deb62eb3ac234..387a2baa256bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormatSuite.scala @@ -368,7 +368,7 @@ class BinaryFileFormatSuite extends QueryTest with SharedSparkSession { checkAnswer(readContent(), expected) } } - assert(caught.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(caught.getCondition.startsWith("FAILED_READ_FILE")) assert(caught.getCause.getMessage.contains("exceeds the max length allowed")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 023f401516dc3..422ae02a18322 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -392,7 +392,7 @@ abstract class CSVSuite condition = "FAILED_READ_FILE.NO_HINT", parameters = Map("path" -> s".*$carsFile.*")) val e2 = e1.getCause.asInstanceOf[SparkException] - assert(e2.getErrorClass == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") + assert(e2.getCondition == "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION") checkError( exception = e2.getCause.asInstanceOf[SparkRuntimeException], condition = "MALFORMED_CSV_RECORD", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index f13d66b76838f..500c0647bcb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -708,7 +708,7 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { val ex = intercept[SparkException] { sql(s"select A from $tableName where A < 0").collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) assert(ex.getCause.isInstanceOf[SparkRuntimeException]) assert(ex.getCause.getMessage.contains( """Found duplicate field(s) "A": [A, a] in case-insensitive mode""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 2e6413d998d12..ab0d4d9bc53b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -604,14 +604,14 @@ abstract class OrcQueryTest extends OrcTest { val e1 = intercept[SparkException] { testIgnoreCorruptFiles() } - assert(e1.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e1.getCondition.startsWith("FAILED_READ_FILE")) assert(e1.getCause.getMessage.contains("Malformed ORC file") || // Hive ORC table scan uses a different code path and has one more error stack e1.getCause.getCause.getMessage.contains("Malformed ORC file")) val e2 = intercept[SparkException] { testIgnoreCorruptFilesWithoutSchemaInfer() } - assert(e2.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e2.getCondition.startsWith("FAILED_READ_FILE")) assert(e2.getCause.getMessage.contains("Malformed ORC file") || // Hive ORC table scan uses a different code path and has one more error stack e2.getCause.getCause.getMessage.contains("Malformed ORC file")) @@ -625,7 +625,7 @@ abstract class OrcQueryTest extends OrcTest { val e4 = intercept[SparkException] { testAllCorruptFilesWithoutSchemaInfer() } - assert(e4.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e4.getCondition.startsWith("FAILED_READ_FILE")) assert(e4.getCause.getMessage.contains("Malformed ORC file") || // Hive ORC table scan uses a different code path and has one more error stack e4.getCause.getCause.getMessage.contains("Malformed ORC file")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 9348d10711b35..040999476ece1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -450,8 +450,8 @@ abstract class OrcSuite val ex = intercept[SparkException] { spark.read.orc(basePath).columns.length } - assert(ex.getErrorClass == "CANNOT_MERGE_SCHEMAS") - assert(ex.getCause.asInstanceOf[SparkException].getErrorClass === + assert(ex.getCondition == "CANNOT_MERGE_SCHEMAS") + assert(ex.getCause.asInstanceOf[SparkException].getCondition === "CANNOT_MERGE_INCOMPATIBLE_DATA_TYPE") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 5c382b1858716..903dda7f41c0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1958,7 +1958,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared val ex = intercept[SparkException] { sql(s"select a from $tableName where b > 0").collect() } - assert(ex.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(ex.getCondition.startsWith("FAILED_READ_FILE")) assert(ex.getCause.isInstanceOf[SparkRuntimeException]) assert(ex.getCause.getMessage.contains( """Found duplicate field(s) "B": [B, b] in case-insensitive mode""")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0afa545595c77..95fb178154929 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -1223,7 +1223,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession val m1 = intercept[SparkException] { spark.range(1).coalesce(1).write.options(extraOptions).parquet(dir.getCanonicalPath) } - assert(m1.getErrorClass == "TASK_WRITE_FAILED") + assert(m1.getCondition == "TASK_WRITE_FAILED") assert(m1.getCause.getMessage.contains("Intentional exception for testing purposes")) } @@ -1233,8 +1233,8 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession .coalesce(1) df.write.partitionBy("a").options(extraOptions).parquet(dir.getCanonicalPath) } - if (m2.getErrorClass != null) { - assert(m2.getErrorClass == "TASK_WRITE_FAILED") + if (m2.getCondition != null) { + assert(m2.getCondition == "TASK_WRITE_FAILED") assert(m2.getCause.getMessage.contains("Intentional exception for testing purposes")) } else { assert(m2.getMessage.contains("TASK_WRITE_FAILED")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index eb4618834504c..87a2843f34de1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1181,7 +1181,7 @@ abstract class ParquetPartitionDiscoverySuite spark.read.parquet(dir.toString) } val msg = exception.getMessage - assert(exception.getErrorClass === "CONFLICTING_PARTITION_COLUMN_NAMES") + assert(exception.getCondition === "CONFLICTING_PARTITION_COLUMN_NAMES") // Partitions inside the error message can be presented in any order assert("Partition column name list #[0-1]: col1".r.findFirstIn(msg).isDefined) assert("Partition column name list #[0-1]: col1, col2".r.findFirstIn(msg).isDefined) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4d413efe50430..22a02447e720f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -1075,7 +1075,7 @@ abstract class ParquetQuerySuite extends QueryTest with ParquetTest with SharedS val e = intercept[SparkException] { readParquet("d DECIMAL(3, 2)", path).collect() } - assert(e.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(e.getCondition.startsWith("FAILED_READ_FILE")) assert(e.getCause.getMessage.contains("Please read this column/field as Spark BINARY type")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala index 6d9092391a98e..30503af0fab6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRebaseDatetimeSuite.scala @@ -414,7 +414,7 @@ abstract class ParquetRebaseDatetimeSuite val e = intercept[SparkException] { df.write.parquet(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } @@ -431,7 +431,7 @@ abstract class ParquetRebaseDatetimeSuite val e = intercept[SparkException] { df.write.parquet(dir.getCanonicalPath) } - assert(e.getErrorClass == "TASK_WRITE_FAILED") + assert(e.getCondition == "TASK_WRITE_FAILED") val errMsg = e.getCause.asInstanceOf[SparkUpgradeException].getMessage assert(errMsg.contains("You may get a different result due to the upgrading")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala index 95378d9467478..08fd8a9ecb53e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowIndexSuite.scala @@ -319,7 +319,7 @@ class ParquetRowIndexSuite extends QueryTest with SharedSparkSession { .load(path.getAbsolutePath) val exception = intercept[SparkException](dfRead.collect()) - assert(exception.getErrorClass.startsWith("FAILED_READ_FILE")) + assert(exception.getCondition.startsWith("FAILED_READ_FILE")) assert(exception.getCause.getMessage.contains( ParquetFileFormat.ROW_INDEX_TEMPORARY_COLUMN_NAME)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala index 4833b8630134c..59c0af8afd198 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceChangeDataReadSuite.scala @@ -90,7 +90,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 2) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + assert(exc.getCondition === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") } } @@ -103,7 +103,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.IS_NEGATIVE") + assert(exc.getCondition === "STDS_INVALID_OPTION_VALUE.IS_NEGATIVE") } } @@ -116,7 +116,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") + assert(exc.getCondition === "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE") } } @@ -130,7 +130,7 @@ abstract class StateDataSourceChangeDataReaderSuite extends StateDataSourceTestB .option(StateSourceOptions.CHANGE_END_BATCH_ID, 0) .load(tempDir.getAbsolutePath) } - assert(exc.getErrorClass === "STDS_CONFLICT_OPTIONS") + assert(exc.getCondition === "STDS_CONFLICT_OPTIONS") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index 5f55848d540df..300da03f73e1f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -1137,7 +1137,7 @@ abstract class StateDataSourceReadSuite extends StateDataSourceTestBase with Ass val exc = intercept[StateStoreSnapshotPartitionNotFound] { stateDfError.show() } - assert(exc.getErrorClass === "CANNOT_LOAD_STATE_STORE.SNAPSHOT_PARTITION_ID_NOT_FOUND") + assert(exc.getCondition === "CANNOT_LOAD_STATE_STORE.SNAPSHOT_PARTITION_ID_NOT_FOUND") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index dcebece29037f..1f2be12058eb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -330,7 +330,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { val err = intercept[AnalysisException] { spark.read.format(dataSourceName).schema(schema).load().collect() } - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("PySparkNotImplementedError")) } @@ -350,7 +350,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { val err = intercept[AnalysisException] { spark.read.format(dataSourceName).schema(schema).load().collect() } - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("error creating reader")) } @@ -369,7 +369,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { val err = intercept[AnalysisException] { spark.read.format(dataSourceName).schema(schema).load().collect() } - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("DATA_SOURCE_TYPE_MISMATCH")) assert(err.getMessage.contains("PySparkAssertionError")) } @@ -480,7 +480,7 @@ class PythonDataSourceSuite extends PythonDataSourceSuiteBase { spark.dataSource.registerPython(dataSourceName, dataSource) val err = intercept[AnalysisException]( spark.read.format(dataSourceName).load().collect()) - assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getCondition == "PYTHON_DATA_SOURCE_ERROR") assert(err.getMessage.contains("partitions")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala index 8d0e1c5f578fa..3d91a045907fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala @@ -574,7 +574,7 @@ class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { val q = spark.readStream.format(dataSourceName).load().writeStream.format("console").start() q.awaitTermination() } - assert(err.getErrorClass == "STREAM_FAILED") + assert(err.getCondition == "STREAM_FAILED") assert(err.getMessage.contains("error creating stream reader")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala index 38533825ece90..99483bc0ee8dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityCheckerSuite.scala @@ -423,14 +423,14 @@ class StateSchemaCompatibilityCheckerSuite extends SharedSparkSession { // collation checks are also performed in this path. so we need to check for them explicitly. if (keyCollationChecks) { assert(ex.getMessage.contains("Binary inequality column is not supported")) - assert(ex.getErrorClass === "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY") + assert(ex.getCondition === "STATE_STORE_UNSUPPORTED_OPERATION_BINARY_INEQUALITY") } else { if (ignoreValueSchema) { // if value schema is ignored, the mismatch has to be on the key schema - assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE") + assert(ex.getCondition === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE") } else { - assert(ex.getErrorClass === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" || - ex.getErrorClass === "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE") + assert(ex.getCondition === "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE" || + ex.getCondition === "STATE_STORE_VALUE_SCHEMA_NOT_COMPATIBLE") } assert(ex.getMessage.contains("does not match existing")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 8bbc7a31760d9..2a9944a81cb2a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1373,7 +1373,7 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] put(store, "a", 0, 0) val e = intercept[SparkException](quietly { store.commit() } ) - assert(e.getErrorClass == "CANNOT_WRITE_STATE_STORE.CANNOT_COMMIT") + assert(e.getCondition == "CANNOT_WRITE_STATE_STORE.CANNOT_COMMIT") if (store.getClass.getName contains ROCKSDB_STATE_STORE) { assert(e.getMessage contains "RocksDBStateStore[id=(op=0,part=0)") } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 41447d8af5740..baf99798965da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -956,7 +956,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { val msg = intercept[SparkRuntimeException] { sql("INSERT INTO TABLE test_table SELECT 2, null") } - assert(msg.getErrorClass == "NOT_NULL_ASSERT_VIOLATION") + assert(msg.getCondition == "NOT_NULL_ASSERT_VIOLATION") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index d9ce8002d285b..a0eea14e54eed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -296,7 +296,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val exception = SparkException.internalError("testpurpose") testSerialization( new QueryTerminatedEvent(UUID.randomUUID, UUID.randomUUID, - Some(exception.getMessage), Some(exception.getErrorClass))) + Some(exception.getMessage), Some(exception.getCondition))) } test("only one progress event per interval when no data") { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala index 8a8bdd4d38ee3..59d1b61f2f8e7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServerErrors.scala @@ -38,7 +38,7 @@ object HiveThriftServerErrors { def runningQueryError(e: Throwable, format: ErrorMessageFormat.Value): Throwable = e match { case st: SparkThrowable if format == ErrorMessageFormat.PRETTY => - val errorClassPrefix = Option(st.getErrorClass).map(e => s"[$e] ").getOrElse("") + val errorClassPrefix = Option(st.getCondition).map(e => s"[$e] ").getOrElse("") new HiveSQLException( s"Error running query: $errorClassPrefix${st.toString}", st.getSqlState, st) case st: SparkThrowable with Throwable => From ab1315b95b8e4058d3e97946afc772f05dc519bb Mon Sep 17 00:00:00 2001 From: beliefer Date: Thu, 10 Oct 2024 19:57:18 +0800 Subject: [PATCH 208/250] [SPARK-49756][SQL] Postgres dialect supports pushdown datetime functions ### What changes were proposed in this pull request? This PR propose to make Postgres dialect supports pushdown datetime functions. ### Why are the changes needed? Currently, DS V2 pushdown framework pushed the datetime functions with in a common way. But Postgres doesn't support some datetime functions. ### Does this PR introduce _any_ user-facing change? 'No'. This is a new feature for Postgres dialect. ### How was this patch tested? GA. ### Was this patch authored or co-authored using generative AI tooling? 'No'. Closes #48210 from beliefer/SPARK-49756. Authored-by: beliefer Signed-off-by: Wenchen Fan --- .../jdbc/v2/PostgresIntegrationSuite.scala | 84 +++++++++++++++++++ .../spark/sql/jdbc/PostgresDialect.scala | 25 +++++- 2 files changed, 108 insertions(+), 1 deletion(-) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 6bb415a928837..05f02a402353b 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -65,6 +65,17 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT |) """.stripMargin ).executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + override def dataPreparation(connection: Connection): Unit = { + super.dataPreparation(connection) + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('amy', '2022-05-19', '2022-05-19 00:00:00')").executeUpdate() + connection.prepareStatement("INSERT INTO datetime VALUES " + + "('alex', '2022-05-18', '2022-05-18 00:00:00')").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -123,4 +134,77 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT ) } } + + override def testDatetime(tbl: String): Unit = { + val df1 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 AND dayofmonth(date1) > 10 ") + checkFilterPushed(df1) + val rows1 = df1.collect() + assert(rows1.length === 2) + assert(rows1(0).getString(0) === "amy") + assert(rows1(1).getString(0) === "alex") + + val df2 = sql(s"SELECT name FROM $tbl WHERE year(date1) = 2022 AND quarter(date1) = 2") + checkFilterPushed(df2) + val rows2 = df2.collect() + assert(rows2.length === 2) + assert(rows2(0).getString(0) === "amy") + assert(rows2(1).getString(0) === "alex") + + val df3 = sql(s"SELECT name FROM $tbl WHERE second(time1) = 0 AND month(date1) = 5") + checkFilterPushed(df3) + val rows3 = df3.collect() + assert(rows3.length === 2) + assert(rows3(0).getString(0) === "amy") + assert(rows3(1).getString(0) === "alex") + + val df4 = sql(s"SELECT name FROM $tbl WHERE hour(time1) = 0 AND minute(time1) = 0") + checkFilterPushed(df4) + val rows4 = df4.collect() + assert(rows4.length === 2) + assert(rows4(0).getString(0) === "amy") + assert(rows4(1).getString(0) === "alex") + + val df5 = sql(s"SELECT name FROM $tbl WHERE " + + "extract(WEEk from date1) > 10 AND extract(YEAROFWEEK from date1) = 2022") + checkFilterPushed(df5) + val rows5 = df5.collect() + assert(rows5.length === 2) + assert(rows5(0).getString(0) === "amy") + assert(rows5(1).getString(0) === "alex") + + val df6 = sql(s"SELECT name FROM $tbl WHERE date_add(date1, 1) = date'2022-05-20' " + + "AND datediff(date1, '2022-05-10') > 0") + checkFilterPushed(df6, false) + val rows6 = df6.collect() + assert(rows6.length === 1) + assert(rows6(0).getString(0) === "amy") + + val df7 = sql(s"SELECT name FROM $tbl WHERE weekday(date1) = 2") + checkFilterPushed(df7) + val rows7 = df7.collect() + assert(rows7.length === 1) + assert(rows7(0).getString(0) === "alex") + + val df8 = sql(s"SELECT name FROM $tbl WHERE dayofweek(date1) = 4") + checkFilterPushed(df8) + val rows8 = df8.collect() + assert(rows8.length === 1) + assert(rows8(0).getString(0) === "alex") + + val df9 = sql(s"SELECT name FROM $tbl WHERE " + + "dayofyear(date1) > 100 order by dayofyear(date1) limit 1") + checkFilterPushed(df9) + val rows9 = df9.collect() + assert(rows9.length === 1) + assert(rows9(0).getString(0) === "alex") + + // Postgres does not support + val df10 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'week') = date'2022-05-16'") + checkFilterPushed(df10, false) + val rows10 = df10.collect() + assert(rows10.length === 2) + assert(rows10(0).getString(0) === "amy") + assert(rows10(1).getString(0) === "alex") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 60258ecbb0d61..8341063e09890 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -23,6 +23,7 @@ import java.util import java.util.Locale import scala.util.Using +import scala.util.control.NonFatal import org.apache.spark.SparkThrowable import org.apache.spark.internal.LogKeys.COLUMN_NAME @@ -30,7 +31,7 @@ import org.apache.spark.internal.MDC import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} import org.apache.spark.sql.connector.catalog.Identifier -import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.{Expression, NamedReference} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo @@ -300,6 +301,28 @@ private case class PostgresDialect() } } + class PostgresSQLBuilder extends JDBCSQLBuilder { + override def visitExtract(field: String, source: String): String = { + field match { + case "DAY_OF_YEAR" => s"EXTRACT(DOY FROM $source)" + case "YEAR_OF_WEEK" => s"EXTRACT(YEAR FROM $source)" + case "DAY_OF_WEEK" => s"EXTRACT(DOW FROM $source)" + case _ => super.visitExtract(field, source) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val postgresSQLBuilder = new PostgresSQLBuilder() + try { + Some(postgresSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + override def supportsLimit: Boolean = true override def supportsOffset: Boolean = true From e589ccd9c18e8159ea05471434834246f494a476 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Thu, 10 Oct 2024 20:55:34 +0800 Subject: [PATCH 209/250] [SPARK-49920][INFRA] Install `R` for `ubuntu 24.04` when GA run `k8s-integration-tests` ### What changes were proposed in this pull request? The pr aims to install `R` for `ubuntu 24.04` when GA run `k8s-integration-tests`. ### Why are the changes needed? - As the GitHub community switches the default version of `ubuntu-latest` from `ubuntu-22.04` to `ubuntu-24.04`. https://github.com/actions/runner-images/issues/10636 - In `ubuntu-24.04`, `R` is `not installed` by default A.`ubuntu-24.04`(`R` is `not installed` by default) https://github.com/actions/runner-images/blob/main/images/ubuntu/Ubuntu2404-Readme.md#tools image B.`ubuntu-22.04`(`R` is `installed` by default) https://github.com/actions/runner-images/blob/main/images/ubuntu/Ubuntu2204-Readme.md#tools image - Fix the failure issue of GA https://github.com/LuciferYang/spark/actions/runs/11268158324/job/31334445659 image ### Does this PR introduce _any_ user-facing change? No, only for tests. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48406 from panbingkun/install_R. Authored-by: panbingkun Signed-off-by: yangjie01 --- .github/workflows/build_and_test.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 2b459e4c73bbb..43ac6b50052ae 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -1112,6 +1112,10 @@ jobs: with: distribution: zulu java-version: ${{ inputs.java }} + - name: Install R + run: | + sudo apt update + sudo apt-get install r-base - name: Start Minikube uses: medyagh/setup-minikube@v0.0.18 with: From b056e0b12786f0b85675cdf73748bdf506e3619f Mon Sep 17 00:00:00 2001 From: YangJie Date: Thu, 10 Oct 2024 23:32:37 +0800 Subject: [PATCH 210/250] [SPARK-49569][BUILD][FOLLOWUP] Exclude `spark-connect-shims` from `sql/core` module ### What changes were proposed in this pull request? This pr exclude `spark-connect-shims` from `sql/core` module for further fix maven daily test. ### Why are the changes needed? For fix maven daily test: After https://github.com/apache/spark/pull/48399, although the Maven build was successful in my local environment, the Maven daily test pipeline still failed to build: - https://github.com/apache/spark/actions/runs/11255598249/job/31311358712 ``` Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala:121: value makeRDD is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:75: value id is not a member of org.apache.spark.rdd.RDD[org.apache.spark.sql.columnar.CachedBatch] Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:82: value env is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:88: value env is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:185: value parallelize is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:481: value cleaner is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:500: value parallelize is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:940: value addSparkListener is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:943: value listenerBus is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:947: value removeSparkListener is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1667: value listenerBus is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1668: value addSparkListener is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1673: value partitions is not a member of org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1674: value listenerBus is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1682: value partitions is not a member of org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1683: value listenerBus is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1687: value removeSparkListener is not a member of org.apache.spark.SparkContext Error: ] /home/runner/work/spark/spark/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala:1708: value partitions is not a member of org.apache.spark.rdd.RDD[org.apache.spark.sql.Row] ... ``` After using the `mvn dependency:tree` command to check, I found that `sql/core` cascadingly introduced `org.apache.spark:spark-connect-shims_2.13:jar:4.0.0-SNAPSHOT:test` through `org.apache.spark:spark-sql-api_2.13:test-jar:tests:4.0.0-SNAPSHOT:test`. ``` [INFO] ------------------< org.apache.spark:spark-sql_2.13 >------------------- [INFO] Building Spark Project SQL 4.0.0-SNAPSHOT [18/42] [INFO] from sql/core/pom.xml [INFO] --------------------------------[ jar ]--------------------------------- [INFO] [INFO] --- dependency:3.6.1:tree (default-cli) spark-sql_2.13 --- [INFO] org.apache.spark:spark-sql_2.13:jar:4.0.0-SNAPSHOT ... [INFO] +- org.apache.spark:spark-catalyst_2.13:test-jar:tests:4.0.0-SNAPSHOT:test [INFO] +- org.apache.spark:spark-sql-api_2.13:test-jar:tests:4.0.0-SNAPSHOT:test [INFO] | +- org.scala-lang.modules:scala-parser-combinators_2.13:jar:2.4.0:compile [INFO] | +- org.apache.spark:spark-connect-shims_2.13:jar:4.0.0-SNAPSHOT:test [INFO] | +- org.antlr:antlr4-runtime:jar:4.13.1:compile [INFO] | +- org.apache.arrow:arrow-vector:jar:17.0.0:compile [INFO] | | +- org.apache.arrow:arrow-format:jar:17.0.0:compile [INFO] | | +- org.apache.arrow:arrow-memory-core:jar:17.0.0:compile [INFO] | | +- com.fasterxml.jackson.datatype:jackson-datatype-jsr310:jar:2.18.0:compile [INFO] | | \- com.google.flatbuffers:flatbuffers-java:jar:24.3.25:compile [INFO] | \- org.apache.arrow:arrow-memory-netty:jar:17.0.0:compile [INFO] | \- org.apache.arrow:arrow-memory-netty-buffer-patch:jar:17.0.0:compile ``` This should be unexpected. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - Pass maven test on GitHub Actions: https://github.com/LuciferYang/spark/runs/31314342332 image All maven test passed ### Was this patch authored or co-authored using generative AI tooling? No Closes #48403 from LuciferYang/test-maven-build. Lead-authored-by: YangJie Co-authored-by: yangjie01 Signed-off-by: yangjie01 --- sql/core/pom.xml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 972cf76d27535..16236940fe072 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -79,6 +79,12 @@ ${project.version} test-jar test + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + org.apache.spark From 2af653688c20dde87eebaa6bd4dc21123fab74cc Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Fri, 11 Oct 2024 10:50:51 +0900 Subject: [PATCH 211/250] [SPARK-49927][SS] pyspark.sql.tests.streaming.test_streaming_listener to wait longer ### What changes were proposed in this pull request? In test pyspark.sql.tests.streaming.test_streaming_listener, instead of waiting for fixed 10 seconds, we wait for progress made. ### Why are the changes needed? In some environment, the test fails with progress_event appears to be None. Likely the wait time is not sufficient. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Run the patch and make sure it still passes ### Was this patch authored or co-authored using generative AI tooling? No Closes #48414 from siying/python_test. Lead-authored-by: Siying Dong Co-authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/streaming/test_streaming_listener.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index c3ae62e64cc30..1f5b0f573807a 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -381,7 +381,8 @@ def verify(test_listener): .start() ) self.assertTrue(q.isActive) - q.awaitTermination(10) + while progress_event is None or progress_event.batchId == 0: + q.awaitTermination(0.5) q.stop() # Make sure all events are empty From 4666972a31907393390a6bc353e87dfedc6d6445 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Oct 2024 12:29:45 +0900 Subject: [PATCH 212/250] [SPARK-48567][PYTHON][TESTS][FOLLOW-UP] Make the query scope higher so finally can access to it ### What changes were proposed in this pull request? This PR is a followup that fixes the test to recover the build. ### Why are the changes needed? To fix up the build. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? CI in this PR. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48417 from HyukjinKwon/SPARK-48567-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../sql/tests/streaming/test_streaming_listener.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 1f5b0f573807a..30d6eee93879d 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -216,6 +216,7 @@ def onQueryIdle(self, event): def onQueryTerminated(self, event): pass + q = None try: error_listener = MyErrorListener() self.spark.streams.addListener(error_listener) @@ -238,10 +239,12 @@ def onQueryTerminated(self, event): self.assertTrue(error_listener.num_error_rows > 0) finally: - q.stop() + if q is not None: + q.stop() self.spark.streams.removeListener(error_listener) def test_streaming_progress(self): + q = None try: # Test a fancier query with stateful operation and observed metrics df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() @@ -265,7 +268,8 @@ def test_streaming_progress(self): self.check_streaming_query_progress(p, True) finally: - q.stop() + if q is not None: + q.stop() class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase): From cacd261e825a253aca98f7019fc484f0d94fad6e Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Oct 2024 15:47:59 +0900 Subject: [PATCH 213/250] [SPARK-49927][SS][PYTHON][TESTS][FOLLOW-UP] Fixes `q.lastProgress.batchId` to `q.lastProgress.progress.batchId` ### What changes were proposed in this pull request? This PR is a followup of https://github.com/apache/spark/pull/48414 that fixes `q.lastProgress.batchId` -> `q.lastProgress.progress.batchId` to fix the test. ### Why are the changes needed? `q.lastProgress` does not have `progress`. We should fix it to fix up the broken build. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48419 from HyukjinKwon/SPARK-49927-followup. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/tests/streaming/test_streaming_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 30d6eee93879d..d28fb57a0da23 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -230,7 +230,7 @@ def onQueryTerminated(self, event): q = observed_ds.writeStream.format("noop").start() - while q.lastProgress is None or q.lastProgress.batchId == 0: + while q.lastProgress is None or q.lastProgress.progress.batchId == 0: q.awaitTermination(0.5) time.sleep(5) From c0d396f862f42b12cfb9b29809ffd651432a7cfe Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Oct 2024 16:16:53 +0900 Subject: [PATCH 214/250] Revert "[SPARK-49927][SS][PYTHON][TESTS][FOLLOW-UP] Fixes `q.lastProgress.batchId` to `q.lastProgress.progress.batchId`" This reverts commit cacd261e825a253aca98f7019fc484f0d94fad6e. --- python/pyspark/sql/tests/streaming/test_streaming_listener.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index d28fb57a0da23..30d6eee93879d 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -230,7 +230,7 @@ def onQueryTerminated(self, event): q = observed_ds.writeStream.format("noop").start() - while q.lastProgress is None or q.lastProgress.progress.batchId == 0: + while q.lastProgress is None or q.lastProgress.batchId == 0: q.awaitTermination(0.5) time.sleep(5) From b93281ed0ed95f30c13c0ae14a9694ac9246c3ce Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Oct 2024 16:16:58 +0900 Subject: [PATCH 215/250] Revert "[SPARK-48567][PYTHON][TESTS][FOLLOW-UP] Make the query scope higher so finally can access to it" This reverts commit 4666972a31907393390a6bc353e87dfedc6d6445. --- .../sql/tests/streaming/test_streaming_listener.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 30d6eee93879d..1f5b0f573807a 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -216,7 +216,6 @@ def onQueryIdle(self, event): def onQueryTerminated(self, event): pass - q = None try: error_listener = MyErrorListener() self.spark.streams.addListener(error_listener) @@ -239,12 +238,10 @@ def onQueryTerminated(self, event): self.assertTrue(error_listener.num_error_rows > 0) finally: - if q is not None: - q.stop() + q.stop() self.spark.streams.removeListener(error_listener) def test_streaming_progress(self): - q = None try: # Test a fancier query with stateful operation and observed metrics df = self.spark.readStream.format("rate").option("rowsPerSecond", 10).load() @@ -268,8 +265,7 @@ def test_streaming_progress(self): self.check_streaming_query_progress(p, True) finally: - if q is not None: - q.stop() + q.stop() class StreamingListenerTests(StreamingListenerTestsMixin, ReusedSQLTestCase): From ed0a63f93f57b05bc6d4988bb1591cdfaf21131d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 11 Oct 2024 16:17:03 +0900 Subject: [PATCH 216/250] Revert "[SPARK-49927][SS] pyspark.sql.tests.streaming.test_streaming_listener to wait longer" This reverts commit 2af653688c20dde87eebaa6bd4dc21123fab74cc. --- python/pyspark/sql/tests/streaming/test_streaming_listener.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 1f5b0f573807a..c3ae62e64cc30 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -381,8 +381,7 @@ def verify(test_listener): .start() ) self.assertTrue(q.isActive) - while progress_event is None or progress_event.batchId == 0: - q.awaitTermination(0.5) + q.awaitTermination(10) q.stop() # Make sure all events are empty From 5104d1d434c1bdd58f06ae7a9c0d0d53f7b21f47 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Fri, 11 Oct 2024 15:52:26 +0800 Subject: [PATCH 217/250] [SPARK-49915][SQL] Handle zeros and ones in ReorderAssociativeOperator ### What changes were proposed in this pull request? For additions, we omit the `Add` operation if the foldable ones finally result in 0, e.g. `-3 + a + 3` is simplified to `a` instead of `a + 0`. For multiplication, - we simplify the expression to `Literal(0, dt)` if the foldable ones finally result in 0 && the expression itself isn't nullable - we omit the `Multiply` operation if the foldable ones finally result in 1 ### Why are the changes needed? Improve the simplicity of expression evaluation and the opportunities for predicates to be pushed down to data sources ### Does this PR introduce _any_ user-facing change? no, the result shall be identical ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48395 from yaooqinn/SPARK-49915. Authored-by: Kent Yao Signed-off-by: Kent Yao --- .../sql/catalyst/optimizer/expressions.scala | 25 +++++++++++---- .../ReorderAssociativeOperatorSuite.scala | 32 +++++++++++++++++++ .../analyzer-results/null-handling.sql.out | 18 +++++++++++ .../sql-tests/inputs/null-handling.sql | 2 ++ .../sql-tests/results/null-handling.sql.out | 16 ++++++++++ .../apache/spark/sql/jdbc/JDBCV2Suite.scala | 2 +- 6 files changed, 88 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 1601d798283c9..c0cd976b9e9b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -260,19 +260,32 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { q.transformExpressionsDownWithPruning(_.containsPattern(BINARY_ARITHMETIC)) { case a @ Add(_, _, f) if a.deterministic && a.dataType.isInstanceOf[IntegralType] => val (foldables, others) = flattenAdd(a, groupingExpressionSet).partition(_.foldable) - if (foldables.size > 1) { + if (foldables.nonEmpty) { val foldableExpr = foldables.reduce((x, y) => Add(x, y, f)) - val c = Literal.create(foldableExpr.eval(EmptyRow), a.dataType) - if (others.isEmpty) c else Add(others.reduce((x, y) => Add(x, y, f)), c, f) + val foldableValue = foldableExpr.eval(EmptyRow) + if (others.isEmpty) { + Literal.create(foldableValue, a.dataType) + } else if (foldableValue == 0) { + others.reduce((x, y) => Add(x, y, f)) + } else { + Add(others.reduce((x, y) => Add(x, y, f)), Literal.create(foldableValue, a.dataType), f) + } } else { a } case m @ Multiply(_, _, f) if m.deterministic && m.dataType.isInstanceOf[IntegralType] => val (foldables, others) = flattenMultiply(m, groupingExpressionSet).partition(_.foldable) - if (foldables.size > 1) { + if (foldables.nonEmpty) { val foldableExpr = foldables.reduce((x, y) => Multiply(x, y, f)) - val c = Literal.create(foldableExpr.eval(EmptyRow), m.dataType) - if (others.isEmpty) c else Multiply(others.reduce((x, y) => Multiply(x, y, f)), c, f) + val foldableValue = foldableExpr.eval(EmptyRow) + if (others.isEmpty || (foldableValue == 0 && !m.nullable)) { + Literal.create(foldableValue, m.dataType) + } else if (foldableValue == 1) { + others.reduce((x, y) => Multiply(x, y, f)) + } else { + Multiply(others.reduce((x, y) => Multiply(x, y, f)), + Literal.create(foldableValue, m.dataType), f) + } } else { m } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala index f4b2fce74dc49..9090e0c7fc104 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReorderAssociativeOperatorSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -74,4 +75,35 @@ class ReorderAssociativeOperatorSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-49915: Handle zero and one in associative operators") { + val originalQuery = + testRelation.select( + $"a" + 0, + Literal(-3) + $"a" + 3, + $"b" * 0 * 1 * 2 * 3, + Count($"b") * 0, + $"b" * 1 * 1, + ($"b" + 0) * 1 * 2 * 3 * 4, + $"a" + 0 + $"b" + 0 + $"c" + 0, + $"a" + 0 + $"b" * 1 + $"c" + 0 + ) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + $"a".as("(a + 0)"), + $"a".as("((-3 + a) + 3)"), + ($"b" * 0).as("((((b * 0) * 1) * 2) * 3)"), + Literal(0L).as("(count(b) * 0)"), + $"b".as("((b * 1) * 1)"), + ($"b" * 24).as("(((((b + 0) * 1) * 2) * 3) * 4)"), + ($"a" + $"b" + $"c").as("""(((((a + 0) + b) + 0) + c) + 0)"""), + ($"a" + $"b" + $"c").as("((((a + 0) + (b * 1)) + c) + 0)") + ).analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out index 26e9394932a17..37d84f6c5fc00 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/null-handling.sql.out @@ -69,6 +69,24 @@ Project [a#x, (b#x + c#x) AS (b + c)#x] +- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet +-- !query +select b + 0 from t1 where a = 5 +-- !query analysis +Project [(b#x + 0) AS (b + 0)#x] ++- Filter (a#x = 5) + +- SubqueryAlias spark_catalog.default.t1 + +- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet + + +-- !query +select -100 + b + 100 from t1 where a = 5 +-- !query analysis +Project [((-100 + b#x) + 100) AS ((-100 + b) + 100)#x] ++- Filter (a#x = 5) + +- SubqueryAlias spark_catalog.default.t1 + +- Relation spark_catalog.default.t1[a#x,b#x,c#x] parquet + + -- !query select a+10, b*0 from t1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql index 040be00503442..dcdf241df73d9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/null-handling.sql @@ -10,6 +10,8 @@ insert into t1 values(7,null,null); -- Adding anything to null gives null select a, b+c from t1; +select b + 0 from t1 where a = 5; +select -100 + b + 100 from t1 where a = 5; -- Multiplying null by zero gives null select a+10, b*0 from t1; diff --git a/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out index ece6dbef1605d..fb96be8317a5b 100644 --- a/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/null-handling.sql.out @@ -77,6 +77,22 @@ struct 7 NULL +-- !query +select b + 0 from t1 where a = 5 +-- !query schema +struct<(b + 0):int> +-- !query output +NULL + + +-- !query +select -100 + b + 100 from t1 where a = 5 +-- !query schema +struct<((-100 + b) + 100):int> +-- !query output +NULL + + -- !query select a+10, b*0 from t1 -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 054c7e644ff55..0550fae3805d4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -2688,7 +2688,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") checkAggregateRemoved(df, ansiMode) val expectedPlanFragment = if (ansiMode) { - "PushedAggregates: [SUM(2147483647 + DEPT)], " + + "PushedAggregates: [SUM(DEPT + 2147483647)], " + "PushedFilters: [], " + "PushedGroupByExpressions: []" } else { From 8e1d317307d9e0daa8f9a48d6b686942a9079b6f Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Fri, 11 Oct 2024 20:59:39 +0800 Subject: [PATCH 218/250] [SPARK-49615] Bugfix: Make ML column schema validation conforms with spark config `spark.sql.caseSensitive` ### What changes were proposed in this pull request? Bugfix: Make ML column schema validation conforms with spark config `spark.sql.caseSensitive`. ### Why are the changes needed? Bugfix: Make ML column schema validation conforms with spark config `spark.sql.caseSensitive`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? N/A ### Was this patch authored or co-authored using generative AI tooling? N/A Closes #48398 from WeichenXu123/SPARK-49615. Authored-by: Weichen Xu Signed-off-by: Weichen Xu --- .../apache/spark/ml/util/SchemaUtils.scala | 22 ++++++++++++++----- .../apache/spark/sql/types/StructType.scala | 2 +- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 3b306eff99689..ff132e2a29a89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.util +import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.VectorUDT -import org.apache.spark.sql.catalyst.util.AttributeNameParser +import org.apache.spark.sql.catalyst.util.{AttributeNameParser, QuotingUtils} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - /** * Utils for handling schemas. */ @@ -206,6 +207,10 @@ private[spark] object SchemaUtils { checkColumnTypes(schema, colName, typeCandidates) } + def toSQLId(parts: String): String = { + AttributeNameParser.parseAttributeName(parts).map(QuotingUtils.quoteIdentifier).mkString(".") + } + /** * Get schema field. * @param schema input schema @@ -213,11 +218,16 @@ private[spark] object SchemaUtils { */ def getSchemaField(schema: StructType, colName: String): StructField = { val colSplits = AttributeNameParser.parseAttributeName(colName) - var field = schema(colSplits(0)) - for (colSplit <- colSplits.slice(1, colSplits.length)) { - field = field.dataType.asInstanceOf[StructType](colSplit) + val fieldOpt = schema.findNestedField(colSplits, resolver = SQLConf.get.resolver) + if (fieldOpt.isEmpty) { + throw new SparkIllegalArgumentException( + errorClass = "FIELD_NOT_FOUND", + messageParameters = Map( + "fieldName" -> toSQLId(colName), + "fields" -> schema.fields.map(f => toSQLId(f.name)).mkString(", ")) + ) } - field + fieldOpt.get._2 } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala index 4ef1cf400b80e..07f6b50bd4a7a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -321,7 +321,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru * * If includeCollections is true, this will return fields that are nested in maps and arrays. */ - private[sql] def findNestedField( + private[spark] def findNestedField( fieldNames: Seq[String], includeCollections: Boolean = false, resolver: SqlApiAnalysis.Resolver = _ == _, From 6d0b8389a20ed523201096355835fe849708e6bf Mon Sep 17 00:00:00 2001 From: panbingkun Date: Fri, 11 Oct 2024 15:21:21 +0200 Subject: [PATCH 219/250] [SPARK-49748][CORE][FOLLOWUP] Add `getCondition` and deprecate `getErrorClass` in `QueryCompilationErrorsSuite` ### What changes were proposed in this pull request? The pr is following up https://github.com/apache/spark/pull/48196 ### Why are the changes needed? Revise the logic of the newly added logic. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48416 from panbingkun/SPARK-49748_FOLLOWUP. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../apache/spark/sql/errors/QueryCompilationErrorsSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala index b4fdf50447458..92c175fe2f94a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryCompilationErrorsSuite.scala @@ -1003,7 +1003,7 @@ class QueryCompilationErrorsSuite val exception = intercept[AnalysisException] { sql(queryWithTrailingComma) } - assert(exception.getErrorClass === "TRAILING_COMMA_IN_SELECT") + assert(exception.getCondition === "TRAILING_COMMA_IN_SELECT") } val unresolvedColumnErrors = Seq( @@ -1017,7 +1017,7 @@ class QueryCompilationErrorsSuite val exception = intercept[AnalysisException] { sql(query) } - assert(exception.getErrorClass === "UNRESOLVED_COLUMN.WITH_SUGGESTION") + assert(exception.getCondition === "UNRESOLVED_COLUMN.WITH_SUGGESTION") } // sanity checks From c79e2d6370cc7a31e65342bfb69b60523aae1b30 Mon Sep 17 00:00:00 2001 From: Marko Date: Fri, 11 Oct 2024 22:24:05 +0800 Subject: [PATCH 220/250] [SPARK-49925][SQL] Add tests for order by with collated strings ### What changes were proposed in this pull request? Tests added for order by clause with collated strings. ### Why are the changes needed? Better testing. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tests added to `CollationSuite`. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48412 from ilicmarkodb/add_tests_for_complex_types_with_collations_order_by. Authored-by: Marko Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/CollationSuite.scala | 206 ++++++++++++++++++ 1 file changed, 206 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index ef01f71c68bf9..b19af542dabf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1101,6 +1101,212 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("Check order by on table with collated string column") { + val tableName = "t" + Seq( + // (collationName, data, expResult) + ( + "", // non-collated + Seq((5, "bbb"), (3, "a"), (1, "A"), (4, "aaaa"), (6, "cc"), (2, "BbB")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_BINARY", + Seq((5, "bbb"), (3, "a"), (1, "A"), (4, "aaaa"), (6, "cc"), (2, "BbB")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_LCASE", + Seq((2, "bbb"), (1, "a"), (1, "A"), (1, "aaaa"), (3, "cc"), (2, "BbB")), + Seq(1, 1, 1, 2, 2, 3) + ), + ( + "UNICODE", + Seq((4, "bbb"), (1, "a"), (2, "A"), (3, "aaaa"), (6, "cc"), (5, "BbB")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UNICODE_CI", + Seq((2, "bbb"), (1, "a"), (1, "A"), (1, "aaaa"), (3, "cc"), (2, "BbB")), + Seq(1, 1, 1, 2, 2, 3) + ) + ).foreach { + case (collationName, data, expResult) => + val collationSetup = if (collationName.isEmpty) "" else "collate " + collationName + withTable(tableName) { + sql(s"create table $tableName (c1 integer, c2 string $collationSetup)") + data.foreach { + case (c1, c2) => + sql(s"insert into $tableName values ($c1, '$c2')") + } + checkAnswer(sql(s"select c1 from $tableName order by c2"), expResult.map(Row(_))) + } + } + } + + test("Check order by on StructType") { + Seq( + // (collationName, data, expResult) + ( + "", // non-collated + Seq((5, "b", "A"), (3, "aa", "A"), (6, "b", "B"), (2, "A", "c"), (1, "A", "D"), + (4, "aa", "B")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_BINARY", + Seq((5, "b", "A"), (3, "aa", "A"), (6, "b", "B"), (2, "A", "c"), (1, "A", "D"), + (4, "aa", "B")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UTF8_LCASE", + Seq((3, "A", "C"), (2, "A", "b"), (2, "a", "b"), (4, "B", "c"), (1, "a", "a"), + (5, "b", "d")), + Seq(1, 2, 2, 3, 4, 5) + ), + ( + "UNICODE", + Seq((4, "A", "C"), (3, "A", "b"), (2, "a", "b"), (5, "b", "c"), (1, "a", "a"), + (6, "b", "d")), + Seq(1, 2, 3, 4, 5, 6) + ), + ( + "UNICODE_CI", + Seq((3, "A", "C"), (2, "A", "b"), (2, "a", "b"), (4, "B", "c"), (1, "a", "a"), + (5, "b", "d")), + Seq(1, 2, 2, 3, 4, 5) + ) + ).foreach { + case (collationName, data, expResult) => + val collationSetup = if (collationName.isEmpty) "" else "collate " + collationName + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName (c1 integer, c2 struct<" + + s"s1: string $collationSetup," + + s"s2: string $collationSetup>)") + data.foreach { + case (c1, s1, s2) => + sql(s"insert into $tableName values ($c1, struct('$s1', '$s2'))") + } + checkAnswer(sql(s"select c1 from $tableName order by c2"), expResult.map(Row(_))) + } + } + } + + test("Check order by on StructType with few collated fields") { + val data = Seq( + (2, "b", "a", "a", "a", "a"), + (4, "b", "b", "B", "a", "a"), + (1, "a", "a", "a", "a", "a"), + (6, "b", "b", "b", "B", "B"), + (3, "b", "b", "a", "a", "a"), + (5, "b", "b", "b", "B", "a")) + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName (c1 integer, c2 struct<" + + s"s1: string, " + + s"s2: string collate UTF8_BINARY, " + + s"s3: string collate UTF8_LCASE, " + + s"s4: string collate UNICODE, " + + s"s5: string collate UNICODE_CI>)") + data.foreach { + case (order, s1, s2, s3, s4, s5) => + sql(s"insert into $tableName values ($order, struct('$s1', '$s2', '$s3', '$s4', '$s5'))") + } + val expResult = Seq(1, 2, 3, 4, 5, 6) + checkAnswer(sql(s"select c1 from $tableName order by c2"), expResult.map(Row(_))) + } + } + + test("Check order by on ArrayType with collated strings") { + Seq( + // (collationName, order, data) + ( + "", + Seq((3, Seq("b", "Aa", "c")), (2, Seq("A", "b")), (1, Seq("A")), (2, Seq("A", "b"))), + Seq(1, 2, 2, 3) + ), + ( + "UTF8_BINARY", + Seq((3, Seq("b", "Aa", "c")), (2, Seq("A", "b")), (1, Seq("A")), (2, Seq("A", "b"))), + Seq(1, 2, 2, 3) + ), + ( + "UTF8_LCASE", + Seq((4, Seq("B", "a")), (4, Seq("b", "A")), (2, Seq("aa")), (1, Seq("A")), + (5, Seq("b", "e")), (3, Seq("b"))), + Seq(1, 2, 3, 4, 4, 5) + ), + ( + "UNICODE", + Seq((5, Seq("b", "C")), (4, Seq("b", "AA")), (1, Seq("a")), (4, Seq("b", "AA")), + (3, Seq("b")), (2, Seq("A", "a"))), + Seq(1, 2, 3, 4, 4, 5) + ), + ( + "UNICODE_CI", + Seq((4, Seq("B", "a")), (4, Seq("b", "A")), (2, Seq("aa")), (1, Seq("A")), + (5, Seq("b", "e")), (3, Seq("b"))), + Seq(1, 2, 3, 4, 4, 5) + ) + ).foreach { + case (collationName, dataWithOrder, expResult) => + val collationSetup = if (collationName.isEmpty) "" else "collate " + collationName + val tableName1 = "t1" + val tableName2 = "t2" + withTable(tableName1, tableName2) { + sql(s"create table $tableName1 (c1 integer, c2 array)") + sql(s"create table $tableName2 (c1 integer," + + s" c2 struct>)") + dataWithOrder.foreach { + case (order, data) => + val arrayData = data.map(d => s"'$d'").mkString(", ") + sql(s"insert into $tableName1 values ($order, array($arrayData))") + sql(s"insert into $tableName2 values ($order, struct(array($arrayData)))") + } + checkAnswer(sql(s"select c1 from $tableName1 order by c2"), expResult.map(Row(_))) + checkAnswer(sql(s"select c1 from $tableName2 order by c2"), expResult.map(Row(_))) + } + } + } + + test("Check order by on StructType with different types containing collated strings") { + val data = Seq( + (5, ("b", Seq(("b", "B", "a"), ("a", "a", "a")), "a")), + (2, ("b", Seq(("a", "a", "a")), "a")), + (2, ("b", Seq(("a", "a", "a")), "a")), + (4, ("b", Seq(("b", "a", "a")), "a")), + (3, ("b", Seq(("a", "a", "a"), ("a", "a", "a")), "a")), + (5, ("b", Seq(("b", "B", "a")), "a")), + (4, ("b", Seq(("b", "a", "a")), "a")), + (6, ("b", Seq(("b", "b", "B")), "A")), + (5, ("b", Seq(("b", "b", "a")), "a")), + (1, ("a", Seq(("a", "a", "a")), "a")), + (7, ("b", Seq(("b", "b", "B")), "b")), + (6, ("b", Seq(("b", "b", "B")), "a")), + (5, ("b", Seq(("b", "b", "a")), "a")) + ) + val tableName = "t" + withTable(tableName) { + sql(s"create table $tableName " + + s"(c1 integer," + + s"c2 string," + + s"c3 array>," + + s"c4 string collate UNICODE_CI)") + data.foreach { + case (c1, (c2, c3, c4)) => + val c3String = c3.map { case (f1, f2, f3) => s"struct('$f1', '$f2', '$f3')"} + .mkString(", ") + sql(s"insert into $tableName values ($c1, '$c2', array($c3String), '$c4')") + } + val expResult = Seq(1, 2, 2, 3, 4, 4, 5, 5, 5, 5, 6, 6, 7) + checkAnswer(sql(s"select c1 from $tableName order by c2, c3, c4"), expResult.map(Row(_))) + } + } + for (collation <- Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI", "UNICODE_CI_RTRIM", "")) { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { From 04ec55ed8106783a1e23abfdcc58e1a5eba30169 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 12 Oct 2024 08:57:25 +0900 Subject: [PATCH 221/250] [MINOR][PYTHON] Minor refine `LiteralExpression` ### What changes were proposed in this pull request? 1, combine if branches; 2, refine the error message ### Why are the changes needed? Minor refine `LiteralExpression` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #48422 from zhengruifeng/expr_lit_nit. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- python/pyspark/errors/error-conditions.json | 4 ++-- python/pyspark/sql/connect/expressions.py | 20 +++++--------------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index ed62ea117d369..6ca21d55555d2 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -94,9 +94,9 @@ "Could not get batch id from ." ] }, - "CANNOT_INFER_ARRAY_TYPE": { + "CANNOT_INFER_ARRAY_ELEMENT_TYPE": { "message": [ - "Can not infer Array Type from a list with None as the first element." + "Can not infer the element data type, an non-empty list starting with an non-None value is required." ] }, "CANNOT_INFER_EMPTY_SCHEMA": { diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 85f1b3565c696..203b6ce371a5c 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -301,7 +301,7 @@ def _infer_type(cls, value: Any) -> DataType: return NullType() elif isinstance(value, (bytes, bytearray)): return BinaryType() - elif isinstance(value, bool): + elif isinstance(value, (bool, np.bool_)): return BooleanType() elif isinstance(value, int): if JVM_INT_MIN <= value <= JVM_INT_MAX: @@ -323,10 +323,8 @@ def _infer_type(cls, value: Any) -> DataType: return StringType() elif isinstance(value, decimal.Decimal): return DecimalType() - elif isinstance(value, datetime.datetime) and is_timestamp_ntz_preferred(): - return TimestampNTZType() elif isinstance(value, datetime.datetime): - return TimestampType() + return TimestampNTZType() if is_timestamp_ntz_preferred() else TimestampType() elif isinstance(value, datetime.date): return DateType() elif isinstance(value, datetime.timedelta): @@ -335,23 +333,15 @@ def _infer_type(cls, value: Any) -> DataType: dt = _from_numpy_type(value.dtype) if dt is not None: return dt - elif isinstance(value, np.bool_): - return BooleanType() elif isinstance(value, list): # follow the 'infer_array_from_first_element' strategy in 'sql.types._infer_type' # right now, it's dedicated for pyspark.ml params like array<...>, array> - if len(value) == 0: - raise PySparkValueError( - errorClass="CANNOT_BE_EMPTY", - messageParameters={"item": "value"}, - ) - first = value[0] - if first is None: + if len(value) == 0 or value[0] is None: raise PySparkTypeError( - errorClass="CANNOT_INFER_ARRAY_TYPE", + errorClass="CANNOT_INFER_ARRAY_ELEMENT_TYPE", messageParameters={}, ) - return ArrayType(LiteralExpression._infer_type(first), True) + return ArrayType(LiteralExpression._infer_type(value[0]), True) raise PySparkTypeError( errorClass="UNSUPPORTED_DATA_TYPE", From 3ecfe8ef7b95cb132ad10eb6cc2fa564864e4952 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Sat, 12 Oct 2024 09:12:10 +0900 Subject: [PATCH 222/250] [SPARK-49930][SS] Ensure that socket updates are flushed on exception from the python worker ### What changes were proposed in this pull request? Ensure that socket updates are flushed on exception from the python worker ### Why are the changes needed? Without this, updates to the socket from the python worker are not delivered to the jvm side ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests were failing on a different Python version and pass after ``` Run completed in 1 minute, 13 seconds. Total number of tests run: 8 Suites: completed 1, aborted 0 Tests: succeeded 8, failed 0, canceled 0, ignored 0, pending 0 All tests passed. ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48418 from anishshri-db/task/SPARK-49930. Authored-by: Anish Shrigondekar Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/streaming/python_streaming_source_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py index c50bd3915784b..a7349779dc626 100644 --- a/python/pyspark/sql/streaming/python_streaming_source_runner.py +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -193,6 +193,8 @@ def main(infile: IO, outfile: IO) -> None: reader.stop() except BaseException as e: handle_worker_exception(e, outfile) + # ensure that the updates to the socket are flushed + outfile.flush() sys.exit(-1) send_accumulator_updates(outfile) From cf657e5e2e49388b89cee5c5113426b316e4892d Mon Sep 17 00:00:00 2001 From: Siying Dong Date: Sat, 12 Oct 2024 09:43:04 +0900 Subject: [PATCH 223/250] [SPARK-49927][SS][PYTHON][TESTS] pyspark.sql.tests.streaming.test_streaming_listener to wait for longer ### What changes were proposed in this pull request? In test pyspark.sql.tests.streaming.test_streaming_listener, instead of waiting for fixed 10 seconds, we wait for progress made. ### Why are the changes needed? In some environment, the test fails with progress_event appears to be None. Likely the wait time is not sufficient. ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Run the patch and make sure it still passes ### Was this patch authored or co-authored using generative AI tooling? No Closes #48424 from siying/python_test2. Authored-by: Siying Dong Signed-off-by: Hyukjin Kwon --- .../pyspark/sql/tests/streaming/test_streaming_listener.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index c3ae62e64cc30..51f62f56a7c54 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -381,7 +381,12 @@ def verify(test_listener): .start() ) self.assertTrue(q.isActive) - q.awaitTermination(10) + wait_count = 0 + while progress_event is None or progress_event.progress.batchId == 0: + q.awaitTermination(0.5) + wait_count = wait_count + 1 + if wait_count > 100: + self.fail("Not getting progress event after 50 seconds") q.stop() # Make sure all events are empty From 1fb3d57c1083125bd565cdc083b221db1ed3e0f4 Mon Sep 17 00:00:00 2001 From: yangjie01 Date: Sat, 12 Oct 2024 12:18:34 +0900 Subject: [PATCH 224/250] [SPARK-49935][BUILD] Exclude `spark-connect-shims` from `assembly` module MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes were proposed in this pull request? This pr exclude `spark-connect-shims` from `assembly` module to avoid it from being included in the distribution when executing `dev/make-distribution.sh`. ### Why are the changes needed? `spark-connect-shims` is only used to resolve compilation issues, and it should not be included in the `jars` directory of the distribution, otherwise, it may disrupt REPL-related functionalities. For examples: 1. spark-shell will fail to start ``` bin/spark-shell --master local WARNING: Using incubator modules: jdk.incubator.vector Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties {"ts":"2024-10-11T11:54:03.437Z","level":"WARN","msg":"Your hostname, MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 172.22.200.181 instead (on interface en0)","context":{"host":"MacBook-Pro.local","host_port":"127.0.0.1","host_port2":"172.22.200.181","network_if":"en0"},"logger":"Utils"} {"ts":"2024-10-11T11:54:03.439Z","level":"WARN","msg":"Set SPARK_LOCAL_IP if you need to bind to another address","logger":"Utils"} Using Spark's default log4j profile: org/apache/spark/log4j2-pattern-layout-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 4.0.0-SNAPSHOT /_/ Using Scala version 2.13.15 (OpenJDK 64-Bit Server VM, Java 17.0.12) Type in expressions to have them evaluated. Type :help for more information. if (_sc.getConf.getBoolean("spark.ui.reverseProxy", false)) { ^ On line 9: error: value getConf is not a member of org.apache.spark.SparkContext val proxyUrl = _sc.getConf.get("spark.ui.reverseProxyUrl", null) ^ On line 10: error: value getConf is not a member of org.apache.spark.SparkContext s"Spark Context Web UI is available at ${proxyUrl}/proxy/${_sc.applicationId}") ^ On line 13: error: value applicationId is not a member of org.apache.spark.SparkContext _sc.uiWebUrl.foreach { ^ On line 18: error: value uiWebUrl is not a member of org.apache.spark.SparkContext s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") ^ On line 23: error: value master is not a member of org.apache.spark.SparkContext s"(master = ${_sc.master}, app id = ${_sc.applicationId}).") ^ On line 23: error: value applicationId is not a member of org.apache.spark.SparkContext ^ error: object SparkContext is not a member of package org.apache.spark note: class SparkContext exists, but it has no companion object. ^ error: object implicits is not a member of package spark ^ error: object sql is not a member of package spark ``` 2. SparkR tests on Windows may also fail due to `spark-connect-shims` being in the classpath. https://github.com/apache/spark/actions/runs/11259624487/job/31309026637 ``` ══ Failed tests ════════════════════════════════════════���═══════════════════════ ── Error ('test_basic.R:25:3'): create DataFrame from list or data.frame ─────── Error in `handleErrors(returnStatus, conn)`: java.lang.NoSuchMethodError: 'void org.apache.spark.SparkContext.(org.apache.spark.SparkConf)' at org.apache.spark.SparkContext$.getOrCreate(SparkContext.scala:3050) at org.apache.spark.api.r.RRDD$.createSparkContext(RRDD.scala:141) at org.apache.spark.api.r.RRDD.createSparkContext(RRDD.scala) ``` ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - Pass GitHub Actions - All Maven test passed on GA: https://github.com/LuciferYang/spark/runs/31405720205 image - Sparkr on windws test passed on GA: https://github.com/LuciferYang/spark/actions/runs/11291559675/job/31434704406 image - Manual check: ``` dev/make-distribution.sh --tgz -Phive ``` `spark-connect-shims` is not in either directory `jars` or directory `jars/connect-repl`, and both spark-shell and connect-shell can be used normally **Spark shell** ``` bin/spark-shell --master local WARNING: Using incubator modules: jdk.incubator.vector Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties {"ts":"2024-10-12T02:56:47.637Z","level":"WARN","msg":"Your hostname, MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 172.22.200.218 instead (on interface en0)","context":{"host":"MacBook-Pro.local","host_port":"127.0.0.1","host_port2":"172.22.200.218","network_if":"en0"},"logger":"Utils"} {"ts":"2024-10-12T02:56:47.639Z","level":"WARN","msg":"Set SPARK_LOCAL_IP if you need to bind to another address","logger":"Utils"} Using Spark's default log4j profile: org/apache/spark/log4j2-pattern-layout-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 4.0.0-SNAPSHOT /_/ Using Scala version 2.13.15 (OpenJDK 64-Bit Server VM, Java 17.0.12) Type in expressions to have them evaluated. Type :help for more information. 24/10/12 10:56:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Spark context Web UI available at http://172.22.200.218:4040 Spark context available as 'sc' (master = local, app id = local-1728701810131). Spark session available as 'spark'. scala> spark.range(10).show() +---+ | id| +---+ | 0| | 1| | 2| | 3| | 4| | 5| | 6| | 7| | 8| | 9| +---+ ``` **Connect shell** ``` bin/spark-shell --remote local WARNING: Using incubator modules: jdk.incubator.vector Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties {"ts":"2024-10-12T02:58:17.326Z","level":"WARN","msg":"Your hostname, MacBook-Pro.local, resolves to a loopback address: 127.0.0.1; using 172.22.200.218 instead (on interface en0)","context":{"host":"MacBook-Pro.local","host_port":"127.0.0.1","host_port2":"172.22.200.218","network_if":"en0"},"logger":"Utils"} {"ts":"2024-10-12T02:58:17.328Z","level":"WARN","msg":"Set SPARK_LOCAL_IP if you need to bind to another address","logger":"Utils"} 24/10/12 10:58:19 INFO BaseAllocator: Debug mode disabled. Enable with the VM option -Darrow.memory.debug.allocator=true. 24/10/12 10:58:19 INFO DefaultAllocationManagerOption: allocation manager type not specified, using netty as the default type 24/10/12 10:58:19 INFO CheckAllocator: Using DefaultAllocationManager at memory/netty/DefaultAllocationManagerFactory.class Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /___/ .__/\_,_/_/ /_/\_\ version 4.0.0-SNAPSHOT /_/ Type in expressions to have them evaluated. Spark connect server version 4.0.0-SNAPSHOT. Spark session available as 'spark'. scala> spark.range(10).show Using Spark's default log4j profile: org/apache/spark/log4j2-pattern-layout-defaults.properties +---+ | id| +---+ | 0| | 1| | 2| | 3| | 4| | 5| | 6| | 7| | 8| | 9| +---+ ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #48421 from LuciferYang/fix-distribution. Authored-by: yangjie01 Signed-off-by: Hyukjin Kwon --- assembly/pom.xml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/assembly/pom.xml b/assembly/pom.xml index 01bd324efc118..17bb81fa023ba 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -117,6 +117,12 @@ org.apache.spark spark-connect-client-jvm_${scala.binary.version} ${project.version} + + + org.apache.spark + spark-connect-shims_${scala.binary.version} + + provided From 6734d4883e76b82249df5c151d42bc83173f4122 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 12 Oct 2024 13:15:39 +0800 Subject: [PATCH 225/250] [SPARK-49932][CORE] Use `tryWithResource` release `JsonUtils#toJsonString` resources to avoid memory leaks ### What changes were proposed in this pull request? The pr aims to use `tryWithResource` release `JsonUtils#toJsonString` resources to avoid memory leaks. ### Why are the changes needed? Avoiding potential memory leaks. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48420 from panbingkun/SPARK-49932. Authored-by: panbingkun Signed-off-by: yangjie01 --- .../scala/org/apache/spark/util/JsonUtils.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala index 4d729adfbb7eb..f88f267727c11 100644 --- a/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala +++ b/common/utils/src/main/scala/org/apache/spark/util/JsonUtils.scala @@ -24,6 +24,7 @@ import com.fasterxml.jackson.core.{JsonEncoding, JsonGenerator} import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.scala.DefaultScalaModule +import org.apache.spark.util.SparkErrorUtils.tryWithResource private[spark] trait JsonUtils { @@ -31,12 +32,12 @@ private[spark] trait JsonUtils { .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) def toJsonString(block: JsonGenerator => Unit): String = { - val baos = new ByteArrayOutputStream() - val generator = mapper.createGenerator(baos, JsonEncoding.UTF8) - block(generator) - generator.close() - baos.close() - new String(baos.toByteArray, StandardCharsets.UTF_8) + tryWithResource(new ByteArrayOutputStream()) { baos => + tryWithResource(mapper.createGenerator(baos, JsonEncoding.UTF8)) { generator => + block(generator) + } + new String(baos.toByteArray, StandardCharsets.UTF_8) + } } } From 1244c5a0f548f3e2da75863880779bcfc9eee5c0 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 12 Oct 2024 10:49:55 +0200 Subject: [PATCH 226/250] [SPARK-49766][SQL] Codegen Support for `json_array_length` (by `Invoke` & `RuntimeReplaceable`) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `json_array_length`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: JsonFunctionsSuite#`json_array_length function`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48224 from panbingkun/SPARK-49766. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../expressions/json/JsonExpressionUtils.java | 58 +++++++++++++++++++ .../expressions/jsonExpressions.scala | 54 +++++------------ .../function_json_array_length.explain | 2 +- 3 files changed, 73 insertions(+), 41 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java new file mode 100644 index 0000000000000..ca2ae80042df7 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.json; + +import java.io.IOException; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; + +import org.apache.spark.sql.catalyst.expressions.SharedFactory; +import org.apache.spark.sql.catalyst.json.CreateJacksonParser; +import org.apache.spark.unsafe.types.UTF8String; + +public class JsonExpressionUtils { + + public static Integer lengthOfJsonArray(UTF8String json) { + // return null for null input + if (json == null) { + return null; + } + try (JsonParser jsonParser = + CreateJacksonParser.utf8String(SharedFactory.jsonFactory(), json)) { + if (jsonParser.nextToken() == null) { + return null; + } + // Only JSON array are supported for this function. + if (jsonParser.currentToken() != JsonToken.START_ARRAY) { + return null; + } + // Parse the array to compute its length. + int length = 0; + // Keep traversing until the end of JSON array + while (jsonParser.nextToken() != JsonToken.END_ARRAY) { + length += 1; + // skip all the child of inner object or array + jsonParser.skipChildren(); + } + return length; + } catch (IOException e) { + return null; + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index bdcf3f0c1eeab..e1f2b1c1df62a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.expressions.json.JsonExpressionUtils +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern} @@ -967,54 +969,26 @@ case class SchemaOfJson( group = "json_funcs", since = "3.1.0" ) -case class LengthOfJsonArray(child: Expression) extends UnaryExpression - with CodegenFallback with ExpectsInputTypes { +case class LengthOfJsonArray(child: Expression) + extends UnaryExpression + with ExpectsInputTypes + with RuntimeReplaceable { override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = IntegerType override def nullable: Boolean = true override def prettyName: String = "json_array_length" - override def eval(input: InternalRow): Any = { - val json = child.eval(input).asInstanceOf[UTF8String] - // return null for null input - if (json == null) { - return null - } - - try { - Utils.tryWithResource(CreateJacksonParser.utf8String(SharedFactory.jsonFactory, json)) { - parser => { - // return null if null array is encountered. - if (parser.nextToken() == null) { - return null - } - // Parse the array to compute its length. - parseCounter(parser, input) - } - } - } catch { - case _: JsonProcessingException | _: IOException => null - } - } - - private def parseCounter(parser: JsonParser, input: InternalRow): Any = { - var length = 0 - // Only JSON array are supported for this function. - if (parser.currentToken != JsonToken.START_ARRAY) { - return null - } - // Keep traversing until the end of JSON array - while(parser.nextToken() != JsonToken.END_ARRAY) { - length += 1 - // skip all the child of inner object or array - parser.skipChildren() - } - length - } - override protected def withNewChildInternal(newChild: Expression): LengthOfJsonArray = copy(child = newChild) + + override def replacement: Expression = StaticInvoke( + classOf[JsonExpressionUtils], + dataType, + "lengthOfJsonArray", + Seq(child), + inputTypes + ) } /** diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain index 50ab91560e64a..d70e2eb60aba5 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_array_length.explain @@ -1,2 +1,2 @@ -Project [json_array_length(g#0) AS json_array_length(g)#0] +Project [static_invoke(JsonExpressionUtils.lengthOfJsonArray(g#0)) AS json_array_length(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] From ed4847ffe07e60d8cf9a01c2855ef32626ae081b Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sat, 12 Oct 2024 20:31:05 +0900 Subject: [PATCH 227/250] [SPARK-49937][INFRA] Ban call the method `SparkThrowable#getErrorClass` ### What changes were proposed in this pull request? The pr aims to ban call the method `SparkThrowable#getErrorClass`. ### Why are the changes needed? After PR https://github.com/apache/spark/pull/48196, `SparkThrowable#getErrorClass` has been marked as `Deprecated`. In order to prevent future developers from calling `SparkThrowable#getErrorClass` again, which may require continuous fix and migration, calling `SparkThrowable#getErrorClass` is strictly prohibited as it will fail at the compilation level. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA. - Manually test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48426 from panbingkun/SPARK-49937. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- pom.xml | 4 ++++ project/SparkBuild.scala | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index bfaee1be609c0..3da8b9ef68b90 100644 --- a/pom.xml +++ b/pom.xml @@ -3075,6 +3075,10 @@ reduce the cost of migration in subsequent versions. --> -Wconf:cat=deprecation&msg=it will become a keyword in Scala 3:e + + -Wconf:cat=deprecation&msg=method getErrorClass in trait SparkThrowable is deprecated:e -Xss128m diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5882fcbf336b0..737efa8f7846b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -254,7 +254,9 @@ object SparkBuild extends PomBuild { // reduce the cost of migration in subsequent versions. "-Wconf:cat=deprecation&msg=it will become a keyword in Scala 3:e", // SPARK-46938 to prevent enum scan on pmml-model, under spark-mllib module. - "-Wconf:cat=other&site=org.dmg.pmml.*:w" + "-Wconf:cat=other&site=org.dmg.pmml.*:w", + // SPARK-49937 ban call the method `SparkThrowable#getErrorClass` + "-Wconf:cat=deprecation&msg=method getErrorClass in trait SparkThrowable is deprecated:e" ) } ) From 62ade5f3f42bdad200bfd9ca9e8110594f7c12e4 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sat, 12 Oct 2024 20:33:33 +0900 Subject: [PATCH 228/250] [SPARK-49924][SQL] Keep `containsNull` after `ArrayCompact` replacement ### What changes were proposed in this pull request? Fix `containsNull` of `ArrayCompact`, by adding a new expression `KnownNotContainsNull` ### Why are the changes needed? https://github.com/apache/spark/pull/47430 attempted to set `containsNull = false` for `ArrayCompact` for further optimization, but in an incomplete way: The `ArrayCompact` is a runtime replaceable expression, so will be replaced in optimizer, and cause the `containsNull` be reverted, e.g. ```sql select array_compact(array(1, null)) ``` Rule `ReplaceExpressions` changed `containsNull: false -> true` ``` old schema: StructField(array_compact(array(1, NULL)),ArrayType(IntegerType,false),false) new schema StructField(array_compact(array(1, NULL)),ArrayType(IntegerType,true),false) ``` ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? added tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48410 from zhengruifeng/fix_array_compact_null. Authored-by: Ruifeng Zheng Signed-off-by: Hyukjin Kwon --- .../expressions/collectionOperations.scala | 6 ++--- .../expressions/constraintExpressions.scala | 13 +++++++++- .../catalyst/optimizer/OptimizerSuite.scala | 25 +++++++++++++++++-- .../function_array_compact.explain | 2 +- 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index c091d51fc177f..bb54749126860 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkException.internalError import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch +import org.apache.spark.sql.catalyst.expressions.KnownNotContainsNull import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke @@ -5330,15 +5331,12 @@ case class ArrayCompact(child: Expression) child.dataType.asInstanceOf[ArrayType].elementType, true) lazy val lambda = LambdaFunction(isNotNull(lv), Seq(lv)) - override lazy val replacement: Expression = ArrayFilter(child, lambda) + override lazy val replacement: Expression = KnownNotContainsNull(ArrayFilter(child, lambda)) override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) override def prettyName: String = "array_compact" - override def dataType: ArrayType = - child.dataType.asInstanceOf[ArrayType].copy(containsNull = false) - override protected def withNewChildInternal(newChild: Expression): ArrayCompact = copy(child = newChild) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 75d912633a0fc..f05db0b090c90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{ArrayType, DataType} trait TaggingExpression extends UnaryExpression { override def nullable: Boolean = child.nullable @@ -52,6 +52,17 @@ case class KnownNotNull(child: Expression) extends TaggingExpression { copy(child = newChild) } +case class KnownNotContainsNull(child: Expression) extends TaggingExpression { + override def dataType: DataType = + child.dataType.asInstanceOf[ArrayType].copy(containsNull = false) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + child.genCode(ctx) + + override protected def withNewChildInternal(newChild: Expression): KnownNotContainsNull = + copy(child = newChild) +} + case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression { override protected def withNewChildInternal(newChild: Expression): KnownFloatingPointNormalized = copy(child = newChild) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala index 48cdbbe7be539..70a2ae94109fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala @@ -21,13 +21,13 @@ import org.apache.spark.SparkException import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Add, Alias, AttributeReference, IntegerLiteral, Literal, Multiply, NamedExpression, Remainder} +import org.apache.spark.sql.catalyst.expressions.{Add, Alias, ArrayCompact, AttributeReference, CreateArray, CreateStruct, IntegerLiteral, Literal, MapFromEntries, Multiply, NamedExpression, Remainder} import org.apache.spark.sql.catalyst.expressions.aggregate.Sum import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{ArrayType, IntegerType, MapType, StructField, StructType} /** * A dummy optimizer rule for testing that decrements integer literals until 0. @@ -313,4 +313,25 @@ class OptimizerSuite extends PlanTest { assert(message1.contains("not a valid aggregate expression")) } } + + test("SPARK-49924: Keep containsNull after ArrayCompact replacement") { + val optimizer = new SimpleTestOptimizer() { + override def defaultBatches: Seq[Batch] = + Batch("test", fixedPoint, + ReplaceExpressions) :: Nil + } + + val array1 = ArrayCompact(CreateArray(Literal(1) :: Literal.apply(null) :: Nil, false)) + val plan1 = Project(Alias(array1, "arr")() :: Nil, OneRowRelation()).analyze + val optimized1 = optimizer.execute(plan1) + assert(optimized1.schema === + StructType(StructField("arr", ArrayType(IntegerType, false), false) :: Nil)) + + val struct = CreateStruct(Literal(1) :: Literal(2) :: Nil) + val array2 = ArrayCompact(CreateArray(struct :: Literal.apply(null) :: Nil, false)) + val plan2 = Project(Alias(MapFromEntries(array2), "map")() :: Nil, OneRowRelation()).analyze + val optimized2 = optimizer.execute(plan2) + assert(optimized2.schema === + StructType(StructField("map", MapType(IntegerType, IntegerType, false), false) :: Nil)) + } } diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain index a78195c4ae295..d42d0fd0a46ee 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_array_compact.explain @@ -1,2 +1,2 @@ -Project [filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false)) AS array_compact(e)#0] +Project [knownnotcontainsnull(filter(e#0, lambdafunction(isnotnull(lambda arg#0), lambda arg#0, false))) AS array_compact(e)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] From 083f44d5e6c8baffc38ee1403fdd3e46cada35f4 Mon Sep 17 00:00:00 2001 From: huangxiaoping <1754789345@qq.com> Date: Sun, 13 Oct 2024 11:09:19 +0200 Subject: [PATCH 229/250] [MINOR][SQL] Improved broadcast timeout message prompt ### What changes were proposed in this pull request? Improved broadcast timeout message prompt ### Why are the changes needed? Help users debug better ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ### Was this patch authored or co-authored using generative AI tooling? No Closes #48389 from huangxiaopingRD/broadcast-message. Authored-by: huangxiaoping <1754789345@qq.com> Signed-off-by: Max Gekk --- common/utils/src/main/resources/error/error-conditions.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 4ceef4b2d8b92..1eaedd9f345a3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -7063,7 +7063,7 @@ }, "_LEGACY_ERROR_TEMP_2097" : { "message" : [ - "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1." + "Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1 or remove the broadcast hint if it exists in your code." ] }, "_LEGACY_ERROR_TEMP_2098" : { From 54fd408316bed4b02e53a7df7077fcdd43f5abe7 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Sun, 13 Oct 2024 11:12:37 +0200 Subject: [PATCH 230/250] [SPARK-49939][SQL] Codegen Support for json_object_keys (by Invoke & RuntimeReplaceable) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `json_object_keys`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: JsonFunctionsSuite#`json_object_keys function`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48428 from panbingkun/SPARK-49939. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../expressions/json/JsonExpressionUtils.java | 31 ++++++++++++ .../expressions/jsonExpressions.scala | 50 ++++--------------- .../function_json_object_keys.explain | 2 +- 3 files changed, 43 insertions(+), 40 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java index ca2ae80042df7..07e13610aa950 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionUtils.java @@ -18,12 +18,15 @@ package org.apache.spark.sql.catalyst.expressions.json; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonToken; import org.apache.spark.sql.catalyst.expressions.SharedFactory; import org.apache.spark.sql.catalyst.json.CreateJacksonParser; +import org.apache.spark.sql.catalyst.util.GenericArrayData; import org.apache.spark.unsafe.types.UTF8String; public class JsonExpressionUtils { @@ -55,4 +58,32 @@ public static Integer lengthOfJsonArray(UTF8String json) { return null; } } + + public static GenericArrayData jsonObjectKeys(UTF8String json) { + // return null for `NULL` input + if (json == null) { + return null; + } + try (JsonParser jsonParser = + CreateJacksonParser.utf8String(SharedFactory.jsonFactory(), json)) { + // return null if an empty string or any other valid JSON string is encountered + if (jsonParser.nextToken() == null || jsonParser.currentToken() != JsonToken.START_OBJECT) { + return null; + } + // Parse the JSON string to get all the keys of outermost JSON object + List arrayBufferOfKeys = new ArrayList<>(); + + // traverse until the end of input and ensure it returns valid key + while (jsonParser.nextValue() != null && jsonParser.currentName() != null) { + // add current fieldName to the ArrayBuffer + arrayBufferOfKeys.add(UTF8String.fromString(jsonParser.currentName())); + + // skip all the children of inner object or array + jsonParser.skipChildren(); + } + return new GenericArrayData(arrayBufferOfKeys.toArray()); + } catch (IOException e) { + return null; + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e1f2b1c1df62a..e01531cc821c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.io._ -import scala.collection.mutable.ArrayBuffer import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ @@ -1014,50 +1013,23 @@ case class LengthOfJsonArray(child: Expression) group = "json_funcs", since = "3.1.0" ) -case class JsonObjectKeys(child: Expression) extends UnaryExpression with CodegenFallback - with ExpectsInputTypes { +case class JsonObjectKeys(child: Expression) + extends UnaryExpression + with ExpectsInputTypes + with RuntimeReplaceable { override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeWithCaseAccentSensitivity) override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType) override def nullable: Boolean = true override def prettyName: String = "json_object_keys" - override def eval(input: InternalRow): Any = { - val json = child.eval(input).asInstanceOf[UTF8String] - // return null for `NULL` input - if(json == null) { - return null - } - - try { - Utils.tryWithResource(CreateJacksonParser.utf8String(SharedFactory.jsonFactory, json)) { - parser => { - // return null if an empty string or any other valid JSON string is encountered - if (parser.nextToken() == null || parser.currentToken() != JsonToken.START_OBJECT) { - return null - } - // Parse the JSON string to get all the keys of outermost JSON object - getJsonKeys(parser, input) - } - } - } catch { - case _: JsonProcessingException | _: IOException => null - } - } - - private def getJsonKeys(parser: JsonParser, input: InternalRow): GenericArrayData = { - val arrayBufferOfKeys = ArrayBuffer.empty[UTF8String] - - // traverse until the end of input and ensure it returns valid key - while(parser.nextValue() != null && parser.currentName() != null) { - // add current fieldName to the ArrayBuffer - arrayBufferOfKeys += UTF8String.fromString(parser.currentName) - - // skip all the children of inner object or array - parser.skipChildren() - } - new GenericArrayData(arrayBufferOfKeys.toArray[Any]) - } + override def replacement: Expression = StaticInvoke( + classOf[JsonExpressionUtils], + dataType, + "jsonObjectKeys", + Seq(child), + inputTypes + ) override protected def withNewChildInternal(newChild: Expression): JsonObjectKeys = copy(child = newChild) diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain index 30153bb192e55..8a2ea5336c160 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_json_object_keys.explain @@ -1,2 +1,2 @@ -Project [json_object_keys(g#0) AS json_object_keys(g)#0] +Project [static_invoke(JsonExpressionUtils.jsonObjectKeys(g#0)) AS json_object_keys(g)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] From 1abfd490d072850ae40c46c1a0f1791a8aaa5698 Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Sun, 13 Oct 2024 19:01:42 +0800 Subject: [PATCH 231/250] [SPARK-49943][PS] Remove `timestamp_ntz_to_long` from `PythonSQLUtils` ### What changes were proposed in this pull request? Remove `timestamp_ntz_to_long` from `PythonSQLUtils` ### Why are the changes needed? we no longer need to add internal functions in `PythonSQLUtils` for PySpark Classic ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48437 from zhengruifeng/fun_cat_nzt. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/data_type_ops/datetime_ops.py | 6 ++---- python/pyspark/pandas/spark/functions.py | 4 ++++ .../org/apache/spark/sql/api/python/PythonSQLUtils.scala | 3 --- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/pyspark/pandas/data_type_ops/datetime_ops.py b/python/pyspark/pandas/data_type_ops/datetime_ops.py index 9b4cc72fa2e45..dc2f68232e730 100644 --- a/python/pyspark/pandas/data_type_ops/datetime_ops.py +++ b/python/pyspark/pandas/data_type_ops/datetime_ops.py @@ -34,6 +34,7 @@ ) from pyspark.sql.utils import pyspark_column_op from pyspark.pandas._typing import Dtype, IndexOpsLike, SeriesOrIndex +from pyspark.pandas.spark import functions as SF from pyspark.pandas.base import IndexOpsMixin from pyspark.pandas.data_type_ops.base import ( DataTypeOps, @@ -150,10 +151,7 @@ class DatetimeNTZOps(DatetimeOps): """ def _cast_spark_column_timestamp_to_long(self, scol: Column) -> Column: - from pyspark import SparkContext - - jvm = SparkContext._active_spark_context._jvm - return Column(jvm.PythonSQLUtils.castTimestampNTZToLong(scol._jc)) + return SF.timestamp_ntz_to_long(scol) def astype(self, index_ops: IndexOpsLike, dtype: Union[str, type, Dtype]) -> IndexOpsLike: dtype, spark_type = pandas_on_spark_type(dtype) diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index 4d95466a98e12..bdd11559df3b6 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -39,6 +39,10 @@ def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> return Column(sc._jvm.PythonSQLUtils.internalFn(name, _to_seq(sc, cols, _to_java_column))) +def timestamp_ntz_to_long(col: Column) -> Column: + return _invoke_internal_function_over_columns("timestamp_ntz_to_long", col) + + def product(col: Column, dropna: bool) -> Column: return _invoke_internal_function_over_columns("pandas_product", col, F.lit(dropna)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 3504f6e76f79d..08395ef4c347c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -143,9 +143,6 @@ private[sql] object PythonSQLUtils extends Logging { } } - def castTimestampNTZToLong(c: Column): Column = - Column.internalFn("timestamp_ntz_to_long", c) - def unresolvedNamedLambdaVariable(name: String): Column = Column(internal.UnresolvedNamedLambdaVariable.apply(name)) From 1aae1608960155885c039c389a6aaa53362103be Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Mon, 14 Oct 2024 11:11:04 +0800 Subject: [PATCH 232/250] [SPARK-49928][PYTHON][TESTS] Refactor plot-related unit tests ### What changes were proposed in this pull request? Refactor plot-related unit tests. ### Why are the changes needed? Different plots have different key attributes of the resulting figure to test against. The refactor makes the comparison easier. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Test changes. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48415 from xinrong-meng/plot_test. Authored-by: Xinrong Meng Signed-off-by: Xinrong Meng --- .../sql/tests/plot/test_frame_plot_plotly.py | 242 ++++++++++++++---- 1 file changed, 192 insertions(+), 50 deletions(-) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 70a1b336f734a..b92b5a91cb766 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -48,79 +48,174 @@ def sdf3(self): columns = ["sales", "signups", "visits", "date"] return self.spark.createDataFrame(data, columns) - def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""): - if kind == "line": - self.assertEqual(fig_data["mode"], "lines") - self.assertEqual(fig_data["type"], "scatter") - elif kind == "bar": - self.assertEqual(fig_data["type"], "bar") - elif kind == "barh": - self.assertEqual(fig_data["type"], "bar") - self.assertEqual(fig_data["orientation"], "h") - elif kind == "scatter": - self.assertEqual(fig_data["type"], "scatter") - self.assertEqual(fig_data["orientation"], "v") - self.assertEqual(fig_data["mode"], "markers") - elif kind == "area": - self.assertEqual(fig_data["type"], "scatter") - self.assertEqual(fig_data["orientation"], "v") - self.assertEqual(fig_data["mode"], "lines") - elif kind == "pie": - self.assertEqual(fig_data["type"], "pie") - self.assertEqual(list(fig_data["labels"]), expected_x) - self.assertEqual(list(fig_data["values"]), expected_y) - return - - self.assertEqual(fig_data["xaxis"], "x") - self.assertEqual(list(fig_data["x"]), expected_x) - self.assertEqual(fig_data["yaxis"], "y") - self.assertEqual(list(fig_data["y"]), expected_y) - self.assertEqual(fig_data["name"], expected_name) + def _check_fig_data(self, fig_data, **kwargs): + for key, expected_value in kwargs.items(): + if key in ["x", "y", "labels", "values"]: + converted_values = [v.item() if hasattr(v, "item") else v for v in fig_data[key]] + self.assertEqual(converted_values, expected_value) + else: + self.assertEqual(fig_data[key], expected_value) def test_line_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="line", x="category", y="int_val") - self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + expected_fig_data = { + "mode": "lines", + "name": "", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"]) - self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") - self._check_fig_data("line", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + expected_fig_data = { + "mode": "lines", + "name": "int_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "mode": "lines", + "name": "float_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) def test_bar_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="bar", x="category", y="int_val") - self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + expected_fig_data = { + "name": "", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"]) - self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") - self._check_fig_data("bar", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + expected_fig_data = { + "name": "int_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "float_val", + "orientation": "v", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) def test_barh_plot(self): # single column as vertical axis fig = self.sdf.plot(kind="barh", x="category", y="int_val") - self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20]) + expected_fig_data = { + "name": "", + "orientation": "h", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"]) - self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val") - self._check_fig_data("barh", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val") + expected_fig_data = { + "name": "int_val", + "orientation": "h", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "float_val", + "orientation": "h", + "x": ["A", "B", "C"], + "xaxis": "x", + "y": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) # multiple columns as horizontal axis fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category") - self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val") - self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val") + expected_fig_data = { + "name": "int_val", + "orientation": "h", + "y": ["A", "B", "C"], + "xaxis": "x", + "x": [10, 30, 20], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "float_val", + "orientation": "h", + "y": ["A", "B", "C"], + "xaxis": "x", + "x": [1.5, 2.5, 3.5], + "yaxis": "y", + "type": "bar", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) def test_scatter_plot(self): fig = self.sdf2.plot(kind="scatter", x="length", y="width") - self._check_fig_data( - "scatter", fig["data"][0], [5.1, 4.9, 7.0, 6.4, 5.9], [3.5, 3.0, 3.2, 3.2, 3.0] - ) + expected_fig_data = { + "name": "", + "orientation": "v", + "x": [5.1, 4.9, 7.0, 6.4, 5.9], + "xaxis": "x", + "y": [3.5, 3.0, 3.2, 3.2, 3.0], + "yaxis": "y", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "", + "orientation": "v", + "y": [5.1, 4.9, 7.0, 6.4, 5.9], + "xaxis": "x", + "x": [3.5, 3.0, 3.2, 3.2, 3.0], + "yaxis": "y", + "type": "scatter", + } fig = self.sdf2.plot.scatter(x="width", y="length") - self._check_fig_data( - "scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9] - ) + self._check_fig_data(fig["data"][0], **expected_fig_data) def test_area_plot(self): # single column as vertical axis @@ -131,13 +226,53 @@ def test_area_plot(self): datetime(2018, 3, 31, 0, 0), datetime(2018, 4, 30, 0, 0), ] - self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9]) + expected_fig_data = { + "name": "", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [3, 2, 3, 9], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) # multiple columns as vertical axis fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"]) - self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9], "sales") - self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 12], "signups") - self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 62], "visits") + expected_fig_data = { + "name": "sales", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [3, 2, 3, 9], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "name": "signups", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [5, 5, 6, 12], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + expected_fig_data = { + "name": "visits", + "orientation": "v", + "x": expected_x, + "xaxis": "x", + "y": [20, 42, 28, 62], + "yaxis": "y", + "mode": "lines", + "type": "scatter", + } + self._check_fig_data(fig["data"][2], **expected_fig_data) def test_pie_plot(self): fig = self.sdf3.plot(kind="pie", x="date", y="sales") @@ -147,11 +282,18 @@ def test_pie_plot(self): datetime(2018, 3, 31, 0, 0), datetime(2018, 4, 30, 0, 0), ] - self._check_fig_data("pie", fig["data"][0], expected_x, [3, 2, 3, 9]) + expected_fig_data = { + "name": "", + "labels": expected_x, + "values": [3, 2, 3, 9], + "type": "pie", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) # y is not a numerical column with self.assertRaises(PySparkTypeError) as pe: self.sdf.plot.pie(x="int_val", y="category") + self.check_error( exception=pe.exception, errorClass="PLOT_NOT_NUMERIC_COLUMN", From 36b2a4e3eab79cf9e38dd8174682d82b4e15958c Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 14 Oct 2024 09:20:19 +0200 Subject: [PATCH 233/250] [SPARK-49891][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_2271 ### What changes were proposed in this pull request? This PR proposes to assign proper error condition for _LEGACY_ERROR_TEMP_2271 ### Why are the changes needed? `UpdateColumnNullability` is not supported for MySQL DB, but the error message has no proper error conditions ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48366 from itholic/SPARK-49891. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 10 +++++----- .../sql/jdbc/v2/MsSqlServerIntegrationSuite.scala | 2 +- .../spark/sql/jdbc/v2/MySQLIntegrationSuite.scala | 2 +- .../apache/spark/sql/errors/QueryExecutionErrors.scala | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 1eaedd9f345a3..6e27561c0f97b 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5180,6 +5180,11 @@ "message" : [ "TRIM specifier in the collation." ] + }, + "UPDATE_COLUMN_NULLABILITY" : { + "message" : [ + "Update column nullability for MySQL and MS SQL Server." + ] } }, "sqlState" : "0A000" @@ -7695,11 +7700,6 @@ "comment on table is not supported." ] }, - "_LEGACY_ERROR_TEMP_2271" : { - "message" : [ - "UpdateColumnNullability is not supported." - ] - }, "_LEGACY_ERROR_TEMP_2272" : { "message" : [ "Rename column is only supported for MySQL version 8.0 and above." diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index e5fd453cb057c..aaaaa28558342 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -115,7 +115,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - condition = "_LEGACY_ERROR_TEMP_2271") + condition = "UNSUPPORTED_FEATURE.UPDATE_COLUMN_NULLABILITY") } test("SPARK-47440: SQLServer does not support boolean expression in binary comparison") { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 700c05b54a256..a895739254373 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -142,7 +142,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest exception = intercept[SparkSQLFeatureNotSupportedException] { sql(s"ALTER TABLE $tbl ALTER COLUMN ID DROP NOT NULL") }, - condition = "_LEGACY_ERROR_TEMP_2271") + condition = "UNSUPPORTED_FEATURE.UPDATE_COLUMN_NULLABILITY") } override def testCreateTableWithProperty(tbl: String): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 301880f1bfc61..6e64e7e9e39bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2275,7 +2275,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def unsupportedUpdateColumnNullabilityError(): SparkSQLFeatureNotSupportedException = { new SparkSQLFeatureNotSupportedException( - errorClass = "_LEGACY_ERROR_TEMP_2271", + errorClass = "UNSUPPORTED_FEATURE.UPDATE_COLUMN_NULLABILITY", messageParameters = Map.empty) } From 5b9b8da3ddec04497b07e1ac0b526a7850548191 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 14 Oct 2024 09:22:38 +0200 Subject: [PATCH 234/250] [SPARK-49904][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_2140 ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for _LEGACY_ERROR_TEMP_2140 ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48381 from itholic/legacy_2140. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../src/main/resources/error/error-conditions.json | 13 +++++++------ .../apache/spark/sql/errors/ExecutionErrors.scala | 4 ++-- .../spark/sql/ScalaReflectionRelationSuite.scala | 8 ++++---- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 6e27561c0f97b..5ffd69c3b1584 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2593,6 +2593,13 @@ }, "sqlState" : "42K0K" }, + "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME" : { + "message" : [ + " is not a valid identifier of Java and cannot be used as field name", + "." + ], + "sqlState" : "46121" + }, "INVALID_JOIN_TYPE_FOR_JOINWITH" : { "message" : [ "Invalid join type in joinWith: ." @@ -7206,12 +7213,6 @@ "cannot have circular references in class, but got the circular reference of class ." ] }, - "_LEGACY_ERROR_TEMP_2140" : { - "message" : [ - "`` is not a valid identifier of Java and cannot be used as field name", - "." - ] - }, "_LEGACY_ERROR_TEMP_2144" : { "message" : [ "Unable to find constructor for . This could happen if is an interface, or a trait without companion object constructor." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala index 3527a10496862..907c46f583cf1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala @@ -195,9 +195,9 @@ private[sql] trait ExecutionErrors extends DataTypeErrorsBase { fieldName: String, walkedTypePath: WalkedTypePath): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2140", + errorClass = "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME", messageParameters = - Map("fieldName" -> fieldName, "walkedTypePath" -> walkedTypePath.toString)) + Map("fieldName" -> toSQLId(fieldName), "walkedTypePath" -> walkedTypePath.toString)) } def primaryConstructorNotFoundError(cls: Class[_]): SparkRuntimeException = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 16118526f2fe4..76919d6583106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -163,9 +163,9 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSparkSession exception = intercept[SparkUnsupportedOperationException] { Seq(InvalidInJava(1)).toDS() }, - condition = "_LEGACY_ERROR_TEMP_2140", + condition = "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME", parameters = Map( - "fieldName" -> "abstract", + "fieldName" -> "`abstract`", "walkedTypePath" -> "- root class: \"org.apache.spark.sql.InvalidInJava\"")) } @@ -174,9 +174,9 @@ class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSparkSession exception = intercept[SparkUnsupportedOperationException] { Seq(InvalidInJava2(1)).toDS() }, - condition = "_LEGACY_ERROR_TEMP_2140", + condition = "INVALID_JAVA_IDENTIFIER_AS_FIELD_NAME", parameters = Map( - "fieldName" -> "0", + "fieldName" -> "`0`", "walkedTypePath" -> "- root class: \"org.apache.spark.sql.ScalaReflectionRelationSuite.InvalidInJava2\"")) } From a2ad4d4db6f183ebb17cd7d1027525a824a86cb5 Mon Sep 17 00:00:00 2001 From: Max Gekk Date: Mon, 14 Oct 2024 09:24:11 +0200 Subject: [PATCH 235/250] [MINOR][CORE] Fix the regenerate command in `SparkThrowableSuite` ### What changes were proposed in this pull request? Fix the command in `SparkThrowableSuite` which regenerates/re-formats `error-conditions.json`. ### Why are the changes needed? To don't confuses other devs. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running the changes command: ``` $ SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "core/testOnly *SparkThrowableSuite -- -t \"Error conditions are correctly formatted\"" ``` ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48432 from MaxGekk/fix-regenerateCommand. Authored-by: Max Gekk Signed-off-by: Max Gekk --- core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala index 9f005e5757193..ea845c0f93a4b 100644 --- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala @@ -47,7 +47,7 @@ class SparkThrowableSuite extends SparkFunSuite { }}} */ private val regenerateCommand = "SPARK_GENERATE_GOLDEN_FILES=1 build/sbt " + - "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error classes match with document\\\"\"" + "\"core/testOnly *SparkThrowableSuite -- -t \\\"Error conditions are correctly formatted\\\"\"" private val errorJsonFilePath = getWorkspaceFilePath( "common", "utils", "src", "main", "resources", "error", "error-conditions.json") From 560748c2b1482836f1fbc117ebb88cee1371554c Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 14 Oct 2024 09:33:22 +0200 Subject: [PATCH 236/250] [SPARK-49892][SQL] Assign proper error class for _LEGACY_ERROR_TEMP_1136 ### What changes were proposed in this pull request? This PR proposes to assign proper error condition for _LEGACY_ERROR_TEMP_1136 ### Why are the changes needed? Currently we don't have proper error condition and SQLSTATE when user try saving interval data type into external storage ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48367 from itholic/LEGACY_1136. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../main/resources/error/error-conditions.json | 5 ----- .../org/apache/spark/sql/avro/AvroSuite.scala | 8 ++++++-- .../sql/errors/QueryCompilationErrors.scala | 6 ------ .../sql/execution/datasources/DataSource.scala | 18 +++++++++++------- .../spark/sql/FileBasedDataSourceSuite.scala | 13 +++++++++++-- 5 files changed, 28 insertions(+), 22 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 5ffd69c3b1584..0a7f0b2845c1f 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6137,11 +6137,6 @@ " is not a valid Spark SQL Data Source." ] }, - "_LEGACY_ERROR_TEMP_1136" : { - "message" : [ - "Cannot save interval data type into external storage." - ] - }, "_LEGACY_ERROR_TEMP_1137" : { "message" : [ "Unable to resolve given []." diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index e9d6c2458df81..0df6a7c4bc90e 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -1673,8 +1673,12 @@ abstract class AvroSuite exception = intercept[AnalysisException] { sql("select interval 1 days").write.format("avro").mode("overwrite").save(tempDir) }, - condition = "_LEGACY_ERROR_TEMP_1136", - parameters = Map.empty + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + parameters = Map( + "format" -> "Avro", + "columnName" -> "`INTERVAL '1 days'`", + "columnType" -> "\"INTERVAL\"" + ) ) checkError( exception = intercept[AnalysisException] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 0e02e4249addd..3d3d9cb70bcf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1691,12 +1691,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat messageParameters = Map("className" -> className)) } - def cannotSaveIntervalIntoExternalStorageError(): Throwable = { - new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1136", - messageParameters = Map.empty) - } - def cannotResolveAttributeError(name: String, outputStr: String): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1137", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 968c204841e46..e4870c9821f64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -514,7 +514,8 @@ case class DataSource( dataSource.createRelation( sparkSession.sqlContext, mode, caseInsensitiveOptions, Dataset.ofRows(sparkSession, data)) case format: FileFormat => - disallowWritingIntervals(outputColumns.map(_.dataType), forbidAnsiIntervals = false) + disallowWritingIntervals( + outputColumns.toStructType.asNullable, format.toString, forbidAnsiIntervals = false) val cmd = planForWritingFileFormat(format, mode, data) val qe = sparkSession.sessionState.executePlan(cmd) qe.assertCommandExecuted() @@ -539,7 +540,7 @@ case class DataSource( } SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => - disallowWritingIntervals(data.schema.map(_.dataType), forbidAnsiIntervals = false) + disallowWritingIntervals(data.schema, format.toString, forbidAnsiIntervals = false) DataSource.validateSchema(data.schema, sparkSession.sessionState.conf) planForWritingFileFormat(format, mode, data) case _ => throw SparkException.internalError( @@ -566,12 +567,15 @@ case class DataSource( } private def disallowWritingIntervals( - dataTypes: Seq[DataType], + outputColumns: Seq[StructField], + format: String, forbidAnsiIntervals: Boolean): Unit = { - dataTypes.foreach( - TypeUtils.invokeOnceForInterval(_, forbidAnsiIntervals) { - throw QueryCompilationErrors.cannotSaveIntervalIntoExternalStorageError() - }) + outputColumns.foreach { field => + TypeUtils.invokeOnceForInterval(field.dataType, forbidAnsiIntervals) { + throw QueryCompilationErrors.dataTypeUnsupportedByDataSourceError( + format, field + )} + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index e44bd5de4f4c4..6661c4473c7b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -506,14 +506,23 @@ class FileBasedDataSourceSuite extends QueryTest withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> useV1List, SQLConf.LEGACY_INTERVAL_ENABLED.key -> "true") { + val formatMapping = Map( + "csv" -> "CSV", + "json" -> "JSON", + "parquet" -> "Parquet", + "orc" -> "ORC" + ) // write path Seq("csv", "json", "parquet", "orc").foreach { format => checkError( exception = intercept[AnalysisException] { sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) }, - condition = "_LEGACY_ERROR_TEMP_1136", - parameters = Map.empty + condition = "UNSUPPORTED_DATA_TYPE_FOR_DATASOURCE", + parameters = Map( + "format" -> formatMapping(format), + "columnName" -> "`INTERVAL '1 days'`", + "columnType" -> "\"INTERVAL\"") ) } From eeb044ea4a5ded35561e3907b77cb88dd36e278f Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 14 Oct 2024 17:21:28 +0800 Subject: [PATCH 237/250] [SPARK-49949][PS] Avoid unnecessary analyze task in `attach_sequence_column` ### What changes were proposed in this pull request? Avoid unnecessary analyze task in `attach_sequence_column` ### Why are the changes needed? In Connect mode, if the input `sdf` hasn't cache its schema, `attach_sequence_column` will trigger an analyze task for it. However, in this case, the column names are not needed. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? existing tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #48448 from zhengruifeng/attach_sequence_column. Authored-by: Ruifeng Zheng Signed-off-by: Ruifeng Zheng --- python/pyspark/pandas/internal.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 4be345201ba65..6063641e22e3b 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -902,11 +902,10 @@ def attach_default_index( @staticmethod def attach_sequence_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDataFrame: - scols = [scol_for(sdf, column) for column in sdf.columns] sequential_index = ( F.row_number().over(Window.orderBy(F.monotonically_increasing_id())).cast("long") - 1 ) - return sdf.select(sequential_index.alias(column_name), *scols) + return sdf.select(sequential_index.alias(column_name), "*") @staticmethod def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDataFrame: From d77b2932eb4b07e6274f64adcdfccbf8bbb5f565 Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 11:58:10 +0200 Subject: [PATCH 238/250] add more tests. --- .../sql/catalyst/util/CollationFactory.java | 4 +- .../unsafe/types/CollationFactorySuite.scala | 38 +++++++++++++++---- 2 files changed, 33 insertions(+), 9 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 6c6594c0b94af..636fe09b0f3b3 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -547,7 +547,7 @@ protected Collation buildCollation() { BiFunction equalsFunction; boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; - if(spaceTrimming == SpaceTrimming.NONE) { + if (spaceTrimming == SpaceTrimming.NONE) { comparator = UTF8String::binaryCompare; hashFunction = s -> (long) s.hashCode(); equalsFunction = UTF8String::equals; @@ -575,7 +575,7 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if (spaceTrimming == SpaceTrimming.NONE ) { + if (spaceTrimming == SpaceTrimming.NONE) { comparator = CollationAwareUTF8String::compareLowerCase; hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 88ef9a3c2d83f..f6ac21d951f83 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -127,7 +127,10 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), CollationTestCase("UTF8_BINARY", "aaa", "bbb", false), CollationTestCase("UTF8_BINARY", "å", "a\u030A", false), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa", true), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -135,18 +138,27 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", false), CollationTestCase("UTF8_LCASE", "aaa", "bbb", false), CollationTestCase("UTF8_LCASE", "å", "a\u030A", false), - CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), CollationTestCase("UNICODE", "å", "a\u030A", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa", true), CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", true) + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false) ) checks.foreach(testCase => { @@ -167,8 +179,11 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", 0), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), @@ -176,21 +191,30 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", 1), CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA ", 0), CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), CollationTestCase("UNICODE", "aaa", "BBB", -1), - CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), - CollationTestCase("UNICODE_RTRIM", "aaa", "AAA ", -1), - CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), - CollationTestCase("UNICODE_CI_RTRIM", "aaa", "aaa ", 0), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1) + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1) ) checks.foreach(testCase => { From 7e781775dd284bf3af91c8b621117cead681011e Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 12:08:33 +0200 Subject: [PATCH 239/250] nit fixes. --- .../org/apache/spark/sql/types/StringType.scala | 3 +++ .../spark/sql/CollationSQLExpressionsSuite.scala | 12 ++---------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index e07471a15b6a3..1c93c2ad550e9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -41,6 +41,9 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ private[sql] def supportsBinaryEquality: Boolean = CollationFactory.fetchCollation(collationId).supportsBinaryEquality + private[sql] def supportsLowercaseEquality: Boolean = + CollationFactory.fetchCollation(collationId).supportsLowercaseEquality + private[sql] def isNonCSAI: Boolean = !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index ac8ad69dd55d6..ce6818652d2b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -497,13 +497,9 @@ class CollationSQLExpressionsSuite val testCases = Seq( BinTestCase("13", "UTF8_BINARY", "1101"), - BinTestCase("13", "UTF8_BINARY_RTRIM", "1101"), BinTestCase("13", "UTF8_LCASE", "1101"), - BinTestCase("13", "UTF8_LCASE_RTRIM", "1101"), BinTestCase("13", "UNICODE", "1101"), - BinTestCase("13", "UNICODE_RTRIM", "1101"), - BinTestCase("13", "UNICODE_CI", "1101"), - BinTestCase("13", "UNICODE_CI_RTRIM", "1101") + BinTestCase("13", "UNICODE_CI", "1101") ) testCases.foreach(t => { val query = @@ -526,13 +522,9 @@ class CollationSQLExpressionsSuite val testCases = Seq( HexTestCase("13", "UTF8_BINARY", "D"), - HexTestCase("13", "UTF8_BINARY_RTRIM", "D"), HexTestCase("13", "UTF8_LCASE", "D"), - HexTestCase("13", "UTF8_LCASE_RTRIM", "D"), HexTestCase("13", "UNICODE", "D"), - HexTestCase("13", "UNICODE_RTRIM", "D"), - HexTestCase("13", "UNICODE_CI", "D"), - HexTestCase("13", "UNICODE_CI_RTRIM", "D") + HexTestCase("13", "UNICODE_CI", "D") ) testCases.foreach(t => { val query = From af27d43000f1f40b494c47155127fefda4dc03de Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 15:10:06 +0200 Subject: [PATCH 240/250] nit fixes. --- .../spark/sql/catalyst/util/CollationFactory.java | 2 +- .../spark/unsafe/types/CollationFactorySuite.scala | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 636fe09b0f3b3..01f6c7e0331b0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -961,7 +961,7 @@ protected Collation buildCollation() { Comparator comparator; ToLongFunction hashFunction; - if (spaceTrimming == SpaceTrimming.NONE){ + if (spaceTrimming == SpaceTrimming.NONE) { hashFunction = s -> (long) collator.getCollationKey( s.toValidString()).hashCode(); comparator = (s1, s2) -> diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index f6ac21d951f83..039babcbb01c3 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -131,6 +131,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " ", true), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -142,6 +143,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", true), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), @@ -150,6 +152,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), @@ -158,7 +161,8 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), - CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false) + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true) ) checks.foreach(testCase => { @@ -184,6 +188,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " " , 0), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), @@ -196,6 +201,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), @@ -206,6 +212,7 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), + CollationTestCase("UNICODE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), @@ -214,7 +221,8 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), - CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1) + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0) ) checks.foreach(testCase => { From d53fa23e4a1e7a680ed9f845540b180045d0f452 Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Mon, 14 Oct 2024 15:26:25 +0200 Subject: [PATCH 241/250] [SPARK-49952][SQL] Assign proper error condition for _LEGACY_ERROR_TEMP_1142 ### What changes were proposed in this pull request? This PR proposes to assign proper error condition & sqlstate for _LEGACY_ERROR_TEMP_1142 ### Why are the changes needed? To improve the error message by assigning proper error condition and SQLSTATE ### Does this PR introduce _any_ user-facing change? No, only user-facing error message improved ### How was this patch tested? Updated the existing tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #48451 from itholic/SPARK-49952. Authored-by: Haejoon Lee Signed-off-by: Max Gekk --- .../main/resources/error/error-conditions.json | 11 ++++++----- .../spark/sql/errors/QueryCompilationErrors.scala | 6 +++--- .../sql/execution/datasources/DataSource.scala | 6 +++--- .../sql/execution/datasources/v2/FileWrite.scala | 2 +- .../spark/sql/FileBasedDataSourceSuite.scala | 15 +++++++++++---- 5 files changed, 24 insertions(+), 16 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 0a7f0b2845c1f..8b2a57d6da3dd 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1212,6 +1212,12 @@ ], "sqlState" : "42604" }, + "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE" : { + "message" : [ + "The datasource does not support writing empty or nested empty schemas. Please make sure the data schema has at least one or more column(s)." + ], + "sqlState" : "0A000" + }, "ENCODER_NOT_FOUND" : { "message" : [ "Not found an encoder of the type to Spark SQL internal representation.", @@ -6162,11 +6168,6 @@ "Multiple sources found for (), please specify the fully qualified class name." ] }, - "_LEGACY_ERROR_TEMP_1142" : { - "message" : [ - "Datasource does not support writing empty or nested empty schemas. Please make sure the data schema has at least one or more column(s)." - ] - }, "_LEGACY_ERROR_TEMP_1143" : { "message" : [ "The data to be inserted needs to have the same number of columns as the target table: target table has column(s) but the inserted data has column(s), which contain partition column(s) having assigned constant values." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 3d3d9cb70bcf3..9dc15c4a1b78d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1723,10 +1723,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "sourceNames" -> sourceNames.mkString(", "))) } - def writeEmptySchemasUnsupportedByDataSourceError(): Throwable = { + def writeEmptySchemasUnsupportedByDataSourceError(format: String): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1142", - messageParameters = Map.empty) + errorClass = "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE", + messageParameters = Map("format" -> format)) } def insertMismatchedColumnNumberError( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index e4870c9821f64..3698dc2f0808e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -541,7 +541,7 @@ case class DataSource( SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode) case format: FileFormat => disallowWritingIntervals(data.schema, format.toString, forbidAnsiIntervals = false) - DataSource.validateSchema(data.schema, sparkSession.sessionState.conf) + DataSource.validateSchema(format.toString, data.schema, sparkSession.sessionState.conf) planForWritingFileFormat(format, mode, data) case _ => throw SparkException.internalError( s"${providingClass.getCanonicalName} does not allow create table as select.") @@ -842,7 +842,7 @@ object DataSource extends Logging { * @param schema * @param conf */ - def validateSchema(schema: StructType, conf: SQLConf): Unit = { + def validateSchema(formatName: String, schema: StructType, conf: SQLConf): Unit = { val shouldAllowEmptySchema = conf.getConf(SQLConf.ALLOW_EMPTY_SCHEMAS_FOR_WRITES) def hasEmptySchema(schema: StructType): Boolean = { schema.size == 0 || schema.exists { @@ -853,7 +853,7 @@ object DataSource extends Logging { if (!shouldAllowEmptySchema && hasEmptySchema(schema)) { - throw QueryCompilationErrors.writeEmptySchemasUnsupportedByDataSourceError() + throw QueryCompilationErrors.writeEmptySchemasUnsupportedByDataSourceError(formatName) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala index cdcf6f21fd008..f4cabcb69d08c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala @@ -96,7 +96,7 @@ trait FileWrite extends Write { SchemaUtils.checkColumnNameDuplication( schema.fields.map(_.name).toImmutableArraySeq, caseSensitiveAnalysis) } - DataSource.validateSchema(schema, sqlConf) + DataSource.validateSchema(formatName, schema, sqlConf) // TODO: [SPARK-36340] Unify check schema filed of DataSource V2 Insert. schema.foreach { field => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 6661c4473c7b9..9c529d1422119 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -128,13 +128,20 @@ class FileBasedDataSourceSuite extends QueryTest allFileBasedDataSources.foreach { format => test(s"SPARK-23372 error while writing empty schema files using $format") { + val formatMapping = Map( + "csv" -> "CSV", + "json" -> "JSON", + "parquet" -> "Parquet", + "orc" -> "ORC", + "text" -> "Text" + ) withTempPath { outputPath => checkError( exception = intercept[AnalysisException] { spark.emptyDataFrame.write.format(format).save(outputPath.toString) }, - condition = "_LEGACY_ERROR_TEMP_1142", - parameters = Map.empty + condition = "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE", + parameters = Map("format" -> formatMapping(format)) ) } @@ -150,8 +157,8 @@ class FileBasedDataSourceSuite extends QueryTest exception = intercept[AnalysisException] { df.write.format(format).save(outputPath.toString) }, - condition = "_LEGACY_ERROR_TEMP_1142", - parameters = Map.empty + condition = "EMPTY_SCHEMA_NOT_SUPPORTED_FOR_DATASOURCE", + parameters = Map("format" -> formatMapping(format)) ) } } From 0606512065b26311ced2de8037366adfc80e2b1d Mon Sep 17 00:00:00 2001 From: Mihailo Milosevic Date: Mon, 14 Oct 2024 16:12:46 +0200 Subject: [PATCH 242/250] [SPARK-49864][SQL][FOLLOW-UP] Fix default suggestion for binary arithmetic overflow ### What changes were proposed in this pull request? Improvement on default branch for try suggestion. ### Why are the changes needed? When we hit default branch in codeGen, we need to return a default value that would specify that we do not know the function, and not just a blank string. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? No branch hits this behaviour so far, but we need to safeguard from the possible errors. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48450 from mihailom-db/binaryArithmeticOverflowFollowup. Authored-by: Mihailo Milosevic Signed-off-by: Max Gekk --- .../apache/spark/sql/catalyst/expressions/arithmetic.scala | 2 +- .../main/scala/org/apache/spark/sql/types/numerics.scala | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d8ba1fe840bd0..497fdc0936267 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -298,7 +298,7 @@ abstract class BinaryArithmetic extends BinaryOperator case "+" => "try_add" case "-" => "try_subtract" case "*" => "try_multiply" - case _ => "" + case _ => "unknown_function" } val overflowCheck = if (failOnError) { val javaType = CodeGenerator.boxedType(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 1c860e61973c6..ccd9ed209f92a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -55,7 +55,12 @@ private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOr private[sql] object ShortExactNumeric extends ShortIsIntegral with Ordering.ShortOrdering { - private def checkOverflow(res: Int, x: Short, y: Short, op: String, hint: String): Unit = { + private def checkOverflow( + res: Int, + x: Short, + y: Short, + op: String, + hint: String = "unknown_function"): Unit = { if (res > Short.MaxValue || res < Short.MinValue) { throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(x, op, y, hint) } From 96c4953680988739e26b860b24d3966e3cc1cb1f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 14 Oct 2024 16:40:18 +0200 Subject: [PATCH 243/250] [SPARK-49955][SQL] null value does not mean corrupted file when parsing JSON string RDD ### What changes were proposed in this pull request? This is a followup of https://github.com/apache/spark/pull/42979 , to fix a regression. For the `spark.read.json(rdd)` API, there is never corrupted file, and we should not fail if the string value is null with non-failfast parsing mode. This PR is a partial revert of https://github.com/apache/spark/pull/42979 , to not treat `RuntimeException` as corrupted file when we are not reading from files. ### Why are the changes needed? A query used to work in 3.5 should still work in 4.0 ### Does this PR introduce _any_ user-facing change? no as this regression is not released yet. ### How was this patch tested? new test ### Was this patch authored or co-authored using generative AI tooling? no Closes #48453 from cloud-fan/json. Lead-authored-by: Wenchen Fan Co-authored-by: Wenchen Fan Signed-off-by: Max Gekk --- .../sql/catalyst/json/JsonInferSchema.scala | 6 +++- .../datasources/json/JsonDataSource.scala | 3 +- .../csv/CSVParsingOptionsSuite.scala | 35 +++++++++++++++++++ .../json/JsonParsingOptionsSuite.scala | 11 ++++++ 4 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParsingOptionsSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala index d982e1f19da0c..9c291634401ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -81,7 +81,8 @@ class JsonInferSchema(options: JSONOptions) extends Serializable with Logging { */ def infer[T]( json: RDD[T], - createParser: (JsonFactory, T) => JsonParser): StructType = { + createParser: (JsonFactory, T) => JsonParser, + isReadFile: Boolean = false): StructType = { val parseMode = options.parseMode val columnNameOfCorruptRecord = options.columnNameOfCorruptRecord @@ -96,6 +97,9 @@ class JsonInferSchema(options: JSONOptions) extends Serializable with Logging { Some(inferField(parser)) } } catch { + // If we are not reading from files but hit `RuntimeException`, it means corrupted record. + case e: RuntimeException if !isReadFile => + handleJsonErrorsByParseMode(parseMode, columnNameOfCorruptRecord, e) case e @ (_: JsonProcessingException | _: MalformedInputException) => handleJsonErrorsByParseMode(parseMode, columnNameOfCorruptRecord, e) case e: CharConversionException if options.encoding.isEmpty => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 7c98c31bba220..cb4c4f5290880 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -164,7 +164,8 @@ object MultiLineJsonDataSource extends JsonDataSource { .getOrElse(createParser(_: JsonFactory, _: PortableDataStream)) SQLExecution.withSQLConfPropagated(sparkSession) { - new JsonInferSchema(parsedOptions).infer[PortableDataStream](sampled, parser) + new JsonInferSchema(parsedOptions) + .infer[PortableDataStream](sampled, parser, isReadFile = true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParsingOptionsSuite.scala new file mode 100644 index 0000000000000..8c8304503cef8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParsingOptionsSuite.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.test.SharedSparkSession + +class CSVParsingOptionsSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("SPARK-49955: null string value does not mean corrupted file") { + val str = "abc" + val stringDataset = Seq(str, null).toDS() + val df = spark.read.csv(stringDataset) + // `spark.read.csv(rdd)` removes all null values at the beginning. + checkAnswer(df, Seq(Row("abc"))) + val df2 = spark.read.option("mode", "failfast").csv(stringDataset) + checkAnswer(df2, Seq(Row("abc"))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 703085dca66f1..11cc0b99bbde7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.json +import org.apache.spark.SparkException import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{StringType, StructType} @@ -185,4 +186,14 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSparkSession { assert(df.first().getString(0) == "Cazen Lee") assert(df.first().getString(1) == "$10") } + + test("SPARK-49955: null string value does not mean corrupted file") { + val str = "{\"name\": \"someone\"}" + val stringDataset = Seq(str, null).toDS() + val df = spark.read.json(stringDataset) + checkAnswer(df, Seq(Row(null, "someone"), Row(null, null))) + + val e = intercept[SparkException](spark.read.option("mode", "failfast").json(stringDataset)) + assert(e.getCause.isInstanceOf[NullPointerException]) + } } From 1d73ad6bc92a4811b16292d263a6fe9c1ad7b68e Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 18:46:49 +0200 Subject: [PATCH 244/250] init commit. --- .../util/CollationAwareUTF8String.java | 4 +- .../sql/catalyst/util/CollationFactory.java | 54 ++++---- .../sql/catalyst/util/CollationSupport.java | 122 +++++++++--------- .../unsafe/types/CollationFactorySuite.scala | 8 +- .../spark/sql/catalyst/expressions/hash.scala | 6 +- .../sql/catalyst/util/UnsafeRowUtils.scala | 2 +- .../expressions/HashExpressionsSuite.scala | 2 +- .../aggregate/HashMapGenerator.scala | 4 +- .../org/apache/spark/sql/CollationSuite.scala | 12 +- 9 files changed, 111 insertions(+), 103 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index fb610a5d96f17..d67697eaea38b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1363,9 +1363,9 @@ public static UTF8String trimRight( public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, final int limit, final int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { return input.split(delim, limit); - } else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { return lowercaseSplitSQL(input, delim, limit); } else { return icuSplitSQL(input, delim, limit, collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 01f6c7e0331b0..b9dce20cf0345 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -154,12 +154,25 @@ public static class Collation { */ public final boolean supportsLowercaseEquality; + /** * Support for Space Trimming implies that that based on specifier (for now only right trim) * leading, trailing or both spaces are removed from the input string before comparison. */ public final boolean supportsSpaceTrimming; + /** + * Is Utf8 binary type as indicator if collation base type is UTF8 binary. Note currently only + * collations Utf8_Binary and Utf8_Binary_RTRIM are considered as Utf8 binary type. + */ + public final boolean isUtf8BinaryType; + + /** + * Is Utf8 lcase type as indicator if collation base type is UTF8 lcase. Note currently only + * collations Utf8_Lcase and Utf8_Lcase_RTRIM are considered as Utf8 Lcase type. + */ + public final boolean isUtf8LcaseType; + public Collation( String collationName, String provider, @@ -168,9 +181,8 @@ public Collation( String version, ToLongFunction hashFunction, BiFunction equalsFunction, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality, + boolean isUtf8BinaryType, + boolean isUtf8LcaseType, boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; @@ -178,16 +190,15 @@ public Collation( this.comparator = comparator; this.version = version; this.hashFunction = hashFunction; - this.supportsBinaryEquality = supportsBinaryEquality; - this.supportsBinaryOrdering = supportsBinaryOrdering; - this.supportsLowercaseEquality = supportsLowercaseEquality; + this.isUtf8BinaryType = isUtf8BinaryType; + this.isUtf8LcaseType = isUtf8LcaseType; this.equalsFunction = equalsFunction; this.supportsSpaceTrimming = supportsSpaceTrimming; - - // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality - assert(!supportsBinaryOrdering || supportsBinaryEquality); + this.supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType; // No Collation can simultaneously support binary equality and lowercase equality - assert(!supportsBinaryEquality || !supportsLowercaseEquality); + assert(!isUtf8BinaryType || !isUtf8LcaseType); assert(SUPPORTED_PROVIDERS.contains(provider)); } @@ -567,9 +578,8 @@ protected Collation buildCollation() { "1.0", hashFunction, equalsFunction, - /* supportsBinaryEquality = */ true, - /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ true, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } else { Comparator comparator; @@ -595,9 +605,8 @@ protected Collation buildCollation() { "1.0", hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ true, spaceTrimming != SpaceTrimming.NONE); } } @@ -982,9 +991,8 @@ protected Collation buildCollation() { ICU_COLLATOR_VERSION, hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } @@ -1191,9 +1199,9 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input; - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input); } else { CollationKey collationKey = collation.collator.getCollationKey( @@ -1207,9 +1215,9 @@ public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input.getBytes(); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes(); } else { return collation.collator.getCollationKey( diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index f05d9e512568f..978b663cc25c9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -37,9 +37,9 @@ public final class CollationSupport { public static class StringSplitSQL { public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(s, d); } else { return execICU(s, d, collationId); @@ -48,9 +48,9 @@ public static UTF8String[] exec(final UTF8String s, final UTF8String d, final in public static String genCode(final String s, final String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplitSQL.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", s, d); } else { return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); @@ -71,9 +71,9 @@ public static UTF8String[] execICU(final UTF8String string, final UTF8String del public static class Contains { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -82,9 +82,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Contains.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -109,9 +109,9 @@ public static class StartsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -120,9 +120,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StartsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -146,9 +146,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class EndsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -157,9 +157,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.EndsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -184,9 +184,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class Upper { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -195,10 +195,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -221,9 +221,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class Lower { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -232,10 +232,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Lower.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -258,9 +258,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class InitCap { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -270,10 +270,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.InitCap.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -296,7 +296,7 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(word, set); } else { return execCollationAware(word, set, collationId); @@ -305,7 +305,7 @@ public static int exec(final UTF8String word, final UTF8String set, final int co public static String genCode(final String word, final String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", word, set); } else { return String.format(expr + "CollationAware(%s, %s, %d)", word, set, collationId); @@ -324,9 +324,9 @@ public static class StringInstr { public static int exec(final UTF8String string, final UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring); } else { return execICU(string, substring, collationId); @@ -336,9 +336,9 @@ public static String genCode(final String string, final String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); @@ -360,9 +360,9 @@ public static class StringReplace { public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(src, search, replace); } else { return execICU(src, search, replace, collationId); @@ -372,9 +372,9 @@ public static String genCode(final String src, final String search, final String final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringReplace.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); @@ -398,9 +398,9 @@ public static class StringLocate { public static int exec(final UTF8String string, final UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring, start); } else { return execICU(string, substring, start, collationId); @@ -410,9 +410,9 @@ public static String genCode(final String string, final String substring, final final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringLocate.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); @@ -436,9 +436,9 @@ public static class SubstringIndex { public static UTF8String exec(final UTF8String string, final UTF8String delimiter, final int count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, delimiter, count); } else { return execICU(string, delimiter, count, collationId); @@ -448,9 +448,9 @@ public static String genCode(final String string, final String delimiter, final String count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.SubstringIndex.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", string, delimiter, count); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", string, delimiter, count, collationId); @@ -474,9 +474,9 @@ public static class StringTranslate { public static UTF8String exec(final UTF8String source, Map dict, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(source, dict); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(source, dict); } else { return execICU(source, dict, collationId); @@ -503,9 +503,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -520,9 +520,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrim.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -559,9 +559,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -576,9 +576,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimLeft.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -614,9 +614,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -631,9 +631,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimRight.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -669,7 +669,7 @@ public static UTF8String execICU( public static boolean supportsLowercaseRegex(final int collationId) { // for regex, only Unicode case-insensitive matching is possible, // so UTF8_LCASE is treated as UNICODE_CI in this context - return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality; + return CollationFactory.fetchCollation(collationId).isUtf8LcaseType; } static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 039babcbb01c3..4672c39d9be8a 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -38,22 +38,22 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig assert(UTF8_BINARY_COLLATION_ID == 0) val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") - assert(utf8Binary.supportsBinaryEquality) + assert(utf8Binary.isUtf8BinaryType) assert(UTF8_LCASE_COLLATION_ID == 1) val utf8Lcase = fetchCollation(UTF8_LCASE_COLLATION_ID) assert(utf8Lcase.collationName == "UTF8_LCASE") - assert(!utf8Lcase.supportsBinaryEquality) + assert(!utf8Lcase.isUtf8BinaryType) assert(UNICODE_COLLATION_ID == (1 << 29)) val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(!unicode.supportsBinaryEquality) + assert(!unicode.isUtf8BinaryType) assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") - assert(!unicodeCi.supportsBinaryEquality) + assert(!unicodeCi.isUtf8BinaryType) } test("UTF8_BINARY and ICU root locale collation names") { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 7128190902550..3a667f370428e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -415,7 +415,7 @@ abstract class HashExpression[E] extends Expression { protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { + if (stringType.supportsBinaryEquality) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" @@ -566,7 +566,7 @@ abstract class InterpretedHashFunction { hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) case s: UTF8String => val st = dataType.asInstanceOf[StringType] - if (st.supportsBinaryEquality && !st.usesTrimCollation) { + if (st.supportsBinaryEquality) { hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) } else { val stringHash = CollationFactory @@ -817,7 +817,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { + if (stringType.supportsBinaryEquality) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 40b8bccafaad2..118dd92c3ed54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -207,7 +207,7 @@ object UnsafeRowUtils { def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { case st: StringType => val collation = CollationFactory.fetchCollation(st.collationId) - (!collation.supportsBinaryEquality || collation.supportsSpaceTrimming) + (!collation.supportsBinaryEquality) case _ => false } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 6f3890cafd2ac..92ef24bb8ec63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -636,7 +636,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(murmur3Hash1, interpretedHash1) checkEvaluation(murmur3Hash2, interpretedHash2) - if (CollationFactory.fetchCollation(collation).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collation).isUtf8BinaryType) { assert(interpretedHash1 != interpretedHash2) } else { assert(interpretedHash1 == interpretedHash2) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 3b1f349520f39..19a36483abe6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,9 +173,9 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation => + case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") - case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation => + case st: StringType if !st.supportsBinaryEquality => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") case CalendarIntervalType => hashInt(s"$input.hashCode()") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index e12c2838b88ab..25e1197bea4a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1127,7 +1127,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).supportsBinaryEquality + CollationFactory.fetchCollation(collation).isUtf8BinaryType test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1352,7 +1352,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: HashJoin => b.leftKeys.head }.head.isInstanceOf[CollationKey]) @@ -1409,7 +1409,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. @@ -1470,7 +1470,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function. @@ -1529,7 +1529,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1588,7 +1588,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) From 74aed77853ca7b4744d6e62fdba05a7fe3a161ff Mon Sep 17 00:00:00 2001 From: Jovan Pavlovic Date: Mon, 14 Oct 2024 20:35:11 +0200 Subject: [PATCH 245/250] [SPARK-49661][SQL] Implement trim collation hashing and comparison ### What changes were proposed in this pull request? Implement support for hashing and comparison for trim collation. ### Why are the changes needed? To have full support for trim collation. ### How was this patch tested? Add tests in CollationFactorySUite and CollationSqlExpressionSuite. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48386 from jovanpavl-db/implement_hashing. Authored-by: Jovan Pavlovic Signed-off-by: Max Gekk --- .../sql/catalyst/util/CollationFactory.java | 103 +++++++++++++----- .../unsafe/types/CollationFactorySuite.scala | 53 ++++++++- .../apache/spark/sql/types/StringType.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 6 +- .../sql/catalyst/util/UnsafeRowUtils.scala | 4 +- .../aggregate/HashMapGenerator.scala | 5 +- .../sql/CollationSQLExpressionsSuite.scala | 38 ++++++- .../org/apache/spark/sql/CollationSuite.scala | 27 ++++- 8 files changed, 195 insertions(+), 43 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 113c5f866fd88..01f6c7e0331b0 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -154,6 +154,12 @@ public static class Collation { */ public final boolean supportsLowercaseEquality; + /** + * Support for Space Trimming implies that that based on specifier (for now only right trim) + * leading, trailing or both spaces are removed from the input string before comparison. + */ + public final boolean supportsSpaceTrimming; + public Collation( String collationName, String provider, @@ -161,9 +167,11 @@ public Collation( Comparator comparator, String version, ToLongFunction hashFunction, + BiFunction equalsFunction, boolean supportsBinaryEquality, boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality) { + boolean supportsLowercaseEquality, + boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; this.collator = collator; @@ -173,6 +181,8 @@ public Collation( this.supportsBinaryEquality = supportsBinaryEquality; this.supportsBinaryOrdering = supportsBinaryOrdering; this.supportsLowercaseEquality = supportsLowercaseEquality; + this.equalsFunction = equalsFunction; + this.supportsSpaceTrimming = supportsSpaceTrimming; // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality assert(!supportsBinaryOrdering || supportsBinaryEquality); @@ -180,12 +190,6 @@ public Collation( assert(!supportsBinaryEquality || !supportsLowercaseEquality); assert(SUPPORTED_PROVIDERS.contains(provider)); - - if (supportsBinaryEquality) { - this.equalsFunction = UTF8String::equals; - } else { - this.equalsFunction = (s1, s2) -> this.comparator.compare(s1, s2) == 0; - } } /** @@ -538,27 +542,63 @@ private static boolean isValidCollationId(int collationId) { @Override protected Collation buildCollation() { if (caseSensitivity == CaseSensitivity.UNSPECIFIED) { + Comparator comparator; + ToLongFunction hashFunction; + BiFunction equalsFunction; + boolean supportsSpaceTrimming = spaceTrimming != SpaceTrimming.NONE; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = UTF8String::binaryCompare; + hashFunction = s -> (long) s.hashCode(); + equalsFunction = UTF8String::equals; + } else { + comparator = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).binaryCompare( + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) applyTrimmingPolicy(s, spaceTrimming).hashCode(); + equalsFunction = (s1, s2) -> applyTrimmingPolicy(s1, spaceTrimming).equals( + applyTrimmingPolicy(s2, spaceTrimming)); + } + return new Collation( normalizedCollationName(), PROVIDER_SPARK, null, - UTF8String::binaryCompare, + comparator, "1.0", - s -> (long) s.hashCode(), + hashFunction, + equalsFunction, /* supportsBinaryEquality = */ true, /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false); + /* supportsLowercaseEquality = */ false, + spaceTrimming != SpaceTrimming.NONE); } else { + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + comparator = CollationAwareUTF8String::compareLowerCase; + hashFunction = s -> + (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(); + } else { + comparator = (s1, s2) -> CollationAwareUTF8String.compareLowerCase( + applyTrimmingPolicy(s1, spaceTrimming), + applyTrimmingPolicy(s2, spaceTrimming)); + hashFunction = s -> (long) CollationAwareUTF8String.lowerCaseCodePoints( + applyTrimmingPolicy(s, spaceTrimming)).hashCode(); + } + return new Collation( normalizedCollationName(), PROVIDER_SPARK, null, - CollationAwareUTF8String::compareLowerCase, + comparator, "1.0", - s -> (long) CollationAwareUTF8String.lowerCaseCodePoints(s).hashCode(), + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true); + /* supportsLowercaseEquality = */ true, + spaceTrimming != SpaceTrimming.NONE); } } @@ -917,16 +957,35 @@ protected Collation buildCollation() { Collator collator = Collator.getInstance(resultLocale); // Freeze ICU collator to ensure thread safety. collator.freeze(); + + Comparator comparator; + ToLongFunction hashFunction; + + if (spaceTrimming == SpaceTrimming.NONE) { + hashFunction = s -> (long) collator.getCollationKey( + s.toValidString()).hashCode(); + comparator = (s1, s2) -> + collator.compare(s1.toValidString(), s2.toValidString()); + } else { + comparator = (s1, s2) -> collator.compare( + applyTrimmingPolicy(s1, spaceTrimming).toValidString(), + applyTrimmingPolicy(s2, spaceTrimming).toValidString()); + hashFunction = s -> (long) collator.getCollationKey( + applyTrimmingPolicy(s, spaceTrimming).toValidString()).hashCode(); + } + return new Collation( normalizedCollationName(), PROVIDER_ICU, collator, - (s1, s2) -> collator.compare(s1.toValidString(), s2.toValidString()), + comparator, ICU_COLLATOR_VERSION, - s -> (long) collator.getCollationKey(s.toValidString()).hashCode(), + hashFunction, + (s1, s2) -> comparator.compare(s1, s2) == 0, /* supportsBinaryEquality = */ false, /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false); + /* supportsLowercaseEquality = */ false, + spaceTrimming != SpaceTrimming.NONE); } @Override @@ -1103,14 +1162,6 @@ public static boolean isCaseSensitiveAndAccentInsensitive(int collationId) { Collation.CollationSpecICU.AccentSensitivity.AI; } - /** - * Returns whether the collation uses trim collation for the given collation id. - */ - public static boolean usesTrimCollation(int collationId) { - return Collation.CollationSpec.getSpaceTrimming(collationId) != - Collation.CollationSpec.SpaceTrimming.NONE; - } - public static void assertValidProvider(String provider) throws SparkException { if (!SUPPORTED_PROVIDERS.contains(provider.toLowerCase())) { Map params = Map.of( @@ -1137,7 +1188,7 @@ public static String[] getICULocaleNames() { public static UTF8String getCollationKey(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (usesTrimCollation(collationId)) { + if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } if (collation.supportsBinaryEquality) { @@ -1153,7 +1204,7 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { Collation collation = fetchCollation(collationId); - if (usesTrimCollation(collationId)) { + if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } if (collation.supportsBinaryEquality) { diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index 66ff551193101..a565d2d347636 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -127,6 +127,11 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", false), CollationTestCase("UTF8_BINARY", "aaa", "bbb", false), CollationTestCase("UTF8_BINARY", "å", "a\u030A", false), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " ", true), CollationTestCase("UTF8_LCASE", "aaa", "aaa", true), CollationTestCase("UTF8_LCASE", "aaa", "AAA", true), CollationTestCase("UTF8_LCASE", "aaa", "AaA", true), @@ -134,15 +139,30 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_LCASE", "aaa", "aa", false), CollationTestCase("UTF8_LCASE", "aaa", "bbb", false), CollationTestCase("UTF8_LCASE", "å", "a\u030A", false), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", true), CollationTestCase("UNICODE", "aaa", "aaa", true), CollationTestCase("UNICODE", "aaa", "AAA", false), CollationTestCase("UNICODE", "aaa", "bbb", false), CollationTestCase("UNICODE", "å", "a\u030A", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa", true), + CollationTestCase("UNICODE_RTRIM", "aaa", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", true), + CollationTestCase("UNICODE_RTRIM", "aaa", " aaa ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true), CollationTestCase("UNICODE_CI", "aaa", "aaa", true), CollationTestCase("UNICODE_CI", "aaa", "AAA", true), CollationTestCase("UNICODE_CI", "aaa", "bbb", false), CollationTestCase("UNICODE_CI", "å", "a\u030A", true), - CollationTestCase("UNICODE_CI", "Å", "a\u030A", true) + CollationTestCase("UNICODE_CI", "Å", "a\u030A", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AaA ", true), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", " AaA ", false), + CollationTestCase("UNICODE_RTRIM", " ", " ", true) ) checks.foreach(testCase => { @@ -162,19 +182,48 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig CollationTestCase("UTF8_BINARY", "aaa", "AAA", 1), CollationTestCase("UTF8_BINARY", "aaa", "bbb", -1), CollationTestCase("UTF8_BINARY", "aaa", "BBB", 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa", "BBB" , 1), + CollationTestCase("UTF8_BINARY_RTRIM", "aaa ", "BBB " , 1), + CollationTestCase("UTF8_BINARY_RTRIM", " ", " " , 0), CollationTestCase("UTF8_LCASE", "aaa", "aaa", 0), CollationTestCase("UTF8_LCASE", "aaa", "AAA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "AaA", 0), CollationTestCase("UTF8_LCASE", "aaa", "aa", 1), CollationTestCase("UTF8_LCASE", "aaa", "bbb", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UTF8_LCASE_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UTF8_LCASE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE", "aaa", "aaa", 0), CollationTestCase("UNICODE", "aaa", "AAA", -1), CollationTestCase("UNICODE", "aaa", "bbb", -1), CollationTestCase("UNICODE", "aaa", "BBB", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "aaa ", 0), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb", -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_RTRIM", "aaa", "BBB" , -1), + CollationTestCase("UNICODE_RTRIM", "aaa ", "BBB " , -1), + CollationTestCase("UNICODE_RTRIM", " ", " ", 0), CollationTestCase("UNICODE_CI", "aaa", "aaa", 0), CollationTestCase("UNICODE_CI", "aaa", "AAA", 0), - CollationTestCase("UNICODE_CI", "aaa", "bbb", -1)) + CollationTestCase("UNICODE_CI", "aaa", "bbb", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "AAA ", 0), + CollationTestCase("UNICODE_CI_RTRIM", "aaa", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "bbb ", -1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa", 1), + CollationTestCase("UNICODE_CI_RTRIM", "aaa ", "aa ", 1), + CollationTestCase("UNICODE_CI_RTRIM", " ", " ", 0) + ) checks.foreach(testCase => { val collation = fetchCollation(testCase.collationName) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 29d48e3d1f47f..1c93c2ad550e9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -48,7 +48,7 @@ class StringType private (val collationId: Int) extends AtomicType with Serializ !CollationFactory.isCaseSensitiveAndAccentInsensitive(collationId) private[sql] def usesTrimCollation: Boolean = - CollationFactory.usesTrimCollation(collationId) + CollationFactory.fetchCollation(collationId).supportsSpaceTrimming private[sql] def isUTF8BinaryCollation: Boolean = collationId == CollationFactory.UTF8_BINARY_COLLATION_ID diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 3a667f370428e..7128190902550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -415,7 +415,7 @@ abstract class HashExpression[E] extends Expression { protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" @@ -566,7 +566,7 @@ abstract class InterpretedHashFunction { hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) case s: UTF8String => val st = dataType.asInstanceOf[StringType] - if (st.supportsBinaryEquality) { + if (st.supportsBinaryEquality && !st.usesTrimCollation) { hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) } else { val stringHash = CollationFactory @@ -817,7 +817,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashString( ctx: CodegenContext, stringType: StringType, input: String, result: String): String = { - if (stringType.supportsBinaryEquality) { + if (stringType.supportsBinaryEquality && !stringType.usesTrimCollation) { val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index e296b5be6134b..40b8bccafaad2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -205,7 +205,9 @@ object UnsafeRowUtils { * can lead to rows being semantically equal even though their binary representations differ). */ def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { - case st: StringType => !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality + case st: StringType => + val collation = CollationFactory.fetchCollation(st.collationId) + (!collation.supportsBinaryEquality || collation.supportsSpaceTrimming) case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 45a71b4da7287..3b1f349520f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,8 +173,9 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") - case st: StringType if !st.supportsBinaryEquality => + case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation => + hashBytes(s"$input.getBytes()") + case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") case CalendarIntervalType => hashInt(s"$input.hashCode()") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 4c3cd93873bd4..ce6818652d2b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -49,9 +49,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Md5TestCase("Spark", "UTF8_BINARY", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_BINARY_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("Spark", "UTF8_LCASE", "8cde774d6f7333752ed72cacddb05126"), + Md5TestCase("Spark", "UTF8_LCASE_RTRIM", "8cde774d6f7333752ed72cacddb05126"), Md5TestCase("SQL", "UNICODE", "9778840a0100cb30c982876741b0b5a2"), - Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2") + Md5TestCase("SQL", "UNICODE_RTRIM", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI", "9778840a0100cb30c982876741b0b5a2"), + Md5TestCase("SQL", "UNICODE_CI_RTRIM", "9778840a0100cb30c982876741b0b5a2") ) // Supported collations @@ -81,11 +85,19 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha2TestCase("Spark", "UTF8_BINARY", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_BINARY_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("Spark", "UTF8_LCASE", 256, "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), + Sha2TestCase("Spark", "UTF8_LCASE_RTRIM", 256, + "529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b"), Sha2TestCase("SQL", "UNICODE", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_RTRIM", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), Sha2TestCase("SQL", "UNICODE_CI", 256, + "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35"), + Sha2TestCase("SQL", "UNICODE_CI_RTRIM", 256, "a7056a455639d1c7deec82ee787db24a0c1878e2792b4597709f0facf7cc7b35") ) @@ -114,9 +126,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Sha1TestCase("Spark", "UTF8_BINARY", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_BINARY_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("Spark", "UTF8_LCASE", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), + Sha1TestCase("Spark", "UTF8_LCASE_RTRIM", "85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c"), Sha1TestCase("SQL", "UNICODE", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), - Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") + Sha1TestCase("SQL", "UNICODE_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d"), + Sha1TestCase("SQL", "UNICODE_CI_RTRIM", "2064cb643caa8d9e1de12eea7f3e143ca9f8680d") ) // Supported collations @@ -144,9 +160,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Crc321TestCase("Spark", "UTF8_BINARY", 1557323817), + Crc321TestCase("Spark", "UTF8_BINARY_RTRIM", 1557323817), Crc321TestCase("Spark", "UTF8_LCASE", 1557323817), + Crc321TestCase("Spark", "UTF8_LCASE_RTRIM", 1557323817), Crc321TestCase("SQL", "UNICODE", 1299261525), - Crc321TestCase("SQL", "UNICODE_CI", 1299261525) + Crc321TestCase("SQL", "UNICODE_RTRIM", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI", 1299261525), + Crc321TestCase("SQL", "UNICODE_CI_RTRIM", 1299261525) ) // Supported collations @@ -172,9 +192,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( Murmur3HashTestCase("Spark", "UTF8_BINARY", 228093765), + Murmur3HashTestCase("Spark ", "UTF8_BINARY_RTRIM", 1779328737), Murmur3HashTestCase("Spark", "UTF8_LCASE", -1928694360), + Murmur3HashTestCase("Spark ", "UTF8_LCASE_RTRIM", -1928694360), Murmur3HashTestCase("SQL", "UNICODE", -1923567940), - Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950) + Murmur3HashTestCase("SQL ", "UNICODE_RTRIM", -1923567940), + Murmur3HashTestCase("SQL", "UNICODE_CI", 1029527950), + Murmur3HashTestCase("SQL ", "UNICODE_CI_RTRIM", 1029527950) ) // Supported collations @@ -200,9 +224,13 @@ class CollationSQLExpressionsSuite val testCases = Seq( XxHash64TestCase("Spark", "UTF8_BINARY", -4294468057691064905L), + XxHash64TestCase("Spark ", "UTF8_BINARY_RTRIM", 6480371823304753502L), XxHash64TestCase("Spark", "UTF8_LCASE", -3142112654825786434L), + XxHash64TestCase("Spark ", "UTF8_LCASE_RTRIM", -3142112654825786434L), XxHash64TestCase("SQL", "UNICODE", 5964849564945649886L), - XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L) + XxHash64TestCase("SQL ", "UNICODE_RTRIM", 5964849564945649886L), + XxHash64TestCase("SQL", "UNICODE_CI", 3732497619779520590L), + XxHash64TestCase("SQL ", "UNICODE_CI_RTRIM", 3732497619779520590L) ) // Supported collations diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index b19af542dabf2..4234d73c1794d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -101,8 +101,12 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("collate function syntax") { assert(sql(s"select collate('aaa', 'utf8_binary')").schema(0).dataType == StringType("UTF8_BINARY")) + assert(sql(s"select collate('aaa', 'utf8_binary_rtrim')").schema(0).dataType == + StringType("UTF8_BINARY_RTRIM")) assert(sql(s"select collate('aaa', 'utf8_lcase')").schema(0).dataType == StringType("UTF8_LCASE")) + assert(sql(s"select collate('aaa', 'utf8_lcase_rtrim')").schema(0).dataType == + StringType("UTF8_LCASE_RTRIM")) } test("collate function syntax with default collation set") { @@ -260,14 +264,23 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { Seq( ("utf8_binary", "aaa", "AAA", false), ("utf8_binary", "aaa", "aaa", true), + ("utf8_binary_rtrim", "aaa", "AAA", false), + ("utf8_binary_rtrim", "aaa", "aaa ", true), ("utf8_lcase", "aaa", "aaa", true), ("utf8_lcase", "aaa", "AAA", true), ("utf8_lcase", "aaa", "bbb", false), + ("utf8_lcase_rtrim", "aaa", "AAA ", true), + ("utf8_lcase_rtrim", "aaa", "bbb", false), ("unicode", "aaa", "aaa", true), ("unicode", "aaa", "AAA", false), + ("unicode_rtrim", "aaa ", "aaa ", true), + ("unicode_rtrim", "aaa", "AAA", false), ("unicode_CI", "aaa", "aaa", true), ("unicode_CI", "aaa", "AAA", true), - ("unicode_CI", "aaa", "bbb", false) + ("unicode_CI", "aaa", "bbb", false), + ("unicode_CI_rtrim", "aaa", "aaa", true), + ("unicode_CI_rtrim", "aaa ", "AAA ", true), + ("unicode_CI_rtrim", "aaa", "bbb", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -284,15 +297,19 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("utf8_binary", "AAA", "aaa", true), ("utf8_binary", "aaa", "aaa", false), ("utf8_binary", "aaa", "BBB", false), + ("utf8_binary_rtrim", "aaa ", "aaa ", false), ("utf8_lcase", "aaa", "aaa", false), ("utf8_lcase", "AAA", "aaa", false), ("utf8_lcase", "aaa", "bbb", true), + ("utf8_lcase_rtrim", "AAA ", "aaa", false), ("unicode", "aaa", "aaa", false), ("unicode", "aaa", "AAA", true), ("unicode", "aaa", "BBB", true), + ("unicode_rtrim", "aaa ", "aaa", false), ("unicode_CI", "aaa", "aaa", false), ("unicode_CI", "aaa", "AAA", false), - ("unicode_CI", "aaa", "bbb", true) + ("unicode_CI", "aaa", "bbb", true), + ("unicode_CI_rtrim", "aaa ", "aaa", false) ).foreach { case (collationName, left, right, expected) => checkAnswer( @@ -355,18 +372,22 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("aggregates count respects collation") { Seq( + ("utf8_binary_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("utf8_binary", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_binary", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), ("utf8_lcase", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("utf8_lcase", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), ("utf8_lcase", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("utf8_lcase_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))), ("unicode", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), ("unicode", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_rtrim", Seq("aaa", "aaa "), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) + ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_CI_rtrim", Seq("aaa", "AAA "), Seq(Row(2, "aaa"))) ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => checkAnswer(sql( From 488f68090b228b30ba4a3b75596c9904eef1f584 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 15 Oct 2024 08:31:33 +0800 Subject: [PATCH 246/250] [SPARK-49929][PYTHON][CONNECT] Support box plots ### What changes were proposed in this pull request? Support box plots with plotly backend on both Spark Connect and Spark classic. ### Why are the changes needed? While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments. See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress. Part of https://issues.apache.org/jira/browse/SPARK-49530. ### Does this PR introduce _any_ user-facing change? Yes. Box plots are supported as shown below. ```py >>> data = [ ... ("A", 50, 55), ... ("B", 55, 60), ... ("C", 60, 65), ... ("D", 65, 70), ... ("E", 70, 75), ... # outliers ... ("F", 10, 15), ... ("G", 85, 90), ... ("H", 5, 150), ... ] >>> columns = ["student", "math_score", "english_score"] >>> sdf = spark.createDataFrame(data, columns) >>> fig1 = sdf.plot.box(column=["math_score", "english_score"]) >>> fig1.show() # see below >>> fig2 = sdf.plot(kind="box", column="math_score") >>> fig2.show() # see below ``` fig1: ![newplot (17)](https://github.com/user-attachments/assets/8c36c344-f6de-47e3-bd63-c0f3b57efc43) fig2: ![newplot (18)](https://github.com/user-attachments/assets/9b7b60f6-58ec-4eff-9544-d5ab88a88631) ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48447 from xinrong-meng/box. Authored-by: Xinrong Meng Signed-off-by: Xinrong Meng --- python/pyspark/errors/error-conditions.json | 5 + python/pyspark/sql/plot/core.py | 153 +++++++++++++++++- python/pyspark/sql/plot/plotly.py | 77 ++++++++- .../sql/tests/plot/test_frame_plot_plotly.py | 77 ++++++++- 4 files changed, 307 insertions(+), 5 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 6ca21d55555d2..ab01d386645b2 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1103,6 +1103,11 @@ "`` is not supported, it should be one of the values from " ] }, + "UNSUPPORTED_PLOT_BACKEND_PARAM": { + "message": [ + "`` does not support `` set to , it should be one of the values from " + ] + }, "UNSUPPORTED_SIGNATURE": { "message": [ "Unsupported signature: ." diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index f9667ee2c0d69..4bf75474d92c3 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -15,15 +15,17 @@ # limitations under the License. # -from typing import Any, TYPE_CHECKING, Optional, Union +from typing import Any, TYPE_CHECKING, List, Optional, Union from types import ModuleType from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError +from pyspark.sql import Column, functions as F from pyspark.sql.types import NumericType -from pyspark.sql.utils import require_minimum_plotly_version +from pyspark.sql.utils import is_remote, require_minimum_plotly_version if TYPE_CHECKING: - from pyspark.sql import DataFrame + from pyspark.sql import DataFrame, Row + from pyspark.sql._typing import ColumnOrName import pandas as pd from plotly.graph_objs import Figure @@ -338,3 +340,148 @@ def pie(self, x: str, y: str, **kwargs: Any) -> "Figure": }, ) return self(kind="pie", x=x, y=y, **kwargs) + + def box( + self, column: Union[str, List[str]], precision: float = 0.01, **kwargs: Any + ) -> "Figure": + """ + Make a box plot of the DataFrame columns. + + Make a box-and-whisker plot from DataFrame columns, optionally grouped by some + other columns. A box plot is a method for graphically depicting groups of numerical + data through their quartiles. The box extends from the Q1 to Q3 quartile values of + the data, with a line at the median (Q2). The whiskers extend from the edges of box + to show the range of the data. By default, they extend no more than + 1.5 * IQR (IQR = Q3 - Q1) from the edges of the box, ending at the farthest data point + within that interval. Outliers are plotted as separate dots. + + Parameters + ---------- + column: str or list of str + Column name or list of names to be used for creating the boxplot. + precision: float, default = 0.01 + This argument is used by pyspark to compute approximate statistics + for building a boxplot. + **kwargs + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [ + ... ("A", 50, 55), + ... ("B", 55, 60), + ... ("C", 60, 65), + ... ("D", 65, 70), + ... ("E", 70, 75), + ... ("F", 10, 15), + ... ("G", 85, 90), + ... ("H", 5, 150), + ... ] + >>> columns = ["student", "math_score", "english_score"] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.box(column="math_score") # doctest: +SKIP + >>> df.plot.box(column=["math_score", "english_score"]) # doctest: +SKIP + """ + return self(kind="box", column=column, precision=precision, **kwargs) + + +class PySparkBoxPlotBase: + @staticmethod + def compute_box( + sdf: "DataFrame", colnames: List[str], whis: float, precision: float, showfliers: bool + ) -> Optional["Row"]: + assert len(colnames) > 0 + formatted_colnames = ["`{}`".format(colname) for colname in colnames] + + stats_scols = [] + for i, colname in enumerate(formatted_colnames): + percentiles = F.percentile_approx(colname, [0.25, 0.50, 0.75], int(1.0 / precision)) + q1 = F.get(percentiles, 0) + med = F.get(percentiles, 1) + q3 = F.get(percentiles, 2) + iqr = q3 - q1 + lfence = q1 - F.lit(whis) * iqr + ufence = q3 + F.lit(whis) * iqr + + stats_scols.append( + F.struct( + F.mean(colname).alias("mean"), + med.alias("med"), + q1.alias("q1"), + q3.alias("q3"), + lfence.alias("lfence"), + ufence.alias("ufence"), + ).alias(f"_box_plot_stats_{i}") + ) + + sdf_stats = sdf.select(*stats_scols) + + result_scols = [] + for i, colname in enumerate(formatted_colnames): + value = F.col(colname) + + lfence = F.col(f"_box_plot_stats_{i}.lfence") + ufence = F.col(f"_box_plot_stats_{i}.ufence") + mean = F.col(f"_box_plot_stats_{i}.mean") + med = F.col(f"_box_plot_stats_{i}.med") + q1 = F.col(f"_box_plot_stats_{i}.q1") + q3 = F.col(f"_box_plot_stats_{i}.q3") + + outlier = ~value.between(lfence, ufence) + + # Computes min and max values of non-outliers - the whiskers + upper_whisker = F.max(F.when(~outlier, value).otherwise(F.lit(None))) + lower_whisker = F.min(F.when(~outlier, value).otherwise(F.lit(None))) + + # If it shows fliers, take the top 1k with the highest absolute values + # Here we normalize the values by subtracting the median. + if showfliers: + pair = F.when( + outlier, + F.struct(F.abs(value - med), value.alias("val")), + ).otherwise(F.lit(None)) + topk = collect_top_k(pair, 1001, False) + fliers = F.when(F.size(topk) > 0, topk["val"]).otherwise(F.lit(None)) + else: + fliers = F.lit(None) + + result_scols.append( + F.struct( + F.first(mean).alias("mean"), + F.first(med).alias("med"), + F.first(q1).alias("q1"), + F.first(q3).alias("q3"), + upper_whisker.alias("upper_whisker"), + lower_whisker.alias("lower_whisker"), + fliers.alias("fliers"), + ).alias(f"_box_plot_results_{i}") + ) + + sdf_result = sdf.join(sdf_stats.hint("broadcast")).select(*result_scols) + return sdf_result.first() + + +def _invoke_internal_function_over_columns(name: str, *cols: "ColumnOrName") -> Column: + if is_remote(): + from pyspark.sql.connect.functions.builtin import _invoke_function_over_columns + + return _invoke_function_over_columns(name, *cols) + + else: + from pyspark.sql.classic.column import _to_seq, _to_java_column + from pyspark import SparkContext + + sc = SparkContext._active_spark_context + return Column( + sc._jvm.PythonSQLUtils.internalFn( # type: ignore + name, _to_seq(sc, cols, _to_java_column) # type: ignore + ) + ) + + +def collect_top_k(col: Column, num: int, reverse: bool) -> Column: + return _invoke_internal_function_over_columns("collect_top_k", col, F.lit(num), F.lit(reverse)) diff --git a/python/pyspark/sql/plot/plotly.py b/python/pyspark/sql/plot/plotly.py index 91f5363464717..71d40720e874d 100644 --- a/python/pyspark/sql/plot/plotly.py +++ b/python/pyspark/sql/plot/plotly.py @@ -17,7 +17,8 @@ from typing import TYPE_CHECKING, Any -from pyspark.sql.plot import PySparkPlotAccessor +from pyspark.errors import PySparkValueError +from pyspark.sql.plot import PySparkPlotAccessor, PySparkBoxPlotBase if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -29,6 +30,8 @@ def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure": if kind == "pie": return plot_pie(data, **kwargs) + if kind == "box": + return plot_box(data, **kwargs) return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs) @@ -43,3 +46,75 @@ def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure": fig = express.pie(pdf, values=y, names=x, **kwargs) return fig + + +def plot_box(data: "DataFrame", **kwargs: Any) -> "Figure": + import plotly.graph_objs as go + + # 'whis' isn't actually an argument in plotly (but in matplotlib). But seems like + # plotly doesn't expose the reach of the whiskers to the beyond the first and + # third quartiles (?). Looks they use default 1.5. + whis = kwargs.pop("whis", 1.5) + # 'precision' is pyspark specific to control precision for approx_percentile + precision = kwargs.pop("precision", 0.01) + colnames = kwargs.pop("column", None) + if isinstance(colnames, str): + colnames = [colnames] + + # Plotly options + boxpoints = kwargs.pop("boxpoints", "suspectedoutliers") + notched = kwargs.pop("notched", False) + if boxpoints not in ["suspectedoutliers", False]: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": str(boxpoints), + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + if notched: + raise PySparkValueError( + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": str(notched), + "supported_values": ", ".join(["False"]), + }, + ) + + fig = go.Figure() + + results = PySparkBoxPlotBase.compute_box( + data, + colnames, + whis, + precision, + boxpoints is not None, + ) + assert len(results) == len(colnames) # type: ignore + + for i, colname in enumerate(colnames): + result = results[i] # type: ignore + + fig.add_trace( + go.Box( + x=[i], + name=colname, + q1=[result["q1"]], + median=[result["med"]], + q3=[result["q3"]], + mean=[result["mean"]], + lowerfence=[result["lower_whisker"]], + upperfence=[result["upper_whisker"]], + y=[result["fliers"]] if result["fliers"] else None, + boxpoints=boxpoints, + notched=notched, + **kwargs, + ) + ) + + fig["layout"]["yaxis"]["title"] = "value" + return fig diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index b92b5a91cb766..d870cdbf9959b 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -19,7 +19,7 @@ from datetime import datetime import pyspark.sql.plot # noqa: F401 -from pyspark.errors import PySparkTypeError +from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message @@ -48,6 +48,22 @@ def sdf3(self): columns = ["sales", "signups", "visits", "date"] return self.spark.createDataFrame(data, columns) + @property + def sdf4(self): + data = [ + ("A", 50, 55), + ("B", 55, 60), + ("C", 60, 65), + ("D", 65, 70), + ("E", 70, 75), + # outliers + ("F", 10, 15), + ("G", 85, 90), + ("H", 5, 150), + ] + columns = ["student", "math_score", "english_score"] + return self.spark.createDataFrame(data, columns) + def _check_fig_data(self, fig_data, **kwargs): for key, expected_value in kwargs.items(): if key in ["x", "y", "labels", "values"]: @@ -300,6 +316,65 @@ def test_pie_plot(self): messageParameters={"arg_name": "y", "arg_type": "StringType()"}, ) + def test_box_plot(self): + fig = self.sdf4.plot.box(column="math_score") + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (5,), + "mean": (50.0,), + "median": (55,), + "name": "math_score", + "notched": False, + "q1": (10,), + "q3": (65,), + "upperfence": (85,), + "x": [0], + "type": "box", + } + self._check_fig_data(fig["data"][0], **expected_fig_data) + + fig = self.sdf4.plot(kind="box", column=["math_score", "english_score"]) + self._check_fig_data(fig["data"][0], **expected_fig_data) + expected_fig_data = { + "boxpoints": "suspectedoutliers", + "lowerfence": (55,), + "mean": (72.5,), + "median": (65,), + "name": "english_score", + "notched": False, + "q1": (55,), + "q3": (75,), + "upperfence": (90,), + "x": [1], + "y": [[150, 15]], + "type": "box", + } + self._check_fig_data(fig["data"][1], **expected_fig_data) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", boxpoints=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "boxpoints", + "value": "True", + "supported_values": ", ".join(["suspectedoutliers", "False"]), + }, + ) + with self.assertRaises(PySparkValueError) as pe: + self.sdf4.plot.box(column="math_score", notched=True) + self.check_error( + exception=pe.exception, + errorClass="UNSUPPORTED_PLOT_BACKEND_PARAM", + messageParameters={ + "backend": "plotly", + "param": "notched", + "value": "True", + "supported_values": ", ".join(["False"]), + }, + ) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass From 217e0da917c7200ff36aa1b9edc90927a45c5a94 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Tue, 15 Oct 2024 10:59:41 +0800 Subject: [PATCH 247/250] [SPARK-49965][BUILD] Upgrade ASM to 9.7.1 ### What changes were proposed in this pull request? This PR aims to upgrade ASM from `9.7` to `9.7.1`. ### Why are the changes needed? - xbean-asm9-shaded 4.26 upgrade to use `ASM 9.7.1` and `ASM 9.7.1` is for `Java 24`. https://github.com/apache/geronimo-xbean/pull/41 - https://asm.ow2.io/versions.html image ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48465 from panbingkun/SPARK-49965. Authored-by: panbingkun Signed-off-by: yangjie01 --- dev/deps/spark-deps-hadoop-3-hive-2.3 | 2 +- pom.xml | 4 ++-- project/plugins.sbt | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 8ba2f6c414cb9..91e84b0780798 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -274,7 +274,7 @@ tink/1.15.0//tink-1.15.0.jar transaction-api/1.1//transaction-api-1.1.jar univocity-parsers/2.9.1//univocity-parsers-2.9.1.jar wildfly-openssl/1.1.3.Final//wildfly-openssl-1.1.3.Final.jar -xbean-asm9-shaded/4.25//xbean-asm9-shaded-4.25.jar +xbean-asm9-shaded/4.26//xbean-asm9-shaded-4.26.jar xmlschema-core/2.3.1//xmlschema-core-2.3.1.jar xz/1.10//xz-1.10.jar zjsonpatch/0.3.0//zjsonpatch-0.3.0.jar diff --git a/pom.xml b/pom.xml index 3da8b9ef68b90..2b89454873782 100644 --- a/pom.xml +++ b/pom.xml @@ -119,7 +119,7 @@ 3.9.9 3.2.0 spark - 9.7 + 9.7.1 2.0.16 2.24.1 @@ -491,7 +491,7 @@ org.apache.xbean xbean-asm9-shaded - 4.25 + 4.26